In [1]:
import json, os, pickle, torch, logging, typing, numpy as np
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
from tqdm import tqdm
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
import transformers as T
from logging.handlers import RotatingFileHandler
from src.utils.pickleUtils import pdump, pload, pjoin

from src.proecssing import correct_count
from transformers.models.bert.modeling_bert import SequenceClassifierOutput
from src.classes.datasets import DataItem
from torch import Tensor

In [2]:
log_formatter = logging.Formatter('%(asctime)s %(levelname)s %(funcName)s(%(lineno)d) %(message)s')
def setup_logger(name, log_file, level=logging.INFO):
  handler = logging.FileHandler(log_file, mode='w')
  handler.setFormatter(log_formatter)
  logger = logging.getLogger(name)
  logger.setLevel(level)
  logger.addHandler(handler)
  return logger

error_logger = setup_logger("error_log", "error_log")

In [3]:
DATASET_NAME = "imdb"
DATASET_NAME = "original_augmented_1x_aclImdb"
DATASET_PATH = f"./datasets/{DATASET_NAME}/base"
OUTPUT_PATH = f"checkpoints/{DATASET_NAME}/model"
TRIPLETS_PATH = f"./datasets/{DATASET_NAME}/augmented_triplets"
TOPK_NUM = 4


import json, psutil
# env = {}
# with open("./env.json", mode="r") as f:
#   env = json.load(f)


# memAvailable = psutil.virtual_memory().available
# estimatedMemConsumed = os.path.getsize(os.path.join(DATASET_PATH, "train_set.pickle.blosc")) * 3
# USE_PINNED_MEMORY = True if (env['USE_PINNED_MEMORY'] & (memAvailable > estimatedMemConsumed)) == 1 else False


sampling_ratio = 1

In [4]:
tokenizer: T.BertTokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


In [5]:
import pickle
import pandas as pd
# train_set = pload(os.path.join(DATASET_PATH, "train_set"))
f = open(os.path.join(DATASET_PATH, "train_set.json"))
data = json.load(f)
SPLIT_SAMPLES = 1000000
NOTEBOOK_INDEX = 0
train_texts = [d['anchor_text'] for d in data][SPLIT_SAMPLES * (NOTEBOOK_INDEX):SPLIT_SAMPLES * (NOTEBOOK_INDEX + 1)]
train_labels = [d['label'] for d in data][SPLIT_SAMPLES * (NOTEBOOK_INDEX):SPLIT_SAMPLES * (NOTEBOOK_INDEX + 1)]
train_texts = train_texts[0:100]
train_labels = train_labels[0:100]
# train_texts:list[str] = train_set['review'].tolist()[0:100]
# train_labels:list = train_set['sentiment'].tolist()[0:100]
train_encodings = tokenizer(train_texts, padding=True, truncation=True)
# pdump(train_encodings, os.path.join(DATASET_PATH, "train_encodings"))




# valid_set = pload(os.path.join(DATASET_PATH, "valid_set"))
# valid_texts:list[str] = valid_set['review'].tolist()
# valid_labels: list = valid_set['sentiment'].tolist()
# valid_encodings = tokenizer(valid_texts, padding=True, truncation=True)
# pdump(valid_encodings, os.path.join(DATASET_PATH, "valid_encodings"))
# print(train_encodings.keys())




In [6]:
# from  src.classes.datasets import IMDBDataset
class IMDBDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)
train_dataset = IMDBDataset(labels=train_labels, encodings=train_encodings)

train_loader = DataLoader(
  train_dataset,
  batch_size=1,
  shuffle=False)



In [7]:

