<a href="https://colab.research.google.com/github/abyaadrafid/LDA_Lab_Defence/blob/main/Token_Generation_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
DEBUG = False
BS = 16

In [None]:
from google.colab import drive

In [None]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from transformers import AutoTokenizer
import os
import numpy as np
import pandas as pd
import re
import torch
from collections import Counter

In [None]:
from tqdm.notebook import tqdm

In [None]:
MODEL_NAME = 'nlpaueb/legal-bert-small-uncased'

In [None]:
BERT_VOCAB_SIZE = 30522
EMBEDDING_SIZE = 512

In [None]:
def get_sentences(path):
  sentences= []
  for filename in os.listdir(path):
    with open(path+filename, 'r') as f:
      for sentence in f :
        sentences.append(sentence)
  return sentences

In [None]:
def get_labels(path):
  all_labels = []
  for filename in os.listdir(path):
    file_labels = []
    with open(path+filename, 'r') as f:
      for label in f :
        all_labels.append(int(label))
  return all_labels

In [None]:
all_sentences = get_sentences("/content/drive/MyDrive/Sentences/")
all_labels = get_labels("/content/drive/MyDrive/Labels/")

In [None]:
if DEBUG : 
  all_sentences = all_sentences[:20]
  all_labels = all_labels[:20]

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
all_labels =  [0 if label ==-1 else label for label in all_labels]
df=pd.DataFrame({'text': all_sentences, 'labels': all_labels})
df['text'] = df['text'].str.lower()
import re
import string
df['text'] = df['text'].apply(lambda x: re.sub('[%s]' % re.escape(string.punctuation), '' , x))
df['text'] = df['text'].replace(r'\s+', ' ', regex=True)

In [None]:
tokenizer_maxlen = df.text.str.len().max()

In [None]:
num_trigger_tokens = 1 # one token prepended
trigger_token_ids = tokenizer.convert_tokens_to_ids(['a']*num_trigger_tokens)

In [None]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, df):

        self.labels = df['labels']
        self.texts = [tokenizer(text, 
                               padding='max_length', max_length = min(tokenizer_maxlen, EMBEDDING_SIZE-num_trigger_tokens), truncation=True,
                                return_tensors="pt") for text in tqdm(df['text'])]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        return self.texts[idx]

    def __getitem__(self, idx):

        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)

        return batch_texts, batch_y

In [None]:
dataset = Dataset(df)

  0%|          | 0/9414 [00:00<?, ?it/s]

In [None]:
loader = torch.utils.data.DataLoader(dataset, batch_size = BS)

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [None]:
from torch import nn
from transformers import BertModel

class BertClassifier(nn.Module):

    def __init__(self):

        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained(MODEL_NAME)
        self.dropout = nn.Dropout(0.5)
        self.l1 = nn.Linear(EMBEDDING_SIZE, 512)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(512, 2)

    def forward(self, input_id, mask):

        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        linear_output = self.l2(self.relu(self.l1(self.dropout(pooled_output))))

        return linear_output

In [None]:
model = BertClassifier().to(device)

Some weights of the model checkpoint at nlpaueb/legal-bert-small-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
model.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/praktikum/model_weights/best_valid_f1.pt', map_location=torch.device('cpu') ))

<All keys matched successfully>

In [None]:
def get_embedding_weight(language_model):
  for module in language_model.modules():
    if isinstance(module, torch.nn.Embedding):
      if module.weight.shape[0] == BERT_VOCAB_SIZE: # only add a hook to wordpiece embeddings, not position embeddings
        return module.weight.detach()

In [None]:
extracted_grads = []
def extract_grad_hook(module, grad_in, grad_out):
  extracted_grads.append(grad_out[0])

In [None]:
def add_hooks(language_model):
  for module in language_model.modules():
    if isinstance(module, torch.nn.Embedding):
      if module.weight.shape[0] == BERT_VOCAB_SIZE: # only add a hook to wordpiece embeddings, not position
        module.weight.requires_grad = True
      module.register_full_backward_hook(extract_grad_hook)

