# Setup

In [None]:
# %%capture
!pip install torch
!pip install transformers # for BERT
!pip install pytorch-lightning
!pip install 'lightning[extra]'

Reference

https://www.youtube.com/watch?v=vNKIg8rXK6w&t=3959s

In [None]:
%reload_ext autoreload
%autoreload 2
import sys
sys.path.append('../')
from src.config import *
import pandas as pd

top_25_products =[
    'Petrel',
    'RE (Petrel RE, DELFI RE, ECLIPSE, INTERSECT, ODRS, FluidModeler)',
    'Delfi Portal',
    'Techlog',
    'RTDS',
    'ProdOps, Avocet, PDF',
    'OFM',
    'PIPESIM, IAM',
    'Deployment',
    'OLGA',
    'DrillPlan',
    'Studio',
    'Edge',
    'RigHour',
    'Storage, File Management, Secure Data Exchange',
    'MERAK',
    'Omega, VISTA, OMNI3D',
    'Symmetry',
    'ProSource, InnerLogix',
    'PetroMod',
    'DrillOps',
    'GeoX',
    'InterACT Inside',
    '3rd Party',
    'Data Science'
]

def get_product_labels(product_name):
    labels = [0] * (len(top_25_products)+1)
    if product_name in top_25_products:
        labels[top_25_products.index(product_name)] = 1
    else:
        labels[-1] = 1
    return labels

train_path = f'{DATA_FOLDER_PATH_PROCESSED}/data_train.xlsx'
test_path = f'{DATA_FOLDER_PATH_PROCESSED}/data_test.xlsx'

config = {
    'model_name': 'distilroberta-base',
    # 'model_name': 'roberta-base',
    'top_n_product': 25,
    'text_column': 'Title_Translated',
    'label_column': 'Product Name',
    'min_title_len': 4,
    'max_title_len': 16,
    'max_sample': 6000,
    'batch_size': 64,
    'lr' : 1.5e-6,
    'warmup': 0.1,
    'w_decay': 0.001,
    'train_size': None,
    'n_epochs': 200
}

# Dataset

In [None]:
from torch.utils.data import Dataset
import torch

class SMAX_Dataset(Dataset):
    
    def __init__(self, data_path, text_column, label_column, tokenizer, max_len, min_len=None, max_sample=None):
        self.data_path = data_path
        self.text_column = text_column
        self.label_column = label_column
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.min_len = min_len
        self.max_sample = max_sample
        self.label_cols = [f'label_{i}' for i in range(25+1)]
        self._prepare_data()
    
    def _prepare_data(self):
        df = pd.read_excel(self.data_path)
        
        # Filter out short titles
        if self.min_len:
            df = df[df[self.text_column].apply(lambda x: len(x.split())) >= self.min_len]
        
        # Filter out over-sampled data
        if self.max_sample:
            group_sizes = df[self.label_column].value_counts()
            large_groups = group_sizes[group_sizes > self.max_sample].index
            df = df.groupby(self.label_column).apply(lambda x: x.sample(n=self.max_sample, random_state=42) if x.name in large_groups else x).reset_index(drop=True)
            
        # Get product labels by one-hot encoding
        if self.label_column in df.columns:
            df[self.label_cols] = df[self.label_column].apply(lambda x: pd.Series(get_product_labels(x)))
        
        self.data = df
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        labels = torch.FloatTensor(item.loc[self.label_cols].astype('float32').values)
        
        text = str(item[self.text_column])
        tokens = self.tokenizer.encode_plus(
            text, 
            add_special_tokens=True,
            max_length=self.max_len + 2, # +2 for [CLS] and [SEP]
            padding='max_length', 
            truncation=True, 
            return_tensors='pt', 
            return_attention_mask=True)
        
        return {
            'input_ids': tokens['input_ids'].flatten(),
            'attention_mask': tokens['attention_mask'].flatten(),
            'labels': labels
        }

## TEST

In [None]:
from transformers import AutoTokenizer
model_name = config['model_name']
tokenizer = AutoTokenizer.from_pretrained(model_name)
train_dataset = SMAX_Dataset(train_path, config['text_column'], config['label_column'], tokenizer, config['max_title_len'], config['min_title_len'], config['max_sample'])
test_dataset = SMAX_Dataset(test_path, config['text_column'], config['label_column'], tokenizer, config['max_title_len'], config['min_title_len'], config['max_sample'])

print(len(train_dataset), len(test_dataset))

train_dataset.__getitem__(10)

# Data Module

In [None]:
import pytorch_lightning as pl
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, random_split

class SMAX_DataModule(pl.LightningDataModule):
    
    def __init__(self, train_path, test_path, text_column, label_column, max_len, min_len=None, max_sample=None, model_name='distilroberta-base', batch_size=16):
        super().__init__()
        self.train_path = train_path
        self.test_path = test_path
        self.text_column = text_column
        self.label_column = label_column
        self.max_len = max_len
        self.min_len = min_len
        self.max_sample = max_sample
        self.model_name = model_name
        self.batch_size = batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
    
    def setup(self, stage=None):
        
        if stage == 'fit' or stage is None:
            self.train_ds = SMAX_Dataset(self.train_path, self.text_column, self.label_column, self.tokenizer, self.max_len, self.min_len, self.max_sample)
            # Split train and val
            train_size = int(0.8 * len(self.train_ds))
            val_size = len(self.train_ds) - train_size
            self.train_ds, self.val_ds = torch.utils.data.random_split(self.train_ds, [train_size, val_size], generator=torch.Generator().manual_seed(42))
        if stage == 'predict':
            self.test_ds = SMAX_Dataset(self.test_path, self.text_column, self.label_column, self.tokenizer, self.max_len, min_len=None, max_sample=None)
    
    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True, num_workers=0)
    
    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size, shuffle=False, num_workers=0)
    
    def predict_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size, shuffle=False, num_workers=0)

