# Fine Tuneing BERT with WAC and Concreteness Score
There reference for this notebook is found at [here](https://medium.com/analytics-vidhya/finetune-distilbert-for-multi-label-text-classsification-task-994eb448f94c) and at [this repository](https://github.com/DhavalTaunk08/NLP_scripts/blob/master/Transformers_multilabel_distilbert.ipynb). To incorporate the GLUE tasks into the script, this [link](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/text_classification.ipynb?authuser=1#scrollTo=imY1oC3SIrJf) is used.


**In this notebook we focsus on binary classification tasks although the notebook is setup such that it is easy to adopt non-binary tasks by only picking the right loss function.**

## Setup

In [None]:
# # For running on Colab
# from google.colab import drive
# import os
# drive.mount('/content/drive')
# os.chdir(r"drive/MyDrive/Colab Notebooks/wac_bert")  # May need to change based on the user

In [None]:
# # Installing Requirements
# !pip3 install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio===0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
# !pip install -r requirements.txt
# !pip install ./wac_bert_lib/wac_bert_tools-0.1.0-py3-none-any.whl

In [None]:
import os
import pickle as pkl
import numpy as np
import pandas as pd
import random
import warnings
from collections import defaultdict
from sklearn.preprocessing import OneHotEncoder
from datasets import load_dataset, load_metric

import transformers
import torch
from tqdm import tqdm_notebook as tqdm

from wac_bert_tools import wac, concreteness
from wac_bert_tools import registrar, tokenize

warnings.filterwarnings('ignore')
transformers.logging.set_verbosity_error()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Parameters
pick bert like model at https://huggingface.co/models

`GLUE_TASKS = ["cola", "mrpc", "qnli", "qqp", "rte", "sst2", "wnli", "mnli", "mnli-mm", "stsb"]`

`BINARY_GLUE_TASKS = ["cola", "mrpc", "qnli", "qqp", "rte", "sst2", "wnli"]`

In [None]:
task='rte'
model_checkpoint = 'distilbert-base-uncased'
WAC_METHOD = ''#'' or 'CAT' or 'ADD'
MATCH_WAC = False  # Resizes BERT output using a MLP to match smaller WAC vectors.
FIXED_SEED = True
TEST=False # True uses 10% of the data
# Model Perameters
MAX_LEN = 128
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 2e-05
DROPOUT=0.5

tokenizer = transformers.AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
if FIXED_SEED:
    seed_val = 2021
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)
reg_arg={
        'model':model_checkpoint,
        'tokenizer':tokenizer.name_or_path,
        'data': task+('_sharded' if TEST else ''),
        'epoch': EPOCHS,
        'max_len': MAX_LEN, 
        'batch_Size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'fixed_seed':FIXED_SEED,
        'loss_name':'',
        'max_epoch':EPOCHS
    }
if WAC_METHOD=='CAT': MATCH_WAC = True

## Data

* Input Data
    - Concreteness Data
    - WAC vectors Data
* Training Data
    - DataProcessor for GLUE tasks

### Input Data
**To get concreteness and WAC vectors data from new sources do as follows:**

- Write a parser function for your data.
- The parser function must return a dataframe with two columns `['word','measure']`
- Write an argument dictionary for your parser function. e.g. `arg={'fname':r"c:\raw_data.csv"}`
- Call `tokenize.External_Data()` and provide its arguments. It returns a Data Series object of tokenized data
- You may use `wac` or `concreteness` function `.ds2dict()` to convert the data series to a dictionary, named `wac_dict`

In [None]:
# Get Concreteness
concreteness_dict = concreteness.load_input("./inputs/uncased_inferred_conc.pkl", tokenizer.name_or_path)
FULL_CONC = np.ceil(max(list(concreteness_dict.values())))
# Get WAC2VEC
wac_dict = wac.load_input("./inputs/clip_wac_513.pkl", tokenizer.name_or_path)
# # Example of how to load a new WAC data
# from wac_bert_tools import parser
# fpath = r""
# pars_clip = parser.clip_wac
# wac_ds = tokenize.External_Data(pars_clip, {'fname':fpath},'bert-base-uncased',False)
# wac_dict = wac.ds2dict(wac_ds, 513)

In [None]:
# Length of a WAC vector
WAC_LEN = next(len(v) for k,v in wac_dict.items())
assert WAC_LEN<= transformers.AutoModel.from_pretrained(model_checkpoint).config_class().hidden_size ,\
 "Size of WAC vectors must NOT be larger than BERT-like model. Use trim_func argument in wac.ds2dict"

