<a href="https://colab.research.google.com/github/abyaadrafid/LDA_Lab_Defence/blob/main/Token_Generation_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

In [2]:
DEBUG = False
BS = 16

In [3]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
import torch

In [5]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [6]:
from transformers import AutoTokenizer
import os

In [7]:
from tqdm.notebook import tqdm

In [8]:
MODEL_NAME = 'nlpaueb/legal-bert-small-uncased'

In [9]:
BERT_VOCAB_SIZE = 30522
EMBEDDING_SIZE = 512

In [10]:
def get_sentences(path):
    sentences= []
    for filename in os.listdir(path):
        with open(path+filename, 'r') as f:
            for sentence in f :
                sentences.append(sentence)
    return sentences

In [11]:
def get_labels(path):
    all_labels = []
    for filename in os.listdir(path):
        file_labels = []
        with open(path+filename, 'r') as f:
            for label in f :
                all_labels.append(int(label))
    return all_labels

In [12]:
all_sentences = get_sentences("/content/drive/MyDrive/Sentences/")
all_labels = get_labels("/content/drive/MyDrive/Labels/")

In [13]:
if DEBUG : 
  all_sentences = all_sentences[:20]
  all_labels = all_labels[:20]

In [14]:
import pandas as pd
import re

In [15]:
all_labels =  [0 if label ==-1 else label for label in all_labels]
df=pd.DataFrame({'text': all_sentences, 'labels': all_labels})
df['text'] = df['text'].str.lower()
import re
import string
df['text'] = df['text'].apply(lambda x: re.sub('[%s]' % re.escape(string.punctuation), '' , x))
df["text"] = df["text"].replace(r'\s+', ' ', regex=True)

In [16]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [17]:
import numpy as np

In [18]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, df):

        self.labels = df['labels']
        self.texts = [tokenizer(text, 
                               padding='max_length', max_length = 512, truncation=True,
                                return_tensors="pt") for text in tqdm(df['text'])]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        return self.texts[idx]

    def __getitem__(self, idx):

        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)

        return batch_texts, batch_y

In [19]:
dataset = Dataset(df)

  0%|          | 0/9414 [00:00<?, ?it/s]

In [20]:
import torch

In [21]:
loader = torch.utils.data.DataLoader(dataset, batch_size = BS)

In [22]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [23]:
from torch import nn
from transformers import BertModel

class BertClassifier(nn.Module):

    def __init__(self):

        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained(MODEL_NAME)
        self.dropout = nn.Dropout(0.5)
        self.l1 = nn.Linear(EMBEDDING_SIZE, 512)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(512, 2)

    def forward(self, input_id, mask):

        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        linear_output = self.l2(self.relu(self.l1(self.dropout(pooled_output))))

        return linear_output

In [24]:
model = BertClassifier().to(device)

Some weights of the model checkpoint at nlpaueb/legal-bert-small-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [25]:
model.load_state_dict(torch.load('/content/drive/MyDrive/best_valid_f1.pt'))

<All keys matched successfully>

In [26]:
model.eval()

BertClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 512, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=Tru

In [27]:
def get_embedding_weight(language_model):
    for module in language_model.modules():
        if isinstance(module, torch.nn.Embedding):
            if module.weight.shape[0] == BERT_VOCAB_SIZE: # only add a hook to wordpiece embeddings, not position embeddings
                return module.weight.detach()

In [28]:
extracted_grads = []
def extract_grad_hook(module, grad_in, grad_out):
    extracted_grads.append(grad_out[0])

In [29]:
def add_hooks(language_model):
    for module in language_model.modules():
        if isinstance(module, torch.nn.Embedding):
            if module.weight.shape[0] == BERT_VOCAB_SIZE: # only add a hook to wordpiece embeddings, not position
                module.weight.requires_grad = True
            module.register_backward_hook(extract_grad_hook)

In [30]:
vocab = tokenizer.vocab

In [31]:
add_hooks(model)

In [32]:
embedding_weight = get_embedding_weight(model)

In [33]:
num_trigger_tokens = 1 # one token prepended
trigger_token_ids = [vocab.get("a")] * num_trigger_tokens

In [34]:
from torch.optim import Adam
from copy import deepcopy

In [35]:
extracted_grads =[]

In [36]:
import torch.nn.functional as F
import torch.nn as nn

In [37]:
criterion = nn.CrossEntropyLoss()

In [38]:
def hotflip_attack(averaged_grad, embedding_matrix, trigger_token_ids,
                   increase_loss=False, num_candidates=1):
    averaged_grad = averaged_grad.cpu()
    embedding_matrix = embedding_matrix.cpu()
    trigger_token_embeds = torch.nn.functional.embedding(torch.LongTensor(trigger_token_ids),
                                                         embedding_matrix).detach().unsqueeze(0)
    averaged_grad = averaged_grad.unsqueeze(0)
    gradient_dot_embedding_matrix = torch.einsum("bij,kj->bik",
                                                 (averaged_grad, embedding_matrix))        
    if not increase_loss:
        gradient_dot_embedding_matrix *= -1    # lower versus increase the class probability.
    if num_candidates > 1: # get top k options
        _, best_k_ids = torch.topk(gradient_dot_embedding_matrix, num_candidates, dim=2)
        return best_k_ids.detach().cpu().numpy()[0]
    _, best_at_each_step = gradient_dot_embedding_matrix.max(2)
    return best_at_each_step[0].detach().cpu().numpy()

In [39]:
cand_trigger_token_ids = []

In [40]:
for _ in range(6) :
  for inputs, targets in tqdm(loader):
    model.train()
    inputs = inputs.to(device)
    targets = targets.to(device)
    optimizer = Adam(model.parameters())
    optimizer.zero_grad()
    original_labels = targets.clone()
    global extracted_grads
    extracted_grads = []
    targets= torch.cat((targets,torch.ones(1, dtype=int).to(device)),0)

    trigger_sequence_tensor = torch.LongTensor(deepcopy(trigger_token_ids))
    trigger_sequence_tensor = trigger_sequence_tensor.repeat(1, 512).to(device)
    original_tokens = inputs['input_ids'].squeeze(1).clone().to(device)

    input_tokens = torch.cat((trigger_sequence_tensor, original_tokens), 0)

    input_attention_masks = torch.cat((inputs['attention_mask'],torch.ones(1,1,512).to(device)),0)
    oh_targets = F.one_hot(targets, num_classes=2).to(torch.float32).to(device)
    outputs = model(input_tokens, input_attention_masks)
    loss = criterion(oh_targets, outputs)
    loss.backward()
    grads = extracted_grads[0].cpu()
    averaged_grad = torch.sum(grads, dim=0)
    averaged_grad = averaged_grad[0:len(trigger_token_ids)]
    cand_trigger_token_ids = hotflip_attack(averaged_grad, embedding_weight, trigger_token_ids, num_candidates=40)
    cand_trigger_token_ids = cand_trigger_token_ids.flatten()
    print([vocab.get(token_id) for token_id in cand_trigger_token_ids])


  0%|          | 0/589 [00:00<?, ?it/s]



[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
[None, None, None, None, None, None,

KeyboardInterrupt: ignored