In [None]:
!pip install iterative_train_test_split tqdm mlflow

In [None]:
import pandas as pd
import numpy as np
import sys, os
import yaml, logging
import mlflow
from mlflow.tracking import MlflowClient
import mlflow.pytorch
from skmultilearn.model_selection import iterative_train_test_split
import transformers
import torch
import pytorch_lightning as pl  
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.metrics import classification_report, f1_score
from tqdm import tqdm
from typing import Optional, Union

# 1) Data Processing

In [None]:
def split_data(
    df:pd.DataFrame, 
    aspect_classes:list, 
    x_col:list=['text', 'label'], 
    test_size:float=0.2, seed:int=0
    ):
    '''
    Split data into test train set
    '''
    np.random.seed(seed)
    x = df[x_col].values
    y = df[aspect_classes].values
    X_train, y_train, X_test, y_test = iterative_train_test_split(x, y, test_size=test_size)
    return pd.DataFrame(X_train, columns=x_col), pd.DataFrame(X_test, columns=x_col)    


class DataModule(pl.LightningDataModule):
    def __init__(
        self, df_train:pd.DataFrame, df_test:pd.DataFrame, max_len:int, batch_size:int, 
        tokenizer:str="distilbert-base-uncased", text_col:str='text'
        ):
        '''Picking Up Raw Data and Processing'''
        
        super().__init__()
        self.train_df = df_train
        self.test_df = df_test
        self.max_len = max_len
        self.batch_size = batch_size
        self.text_col = text_col
        
        if tokenizer == "distilbert-base-uncased":
            logger.info("Applying Distillbert Tokenizer")
            self.tokenizer = transformers.DistilBertTokenizer.from_pretrained(
                "distilbert-base-uncased")
        else:
            logger.info("Applying Bertweet Tokenizer")
            self.tokenizer = transformers.BertweetTokenizer.from_pretrained(
                "vinai/bertweet-base", normalization=True)
        
    class Dataset(Dataset):
        def __init__(self, encodings, labels):
            self.encodings = encodings
            self.labels = labels

        def __getitem__(self, idx):
            item = {
                key: torch.tensor(val[idx]).clone().detach() 
                for key, val in self.encodings.items()
                }
            item['labels'] = torch.tensor(self.labels[idx])
            return item

        def __len__(self):
            return len(self.labels)
    
    def train_dataloader(self):
        '''Return DataLoader for train tokens and labels'''
        
        features = self.tokenizer(
            self.train_df[self.text_col].tolist(), 
            max_length=self.max_len, 
            truncation=True, 
            padding='max_length',  
            return_tensors='pt')
        
        labels = self.train_df['label'].tolist()
        dataset = self.Dataset(features, labels)
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=10)
    
    def val_dataloader(self):
        '''Return DataLoader for test tokens and labels'''
        
        # only pad to longest length of the current batch
        features = self.tokenizer(
            self.test_df[self.text_col].tolist(), 
            max_length=self.max_len, 
            truncation=True, 
            padding='longest',  
            return_tensors='pt')
        
        labels = self.test_df['label'].tolist()
        dataset = self.Dataset(features, labels)
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=10)
    
    def calculate_pos_weights(self, class_counts, len_data):
        pos_weights = np.ones_like(class_counts)
        neg_counts = [len_data - pos_count for pos_count in class_counts]
        for cdx, (pos_count, neg_count) in enumerate(zip(class_counts,  neg_counts)):
            pos_weights[cdx] = neg_count / (pos_count + 1e-5)
            # pos_weights[cdx] = 1. if pos_weights[cdx] == 0 else pos_weights[cdx]
        return torch.as_tensor(pos_weights, dtype=torch.float)
    
    def get_weight(self, df, aspect_classes):
        return self.calculate_pos_weights(
             df[aspect_classes].sum().values, 
             len(df)
         )

In [None]:
INPUT = 'sample.csv'
X_COL = ['text']
Y_COL = ['ability', 'dependability', 'purpose', 'integrity']

BATCH_SIZE = 32
INPUT_MAX_LEN = 128
STANDARD_LR = 5e-5
FINE_LR = 5e-7
EPOCHS = 20
LIMIT_STEP = 500
MODEL = "vinai/bertweet-base"
TEXT_COL = 'text'

df = pd.read_parquet(INPUT)
df_train, df_val = split_data(df, Y_COL, X_COL, test_size=0.2, seed=0)

data_module = DataModule(
    df_train,
    df_val,
    max_len=INPUT_MAX_LEN, 
    batch_size=BATCH_SIZE,
    tokenizer=MODEL,
    text_col=TEXT_COL
)

# 2) Model Training

