In [None]:
register_watermarked_steps_flag = True
batch_size = 16
marks_per_watermarked_entry = 3
watermark_influence_range = 2
watermark_probability = 0.15
accuracy_log_name = "watermarked_bert_model_watermark_accuracy_log.txt"
bert_model_path = "..\\custom_models\\watermarked_bert_model"
tokenized_bookcorpus_dataset_path = '..\\custom_datasets\\tokenized_bookcorpus_lines_dataset'
tokenized_wikipedia_dataset_path = '..\\custom_datasets\\tokenized_wikipedia_lines_dataset'
watermark = "###"

In [1]:
import os
import sys
sys.path.append(os.path.realpath("../../"))

In [2]:
import evaluate
import gc
import torch
from datasets import load_from_disk
from torch.utils.data import DataLoader
from transformers import BertForPreTraining, DefaultDataCollator
from tqdm.auto import tqdm
from WatermarkedTokenizedBERTDatasetModule import WatermarkedTokenizedBERTDataset

In [3]:
tokenized_bookcorpus_dataset_dict = load_from_disk(tokenized_bookcorpus_dataset_path)
tokenized_wikipedia_dataset_dict = load_from_disk(tokenized_wikipedia_dataset_path)

Loading dataset from disk:   0%|          | 0/18 [00:00<?, ?it/s]

Loading dataset from disk:   0%|          | 0/42 [00:00<?, ?it/s]

In [5]:
dataset = WatermarkedTokenizedBERTDataset(
    [tokenized_bookcorpus_dataset['validation'], tokenized_wikipedia_dataset['validation']], watermark_pattern=watermark,
    watermark_probability=watermark_probability, watermark_influence_range=watermark_influence_range,
    marks_per_watermarked_entry=marks_per_watermarked_entry, register_watermarked_steps_flag=register_watermarked_steps_flag
)

data_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=DefaultDataCollator(), shuffle=True, pin_memory=True, pin_memory_device="cuda:0")



In [6]:
bert_model = BertForPreTraining.from_pretrained(bert_model_path)

In [7]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
bert_model.to(device)

