<a href="https://colab.research.google.com/github/abyaadrafid/LDA_Lab_Defence/blob/main/NaturalAdvAttack.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Adversarial attacks against Legal-BERT Model (BertForSequenceClassification)

In [1]:
# Global variables

BATCH_SIZE = 16
MODEL_NAME = 'nlpaueb/legal-bert-small-uncased'#'bert-base-uncased'
EPOCHS = 3
EMBEDDING_SIZE = 512
NUM_CLASSES = 2
VOCABULARY_SIZE = 30522
NUM_TOKENS = 6


### Installation of packages

In [2]:
!pip install transformers
!pip install torch-lr-finder

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


### Imports

In [3]:
import torch
import os
import sys
import json
import argparse
from transformers import BertTokenizer
from google.colab import drive
from torch.utils.data import TensorDataset, random_split
from transformers import BertForSequenceClassification, AdamW, BertConfig
from transformers import get_linear_schedule_with_warmup
import numpy as np
import time
import datetime
import random
import gc
from torch.autograd import Variable
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from sklearn.model_selection import train_test_split
from copy import deepcopy
from tqdm import tqdm_notebook

### Device

In [4]:
# If there's a GPU available...
if torch.cuda.is_available():     
    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))

# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: Tesla T4


### Reading dataset

In [5]:
drive.mount('/content/drive')
%cd /content/drive/MyDrive/

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive


In [6]:
sys.path.append("/content/drive/MyDrive/Colab Notebooks")

In [7]:
from ARAE_utils import Seq2Seq, MLP_D, MLP_G, generate
from attack_util import project_noise, one_hot_prob, GPT2_LM_loss, select_fluent_trigger

In [8]:
# Funtion to read all sentences
def get_sentences(path):
    sentences= []
    for filename in os.listdir(path):
        with open(path+filename, 'r') as f:
            for sentence in f :
                sentences.append(sentence)
    return sentences

In [9]:
# Function to read get all labels
def get_labels(path):
    all_labels = []
    for filename in os.listdir(path):
        file_labels = []
        with open(path+filename, 'r') as f:
            for label in f :
                all_labels.append(int(label))
    return all_labels

In [10]:
# Reading sentences and labels
all_sentences = get_sentences("Sentences/")
all_labels = get_labels("Labels/")

In [11]:
# Since unfair sentences are marked as "-1", we change them to "0" for simplicity. Zero means fair, One means unfair
all_labels =  [0 if label ==-1 else label for label in all_labels]

### Bert Tokenizer

In [12]:
# Load the BERT tokenizer.
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME, do_lower_case=True) # the model 'bert-base-uncased' only contains lower case sentences

Loading BERT tokenizer...


### Model BertForSequenceClassification (Load model)

In [13]:
model = BertForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels = NUM_CLASSES,
    output_attentions = False,
    output_hidden_states = False,
)

model.to(device);

Some weights of the model checkpoint at nlpaueb/legal-bert-small-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification we

In [14]:
# model.load_state_dict(torch.load('Bert4SeqClassif_202207072015.pt'))

In [15]:
def load_ARAE_models(load_path, args):
    # function to load ARAE model.
    if not os.path.exists(load_path):
        print('Please download the pretrained ARAE model first')
        
    ARAE_args = json.load(open(os.path.join(load_path, 'options.json'), 'r'))
    vars(args).update(ARAE_args)
    autoencoder = Seq2Seq(emsize=args.emsize,
                          nhidden=args.nhidden,
                          ntokens=args.ntokens,
                          nlayers=args.nlayers,
                          noise_r=args.noise_r,
                          hidden_init=args.hidden_init,
                          dropout=args.dropout,
                          gpu=True)
    gan_gen = MLP_G(ninput=args.z_size, noutput=args.nhidden, layers=args.arch_g)
    gan_disc = MLP_D(ninput=args.nhidden, noutput=1, layers=args.arch_d)

    autoencoder = autoencoder.cuda()
    gan_gen = gan_gen.cuda()
    gan_disc = gan_disc.cuda()

    ARAE_word2idx = json.load(open(os.path.join(args.load_path, 'vocab.json'), 'r'))
    ARAE_idx2word = {v: k for k, v in ARAE_word2idx.items()}

    print('Loading models from {}'.format(args.load_path))
    loaded = torch.load(os.path.join(args.load_path, "model.pt"))
    autoencoder.load_state_dict(loaded.get('ae'))
    gan_gen.load_state_dict(loaded.get('gan_g'))
    gan_disc.load_state_dict(loaded.get('gan_d'))
    return ARAE_args, ARAE_idx2word, ARAE_word2idx, autoencoder, gan_gen, gan_disc