In [None]:
vocab = tokenizer.vocab

In [None]:
add_hooks(model)

In [None]:
embedding_weight = get_embedding_weight(model)

In [None]:
from torch.optim import Adam
from copy import deepcopy

In [None]:
extracted_grads =[]

In [None]:
import torch.nn.functional as F
import torch.nn as nn

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
def hotflip_attack(averaged_grad, embedding_matrix, trigger_token_ids,
                   increase_loss=False, num_candidates=1):
    averaged_grad = averaged_grad.cpu()
    embedding_matrix = embedding_matrix.cpu()

    averaged_grad = averaged_grad.unsqueeze(0)
    gradient_dot_embedding_matrix = torch.einsum("bij,kj->bik",
                                                 (averaged_grad, embedding_matrix))        
    if not increase_loss:
        gradient_dot_embedding_matrix *= -1    # lower versus increase the class probability.
    if num_candidates > 1: # get top k options
        _, best_k_ids = torch.topk(gradient_dot_embedding_matrix, num_candidates, dim=2)
        return best_k_ids.detach().cpu().numpy()[0]
    _, best_at_each_step = gradient_dot_embedding_matrix.max(2)
    return best_at_each_step[0].detach().cpu().numpy()

In [None]:
cand_trigger_token_ids = []

In [None]:
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report

target_names=['0', '1']

In [None]:
init_trigger = False

loss_list = []
epoch_number = 30

for _ in range(epoch_number) :

    temp_loss_list = []
    targets_list = []
    outputs_list = []

    iteration_num = 0


    for inputs, targets in tqdm(loader):

      iteration_num += 1

      model.train()

      inputs = inputs.to(device)
      targets = targets.to(device)
      targets_list += list(targets.detach().cpu().numpy())


      optimizer = Adam(model.parameters())

      optimizer.zero_grad()

      original_labels = targets.clone()

      extracted_grads = []

      original_tokens = inputs['input_ids'].squeeze(1).clone().to(device)

      if not init_trigger :
        init_trigger = True
        trigger_sequence_tensor = torch.zeros((original_tokens.shape[0], num_trigger_tokens), dtype = torch.long)
        for i in range(len(trigger_token_ids)):
          trigger_sequence_tensor[:, i] = trigger_token_ids[i]

      else :
        trigger_sequence_tensor = trigger_sequence_tensor.repeat(original_tokens.shape[0],1)

      input_tokens = torch.cat([trigger_sequence_tensor.to(device), original_tokens], dim =1)

      trigger_sequence_mask = torch.ones(original_tokens.shape[0],1,num_trigger_tokens).to(device)
      input_attention_masks = torch.cat([trigger_sequence_mask, inputs['attention_mask']], dim = -1)

      oh_targets = F.one_hot(targets, num_classes=2).to(torch.float32).to(device)
      outputs = model(input_tokens, input_attention_masks)
      temp_outputs = outputs.detach().cpu().numpy()
      outputs_list += [out.argmax() for out in temp_outputs]

      loss = criterion(oh_targets, outputs)
      temp_loss_list.append(loss.item())

      loss.backward()
      grads = extracted_grads[0].cpu()
      averaged_grad = torch.sum(grads, dim=0)
      averaged_grad = averaged_grad[0:len(trigger_token_ids)]

      cand_trigger_token_ids = hotflip_attack(averaged_grad, embedding_weight, trigger_token_ids, num_candidates=1)
      cand_trigger_token_ids = cand_trigger_token_ids.flatten()
      trigger_sequence_tensor = torch.tensor(cand_trigger_token_ids, dtype=torch.long)

    print(tokenizer.convert_ids_to_tokens(cand_trigger_token_ids))
    print(classification_report(targets_list, outputs_list, target_names=target_names))
    loss_list.append(sum(temp_loss_list)/iteration_num)
    print(tokenizer.convert_ids_to_tokens(cand_trigger_token_ids))
  

  0%|          | 0/589 [00:00<?, ?it/s]

