# Adversarial attacks against Legal-BERT Model (BertForSequenceClassification)

In [1]:
# Global variables

BATCH_SIZE = 32
MODEL_NAME = 'nlpaueb/legal-bert-small-uncased'#'bert-base-uncased'
EPOCHS = 3
EMBEDDING_SIZE = 512
NUM_CLASSES = 2
VOCABULARY_SIZE = 30522
NUM_TOKENS = 3
LIST_ID_SPECIAL_TOKENS = [0, 101, 102, 103]
LIST_SPECIAL_TOKENS = ['[PAD]', '[CLS]', '[SEP]', '[MASK]']

### Installation of packages

In [2]:
!pip install transformers
!pip install torch-lr-finder

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


### Imports

In [3]:
import torch
import os
from transformers import BertTokenizer
from google.colab import drive
from torch.utils.data import TensorDataset, random_split
from transformers import BertForSequenceClassification, AdamW, BertConfig
from transformers import get_linear_schedule_with_warmup
import numpy as np
import time
import datetime
import random
import gc
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from sklearn.model_selection import train_test_split
from copy import deepcopy

### Device

In [4]:
# If there's a GPU available...
if torch.cuda.is_available():     
    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))

# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: Tesla V100-SXM2-16GB


### Reading dataset

In [5]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [6]:
# Funtion to read all sentences
def get_sentences(path):
    sentences= []
    for filename in sorted(os.listdir(path)):
        with open(path+filename, 'r') as f:
            for sentence in f :
                sentences.append(sentence)
    return sentences

In [7]:
# Function to read get all labels
def get_labels(path):
    all_labels = []
    for filename in sorted(os.listdir(path)):
        file_labels = []
        with open(path+filename, 'r') as f:
            for label in f :
                all_labels.append(int(label))
    return all_labels

In [8]:
# Reading sentences and labels
all_sentences = get_sentences("/content/drive/MyDrive/Colab Notebooks/praktikum2/data/Sentences/")
all_labels = get_labels("/content/drive/MyDrive/Colab Notebooks/praktikum2/data/Labels/")

In [9]:
# Since unfair sentences are marked as "-1", we change them to "0" for simplicity. Zero means fair, One means unfair
all_labels =  [0 if label ==-1 else label for label in all_labels]

### Bert Tokenizer

In [10]:
# Load the BERT tokenizer.
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME, do_lower_case=True) # the model 'bert-base-uncased' only contains lower case sentences

Loading BERT tokenizer...


In [42]:
input_ids[1,:]

