In [None]:
import torch
from tqdm.notebook import trange, tqdm
from transformers import *

### Change device for GPU if available

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
for index in range(n_gpu):
    print(torch.cuda.get_device_name(index))

## Load Data (Dataframes / Dataloaders)

In [None]:
import pandas as pd

In [None]:
df = pd.read_csv('../train.csv') 
df = df[df.columns[3:]]
print(df.shape)
df.head()

In [None]:
df_dev = pd.read_csv('../dev.csv')
df_dev = df_dev[df_dev.columns[3:]]
print(df_dev.shape)
df_dev.head()

In [None]:
df_dev = pd.read_csv('../test.csv')
df_dev = df_dev[df_dev.columns[3:]]
print(df_dev.shape)
df_dev.head()

In [None]:
label_cols = df.columns.to_list()
num_labels = len(label_cols)
bs = 8
max_length = 512
num_labels

In [None]:
small_prefix = ""
train_dataloader = torch.load(f'dataloaders/{small_prefix}train_data_loader-{bs}-{max_length}')
validation_dataloader = torch.load(f'dataloaders/validation_data_loader-{bs}-{max_length}')
test_dataloader = torch.load(f'dataloaders/test_data_loader-{bs}-{max_length}')

## Target Probabilities Tensor Creation

In [None]:
counts = df.astype(bool).sum(axis=0).to_dict()
print(counts)

In [None]:
counts_dev = df_dev.astype(bool).sum(axis=0).to_dict()
print(counts_dev)

In [None]:
def make_target_prob_tensor(counts: dict, dataframe):
    columns = list(counts.keys())
    target_prob = []
    for column_1 in tqdm(columns, desc="Column-1", leave=True):
        temp_list = []
        for column_2 in tqdm(columns, desc="Column-2", leave=False):

            count = len(dataframe[(dataframe[column_1] == 1) & (dataframe[column_2] == 1)])
            freq = count / counts[column_1] if counts[column_1] else 0
            temp_list.append(freq)
            
        target_prob.append(temp_list)
        
    target_prob = torch.tensor(target_prob, dtype=torch.float32)
    target_prob = target_prob # - 0.5
    return target_prob

In [None]:
target_probs = make_target_prob_tensor(counts=counts, dataframe=df)
target_probs = target_probs.to(device)
print(target_probs.shape)

In [None]:
target_probs_dev = make_target_prob_tensor(counts=counts_dev, dataframe=df_dev)
target_probs_dev = target_probs_dev.to(device)

In [None]:
dataloaders = {
    'train': train_dataloader,
    'dev': validation_dataloader,
    'test': test_dataloader
}

In [None]:
target_probabs = {
    'train': target_probs,
    'dev': target_probs_dev
}

## Training the model

### Metrics

In [None]:
from sklearn.metrics import classification_report, confusion_matrix, multilabel_confusion_matrix, f1_score, accuracy_score, precision_score, recall_score

In [None]:
def print_results(method, f1, acc, precision, recall):
    print('\n'+method+' :')
    print('Micro F1-Score =', f1)
    print('Accuracy =', acc)
    print('Micro Avg : precision =', precision, 'recall =', recall)

In [None]:
def print_results(method, f1, acc, precision, recall):
    print('\n'+method+' :')
    print('Micro F1-Score =', f1)
    print('Accuracy =', acc)
    print('Micro Avg : precision =', precision, 'recall =', recall)

In [None]:
def get_metrics(true_bools, pred_bools):
    clf_report_optimized = classification_report(true_bools, pred_bools, target_names=label_cols, digits=5, zero_division=0, output_dict=True)
    micro_avg = clf_report_optimized['micro avg']
    f1 = f1_score(true_bools, pred_bools,average='micro')*100
    acc = accuracy_score(true_bools, pred_bools)*100
    precision = micro_avg['precision']*100
    recall = micro_avg['recall']*100
    
    return f1, acc, precision, recall

### Preparing the model

In [None]:
model = BertForSequenceClassification.from_pretrained("bert-base-cased", num_labels=num_labels)
model.cuda()

In [None]:
from collections import OrderedDict
class DepClassifier(torch.nn.Module):

    def __init__(self, input_size, out_features, weight_tensor):
        super(DepClassifier, self).__init__()
        self.weight_tensor = weight_tensor
        self.dense = torch.nn.Linear(in_features=input_size, out_features=out_features, bias=True)
        
    def forward(self, x):
        out = self.dense(x)
        
        #  activation with label frequencies
        bs = out.shape[0]
        d_labels = out.shape[1] 
        out = torch.reshape(out, (bs, 1, d_labels))
        out = torch.bmm(out, torch.broadcast_to(self.weight_tensor, (bs, d_labels, d_labels)))
        out = torch.squeeze(out, dim=1)
        out = torch.sigmoid(out)
        return out

In [None]:
model.classifier = DepClassifier(input_size=768, out_features=num_labels, weight_tensor=target_probabs['train'])
model.cuda()

### Loss function and Optimizers

In [None]:
# setting custom optimization parameters. You may implement a scheduler here as well.
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
# single = ['dense_2', 'pooler']