# Model

In [None]:
import math
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.functional import accuracy
from transformers import AutoModel, AdamW, get_cosine_schedule_with_warmup

class SMAX_Classifier(pl.LightningModule):

    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        self.num_classes = config['top_n_product']+1
        
        self.pretrained_model = AutoModel.from_pretrained(self.config['model_name'], return_dict=True)
        
        self.hidden = nn.Linear(self.pretrained_model.config.hidden_size, self.pretrained_model.config.hidden_size)
        torch.nn.init.xavier_uniform_(self.hidden.weight)

        self.classifier = nn.Linear(self.pretrained_model.config.hidden_size, self.num_classes)
        torch.nn.init.xavier_uniform_(self.classifier.weight)
        self.loss_func = nn.BCEWithLogitsLoss(reduction='mean')
        self.dropout = nn.Dropout()

    
    def forward(self, input_ids, attention_mask, labels=None):
        # pretrained layer
        output = self.pretrained_model(input_ids=input_ids, attention_mask=attention_mask)
        
        pooled_output = torch.mean(output.last_hidden_state, dim=1)
        
        # calssification layer
        pooled_output = self.hidden(pooled_output)
        pooled_output = self.dropout(pooled_output)
        pooled_output = F.relu(pooled_output)
        # pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        # loss
        loss = 0
        if labels is not None:
            loss = self.loss_func(logits.view(-1, self.config['top_n_product']+1), labels.view(-1, self.num_classes))
        
        return loss, logits
    
    def training_step(self, batch, batch_idx):
        loss, logits = self(**batch)
        
        self.log('train_loss', loss, prog_bar=True, logger=True)
        return {'loss': loss, 'predictions': logits, 'labels': batch['labels']}
    
    def validation_step(self, batch, batch_idx):
        loss, logits = self(**batch)
        
        self.log('val_loss', loss, prog_bar=True, logger=True)     
        return {'val_loss': loss, 'predictions': logits, 'labels': batch['labels']}
    
    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        _, logits = self(**batch)
        return logits
    
    # https://huggingface.co/docs/transformers/main_classes/optimizer_schedules
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.config['lr'], weight_decay=self.config['w_decay'])
        total_steps = self.config['train_size'] / self.config['batch_size'] 
        warmup_steps = math.floor(total_steps * self.config['warmup'])
        scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
        return [optimizer], [scheduler]

# Train

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

# Data Module
print('Preparing data module...')
smax_data_module = SMAX_DataModule(train_path, test_path, config['text_column'], config['label_column'], config['max_title_len'], min_len=config['min_title_len'], max_sample=config['max_sample'], model_name=config['model_name'], batch_size=config['batch_size'])
smax_data_module.setup()

train_size = len(smax_data_module.train_ds)
config['train_size'] = train_size

print(f'Train size: {train_size}')

# Model 
smax_model = SMAX_Classifier(config)

# Callbacks

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=MODEL_FOLDER_PATH,
    filename='smax-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min',
)

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=3,
    verbose=False,
    mode='min'
)

# Trainer
trainer = pl.Trainer(
    max_epochs=config['n_epochs'], 
    num_sanity_val_steps=50,
    callbacks=[
        checkpoint_callback, 
        early_stop_callback
        ],
    )
trainer.fit(smax_model, smax_data_module)

# Predict with model

In [None]:
import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt

# method to convert list of comments into predictions for each comment
def classify_raw_comments(model, dm):
    predictions = trainer.predict(model, datamodule=dm)
    flattened_predictions = np.stack([torch.sigmoid(torch.Tensor(p)) for batch in predictions for p in batch])
    return flattened_predictions
predictions = classify_raw_comments(smax_model, smax_data_module)

test_data = pd.read_excel(test_path)
label_column = config['label_column']

# print the confusion matrix
pred_true = test_data[label_column]
pred = [top_25_products[i] if i<25 else 'Others' for i in np.argmax(predictions, axis=1)]

label_cols = top_25_products + ['Others']
test_data[label_cols] = test_data[label_column].apply(lambda x: pd.Series(get_product_labels(x)))
true_labels = np.array(test_data[label_cols])

print(metrics.classification_report(pred_true, pred))


plt.figure(figsize=(32, 16))
for i, attribute in enumerate(label_cols):
    fpr, tpr, _ = metrics.roc_curve(
        true_labels[:,i].astype(int), predictions[:, i])
    auc = metrics.roc_auc_score(
        true_labels[:,i].astype(int), predictions[:, i])
    plt.plot(fpr, tpr, label='%s %g' % (attribute, auc))
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc='lower right')
plt.title(f"{config['model_name']} trained on SMAX Datatset - AUC ROC")
plt.show()



In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/

# Test

In [None]:
# TEST

smax_data_module = SMAX_DataModule(train_path, test_path, config['text_column'], config['label_column'], config['max_title_len'], min_len=config['min_title_len'], max_sample=config['max_sample'], model_name=config['model_name'], batch_size=config['batch_size'])
smax_data_module.setup()
smax_train_ds = smax_data_module.train_ds
smax_test_ds = smax_data_module.test_ds


print(smax_train_ds.__len__())
print(smax_test_ds.__len__())

idx=100
input_ids = smax_train_ds.__getitem__(idx)['input_ids']
attention_mask = smax_train_ds.__getitem__(idx)['attention_mask']
labels = smax_train_ds.__getitem__(idx)['labels']
print(input_ids)
print(attention_mask)
print(labels)

loss, output = smax_model(input_ids.unsqueeze(0), attention_mask.unsqueeze(0), labels.unsqueeze(0))