model:BertForSequenceClassification = BertForSequenceClassification.from_pretrained(os.path.join(OUTPUT_PATH, 'best_epoch')) #type:ignore
model.to(device)
def get_gradient_norms(batch):
  input_ids:Tensor = batch['input_ids'].to(device)
  attention_mask:Tensor = batch['attention_mask'].to(device)
  labels:Tensor = batch['labels'].to(device)

  _, labels = torch.max(labels, dim=1)


  outputs:SequenceClassifierOutput | tuple[Tensor] = model.forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels, return_dict=True)
  assert isinstance(outputs, SequenceClassifierOutput)
  
  loss = outputs['loss']
  loss.backward(retain_graph=True)
  torch.cuda.empty_cache()
  importances = torch.tensor([]).to(device)
  for pos_index, token_index in zip(range(1, len(input_ids[0])), input_ids[0][1:]):
    if token_index == tokenizer.sep_token_id:
      break
    importance = torch.norm(model.bert.embeddings.position_embeddings.weight.grad[pos_index], 2).float().detach() #type:ignore
    importances = torch.cat((importances, importance.unsqueeze(0)), dim=-1)
  
  model.bert.embeddings.position_embeddings.weight.grad = None #! why???

  return importances





def compute_importances(data_loader:DataLoader, importance_function:typing.Callable) -> list[Tensor]:
  all_importances: list = []
  for batch in tqdm(data_loader):
    importances = importance_function(batch)
    all_importances.append(importances)
  return all_importances

def compute_average_importance(data_loader:DataLoader, all_importances) -> list[Tensor]:
  all_averaged_importances:list = []
  importance_dict = dict()
  importance_dict_counter = dict()

  for importances, batch in tqdm(zip(all_importances, data_loader)):
    tokens = [x for x in batch['input_ids'][0][1:] if x not in [tokenizer.sep_token_id, tokenizer.pad_token_id]]

    for token_importance, token in zip(importances, tokens):
      if not token in importance_dict.keys():
        importance_dict[token.item()] = 0
        importance_dict_counter[token.item()] = 0
      importance_dict[token.item()] += token_importance
      importance_dict_counter[token.item()] += 1
    
  for importances, batch in tqdm(zip(all_importances, data_loader)):
    tokens = [x for x in batch['input_ids'][0][1:] if x not in [tokenizer.sep_token_id, tokenizer.pad_token_id]]
    averaged_importances = torch.Tensor([importance_dict[x.item()]/importance_dict_counter[x.item()] for x in tokens])
    all_averaged_importances.append(averaged_importances)
  return all_averaged_importances



In [8]:
importancePath = os.path.join(DATASET_PATH, "train_set_importance")
importance: list[Tensor] = []
# if(os.path.exists(pjoin(importancePath))):
#   importance = pload(importancePath)
# else:
#   importance = compute_importances(train_loader, get_gradient_norms)
#   pdump(importance, importancePath)
importance = compute_importances(train_loader, get_gradient_norms)

averageImportancePath = os.path.join(DATASET_PATH, "train_set_average_importance")
averageImportance: list[Tensor] = []
# if(os.path.exists(pjoin(averageImportancePath))):
#   averageImportance = pload(averageImportancePath)

# else:
#   averageImportance = compute_average_importance(train_loader, importance)
#   pdump(averageImportance, averageImportancePath)
averageImportance = compute_average_importance(train_loader, importance)

100%|██████████| 100/100 [00:02<00:00, 37.62it/s]
100it [00:00, 847.92it/s]
100it [00:00, 688.71it/s]


In [9]:

mlm_model: BertForMaskedLM = BertForMaskedLM.from_pretrained('bert-base-uncased') #type: ignore
mlm_model: BertForMaskedLM = mlm_model.to(device) #type: ignore
mlm_model.eval()

