# Baseline Model

## Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pandas as pd

import transformers
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from transformers import DataCollatorWithPadding

torch.manual_seed(1)
print(transformers.__version__)

## Data Loader

In [None]:
# Define default Tokenizer
defaultTokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# Parameters: Path to dataset csv, tokenizer object, maximum token length
# Label_id is a number 0-4, indicate which set of EI/SN/TF/JP will be the label
class Dset(Dataset):
    # Constructor
    def __init__(self, path, label_id, tokenizer=defaultTokenizer, max_token_len=500):
        # Initialize some variables
        self.df = pd.read_csv(path).dropna()
        self.label_id = label_id
        self.tokenizer = tokenizer
        self.max_token_len = max_token_len
        
        # in order to convert MBTI labels into numbers
        self.labelstrdicts={1:"ESTJ", 0:"INFP"}
    
    # Override __len__ 
    def __len__(self):
        return len(self.df)
    
    # Override __getitem__
    def __getitem__(self, index):
        # get dataframe row
        item = self.df.iloc[index]
        # get the text and label from dataframe row
        text = item["post"]
        ptype = item["type"] # Personality Type
        # and for labels, turn them into list of numbers (1/0s)
        labels = self.str2label(ptype)[self.label_id]
        
        # Now try tokenize with the BERT Tokenizer
        try:
          tokens=self.tokenizer(text,return_tensors="pt", truncation=True, max_length=self.max_token_len, padding="max_length")
            # return_tensors -> Return "pt" pytorch "torch.Tensor" objects instead of python int list
            # truncation -> "true": Truncate to maximum length specified with argument "max_length"
            # padding -> "max_length": Pad to a maximum length specified witht eh argument "max_length" 
        except:
          print(text)
          quit()
        return {"input_ids": torch.squeeze(tokens.input_ids), \
                "attention_mask":torch.squeeze(tokens.attention_mask), \
                "labels":labels}
    
    # Auxiliary Functions
    def str2label(self, string):
        label=[]
        for letter in string:
            if letter in "ESTJ":
                label.append(1)
            else:
                label.append(0)
        return label
    def label2str(self, label):
        string=[]
        for index,number in enumerate(label):
            string.append(self.labelstrdicts[number][index])
        return string
        

In [None]:
# Function to automatically split and return train/validation/test dataset object
data_collator=DataCollatorWithPadding(tokenizer=defaultTokenizer)
def getdl(ds, batch_size):
    total_len=len(ds)
    train_len=int(len(ds)*0.8)
    val_len=int((total_len-train_len)/2)
    test_len=total_len-train_len-val_len
    [train_ds, val_ds, test_ds]=torch.utils.data.random_split(ds, [train_len, val_len, test_len])
    #return (training dataloader, validation dataloader, test dataloader)
    return DataLoader(train_ds, batch_size=batch_size, shuffle=False, collate_fn=data_collator),\
        DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=data_collator),\
        DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=data_collator)

In [None]:
# Now getting all three datasets at once
path="./dataset2.csv"
dataset_EI_all=Dset(path, 0) # use default tokenizer
dataset_SN_all=Dset(path, 1)
dataset_TF_all=Dset(path, 2)
dataset_JP_all=Dset(path, 3)
# Split into three dataloaders
train_dl_EI, val_dl_EI, test_dl_EI=getdl(dataset_EI_all, batch_size=50)
train_dl_SN, val_dl_SN, test_dl_SN=getdl(dataset_SN_all, batch_size=50)
train_dl_TF, val_dl_TF, test_dl_TF=getdl(dataset_TF_all, batch_size=50)
train_dl_JP, val_dl_JP, test_dl_JP=getdl(dataset_JP_all, batch_size=50)

In [None]:
# Test Print
for i,batch in enumerate(train_dl_EI):
    print(batch["input_ids"].shape)
    print(batch["input_ids"])
    print(batch["labels"])
    break

## Model(s)

In [None]:
class BaselineModel(nn.Module):
    def __init__(self):
        pass

In [None]:
baselineModel_EI = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels = 2,
    output_attentions = False,
    output_hidden_states = False,
    return_dict=False
)
baselineModel_SN = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels = 2,
    output_attentions = False,
    output_hidden_states = False,
    return_dict=False
)
baselineModel_TF = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels = 2,
    output_attentions = False,
    output_hidden_states = False,
    return_dict=False
)
baselineModel_JP = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels = 2,
    output_attentions = False,
    output_hidden_states = False,
    return_dict=False
)

In [None]:
# Inspect Model Structure
baselineModel_EI.modules

## Loss and Optimizer (moved to training function)

In [None]:
# Loss
#criterion = nn.CrossEntropyLoss()
# Optimizer
#optimizer = optim.Adam(baselineModel.parameters(), lr=1e-5)
# Scheduler
#scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=500, eta_min=1e-6)

## Evaluation Metrics

In [None]:
def b_tp(preds, labels):
  '''Returns True Positives (TP): count of correct predictions of actual class 1'''
  return sum([preds == labels and preds == 1 for preds, labels in zip(preds, labels)])

def b_fp(preds, labels):
  '''Returns False Positives (FP): count of wrong predictions of actual class 1'''
  return sum([preds != labels and preds == 1 for preds, labels in zip(preds, labels)])

