In [None]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.models.auto.configuration_auto import AutoConfig
from typing import Optional

import torch.utils.data as data
import torch.nn.functional as F
import pytorch_lightning as pl
import torch.optim as optim
import torch.nn as nn
import torch

import pandas as pd
import os

In [None]:
model_name = "bert-base-uncased"

In [None]:
class Dataset(data.Dataset):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    def __init__(self, split, data_dir):
        self.df = pd.read_csv(os.path.join(data_dir, f'{split}_data.csv'))
        self.df = self.df.sample(frac=1).reset_index(drop=True)
        
    def __getitem__(self, idx: int):
        return tuple(self.df.iloc[idx])
    
    def __len__(self):
        return len(self.df)
    
    @classmethod
    def collate_fn(cls, data):
        sentences, sentiments = zip(*data)
        sentences = cls.tokenizer(list(sentences), padding=True, return_tensors="pt")
        sentiments = torch.tensor(list(sentiments), dtype=torch.float32)
        return sentences, sentiments

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
  
    def setup(self, stage: Optional[str] = None):
        self.train_dataset = Dataset("train", self.data_dir)
        self.val_dataset  = Dataset("test", self.data_dir)

    def train_dataloader(self):
        return data.DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=Dataset.collate_fn)

    def val_dataloader(self):
        return data.DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=Dataset.collate_fn)

In [None]:
class Model(pl.LightningModule):
    def __init__(self, model_name):
        super().__init__()

        self.transformer = AutoModel.from_pretrained(model_name, add_pooling_layer=False)
        self.config = AutoConfig.from_pretrained(model_name)
        
        self.pooler = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(768, 384),
            nn.LeakyReLU(),
            nn.Dropout(0.5),
            nn.Linear(384, 1),
        )
        
    def forward(self, sentences: dict[str, torch.Tensor]):
        cls_token = self.transformer(**sentences).last_hidden_state[:, 0, :]
        sentiment = self.pooler(cls_token)
        return sentiment.squeeze(1)

In [None]:
class Train:
    """
    This class contains all the functions required for training...
    """
    
    criterion = nn.BCEWithLogitsLoss()
    transformer_lr = 1e-5
    pooler_lr = 3e-4
    
    weight_decay = 0.02
    default_lr = 3e-4
    
    
    def training_step(self, batch, batch_idx):
        sentences, true_sentiments = batch
        pred_sentiments = self(sentences)
        loss = Train.criterion(pred_sentiments, true_sentiments)
        return loss
        
    def validation_step(self, batch, batch_idx):
        sentences, true_sentiments = batch
        pred_sentiments = self(sentences)
        loss = Train.criterion(pred_sentiments, true_sentiments)
        return loss
    
    def configure_optimizers(self):
        transformer = {
            'decay': list(map(lambda s: s[1], filter(lambda s: s[0].endswith('weight'), self.transformer.named_parameters()))),
            'no_decay': list(map(lambda s: s[1], filter(lambda s: s[0].endswith('bias'), self.transformer.named_parameters()))),
            'weight_decay': Train.weight_decay,
            'learning_rate': Train.transformer_lr
        }
        
        pooler = {
            'decay': list(map(lambda s: s[1], filter(lambda s: s[0].endswith('weight'), self.pooler.named_parameters()))),
            'no_decay': list(map(lambda s: s[1], filter(lambda s: s[0].endswith('bias'), self.pooler.named_parameters()))),
            'weight_decay': Train.weight_decay,
            'learning_rate': Train.pooler_lr
        }
        
        params = [
            {'params': transformer['decay'], 'lr': transformer['learning_rate'], 'weight_decay': transformer['weight_decay']},
            {'params': transformer['no_decay'], 'lr': transformer['learning_rate'], 'weight_decay': 0.0},
            
            {'params': pooler['decay'], 'lr': pooler['learning_rate'], 'weight_decay': pooler['weight_decay']},
            {'params': pooler['no_decay'], 'lr': pooler['learning_rate'], 'weight_decay': 0.0}
        ]
        
        return optim.Adam(params, lr=Train.default_lr)
    
    
Model.training_step = Train.training_step
Model.validation_step = Train.validation_step
Model.configure_optimizers = Train.configure_optimizers

In [None]:
datamodule = DataModule("Data/", batch_size=8)
model = Model(model_name)
trainer = pl.Trainer(gpus=1, accumulate_grad_batches=8, max_epochs=50, callbacks=[EarlyStopping(monitor="val_loss", mode="min")])

In [None]:
trainer.fit(model, datamodule)