#exclude the last layer parameter from optimizer
optimizer_grouped_parameters_classification = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)] ,
     'weight_decay_rate': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.0}
]

In [None]:
optimizer_classification = torch.optim.AdamW(optimizer_grouped_parameters_classification, lr=2e-5)

classification_criterion = torch.nn.BCELoss()

### Heatmap visualization function

In [None]:
import math
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import io

def get_heatmap(input_tensor):
    input_tensor = input_tensor.cpu().detach()
    input_tensor = input_tensor.tolist()
    input_tensor_dict = dict()
    
    for index_1, column_1 in enumerate(tqdm(label_cols, desc="Labels", leave=False)):
        input_tensor_dict[column_1] = dict()
        for index_2, column_2 in enumerate(tqdm(label_cols, desc="Column-2", leave=False)):       
            input_tensor_dict[column_1][column_2] = input_tensor[index_1][index_2]
            
    input_tensor_df = pd.DataFrame.from_dict(input_tensor_dict)
    f, ax = plt.subplots(figsize=(40, 30))
    heatmap = sns.heatmap(input_tensor_df, annot=True, fmt=".2f", linewidths=2, ax=ax)

    img_buf = io.BytesIO()
    heatmap.get_figure().savefig(img_buf, format='png')
    plt.close()
    heatmap_image = Image.open(img_buf)

    return heatmap_image

### Logging and Saving

In [None]:
model_name = "your_model_name"
dataset_name = "dataset_name"
epochs = 30 # Number of training epochs

In [None]:
import numpy as np
import copy
import wandb

config = {"epochs": epochs, "batch_size": bs, "seq_max_length": max_length,
          "lr_cls": 2e-5,
         "optimizer": "AdamW", "wd": 0.01}
config.update({"dataset": dataset_name})

# mode = "disabled"
wandb.init(project="project_name", entity="your_entity", name="run_name", config=config)

In [None]:
train_loss_set = np.array([])
train_classif_loss = np.array([])
train_dependency_loss = np.array([])
all_val_f1s = np.array([])
all_val_accs = np.array([])
all_val_precisions = np.array([])
all_val_recalls = np.array([])


best_model_wts = copy.deepcopy(model.state_dict())
best_val_f1 = -1.0

### Train !

In [None]:
threshold = 0.5
model.eval()
# trange is a tqdm wrapper around the normal python range
for epoch_num in trange(epochs, desc="Epoch", position=0):
    
    for phase in tqdm(['train', 'dev', 'test'], leave=False, desc='Phases', position=1):

        # Tracking variables
        true_labels,pred_labels = [], [] # for metrics
        epoch_loss, cls_loss = 0, 0 #running losses
        epoch_steps = 0
        
        if phase == 'train': 
            model.train()
            
        if phase == 'dev':
            model.eval()
            
        for step, batch in enumerate(tqdm(dataloaders[phase], leave=False, desc=f"{phase.capitalize()} Dataloader", position=2)):

            # Add batch to GPU
            batch = tuple(t.to(device) for t in batch)

            # Unpack the inputs from our dataloader
            b_input_ids, b_input_mask, b_labels = batch

            # Forward pass for multilabel classification
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(b_input_ids, attention_mask=b_input_mask)[0]
                classification_logits = outputs
                
            del b_input_ids, b_input_mask, outputs
            torch.cuda.empty_cache()

            #loss calculation
            loss = classification_criterion(classification_logits, b_labels.type_as(classification_logits))
            
            if phase == 'train': 

                # Clear out the gradients 
                optimizer_classification.zero_grad()
                
                # Backward pass
                loss.backward()
                    
                # Update parameters and take a step using the computed gradient
                optimizer_classification.step()

            # Update tracking variables
            cls_loss += loss.item()
            epoch_steps += 1
            
            # Update Epoch Metrics
            pred_label = classification_logits.detach().to('cpu').numpy()
            b_labels = b_labels.to('cpu').numpy()

            true_labels.append(b_labels)
            pred_labels.append(pred_label)
            


        # Get Epoch Metrics
        # Flatten outputs
        pred_labels = [item for sublist in pred_labels for item in sublist]
        true_labels = [item for sublist in true_labels for item in sublist]
        
        true_bools = true_labels 
        pred_bools = [pl>threshold for pl in pred_labels] 
        f1_accuracy, flat_accuracy, precision, recall = get_metrics(true_bools, pred_bools)
        
        # Get Epoch Losses
        cls_loss = cls_loss/epoch_steps

        # Log Epoch Metrics
        metrics = {
            'F1_score': f1_accuracy,
            'Accuracy': flat_accuracy,
            'Precision': precision,
            'Recall': recall,
            'Cls_loss': cls_loss,
        }
        wandb.log({f'{phase.capitalize()}': metrics}, commit=False)
        
        # Save model if valid performances are better
        if phase == 'dev':
            if  f1_accuracy > best_val_f1:
                best_val_f1 = f1_accuracy
                torch.save(model.state_dict(), 'state_dicts/best_'+ model_name +'.pt')


    wandb.log(data={}, commit=True)
        
    
# save last model
torch.save(model.state_dict(), 'state_dicts/last_'+ model_name +'.pt')