tensor([ 101,  206, 4313,  177,  115,  521,  245,  581,  115, 4119,  215, 7247,
         207,  410,  236,  220,  259,  115,  283,  207, 3101,  210,  220, 6136,
         115, 4137,  115,  215, 1259,  117,  207,  207,  207,  102,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,   

### Model BertForSequenceClassification (Load model)

In [11]:
model = BertForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels = NUM_CLASSES,
    output_attentions = False,
    output_hidden_states = False,
)

model.cuda()

Some weights of the model checkpoint at nlpaueb/legal-bert-small-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification we

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 512, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, element

In [12]:
model.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/praktikum/trigger_generation(baseline)/Bert4SeqClassif_202207072015.pt'))

<All keys matched successfully>

### Trigger generation

##### General functions

In [13]:
# hook used in add_hooks()
extracted_grads = []
def extract_grad_hook(module, grad_in, grad_out):
    extracted_grads.append(grad_out[0])

In [15]:
# returns the wordpiece embedding weight matrix
def get_embedding_weight(language_model):
    for module in language_model.modules():
        if isinstance(module, torch.nn.Embedding):
            if module.weight.shape[0] == 30522:
                return module.weight.detach()

In [16]:
# add hooks for embeddings
def add_hooks(language_model):
    for module in language_model.modules():
        if isinstance(module, torch.nn.Embedding):
            if module.weight.shape[0] == 30522:
                module.weight.requires_grad = True
                module.register_full_backward_hook(extract_grad_hook)

In [17]:
# creates the batch of target texts with -1 placed at the end of the sequences for padding (for masking out the loss).
def make_target_batch(tokenizer, device, target_texts):
    encoded_texts = []
    max_len = 0
    for target_text in target_texts:
        encoded_target_text = tokenizer.encode_plus(
            target_text,
            add_special_tokens = True,
            max_length = EMBEDDING_SIZE - NUM_TOKENS,
            pad_to_max_length = True,
            return_attention_mask = True
        )
        encoded_texts.append(encoded_target_text.input_ids)
        if len(encoded_target_text.input_ids) > max_len:
            max_len = len(encoded_target_text)

    for indx, encoded_text in enumerate(encoded_texts):
        if len(encoded_text) < max_len:
            encoded_texts[indx].extend([-1] * (max_len - len(encoded_text)))

    target_tokens_batch = None
    for encoded_text in encoded_texts:
        target_tokens = torch.tensor(encoded_text, device=device, dtype=torch.long).unsqueeze(0)
        if target_tokens_batch is None:
            target_tokens_batch = target_tokens
        else:
            target_tokens_batch = torch.cat((target_tokens, target_tokens_batch), dim=0)
    return target_tokens_batch

In [18]:
# Got from https://github.com/Eric-Wallace/universal-triggers/blob/master/attacks.py

def hotflip_attack(averaged_grad, embedding_matrix, trigger_token_ids,
                   increase_loss=False, num_candidates=1):
    """
    The "Hotflip" attack described in Equation (2) of the paper. This code is heavily inspired by
    the nice code of Paul Michel here https://github.com/pmichel31415/translate/blob/paul/
    pytorch_translate/research/adversarial/adversaries/brute_force_adversary.py
    This function takes in the model's average_grad over a batch of examples, the model's
    token embedding matrix, and the current trigger token IDs. It returns the top token
    candidates for each position.
    If increase_loss=True, then the attack reverses the sign of the gradient and tries to increase
    the loss (decrease the model's probability of the true class). For targeted attacks, you want
    to decrease the loss of the target class (increase_loss=False).
    """
    averaged_grad = averaged_grad.cpu()
    embedding_matrix = embedding_matrix.cpu()
    trigger_token_embeds = torch.nn.functional.embedding(torch.LongTensor(trigger_token_ids),
                                                         embedding_matrix).detach().unsqueeze(0)
    averaged_grad = averaged_grad.unsqueeze(0)
    gradient_dot_embedding_matrix = torch.einsum("bij,kj->bik",
                                                 (averaged_grad, embedding_matrix))        
    if not increase_loss:
        gradient_dot_embedding_matrix *= -1    
    if num_candidates > 1: 
        _, best_k_ids = torch.topk(gradient_dot_embedding_matrix, num_candidates, dim=2)
        return best_k_ids.detach().cpu().numpy()[0]
    _, best_at_each_step = gradient_dot_embedding_matrix.max(2)
    return best_at_each_step[0].detach().cpu().numpy()

In [54]:
def get_input_masks_and_labels_with_tokens(sentences, labels, tokens, position='B'):
    input_ids = []
    attention_masks = []
    number_of_tokens = []

    for sent in sentences:

        if position == 'B':
            sent_with_tokens = tokens + " " + sent
        elif position == 'E':
            sent_with_tokens = sent + " " + tokens
        else:
            print('Wrong position command, please enter "E" or "B"')
            return

        encoded_dict = tokenizer.encode_plus(
                        sent_with_tokens,
                        add_special_tokens = True,
                        max_length = 512,
                        pad_to_max_length = True,
                        return_attention_mask = True,
                        return_tensors = 'pt',
                   )


        input_ids.append(encoded_dict['input_ids']) 
            
        #print(encoded_dict['input_ids'])

        attention_masks.append(encoded_dict['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    labels = torch.tensor(labels)

    # count number of tokens of each sentence
    for idx in range(len(input_ids)):
      sent_ids = input_ids[idx, :]

      cnt = 0
      for id in sent_ids:
          if id != 0:
              cnt += 1

      number_of_tokens.append(cnt)  

    return input_ids, attention_masks, labels, number_of_tokens

In [55]:
def get_loss_and_metrics(model, dataloader, device):
    # get initial loss for the trigger
    model.zero_grad()

    test_preds = []
    test_targets = []

    # Tracking variables 
    total_test_accuracy = 0
    total_test_loss = 0
    io_total_test_acc = 0
    io_total_test_prec = 0
    io_total_test_recall = 0
    io_total_test_f1 = 0

    for batch in dataloader:
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        model.zero_grad()

        result = model(b_input_ids, 
                    token_type_ids=None, 
                    attention_mask=b_input_mask, 
                    labels=b_labels,
                    return_dict=True)

        loss = result.loss
        logits = result.logits

        test_preds.extend(logits.argmax(dim=1).cpu().numpy())
        test_targets.extend(batch[2].numpy())

        # Accumulate the validation loss.
        total_test_loss += loss.item()

        test_preds.extend(logits.argmax(dim=1).cpu().numpy())
        test_targets.extend(batch[2].numpy())

        # Move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        loss.backward()        

        # Calculate the accuracy for this batch of test sentences, and
        # accumulate it over all batches.        
        test_acc = accuracy_score(test_targets, test_preds)
        test_precision = precision_score(test_targets, test_preds)
        test_recall = recall_score(test_targets, test_preds)
        test_f1 = f1_score(test_targets, test_preds)

        io_total_test_acc += test_acc
        io_total_test_prec += test_precision
        io_total_test_recall += test_recall
        io_total_test_f1 += test_f1

    io_avg_test_loss = total_test_loss/len(dataloader)
    io_avg_test_acc = io_total_test_acc / len(dataloader)
    io_avg_test_prec = io_total_test_prec / len(dataloader)
    io_avg_test_recall = io_total_test_recall / len(dataloader)
    io_avg_test_f1 = io_total_test_f1 / len(dataloader)
    print(
            f'Loss {io_avg_test_loss} : \t\
            Valid_acc : {io_avg_test_acc}\t\
            Valid_F1 : {io_avg_test_f1}\t\
            Valid_precision : {io_avg_test_prec}\t\
            Valid_recall : {io_avg_test_recall}'
          )

    return io_avg_test_loss, io_avg_test_acc, io_avg_test_prec, io_avg_test_recall, io_avg_test_f1

In [86]:
def change_input_ids_with_candidate_token(input_ids, position, candidate, number_of_tokens, trigger_position='B'):
    if trigger_position == 'B':
        input_ids[:, position] = candidate
    elif trigger_position == 'E':
        #print(".", end="")
        for idx in range(len(input_ids)):
            #print("number_of_tokens[idx]: ", number_of_tokens[idx])

            if number_of_tokens[idx] > EMBEDDING_SIZE:
                #print("+", end="")
                #print("Index: ", EMBEDDING_SIZE-NUM_TOKENS-2+position, "\n")
                #print(f"+[{input_ids[idx, EMBEDDING_SIZE-NUM_TOKENS-2+position]}][{candidate}]", end="")
                input_ids[idx, EMBEDDING_SIZE-NUM_TOKENS-2+position] = candidate
            else:
                #print("-", end="")
                #print("Index: ", number_of_tokens[idx]-NUM_TOKENS-2+position, "\n")
                #print(f"+[{input_ids[idx, number_of_tokens[idx]-NUM_TOKENS-2+position]}][{candidate}]", end="")
                input_ids[idx, number_of_tokens[idx]-NUM_TOKENS-2+position] = candidate
    else:
        print('Wrong position command, please enter "E" or "B"')
        return
    return input_ids

In [67]:
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [68]:
positions_unfair = np.where(np.array(all_labels) == 1)[0]
print(f'First 32 positions: {positions_unfair[0:32]} with total of unfair sentences {len(positions_unfair)}')

target_unfair_sentences = []
labels_unfair_sentences = []
for index in range(len(positions_unfair)):
    target_unfair_sentences.append(all_sentences[positions_unfair[index]])
    labels_unfair_sentences.append(all_labels[positions_unfair[index]])


First 32 positions: [  4   9  10  11  12  13  24  25  43  45  61  62  78  79  87  89  91  92
 100 104 109 111 143 151 154 157 169 195 206 258 260 266] with total of unfair sentences 1032


In [69]:
model.eval()
model.to(device)

add_hooks(model) # add gradient hooks to embeddings
embedding_weight = get_embedding_weight(model) # save the word embedding matrix

In [81]:
trigger_tokens = np.array([207]*NUM_TOKENS)
print(tokenizer.decode(trigger_tokens))

the the the


In [91]:
position = 'E'
#position = 'B'

input_ids, attention_masks, labels, number_of_tokens = get_input_masks_and_labels_with_tokens(target_unfair_sentences, labels_unfair_sentences, tokenizer.decode(trigger_tokens), position=position)

dataset = TensorDataset(input_ids, attention_masks, labels)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)



In [None]:
extracted_grads = []

loss_obtained, acc_obtained, prec_obtained, recall_obtained, f1_obtained = get_loss_and_metrics(model, dataloader, device)
print(f'acc_obtained {acc_obtained}')

candidates_selected = [207]*NUM_TOKENS
# try all the candidates and pick the best
curr_best_loss = acc_obtained
curr_best_trigger_tokens = None

for id_token_to_flip in range(0, NUM_TOKENS):

    averaged_grad = torch.sum(extracted_grads[0], dim=0)
    averaged_grad = averaged_grad[id_token_to_flip].unsqueeze(0)

    # Use hotflip (linear approximation) attack to get the top num_candidates
    candidates = hotflip_attack(averaged_grad, embedding_weight,
                                        [trigger_tokens[id_token_to_flip]], 
                                        increase_loss=False, num_candidates=100)[0]
    print(f'candidates {candidates}')
    
    for index, cand in enumerate(candidates):
        extracted_grads = []

        if cand in LIST_ID_SPECIAL_TOKENS:
          continue

        #print('input ids: ', input_ids)
        #print('input ids shape: ', input_ids.shape)
        input_ids_with_candidate_trigger = change_input_ids_with_candidate_token(deepcopy(input_ids), id_token_to_flip+1, cand, number_of_tokens, trigger_position=position)
        dataset_with_candidate_trigger = TensorDataset(input_ids_with_candidate_trigger, attention_masks, labels)
        dataloader_with_candidate_trigger = torch.utils.data.DataLoader(dataset_with_candidate_trigger, batch_size=BATCH_SIZE)

        current_loss, current_acc, current_prec, current_recall, current_f1 = get_loss_and_metrics(model, dataloader_with_candidate_trigger, device)

        if curr_best_loss > current_acc:
            curr_best_loss = current_acc
            candidates_selected[id_token_to_flip] = cand

        del input_ids_with_candidate_trigger
        del dataset_with_candidate_trigger
        del dataloader_with_candidate_trigger

        gc.collect()
        torch.cuda.empty_cache()

        print(f'[{id_token_to_flip}][{index}] acc[{index}] {current_acc} ({curr_best_loss})')


    #extracted_grads = []
    input_ids = change_input_ids_with_candidate_token(deepcopy(input_ids), id_token_to_flip+1, candidates_selected[id_token_to_flip], number_of_tokens, trigger_position=position)
    print(f'Worst acc {curr_best_loss} with candidates {candidates_selected}')


Loss 0.3653466764724616 : 	            Valid_acc : 0.8911350130180676	            Valid_F1 : 0.9423908420489892	            Valid_precision : 1.0	            Valid_recall : 0.8911350130180676
acc_obtained 0.8911350130180676
candidates [ 5102   232  2705   457   660  7742 12986 10902  1753  3522  7784   591
  1914  2596  1327   572   410   271   587  1050  1607  1626   531   382
 23357  1731   799  1363  6767 19616 29799   679  1807  2117  3477  1059
  6966  1428   378  2586 19090  1599  1825   635   547  1281  1635   819
   258 18253  2774  2126  2730  1415  1750 14703   266  5578  4099  1382
   598   226   333  4437  1297  4004   705   621 23272   454  1216  8498
  1688   599  2176  2058 21418   477  9240   863   935   694   907   198
  2624  5947  9784  2214 12491  1401  1808  3639  4348 19548 18878  2268
 13056  6092  2666 20256]
Loss 0.3595374542655367 : 	            Valid_acc : 0.8984564107364256	            Valid_F1 : 0.9464368646427817	            Valid_precision : 1.0	         

In [None]:
#print(tokenizer.encode("the the the")) #[101, 207, 207, 207, 102]
#print(tokenizer.decode([621, 13890, 21241, 23113, 221, 1898]))# Loss => unless communist normativ encroachments as anything
#print(tokenizer.decode([621, 13890, 13064, 1897, 1629, 29403]))# Accuracy => unless communist tolerate political dismissed disjunctive
#print(tokenizer.decode([621, 13890, 13064, 1897, 1629, 22121]))# F1 => unless communist tolerate political dismissed symmetrical

In [85]:
print(tokenizer.decode([207, 14768, 207]))

the entailing the