def mask_data(data_loader:DataLoader, all_importances: list[Tensor], sampling_ratio, augment_ratio):
  triplets = []
  error_count = 0
  no_flip_count = 0
  no_flip_index = []

  for importances, batch in tqdm(zip(all_importances, data_loader)):
    label = []
    tokens = torch.tensor([x for x in batch['input_ids'][0][1:] if x not in [tokenizer.sep_token_id, tokenizer.pad_token_id]]) # could this be done better?
    assert tokens.size() == importances.size()
    
    orig_sample = tokenizer.decode(tokens)
    causal_mask, err_flag, maximum_score = mask_causal_words(tokens.cpu().numpy(), batch, importances.cpu().numpy(), topk=sampling_ratio)
    no_flip_index.append(err_flag)
    if err_flag:
      no_flip_count += 1
    
    if 1 not in causal_mask:
      triplets.append((label, orig_sample, orig_sample, orig_sample, err_flag, maximum_score))
      continue

    for _ in range(augment_ratio):
      causal_masked_tokens = []
      noncausal_masked_tokens = []

      if sampling_ratio is None:
        causal_masked_tokens = [tokens[i] if causal_mask[i] == 0 else tokenizer.mask_token_id for i in range(len(tokens))]
        noncausal_masked_tokens = [tokens[i] if causal_mask[i] == 1 else tokenizer.mask_token_id for i in range(len(tokens))]

      elif type(sampling_ratio) == int:
        causal_indices = np.where(np.array(causal_mask) == 1)[0]
        noncausal_indices = np.where(np.array(causal_mask) == 0)[0]

        causal_mask_indices = np.random.choice(causal_indices, sampling_ratio)

        try:
          noncausal_mask_indices = np.random.choice(noncausal_indices, max(1, min(sampling_ratio, len(noncausal_indices))))
        except:
          noncausal_mask_indices = np.random.choice(causal_indices, sampling_ratio)
          error_count += 1

        causal_masked_tokens = [tokens[i] if i not in causal_mask_indices else tokenizer.mask_token_id for i in range(len(tokens))]
        noncausal_masked_tokens = [tokens[i] if i not in noncausal_mask_indices else tokenizer.mask_token_id for i in range(len(tokens))]
      else:
        pass

      causal_masked_sample = tokenizer.decode(causal_masked_tokens)
      noncausal_masked_sample = tokenizer.decode(noncausal_masked_tokens)

      _, labels = torch.max(batch['labels'], dim=1)
   
      if labels[0] == 0: label = [0, 1]
      elif labels[0] == 1: label = [1, 0]
      triplets.append((label, orig_sample, causal_masked_sample, noncausal_masked_sample, err_flag, maximum_score))
  print(f"Error count: {error_count}")
  print(f"No flip count: {no_flip_count}")
  return triplets, no_flip_index