### Training Data

A `DataProcessor` class is reuqired for any new dataset. Any new `DataProcessor` must have the same functions with the same outputs.

In [None]:
class DataProcessor():
    """ Preprocesses a dataset and provides methods and properties required for running the notebook """
    def __init__(self, task):
        GLUE_TASKS = ['cola','mnli', "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]
        assert task in GLUE_TASKS, 'The task is not a valid GLUE task'
        actual_task = 'mnli' if task =='mnli-mm' else task  # there are special cased for mnli
        self.task = task
        self.dataset = load_dataset('glue',actual_task)
        self.metric = load_metric('glue',actual_task)

    def get_NUM_OUT(self):
        """ Returns number of classes in the labels"""
        return 3 if self.task.startswith('mnli') else 1 if self.task=='stsb' else 2

    def get_metric_name(self):
        """Returns the metric name for the task. Must be a key in get_metric_fn() output"""
        return  "pearson" if self.task == "stsb" else "matthews_correlation" if self.task == "cola" else "accuracy"

    def get_metric_fn(self,guess,targs):
        """ guess: torch tensor of predictions (guesses)
          targs: torch tensor of labels (targets)
        Returns a dictionary of metric_name(s) and their values."""
        if self.task != "stsb":
            guesses = torch.max(guess, dim=1)  # maximum value of each row of the input tensor
            targets = torch.max(targs, dim=1)  # like np argmax
            predictions = guesses.indices
            labels = targets.indices
        else:
            predictions = guess.squeeze()
            labels = targs.squeeze()
        return self.metric.compute(predictions=predictions, references=labels)
    
    def get_loss_fn(self):
        if self.get_NUM_OUT()==1:  
            _loss_fn = torch.nn.MSELoss()
        else:
            _loss_fn = torch.nn.BCEWithLogitsLoss()
        return _loss_fn
        
    def get_data(self):
        """Returns a dictionary of 'train_X1', 'train_y', 'test_X1', 'test_y', 'test_X2' and 'train_X2'
              Values must be `numpy.ndarray` of compatible size. 'test_X2' and 'train_X2' values can be None.
              Arrays are examples of string sequences (sentences)"""
        dataz=defaultdict(lambda:None)
        task_to_keys = {
        "cola": ("sentence", None),
        "mnli": ("premise", "hypothesis"),
        "mnli-mm": ("premise", "hypothesis"),
        "mrpc": ("sentence1", "sentence2"),
        "qnli": ("question", "sentence"),
        "qqp": ("question1", "question2"),
        "rte": ("sentence1", "sentence2"),
        "sst2": ("sentence", None),
        "stsb": ("sentence1", "sentence2"),
        "wnli": ("sentence1", "sentence2"),
        }
        validation_key = "validation_mismatched" if self.task == "mnli-mm" else "validation_matched" if self.task == "mnli" else "validation"
        sentence1_key, sentence2_key = task_to_keys[self.task]

        train_X1 = self.dataset['train'][sentence1_key]
        test_X1 = self.dataset[validation_key][sentence1_key]
        train_X1, test_X1 = [np.array(d) for d in [train_X1, test_X1]]

        train_y = np.array(self.dataset['train']['label'])
        test_y = np.array(self.dataset[validation_key]['label'])
        if self.task!='stsb':
            enc = OneHotEncoder(sparse=False)
            train_y = enc.fit_transform(train_y.reshape(-1,1))
            test_y = enc.transform(test_y.reshape(-1,1))
        else:
            train_y = train_y/5
            test_y  = test_y/5
        dataz.update({'train_X1':train_X1, 'train_y':train_y, 'test_X1':test_X1,'test_y':test_y , })
        if sentence2_key:
            train_X2 = self.dataset['train'][sentence1_key]
            test_X2 = self.dataset[validation_key][sentence1_key]
            train_X2, test_X2 = [np.array(d) for d in [train_X2, test_X2]]
            dataz.update({'train_X2':train_X2,'test_X2':test_X2})
            
        return dataz

In [None]:
myData = DataProcessor(task)
dataz = myData.get_data()
compute_metrics = myData.get_metric_fn
loss_fn = myData.get_loss_fn()
NUM_OUT = myData.get_NUM_OUT()
metric_name = myData.get_metric_name()
reg_arg['metric'] = metric_name
reg_arg['loss_name'] = str(loss_fn)

## Tokenizer

In [None]:
class MultiLabelDataset(torch.utils.data.Dataset): 
    """ The purpose of this class is to prepare our data for torch.utils.data.DataLoader"""
    def __init__(self, labels, tokenizer, max_len, text1, text2 = None):
        self.text1 = text1
        self.text2 = text2
        self.targets = torch.from_numpy(labels)
        self.tokenizer = tokenizer
        self.max_len = max_len  # Max len of a sequence
        self.num_class = len(torch.unique(self.targets))

    def __len__(self):  # Length of the dataset
        return len(self.targets)

    def __getitem__(self, index):
        text1 = self.text1[index]
        text2 = self.text2[index] if self.text2 is not None else None
        # Encode plus adds “special tokens” which are special IDs the model uses and converts tokens into IDs which are understandable by the model
        inputs = self.tokenizer.encode_plus(
            text1,
            text2,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length', # len of inputs elements must be the same, 'cause of torch.stack method
            truncation=True,
            return_token_type_ids=True
        )
        ids = inputs['input_ids']  # Token indicies, numerical representation of tokens that are building a sequence
        concreteness  = [concreteness_dict[i] for i in ids]
        wac_vector    = [torch.from_numpy(wac_dict[i]) for i in ids]
        mask = inputs['attention_mask']  # indicates to the model which tokens should be attended to and which should not
        # For models with classification and QA purposel, two sequences are to be encoded in the same input ID, hence they need to be seperated by token type
        token_type_ids = inputs["token_type_ids"]  # 

        return {
            'ids': torch.tensor(ids, dtype=torch.long),  # Torch tensor is multi-dimensional matrix containing elements of a single data type
            'concreteness':torch.tensor(concreteness, dtype=torch.float),
            'wac_vector':torch.stack(wac_vector),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'targets': torch.tensor(self.targets[index], dtype=torch.long if self.num_class<=2 else torch.float)
        }

In [None]:
training_data = MultiLabelDataset(dataz['train_y'],tokenizer, MAX_LEN, dataz['train_X1'], dataz['train_X2'])
test_data = MultiLabelDataset(dataz['test_y'],tokenizer, MAX_LEN, dataz['test_X1'], dataz['test_X2'])

if TEST:
    lengths = [len(training_data)//10, len(training_data)-len(training_data)//10]
    training_data = torch.utils.data.random_split(training_data, lengths, generator=torch.Generator().manual_seed(2021))[0]
    lengths = [len(test_data)//10, len(test_data)-len(test_data)//10]
    test_data = torch.utils.data.random_split(test_data, lengths, generator=torch.Generator().manual_seed(2021))[0]

train_params = {'batch_size': BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0,
                'worker_init_fn' : 0 if FIXED_SEED else None
                }

test_params = {'batch_size': BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0,
                'worker_init_fn' : 0 if FIXED_SEED else None
                }    

training_loader = torch.utils.data.DataLoader(training_data, **train_params)
testing_loader = torch.utils.data.DataLoader(test_data, **test_params)  # Returns batches of data

## Build Model

In [None]:
class BertLikeModel(torch.nn.Module): 
    def __init__(self, model_checkpoint, NUM_OUT):
        super(BertLikeModel, self).__init__()
        self.bert = transformers.AutoModel.from_pretrained(model_checkpoint, num_labels = NUM_OUT,output_hidden_states=True)
        LAST_LYR_SIZE = self.bert.config_class().hidden_size
        if WAC_METHOD=='ADD' and MATCH_WAC:
            in_feature_size = WAC_LEN
        elif WAC_METHOD=='CAT':
            in_feature_size = LAST_LYR_SIZE+WAC_LEN+1  # 1 is concreteness
        else:
            in_feature_size = LAST_LYR_SIZE

        self.dense_pool = torch.nn.Linear(in_feature_size, in_feature_size)
        self.activation_pool = torch.nn.ReLU()
        # Applies a linear transformation to the incoming data. Transforming the size of L1 output to a binary size for binary classification
        self.classifier = torch.nn.Linear(in_features=in_feature_size, out_features=NUM_OUT)  # Size of pooler is 768(Hidden)
        self.dropout = torch.nn.Dropout(DROPOUT)
        self.softmax = torch.nn.Softmax(dim=1)  # Force the output Tensor distribution to lie in the range [0,1] and sum to 1.
        self.resize = torch.nn.Linear(LAST_LYR_SIZE,WAC_LEN)
        self.prepooler = torch.nn.Linear(in_feature_size, in_feature_size)

    def pooler(self,hidden_state):
        # We "pool" the model by simply taking the hidden state corresponding to the first token.
        first_token_tensor = hidden_state[:, 0]
        pooled_output = self.dense_pool(first_token_tensor)
        pooled_output = self.activation_pool(pooled_output)
        return pooled_output
    
    def apply_wac(self, hidden_state, wac, conc):
        if MATCH_WAC: wac = wac[:,:, :WAC_LEN]
        concrete_w = conc[:,:,None]
        abstract_w = FULL_CONC-concrete_w  # Assumes full weight for BERT when WAC is not valid
        if WAC_METHOD=='ADD':
            if MATCH_WAC: hidden_state = self.resize(hidden_state)
            hidden_state = (hidden_state*abstract_w + wac*concrete_w)/FULL_CONC 
        elif WAC_METHOD=='CAT':
            hidden_state = torch.cat((hidden_state, wac, concrete_w ),2)
        else:
            if WAC_METHOD != '': raise ValueError('Not a valid WAC_METHOD')
        return self.prepooler(hidden_state)  # To bring the impact of WAC into picture to the first element

    def forward(self, input_ids, attention_mask, token_type_ids, wac, conc):
        output_1 = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        #hidden state size: (batch_size, sequence_length, hidden_size)
        hidden_state = output_1[-1][-1] # Last hidden layer
        if WAC_METHOD: hidden_state = self.apply_wac(hidden_state, wac, conc)
        # Immitate BERT sequence classifier
        pooled = self.pooler(hidden_state)
        pooled = self.dropout(pooled)
        output = self.classifier(pooled)  # Transforms pooler to a e.g. binary 
        return output.squeeze()

## Train and Eval

In [None]:
def train(model, training_loader, optimizer):
    model.train()
    for data in tqdm(training_loader):
        ids = data['ids'].to(device, dtype = torch.long)  # ids.shape=batch_size*max_len
        mask = data['mask'].to(device, dtype = torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.float) 
        wac = data['wac_vector'].to(device, dtype = torch.float)
        conc = data['concreteness'].to(device, dtype = torch.float)
        outputs = model(ids, mask, token_type_ids, wac, conc)  # Its calling forward here, output is distribution over 2 things  
        optimizer.zero_grad()
        loss = loss_fn(outputs.squeeze(), targets.squeeze())
        loss.backward()
        optimizer.step()
    return loss
    
def validation(model, testing_loader):
    model.eval()
    fin_targets=[]
    fin_outputs=[]
    with torch.no_grad():  # Context-manager that disabled gradient calculation.
        for data in tqdm(testing_loader):
            targets = data['targets']
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.long)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            wac = data['wac_vector'].to(device, dtype = torch.float)
            conc = data['concreteness'].to(device, dtype = torch.float)
            outputs = model(ids, mask, token_type_ids, wac, conc)
            if NUM_OUT!=1: outputs = torch.sigmoid(outputs)
            outputs = outputs.cpu().detach()  # Converts to 0 and 1 results instead of doing a loss and backward step, for multiple classes something like argmax is needed!
            fin_outputs.extend(outputs)
            fin_targets.extend(targets)
    return torch.stack(fin_outputs), torch.stack(fin_targets)  # Concatenates a sequence of tensors along a new dimension
def set_value(val, col):
    reg = pd.read_csv('model_registry.csv')
    reg[-1:][col] = val
    reg.to_csv('model_registry.csv', index=False)

In [None]:
#%debug
torch.cuda.empty_cache()
model = BertLikeModel(model_checkpoint,NUM_OUT)
model.to(device)    
optimizer = torch.optim.AdamW(params =  model.parameters(), lr=LEARNING_RATE)
print("Models with no study name won't be saved.")
reg = registrar.register(**reg_arg)
metric_lt = []
for epoch in range(EPOCHS):
    loss = train(model, training_loader, optimizer)
    print(f'Epoch: {epoch}, Loss:  {loss.item()}')  
    guess, targs = validation(model, testing_loader)
    metric_value = compute_metrics(guess, targs)
    metric_lt.append(metric_value[metric_name])
    print('{} of test set'.format(metric_value))
max_metric = max(metric_lt)
registrar.update_metric(max_metric,reg['fDir'])
max_epoch = metric_lt.index(max_metric)+1
set_value(max_epoch, 'epoch')

In [None]:
if reg['study_name']: pkl.dump(model, open('./models/%s.sav'%reg['study_name'],'wb'))