In [16]:
parser = argparse.ArgumentParser()
parser.add_argument('--load_path', type=str, default='/content/drive/MyDrive/oneb_pretrained',
                    help='directory to load models from')
parser.add_argument('--seed', type=int, default=1111,
                    help='random seed')
parser.add_argument('--sample', action='store_true',
                    help='sample when decoding for generation')
parser.add_argument('--len_lim', type=int, default=5,
                    help='maximum length of sentence')
parser.add_argument('--r_lim', type=float, default=1,
                    help='lim of radius of z')
parser.add_argument('--sentiment_path', type=str, default='./opinion_lexicon_English',
                    help='directory to load sentiment word from')
parser.add_argument('--z_seed', type=float, default=6.,
                    help='noise seed for z')
parser.add_argument('--avoid_l', type=int, default=4,
                    help='length to avoid repeated pattern')
parser.add_argument('--lr', type=float, default=1e3,
                    help='learn rate')
parser.add_argument('--attack_class', type=str, default='1',
                    help='the class label to attack')
parser.add_argument('--noise_n', type=int, default=1,
                    help='number of generated noise vectors')
parser.add_argument('--tot_runs', type=int, default=1,
                    help='number of attack runs')
args = parser.parse_args([])

In [17]:
r_threshold = args.r_lim
step_bound = r_threshold / 100
max_iterations = 1000

In [18]:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

# initialize ARAE model.
ARAE_args, ARAE_idx2word, ARAE_word2idx, autoencoder, gan_gen, gan_disc = load_ARAE_models(args.load_path, args)

Loading models from /content/drive/MyDrive/oneb_pretrained


In [19]:
# returns the wordpiece embedding weight matrix
def get_embedding_weight(language_model):
    for module in language_model.modules():
        if isinstance(module, torch.nn.Embedding):
            if module.weight.shape[0] == 30522:
                return module.weight.detach()

In [20]:
# add hooks for embeddings
def add_hooks(language_model):
    for module in language_model.modules():
        if isinstance(module, torch.nn.Embedding):
            if module.weight.shape[0] == 30522:
                module.weight.requires_grad = True
                module.register_full_backward_hook(extract_grad_hook)

In [21]:
# hook used in add_hooks()
extracted_grads = []
def extract_grad_hook(module, grad_in, grad_out):
    extracted_grads.append(grad_out[0])

In [22]:
model.eval()
model.to(device)

add_hooks(model) # add gradient hooks to embeddings
embedding_weight = get_embedding_weight(model) # save the word embedding matrix


In [23]:
ARAE_weight_embedding = []
for num in range(len(ARAE_idx2word)):
    ARAE_weight_embedding.append(embedding_weight[tokenizer.convert_tokens_to_ids(ARAE_idx2word[num])])
ARAE_weight_embedding = torch.stack(ARAE_weight_embedding)
print(ARAE_weight_embedding.shape)

torch.Size([30004, 512])


### Trigger generation

##### General functions

In [24]:
# creates the batch of target texts with -1 placed at the end of the sequences for padding (for masking out the loss).
def make_target_batch(tokenizer, device, target_texts):
    encoded_texts = []
    max_len = 0
    for target_text in target_texts:
        encoded_target_text = tokenizer.encode_plus(
            target_text,
            add_special_tokens = True,
            max_length = EMBEDDING_SIZE - NUM_TOKENS,
            pad_to_max_length = True,
            return_attention_mask = True
        )
        encoded_texts.append(encoded_target_text.input_ids)
        if len(encoded_target_text.input_ids) > max_len:
            max_len = len(encoded_target_text)

    for indx, encoded_text in enumerate(encoded_texts):
        if len(encoded_text) < max_len:
            encoded_texts[indx].extend([-1] * (max_len - len(encoded_text)))

    target_tokens_batch = None
    for encoded_text in encoded_texts:
        target_tokens = torch.tensor(encoded_text, device=device, dtype=torch.long).unsqueeze(0)
        if target_tokens_batch is None:
            target_tokens_batch = target_tokens
        else:
            target_tokens_batch = torch.cat((target_tokens, target_tokens_batch), dim=0)
    return target_tokens_batch

