In [16]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import pandas as pd

from transformers import BertTokenizer, BertForSequenceClassification

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

import wandb

# BERT Classifier

Our BERT classifier for the Data Chatbot.

**Relevant Resources**

- https://docs.wandb.ai/guides/integrations/lightning#logger-arguments
- https://pytorch-lightning.readthedocs.io/en/0.9.0/hyperparameters.html

## Define Dataset & DataLoader

In [17]:
class ClassifierDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.text = dataframe.text
        self.labels = dataframe.label
        self.max_len = max_len

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        text = str(self.text[index])
        text = " ".join(text.split()) # Removes any extra whitespace

        # https://huggingface.co/docs/transformers/v4.34.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.__call__
        inputs = self.tokenizer(
            text,
            None,
            add_special_tokens=True, # Add '[CLS]' and '[SEP]', default True
            max_length=self.max_len, # Maximum length to use by one of the truncation/padding parameters
            padding='max_length', # Pad to a maximum length specified with the argument max_length
            truncation=True, # Truncate to a maximum length specified with the argument max_length
        )
        ids = inputs['input_ids'] # Indices of input sequence tokens in the vocabulary
        mask = inputs['attention_mask'] # Mask to avoid performing attention on padding token indices

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'labels': torch.tensor(self.labels[index], dtype=torch.long)
        }
    

# Load data and return DataLoader
def get_dataloader(df, tokenizer, max_len=None, batch_size=32, shuffle=True, nobatch=False):
    """
    Loads data into a PyTorch DataLoader object.

    Parameters:
    - df (pd.DataFrame): The data frame containing the text and labels.
    - tokenizer (Tokenizer): The tokenizer to be used.
    - max_len (int, optional): The maximum length for the tokenized sequences. Defaults to None (model's limitation).
    - batch_size (int, optional): The size of each batch. Defaults to 32.
    - shuffle (bool, optional): Whether to shuffle the data. Defaults to True.
    - nobatch (bool, optional): Whether to disable batching. If True, batch_size will be set to the length of df. Defaults to False.

    Returns:
    - DataLoader: A PyTorch DataLoader object containing the tokenized data.

    Notes:
    - The label mapping {'other': 0, 'question': 1, 'concern': 2} is applied to the labels in df.
    """
    label_mapping = {'other': 0, 'question': 1, 'concern': 2}
    df['label'] = df['label'].map(label_mapping)
    dataset = ClassifierDataset(df, tokenizer, max_len)

    # Handle nobatch
    batch_size = batch_size if not nobatch else df.__len__()
    print(f"DataLoader | No Batch: {nobatch}; Batch Size: {batch_size}")

    # Create DataLoader
    params = {'batch_size': batch_size, 'shuffle': shuffle, 'num_workers': 0}
    data_loader = DataLoader(dataset, **params)
    return data_loader

## Define Model

In [18]:
class BERTClassifier(pl.LightningModule):
    def __init__(self, hparams):
        super(BERTClassifier, self).__init__()

        # Save hyperparameters
        self.hparams.update(hparams)
        self.__configure_from_hyperparams()

        self.model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=3)
    
    def forward(self, ids, mask):
        output = self.model(ids, attention_mask=mask)
        return output.logits

    def training_step(self, batch, batch_nb):
        return self.__step(batch, batch_nb, 'train')

    def validation_step(self, batch, batch_nb):
        return self.__step(batch, batch_nb, 'val')

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-5)

    def __configure_from_hyperparams(self):
        # Set N/A hyperparameters to default values
        self.max_len = self.hparams.get("max_len", 100)
        self.batch_size =  self.hparams.get("batch_size", 32)

    def __step(self, batch, batch_idx, stage):
        preds, loss, accuracy = self.__get_preds_loss_accuracy(batch)
        
        self.log(
            f'{stage}/accuracy',
            accuracy,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.log(f'{stage}/loss', loss)

        return loss

    def __get_preds_loss_accuracy(self, batch):
        # Helper function to get predictions and loss
        ids = batch['ids']
        mask = batch['mask']
        labels = batch['labels']
        
        preds = self(ids, mask)
        loss = torch.nn.CrossEntropyLoss()(preds, labels)

        # Calculate accuracy
        _, predicted = torch.max(preds, 1)
        correct = (predicted == labels).sum().item()
        total = labels.size(0)
        accuracy = correct / total

        return preds, loss, accuracy

## Fine-tune BERT Model

In [19]:
# WandB initialization 
wandb.login()

# Config
run_config = {
    'epochs': 5,
    'max_len': 100,
    'batch_size': 32
}

# Initialize BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

# Load training and validation data
train_df = pd.read_csv('data/train.csv')
train_dataloader = get_dataloader(train_df, tokenizer, max_len=run_config.get('max_len'), batch_size=run_config.get('batch_size'))

val_df = pd.read_csv('data/val.csv')
val_dataloader = get_dataloader(val_df, tokenizer, run_config.get('max_len'), batch_size=run_config.get('batch_size'), shuffle=False, nobatch=True)

# Initialize model
model = BERTClassifier(hparams=run_config)

# Initialize WandbLogger
wandb_logger = WandbLogger(entity='yvokeller', project='data-chatbot') # log_model='all'
wandb_logger.experiment.config.update(run_config)

# Initialize Trainer
trainer = pl.Trainer(
    max_epochs=run_config.get('epochs'), 
    logger=wandb_logger, 
    log_every_n_steps=1, 
    enable_progress_bar=True
)

# Train the model
trainer.fit(model, train_dataloader, val_dataloader)

# Close WandB logger
wandb.finish()

DataLoader | No Batch: False; Batch Size: 32
DataLoader | No Batch: True; Batch Size: 3


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                          | Params
--------------------------------------------------------
0 | model | BertForSequenceClassification | 177 M 
--------------------------------------------------------
177 M     Trainable params
0         Non-trainable params
177 M     Total params
711.423   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.




0,1
epoch,▁▁▁▃▃▃▅▅▅▆▆▆███
train/accuracy,▁▆▆██
train/loss,█▇▇▂▁
trainer/global_step,▁▁▁▃▃▃▅▅▅▆▆▆███
val/accuracy,▁▁███
val/loss,█▆▄▃▁

0,1
epoch,4.0
train/accuracy,1.0
train/loss,0.84402
trainer/global_step,4.0
val/accuracy,0.66667
val/loss,0.99867