def b_tn(preds, labels):
  '''Returns True Negatives (TN): count of correct predictions of actual class 0'''
  return sum([preds == labels and preds == 0 for preds, labels in zip(preds, labels)])

def b_fn(preds, labels):
  '''Returns False Negatives (FN): count of wrong predictions of actual class 0'''
  return sum([preds != labels and preds == 0 for preds, labels in zip(preds, labels)])

def b_metrics(preds, labels):
  '''
  Returns the following metrics:
    - accuracy    = (TP + TN) / N
    - precision   = TP / (TP + FP)
    - recall      = TP / (TP + FN)
    - specificity = TN / (TN + FP)
  '''
  preds = np.argmax(preds, axis = 1).flatten()
  labels = labels.flatten()
  tp = b_tp(preds, labels)
  tn = b_tn(preds, labels)
  fp = b_fp(preds, labels)
  fn = b_fn(preds, labels)
  b_accuracy = (tp + tn) / len(labels)
  b_precision = tp / (tp + fp) if (tp + fp) > 0 else 'nan'
  b_recall = tp / (tp + fn) if (tp + fn) > 0 else 'nan'
  b_specificity = tn / (tn + fp) if (tn + fp) > 0 else 'nan'
  return b_accuracy, b_precision, b_recall, b_specificity

## Training

In [None]:
def train(train_ds, eval_ds, model, epochs):
    dev = "mps:0"
    device = torch.device(dev)
    
    # Loss
    criterion = nn.CrossEntropyLoss()
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    # Scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=500, eta_min=1e-6)
    
    # Setup on GPU
    model = model.to(device)
    criterion = criterion.to(device)
    
    for e in range(1, epochs+1):
        print(f'Epoch: {e}')
        # Set model to training mode
        model.train()
        
        # Tracking variables
        tr_loss = 0 # Training Loss
        nb_tr_examples = 0 # number of training examples
        nb_tr_steps = 0 # number of training examples
        
        for step, batch in enumerate(train_ds):
            #batch = tuple(t.to(device) for t in batch)
            batch.to(device)
            b_input_ids = batch["input_ids"]
            b_input_mask = batch["attention_mask"]
            b_labels = batch["labels"]
            optimizer.zero_grad()
            # Forward Pass
            train_output = model(b_input_ids, 
                                 token_type_ids = None,
                                 attention_mask = b_input_mask, 
                                 labels = b_labels)

            # Backward Pass
            bloss = criterion(train_output[1], b_labels)
            bloss.backward()
            #train_output.loss.backward()
            optimizer.step()
            scheduler.step()
            
            # Update Tracking variables
            #tr_loss += train_output.loss.item()
            tr_loss += bloss.item()
            nb_tr_examples += b_input_ids.size(0)
            nb_tr_steps += 1
        
        # Validation ==================================
        
        # Set model to evaluation mode
        model.eval()
        # Tracking variables
        val_accuracy = []
        val_precision = []
        val_recall = []
        val_specificity = []
        
        for batch in validation_dataloader:
            #batch = tuple(t.to(device) for t in batch)
            batch.to(device)
            b_input_ids = batch["input_ids"]
            b_input_mask = batch["attention_mask"]
            b_labels = batch["labels"]
            with torch.no_grad():
                # Forward Pass
                eval_output = model(b_input_ids, 
                                 token_type_ids = None,
                                 attention_mask = b_input_mask)
            logits = eval_output.logits.detatch().cpu().numpy()
            label_ids = b_labels.to('cpu').numpy()
            
            # Calculate Validation metrics
            b_accuracy, b_precision, b_recall, b_specificity = b_metrics(logits, label_ids)
            val_accuracy.append(b_accuracy)
            # Update the rest tracking variables only when not zero ('nan')
            if b_precision != 'nan': val_precision.append(b_precision)
            if b_recall != 'nan': val_recall.append(b_recall)
            if b_specificity != 'nan': val_specificity.append(b_specificity)
        print('\n\t - Train loss: {:.4f}'.format(tr_loss / nb_tr_steps))
        print('\t - Validation Accuracy: {:.4f}'.format(sum(val_accuracy)/len(val_accuracy)))
        print('\t - Validation Precision: {:.4f}'.format(sum(val_precision)/len(val_precision)) if len(val_precision)>0 else '\t - Validation Precision: NaN')
        print('\t - Validation Recall: {:.4f}'.format(sum(val_recall)/len(val_recall)) if len(val_recall)>0 else '\t - Validation Recall: NaN')
        print('\t - Validation Specificity: {:.4f}\n'.format(sum(val_specificity)/len(val_specificity)) if len(val_specificity)>0 else '\t - Validation Specificity: NaN')
        

In [None]:
# Execute Training:
epochs = 5
train(train_dl_EI, val_dl_EI, baselineModel_EI, epochs)

In [None]:
# Execute Training:
epochs = 5
train(train_dl_SN, val_dl_SN, baselineModel_SN, epochs)

In [None]:
# Execute Training:
epochs = 5
train(train_dl_TF, val_dl_TF, baselineModel_TF, epochs)

In [None]:
# Execute Training:
epochs = 5
train(train_dl_JP, eval_dl_JP, baselineModel_JP, epochs)