In [25]:
def get_input_masks_and_labels_with_tokens(sentences, labels, tokens):
    input_ids = []
    attention_masks = []

    for sent in sentences:
        sent_with_tokens = " ".join(tokens) + " " + sent

        encoded_dict = tokenizer.encode_plus(
                        sent,
                        add_special_tokens = True,
                        max_length = 512 - NUM_TOKENS+1,
                        pad_to_max_length = True,
                        return_attention_mask = True,
                        return_tensors = 'pt',
                   )
           
        input_ids.append(encoded_dict['input_ids'])

        attention_masks.append(encoded_dict['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    labels = torch.tensor(labels)

    return input_ids, attention_masks, labels

In [26]:
def get_loss_and_metrics(model, dataloader, device):
    # get initial loss for the trigger
    model.zero_grad()

    test_preds = []
    test_targets = []

    # Tracking variables 
    total_test_accuracy = 0
    total_test_loss = 0
    io_total_test_acc = 0
    io_total_test_prec = 0
    io_total_test_recall = 0
    io_total_test_f1 = 0

    for batch in dataloader:
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        model.zero_grad()

        result = model(b_input_ids, 
                    token_type_ids=None, 
                    attention_mask=b_input_mask, 
                    labels=b_labels,
                    return_dict=True)

        loss = result.loss
        logits = result.logits

        test_preds.extend(logits.argmax(dim=1).cpu().numpy())
        test_targets.extend(batch[2].cpu().numpy())

        # Accumulate the validation loss.
        total_test_loss += loss.item()

        test_preds.extend(logits.argmax(dim=1).cpu().numpy())
        test_targets.extend(batch[2].cpu().numpy())

        # Move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        loss.backward()        

        # Calculate the accuracy for this batch of test sentences, and
        # accumulate it over all batches.        
        test_acc = accuracy_score(test_targets, test_preds)
        test_precision = precision_score(test_targets, test_preds)
        test_recall = recall_score(test_targets, test_preds)
        test_f1 = f1_score(test_targets, test_preds)

        io_total_test_acc += test_acc
        io_total_test_prec += test_precision
        io_total_test_recall += test_recall
        io_total_test_f1 += test_f1

    io_avg_test_loss = total_test_loss/len(dataloader)
    io_avg_test_acc = io_total_test_acc / len(dataloader)
    io_avg_test_prec = io_total_test_prec / len(dataloader)
    io_avg_test_recall = io_total_test_recall / len(dataloader)
    io_avg_test_f1 = io_total_test_f1 / len(dataloader)
    print(
            f'Loss {io_avg_test_loss} : \t\
            Valid_acc : {io_avg_test_acc}\t\
            Valid_F1 : {io_avg_test_f1}\t\
            Valid_precision : {io_avg_test_prec}\t\
            Valid_recall : {io_avg_test_recall}'
          )

    return io_avg_test_loss, io_avg_test_acc, io_avg_test_prec, io_avg_test_recall, io_avg_test_f1

In [27]:
def change_input_ids_with_candidate_token(input_ids, position, candidate):
    input_ids[:,position] = candidate

    return input_ids

In [28]:
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [29]:
positions_unfair = np.where(np.array(all_labels) == 1)[0]
print(f'First 32 positions: {positions_unfair[0:32]} with total of unfair sentences {len(positions_unfair)}')

target_unfair_sentences = []
labels_unfair_sentences = []
for index in range(len(positions_unfair)):
    target_unfair_sentences.append(all_sentences[positions_unfair[index]])
    labels_unfair_sentences.append(all_labels[positions_unfair[index]])


First 32 positions: [ 11  44  48  57  58  59  75  80  89  90  93  95  96 112 113 116 144 145
 148 151 163 183 184 195 196 197 198 199 200 209 210 211] with total of unfair sentences 1032


In [30]:
trigger_tokens = [0,0,0,0,0,0]

In [31]:
input_ids, attention_masks, labels = get_input_masks_and_labels_with_tokens(target_unfair_sentences, labels_unfair_sentences, tokenizer.decode(trigger_tokens))

dataset = TensorDataset(input_ids, attention_masks, labels)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [33]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 512, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, element

In [34]:
def print_generated_sentences_from_ARAE(max_indices):
  max_indices = max_indices.data.cpu().numpy()
  sentences = []
  for idx in max_indices:
      # generated sentence
      words = tokenizer.convert_ids_to_tokens(idx)
      # truncate sentences to first occurrence of <eos>
      truncated_sent = []
      for w in words:
          if w != '<eos>':
              truncated_sent.append(w)
          else:
              break
      sent = " ".join(truncated_sent)
      sentences.append(sent)
  print(sentences)

In [35]:
class GlueEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer = nn.Linear(256,1)
  
  def forward(self, x):
    x = x.permute(1,2,0)
    return self.layer(x).squeeze(-1)

In [36]:
criterion = nn.CrossEntropyLoss()

In [37]:
def forward_with_trigger(out_emb, tokens, masks, labels):
  model.train()
  token_embeddings = model.bert.embeddings.word_embeddings(tokens)
  out_emb = out_emb.repeat(token_embeddings.shape[0],1,1)
  input_embeddings = torch.cat([token_embeddings, out_emb], dim = 1)
  enc_output = model.bert.encoder(input_embeddings)
  pooler_output = model.bert.pooler(enc_output.last_hidden_state)
  dropout_output = model.dropout(pooler_output)
  return model.classifier(dropout_output)

In [40]:
get_loss_and_metrics(model, dataloader, device)
maxlen = args.len_lim
# initialize noise
noise_n = args.noise_n  # this should be a factor of batch_size
tot_runs = args.tot_runs
n_repeat = 1


r_threshold = args.r_lim
step_bound = r_threshold / 100
max_iterations = 1000

patience_lim = 3
patience = 0 
max_trial = 3
all_output = list()
log_loss = int(1e2)

for tmp in tqdm_notebook(range(tot_runs)):
    get_loss_and_metrics(model, dataloader, device)
    step_size = args.lr
    step_scale = 0.1 
    patience = 0
    old_noise = None
    old_loss = float('-Inf')
    loss_list = list()
    update = False
    i_trial = 0

    torch.manual_seed(args.z_seed + tmp)
    print('z_seed:{}'.format(args.z_seed + tmp))
    noise = torch.randn(noise_n, ARAE_args['z_size'], requires_grad=True, device = "cuda")
    noise = Variable(noise, requires_grad=True)
    start_noise_data = noise.data.clone()
    iter = 0
    for i, batch in enumerate(dataloader) :
        # evaluate_batch(model, batch, trigger_token_ids, snli)
        # generate sentence with ARAE, output the word embedding instead of index.

        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        model.train()
        autoencoder.train()
        gan_gen.eval()
        gan_disc.eval()

        hidden = gan_gen(noise)

        max_indices, decoded = autoencoder.generate_decoding(hidden=hidden, maxlen=maxlen, sample=False, avoid_l=args.avoid_l)
        
        # print_generated_sentences_from_ARAE(max_indices)
        
        decoded = torch.stack(decoded, dim=1).float()
        if n_repeat > 1:
            decoded = torch.repeat_interleave(decoded, repeats=n_repeat, dim=0)

        decoded_prob = F.softmax(decoded, dim=-1)
        decoded_prob = one_hot_prob(decoded_prob, max_indices)
        out_emb = torch.matmul(decoded_prob, ARAE_weight_embedding)

        output = forward_with_trigger(out_emb, b_input_ids, b_input_mask, b_labels.unsqueeze(-1))

        oh_targets = F.one_hot(b_labels, num_classes=2).to(torch.float32).to(device)
        loss = criterion(output, oh_targets)
        iter += 1

        loss_list.append(loss.item())
        if noise.grad is not None:
          noise.grad.zero_()
        noise.retain_grad()
        loss.backward()

        noise_diff = step_size * noise.grad.data
        noise_diff = project_noise(noise_diff, r_threshold=step_bound)

        noise.data = noise.data + noise_diff

        whole_diff = noise.data - start_noise_data
        whole_diff = project_noise(whole_diff, r_threshold=r_threshold)
        noise.data = start_noise_data + whole_diff

        if iter % log_loss == 0:
            cur_loss = np.mean(loss_list)
            print('current iter:{}'.format(iter))
            print('current loss:{}'.format(cur_loss))

            loss_list = list()
            if cur_loss > old_loss:
                patience = 0
                old_loss = cur_loss
                old_noise = noise.data.clone()
                update = True
            else:
                patience += 1

            print('current patience:{}'.format(patience))
            print('\n')

            if patience >= patience_lim:
                patience = 0
                step_size *= step_scale
                noise.data = old_noise
                print('current step size:{}'.format(step_size))
                i_trial += 1
                print('current trial:{}'.format(i_trial))
                print('\n')

        if i_trial >= max_trial or iter >= max_iterations:
            if update:
                with torch.no_grad():
                    noise_new = torch.ones(noise_n, ARAE_args['z_size'], requires_grad=False).cuda()
                    noise_new.data = old_noise
                    hidden = gan_gen(noise_new)  # [:1, :]
                    max_indices, decoded = autoencoder.generate_decoding(hidden=hidden, maxlen=maxlen, sample=False, avoid_l=args.avoid_l)

                    decoded = torch.stack(decoded, dim=1).float()
                    if n_repeat > 1:
                        decoded = torch.repeat_interleave(decoded, repeats=n_repeat, dim=0)

                    decoded_prob = F.softmax(decoded, dim=-1)
                    decoded_prob = one_hot_prob(decoded_prob, max_indices)

                sen_idxs = torch.argmax(decoded_prob, dim=2)
                sen_idxs = sen_idxs.cpu().numpy()

                output_s = list()
                glue = ' '
                sentence_list = list()
                for ss in sen_idxs:
                    sentence = [ARAE_idx2word[s] for s in ss]
                    trigger_token_ids = list()
                    last_word = None
                    last_word2 = None
                    contain_sentiment_word = False
                    new_sentence = list()
                    for word in sentence:
                        cur_idx = tokenizer.convert_tokens_to_ids(word)
                        if cur_idx != last_word and cur_idx != last_word2:
                            trigger_token_ids.append(cur_idx)
                            new_sentence.append(word)
                            last_word2 = last_word
                            last_word = cur_idx

                    threshold = 0.5
                    num_lim = 20
                    s_str = glue.join(new_sentence)
                    if not (s_str in sentence_list):
                        _, accuracy, _, _ ,_ = get_loss_and_metrics(model, dataloader, device)
                        if accuracy < threshold:
                            sentence_list.append(s_str)
                            output_s.append((s_str, accuracy, contain_sentiment_word))

                if len(output_s) > 0:
                    all_output = all_output + output_s
                update = False
            break

Loss 0.9242930769920349 : 	            Valid_acc : 0.06792916966156409	            Valid_F1 : 0.12703944889626015	            Valid_precision : 1.0	            Valid_recall : 0.06792916966156409


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/1 [00:00<?, ?it/s]

Loss 0.9315606181438153 : 	            Valid_acc : 0.07304722376856175	            Valid_F1 : 0.13610433909229983	            Valid_precision : 1.0	            Valid_recall : 0.07304722376856175
z_seed:6.0


In [41]:
get_loss_and_metrics(model, dataloader, device)

Loss 0.918743673654703 : 	            Valid_acc : 0.0819289312100446	            Valid_F1 : 0.1510989260208236	            Valid_precision : 1.0	            Valid_recall : 0.0819289312100446


(0.918743673654703,
 0.0819289312100446,
 1.0,
 0.0819289312100446,
 0.1510989260208236)