def mask_causal_words(tokens:Tensor, batch:DataItem, importances: Tensor, topk=1):
  dropout = torch.nn.Dropout(0.5)
  causal_mask = [0 for _ in range(len(tokens))]
  all_importance_indices = np.argsort(importances)[::-1]

  err_flag = False
  find_flag = False

  input_ids:Tensor = batch['input_ids'].squeeze().repeat((TOPK_NUM,)).reshape(TOPK_NUM, -1).to(device)
  attention_mask:Tensor = batch['attention_mask'].expand(TOPK_NUM, -1).to(device)
  token_type_ids:Tensor = batch['token_type_ids'].expand(TOPK_NUM, -1).to(device) #! is token_type_ids actually used anywhere???

  masked_input_ids = batch['input_ids'].squeeze().repeat((len(tokens),)).reshape(len(tokens), -1).to(device)
  masked_attention_mask = batch['attention_mask'].expand(len(tokens), -1).to(device)
  masked_token_type_ids = batch['token_type_ids'].expand(len(tokens), -1).to(device)

  fake_labels = torch.ones((len(tokens), ))
  
  masked_train = IMDBDataset({
    'input_ids': masked_input_ids,
    'attention_mask': masked_attention_mask,
    'token_type_ids': masked_token_type_ids,
    'importance_indices': all_importance_indices
  }, fake_labels)

  masked_train_loader = DataLoader(masked_train, batch_size=4, shuffle=False)
  logits = []
  for masked_batch in masked_train_loader:
    masked_input_ids = masked_batch['input_ids'].to(device) # 4 x 313
    masked_attention_mask = masked_batch['attention_mask'].to(device) # 4 x 313
    masked_token_type_ids = masked_batch['token_type_ids'].to(device) # 4 x 313
    importance_indices = masked_batch['importance_indices'].to(device)# 4
    masked_input_embeds: Tensor = mlm_model.bert.embeddings.word_embeddings(masked_input_ids) #4 x 313 x 768

    #dropout some of the embeddings at random
    for mi_i, topk_i in zip(range(masked_input_embeds.size(0)), importance_indices):
      masked_input_embeds[mi_i][topk_i + 1] = dropout(masked_input_embeds[mi_i][topk_i + 1])
    
    #get predicted words from mlm model given the partially missing embeds
    with torch.no_grad():
      outputs = mlm_model(attention_mask = masked_attention_mask, token_type_ids = masked_token_type_ids, inputs_embeds = masked_input_embeds)
      predictions = outputs[0] # 4 x 313 x 30522, just a casual 38 million numbers. shape(batch_size, sequence_length, config.vocab_size)
    
    #search through and find top k logits, in this case 4. 
    topk_logit_indices = torch.topk(predictions, TOPK_NUM, dim=-1)[1] # 4 sequences x 313 tokens x 4 

    #for the top k most important tokens, get their respective k candidates
    mask_candidates = [topk_logits[importance_index + 1] for importance_index, topk_logits in zip(importance_indices, topk_logit_indices)]

    
    for importance_index, mask_candidate in zip(importance_indices, mask_candidates):
      if importances[importance_index] == 0:
        continue
      recon_input_ids = input_ids.clone()
      for i, mc in enumerate(mask_candidate):
        recon_input_ids[i][importance_index + 1] = mc
      
      with torch.no_grad():
        recon_outputs = model(recon_input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        _, recon_prediction = torch.max(recon_outputs[0], dim=1)
        # print(f"{recon_outputs[0]}")
      if len(torch.unique(recon_prediction)) != 1:
        
        causal_mask[importance_index] = 1
        find_flag = True
        break


    if find_flag:
      break

  if 1 not in causal_mask:
    causal_mask[all_importance_indices[0]] = 1
    err_flag = True
    return causal_mask, err_flag, 0
  
  return causal_mask, err_flag, 0




  


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
sampling_ratio = 1
augment_ratio = 1

triplets_train, no_flip_idx_train = mask_data(train_loader, averageImportance, sampling_ratio=sampling_ratio, augment_ratio=augment_ratio)

triplets_json:list[dict] = []
triplets_pickle:dict[str,list] = {
  "labels": [],
  "anchor_texts": [],
  "positive_texts": [],
  "negative_texts": [],
  "triplet_sample_masks": []
}
for x in triplets_train:
  triplets_json.append(
    {
      "label": x[0],
      "anchor_text": x[1],
      "positive_text": x[3],
      "negative_text": x[2],
      "triplet_sample_mask": x[4]
    }
  )
  triplets_pickle["labels"].append(x[0])
  triplets_pickle["anchor_texts"].append(x[1])
  triplets_pickle["positive_texts"].append(x[3])
  triplets_pickle["negative_texts"].append(x[2])
  triplets_pickle["triplet_sample_masks"].append(x[4])


if not os.path.exists(TRIPLETS_PATH):
  os.mkdir(TRIPLETS_PATH)

pdump(triplets_pickle, os.path.join(TRIPLETS_PATH, "augmented_triplets"))
with open(os.path.join(TRIPLETS_PATH, "augmented_triplets.json"), mode='w') as f:
  json.dump(triplets_json, f)



  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
  item['labels'] = torch.tensor(self.labels[idx])
100it [02:12,  1.33s/it]

Error Cnt: 0
No Flip Cnt: 83