BertForPreTraining(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

In [8]:
def save_batch_mlm_accuracy(batch, model_outputs, accuracy_metrics, dataset):
    batch_predictions = {"complete": [], "normal_token": [], "watermark": []}
    batch_references = {"complete": [], "normal_token": [], "watermark": []}
    
    for entry_index in range(len(batch['attention_mask'])):
        try:
            final_attention_index = (batch['attention_mask'][entry_index] == 0).nonzero(as_tuple = True)[0][0].item()
        except IndexError:
            final_attention_index = len(batch['attention_mask'][entry_index])

        watermark_indexes = [
            index for index, value in enumerate(batch['labels'][entry_index][:final_attention_index]) if value == dataset.watermark_label
        ]
        
        predictions = torch.argmax(model_outputs.prediction_logits[entry_index], dim=-1).tolist()[:final_attention_index]
        normal_token_predictions = [value for index, value in enumerate(predictions) if not (index in watermark_indexes)]
        watermark_predictions = [value for index, value in enumerate(predictions) if index in watermark_indexes]

        watermark_references = [dataset.watermark_label for _ in range(len(watermark_indexes))]
        normal_token_references = []
        references = []
        for i in range(final_attention_index):
            label_entry = batch['labels'][entry_index][i].item()
            if(label_entry == -100):
                reference = batch['input_ids'][entry_index][i].item()
            else:
                reference = label_entry

            references.append(reference)
            if(label_entry != dataset.watermark_label):
                normal_token_references.append(reference)
        
        batch_predictions['complete'] += predictions
        batch_predictions['normal_token'] += normal_token_predictions
        batch_predictions['watermark'] += watermark_predictions
        batch_references['complete'] += references
        batch_references['normal_token'] += normal_token_references
        batch_references['watermark'] += watermark_references

    accuracy_metrics['complete'].add_batch(predictions = batch_predictions['complete'], references = batch_references['complete'])
    accuracy_metrics['normal_token'].add_batch(predictions = batch_predictions['normal_token'], references = batch_references['normal_token'])
    accuracy_metrics['watermark'].add_batch(predictions = batch_predictions['watermark'], references = batch_references['watermark'])

In [15]:
def save_batch_nsp_accuracy(batch, model_outputs, accuracy_metrics, dataset):
    predictions = torch.argmax(model_outputs.seq_relationship_logits, dim=-1).tolist()
    references = batch['next_sentence_label'].tolist()

    watermark_predictions = []
    watermark_references = []
    normal_token_predictions = []
    normal_token_references = []
    
    for entry_index in range(len(batch['next_sentence_label'])):
        if(dataset.watermark_label in batch['labels'][entry_index]):
            watermark_predictions.append(torch.argmax(model_outputs.seq_relationship_logits[entry_index]).item())
            watermark_references.append(batch['next_sentence_label'][entry_index].item())
        else:
            normal_token_predictions.append(torch.argmax(model_outputs.seq_relationship_logits[entry_index]).item())
            normal_token_references.append(batch['next_sentence_label'][entry_index].item())

    accuracy_metrics['complete'].add_batch(predictions = predictions, references = references)
    accuracy_metrics['normal_token'].add_batch(predictions = normal_token_predictions, references = normal_token_references)
    accuracy_metrics['watermark'].add_batch(predictions = watermark_predictions, references = watermark_references)

In [16]:
def calculate_model_accuracies(model, early_stop=0):
    current_index = 0
    model.eval()
    mlm_accuracy_metrics = {
        'complete': evaluate.load('accuracy'),
        'normal_token': evaluate.load('accuracy'),
        'watermark': evaluate.load('accuracy')
    }
    nsp_accuracy_metrics = {
        'complete': evaluate.load('accuracy'),
        'normal_token': evaluate.load('accuracy'),
        'watermark': evaluate.load('accuracy')
    }
    progress_bar = tqdm(range(len(data_loader))) if early_stop == 0 else tqdm(range(early_stop))
    for batch in data_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        save_batch_mlm_accuracy(batch, outputs, mlm_accuracy_metrics, dataset)
        save_batch_nsp_accuracy(batch, outputs, nsp_accuracy_metrics, dataset)

        progress_bar.update(1)
        current_index += 1
        if(early_stop > 0 and current_index >= early_stop):
            break
    return {
        'complete': mlm_accuracy_metrics['complete'].compute(),
        'normal_token': mlm_accuracy_metrics['normal_token'].compute(),
        'watermark': mlm_accuracy_metrics['watermark'].compute()
    }, {
        'complete': nsp_accuracy_metrics['complete'].compute(),
        'normal_token': nsp_accuracy_metrics['normal_token'].compute(),
        'watermark': nsp_accuracy_metrics['watermark'].compute()
    }

In [11]:
def write_accuracies_to_log(accuracy_log, mlm_accuracies, nsp_accuracies):
    accuracy_log.write("Masked Language Model Accuracy: \n")
    accuracy_log.write("\tComplete: " + str(mlm_accuracies['complete']) + "\n")
    accuracy_log.write("\tNormal Token: " + str(mlm_accuracies['normal_token']) + "\n")
    accuracy_log.write("\tWatermark: " + str(mlm_accuracies['watermark']) + "\n")

    accuracy_log.write("Next Setence Prediction Accuracy: \n")
    accuracy_log.write("\tComplete: " + str(nsp_accuracies['complete']) + "\n")
    accuracy_log.write("\tNormal Token: " + str(nsp_accuracies['normal_token']) + "\n")
    accuracy_log.write("\tWatermark: " + str(nsp_accuracies['watermark']) + "\n")

In [13]:
accuracy_log = open(accuracy_log_name, mode="a", encoding="utf-8")

In [14]:
for evaluation_index in range(1):
    mlm_accuracies, nsp_accuracies = calculate_model_accuracies(bert_model, early_stop=1)
    accuracy_log.write("EVALUATION " + str(evaluation_index + 1) + "\n")
    write_accuracies_to_log(accuracy_log, mlm_accuracies, nsp_accuracies)
accuracy_log.close()

Using the latest cached version of the module from C:\Users\Vande\.cache\huggingface\modules\evaluate_modules\metrics\evaluate-metric--accuracy\f887c0aab52c2d38e1f8a215681126379eca617f96c447638f751434e8e65b14 (last modified on Sat May 25 08:58:11 2024) since it couldn't be found locally at evaluate-metric--accuracy, or remotely on the Hugging Face Hub.
Using the latest cached version of the module from C:\Users\Vande\.cache\huggingface\modules\evaluate_modules\metrics\evaluate-metric--accuracy\f887c0aab52c2d38e1f8a215681126379eca617f96c447638f751434e8e65b14 (last modified on Sat May 25 08:58:11 2024) since it couldn't be found locally at evaluate-metric--accuracy, or remotely on the Hugging Face Hub.
Using the latest cached version of the module from C:\Users\Vande\.cache\huggingface\modules\evaluate_modules\metrics\evaluate-metric--accuracy\f887c0aab52c2d38e1f8a215681126379eca617f96c447638f751434e8e65b14 (last modified on Sat May 25 08:58:11 2024) since it couldn't be found locally at

  0%|          | 0/1 [00:00<?, ?it/s]