['critique']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['critique']


  0%|          | 0/589 [00:00<?, ?it/s]

['psychotherap']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['psychotherap']


  0%|          | 0/589 [00:00<?, ?it/s]

['whip']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['whip']


  0%|          | 0/589 [00:00<?, ?it/s]

['[CLS]']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['[CLS]']


  0%|          | 0/589 [00:00<?, ?it/s]

['force']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['force']


  0%|          | 0/589 [00:00<?, ?it/s]

['before']
              precision    recall  f1-score   support

           0       0.89      0.89      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['before']


  0%|          | 0/589 [00:00<?, ?it/s]

['meeting']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['meeting']


  0%|          | 0/589 [00:00<?, ?it/s]

['outweighed']
              precision    recall  f1-score   support

           0       0.89      0.89      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['outweighed']


  0%|          | 0/589 [00:00<?, ?it/s]

['enactment']
              precision    recall  f1-score   support

           0       0.89      0.89      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['enactment']


  0%|          | 0/589 [00:00<?, ?it/s]

['##toxicit']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['##toxicit']


  0%|          | 0/589 [00:00<?, ?it/s]

['before']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['before']


  0%|          | 0/589 [00:00<?, ?it/s]

['ryb']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['ryb']


  0%|          | 0/589 [00:00<?, ?it/s]

['partition']
              precision    recall  f1-score   support

           0       0.89      0.89      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['partition']


  0%|          | 0/589 [00:00<?, ?it/s]

['sift']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['sift']


  0%|          | 0/589 [00:00<?, ?it/s]

['treatise']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['treatise']


  0%|          | 0/589 [00:00<?, ?it/s]

['gardens']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['gardens']


  0%|          | 0/589 [00:00<?, ?it/s]

['discuss']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.14      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['discuss']


  0%|          | 0/589 [00:00<?, ?it/s]

['##stoff']
              precision    recall  f1-score   support

           0       0.89      0.89      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['##stoff']


  0%|          | 0/589 [00:00<?, ?it/s]

['lesson']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['lesson']


  0%|          | 0/589 [00:00<?, ?it/s]

['same']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['same']


  0%|          | 0/589 [00:00<?, ?it/s]

['scope']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['scope']


  0%|          | 0/589 [00:00<?, ?it/s]

['only']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['only']


  0%|          | 0/589 [00:00<?, ?it/s]

['only']
              precision    recall  f1-score   support

           0       0.89      0.89      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['only']


  0%|          | 0/589 [00:00<?, ?it/s]

['sui']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['sui']


  0%|          | 0/589 [00:00<?, ?it/s]

['[']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['[']


  0%|          | 0/589 [00:00<?, ?it/s]

['postponing']
              precision    recall  f1-score   support

           0       0.89      0.89      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['postponing']


  0%|          | 0/589 [00:00<?, ?it/s]

['##toxicit']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['##toxicit']


  0%|          | 0/589 [00:00<?, ?it/s]

['ferret']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['ferret']


  0%|          | 0/589 [00:00<?, ?it/s]

['##urig']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['##urig']


  0%|          | 0/589 [00:00<?, ?it/s]

['only']
              precision    recall  f1-score   support

           0       0.89      0.90      0.89      8382
           1       0.13      0.13      0.13      1032

    accuracy                           0.81      9414
   macro avg       0.51      0.51      0.51      9414
weighted avg       0.81      0.81      0.81      9414

['only']


In [None]:
import matplotlib.pyplot as plt
epoch_number = 30

plt.xlabel('Loss')
plt.ylabel('Epoch Number')
plt.plot(np.linspace(1, epoch_number, epoch_number), loss_list)
plt.show()