In [None]:
class LightningArticleClassifier(pl.LightningModule):
    def __init__(
        self, output_class_len, learning_rate, 
        max_len=64, hidden_dim=64, pos_weight=None, bert_model="distilbert-base-uncased"
        ):

        super().__init__()
        self.max_len = max_len
        self.lr = learning_rate
        self.emb_dim = 768
        self.hidden_dim = hidden_dim
        self.drop_out = torch.nn.Dropout(0.1)
        self.fc1 = torch.nn.Linear(self.emb_dim, self.emb_dim // 2)
        self.fc2 = torch.nn.Linear(self.emb_dim // 2, self.hidden_dim * 4)
        self.fc3 = torch.nn.Linear(self.hidden_dim * 4, self.hidden_dim)
        self.fc4 = torch.nn.Linear(self.hidden_dim, output_class_len)
        self.tanh = torch.nn.Tanh()
        self.gelu = torch.nn.GELU()
        self.softmax = torch.nn.LogSoftmax(dim=1)
        
        if pos_weight is not None:
            self.pos_weight = torch.tensor(pos_weight, dtype=torch.float)
        else:
            self.pos_weight = None
        
        if bert_model == "distilbert-base-uncased":
            logger.info("Importing Distillbert Model")
            self.bert_model = transformers.DistilBertModel.from_pretrained("distilbert-base-uncased")
        else:
            logger.info("Importing Bertweet Model")
            self.bert_model = transformers.AutoModel.from_pretrained("vinai/bertweet-base")
        
        # metrics
        self.val_loss, self.val_corrects, self.val_len = 0., 0., 0.
        self.train_loss, self.train_corrects, self.train_len = 0., 0., 0.
        # self.train_f1, self.val_f1 = 0., 0.
        self.train_step, self.val_step = 0, 0
        self.epoch_loss_train, self.epoch_acc_train,  self.epoch_f1_train = [], [], []
        self.epoch_loss_val, self.epoch_acc_val, self.epoch_f1_val = [], [], []
            
    def forward(self, input_ids, attention_mask):
        bert_output = self.bert_model(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            output_attentions=False, 
            output_hidden_states=False)
        
        # feed forward layer
        output = bert_output['last_hidden_state'][:, 0, :]
        output = self.fc1(output)
        output = self.tanh(output)
        output = self.drop_out(output)
        output = self.fc2(output)
        output = self.gelu(output)
        output = self.drop_out(output)
        output = self.fc3(output)
        output = self.gelu(output)
        output = self.drop_out(output)
        output = self.fc4(output)
        # output = torch.sigmoid(output)
        # output = self.softmax(output)
        return output
    
    def criterion(self, y_pred, y_true):
        
        if self.pos_weight != None:
            # criterion = torch.nn.CrossEntropyLoss(weight=self.pos_weight.cuda())
            criterion = torch.nn.BCEWithLogitsLoss(pos_weight=self.pos_weight.cuda())
        else:
            # criterion = torch.nn.CrossEntropyLoss()
            criterion = torch.nn.BCEWithLogitsLoss()
        return criterion(y_pred, y_true.float())
    
    def training_step(self, train_batch, batch_idx):
        input_ids = train_batch['input_ids']
        attention_mask = train_batch['attention_mask']
        labels = train_batch['labels']
        y_pred = self.forward(input_ids, attention_mask)
        loss = self.criterion(y_pred, labels)
        # _, preds = torch.max(y_pred, 1)
        preds = (y_pred >= 0.5).int()
                
        self.train_loss += loss
        # self.train_corrects += torch.sum(torch.sum(preds == labels.data))
        self.train_corrects += torch.sum(torch.all(torch.eq(preds, labels), dim=1).int())
        # self.train_f1 += f1_score(preds.cpu(), labels.data.cpu(), average='macro')

        self.train_len += len(labels)
        self.train_step += 1
        self.log('train_loss', loss)
        
        return loss
    
    def training_epoch_end(self, out):
        self.epoch_acc_train.append(self.train_corrects / (self.train_len + 1))
        self.epoch_loss_train.append(self.train_loss / (self.train_len + 1))
        # self.epoch_f1_train.append(self.train_f1 / self.train_step)
        
        if self.current_epoch % 2 == 0:
            logger.info(f'\nEpoch: {self.current_epoch}')
            logger.info(f'Training: loss: {self.epoch_loss_train[-1]}')
            logger.info(f'Training: Accuracy: {self.epoch_acc_train[-1]}')
            # print(f'Training: Macro F1: {self.epoch_f1_train[-1]}')
            
        self.train_loss, self.train_corrects = 0., 0.
        self.train_step, self.train_len = 0., 0.
            
    def validation_step(self, val_batch, batch_idx):
        input_ids = val_batch['input_ids']
        attention_mask = val_batch['attention_mask']
        labels = val_batch['labels']
        y_pred = self.forward(input_ids, attention_mask)
        loss = self.criterion(y_pred, labels)
        preds = (y_pred >= 0.5).int()
        # _, preds = torch.max(y_pred, 1)
        
        self.val_loss += loss
        # self.val_corrects += torch.sum(torch.sum(preds == labels.data))   
        self.val_corrects += torch.sum(torch.all(torch.eq(preds, labels), dim=1).int())
        
        # self.val_f1 += f1_score(preds.cpu(), labels.data.cpu(), average='macro')
        
        self.val_len += len(labels)
        self.val_step += 1
        self.log('val_loss', loss)
        return loss
    
    def validation_epoch_end(self, out):
        self.epoch_acc_val.append(self.val_corrects / (self.val_len + 1))
        self.epoch_loss_val.append(self.val_loss / (self.val_len + 1))
        # self.epoch_f1_val.append(self.val_f1 / self.val_step)
        
        if self.current_epoch % 2 == 0:
            logger.info(f'Validation: loss: {self.epoch_loss_val[-1]}')
            logger.info(f'Validation: Accuracy: {self.epoch_acc_val[-1]}')
            # print(f'Validation: Macro F1: {self.epoch_f1_val[-1]}')
            
        self.val_loss, self.val_corrects = 0., 0.
        self.val_step, self.val_len = 0., 0.
            
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        '''
        lr_scheduler = {'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min'),
                        "monitor": "train_loss",
                       }
        
        return [optimizer], [lr_scheduler]
        '''
        return optimizer

In [None]:
model = LightningArticleClassifier(
    output_class_len=len(aspect_classes),
    learning_rate=STANDARD_LR,
    max_len=INPUT_MAX_LEN,
    # pos_weight=[1, 5, 1, 1, 1]
    bert_model=MODEL
)

trainer = pl.Trainer(
    max_epochs=EPOCHS,
    limit_train_batches=LIMIT_STEP,
    accelerator="gpu",
    strategy="dp",
    gpus=-1,
    # accelerator='ddp',
    # default_root_dir='/dbfs/FileStore/temp/kean_temp/logs'
)

for param in model.bert_model.parameters():
    param.requires_grad = False
trainer.fit(model, data_module)

### fine tuning
if FINE_LR is not None:
    trainer = pl.Trainer(
        max_epochs=5, 
        limit_train_batches=LIMIT_STEP,
        accelerator="gpu", 
        strategy="dp",
        gpus=-1,
        # accelerator='ddp',
        # default_root_dir='/dbfs/FileStore/temp/kean_temp/logs'
    )

    model.lr = FINE_LR
    for param in model.bert_model.parameters():
        param.requires_grad = True
    trainer.fit(model, data_module)

# 3) Model Logging

In [None]:
def log_model(model_name, model, params, metrics, artifacts=None, experiment_uri=''):
    '''Log Model in MLFlow'''
    
    ### define experiment uri for mlflow
    if experiment_uri != '':
        if not mlflow.get_experiment_by_name(experiment_uri):
            mlflow.create_experiment(experiment_uri)
        mlflow.set_experiment(experiment_uri)
    
    with mlflow.start_run(run_name=model_name) as run:
        experimentID = run.info.experiment_id
        print("Experiment ID", experimentID)
        mlflow.pytorch.log_model(model, model_name)
        mlflow.pytorch.log_state_dict(model.state_dict(), model_name)
        
        for k,v in params.items():
            mlflow.log_param(k, v)
        for k,v in metrics.items():
            mlflow.log_metric(k, v)
        if artifacts is not None:
            for artifact in artifacts:
                mlflow.log_artifact(artifact)
        mlflow.end_run()

In [None]:
experiment_uri = f"/Users/{cfg['mlflow_email']}/{cfg['model_name']}"
logger.info(f"Logging to MlFlow in : {experiment_uri}")
model_name = cfg['model_name']

params = {
    'description': cfg['model_desc'],
    'epochs': EPOCHS,
    'max_sequence_length': INPUT_MAX_LEN,
    'batch_size': BATCH_SIZE,
    'max_step_per_epoch': LIMIT_STEP,
    'lr': STANDARD_LR,
}

model_metrics = {
  "train_loss" : round(model.epoch_loss_train[-1].item(), 3),
  "train_acc" : round(model.epoch_acc_train[-1].item(), 3),
  "eval_loss" : round(model.epoch_loss_val[-1].item(), 3),
  "eval_acc" : round(model.epoch_acc_val[-1].item(), 3),
}

log_model(model_name, model, params, model_metrics, experiment_uri=experiment_uri)

# 4) Validation

In [None]:
### model validation
data_module = DataModule(
    df_train,
    df_val,
    max_len=INPUT_MAX_LEN, 
    batch_size=BATCH_SIZE,
    tokenizer=MODEL,
    text_col=text_col
)

df_test = data_module.test_df
test_set = data_module.val_dataloader()
batch = test_set.dataset.encodings
total_len = len(batch['input_ids'])
step = 512

probas = []
model.eval()
with torch.no_grad():
    for i in tqdm(range(0, total_len, step)):
        outputs = model.forward(
            input_ids=batch['input_ids'][i : i + step], 
            attention_mask=batch['attention_mask'][i : i + step]
        )
        probas.append(torch.sigmoid(outputs).cpu().detach().numpy())

thres = 0.5
result = np.vstack(probas)
result = (result >= thres).astype(int)
print(
    classification_report(np.array(test_set.dataset.labels), result, target_names=aspect_classes)
)    