In [6]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import pandas as pd

from transformers import BertTokenizer, BertForSequenceClassification

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

import wandb

## Data Loader & Tokenizer

In [9]:
class ClassifierDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.text = dataframe.text
        self.labels = dataframe.label
        self.max_len = max_len

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        text = str(self.text[index])
        text = " ".join(text.split()) # Removes any extra whitespace

        # https://huggingface.co/docs/transformers/v4.34.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.__call__
        inputs = self.tokenizer(
            text,
            None,
            add_special_tokens=True, # Add '[CLS]' and '[SEP]', default True
            max_length=self.max_len, # Maximum length to use by one of the truncation/padding parameters
            padding='max_length', # Pad to a maximum length specified with the argument max_length
            truncation=True, # Truncate to a maximum length specified with the argument max_length
        )
        ids = inputs['input_ids'] # Indices of input sequence tokens in the vocabulary
        mask = inputs['attention_mask'] # Mask to avoid performing attention on padding token indices

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'labels': torch.tensor(self.labels[index], dtype=torch.long)
        }

class BERTClassifier(pl.LightningModule):
    def __init__(self, train_df, val_df, tokenizer, hparams):
        super(BERTClassifier, self).__init__()

        # Save hyperparameters
        self.hparams.update(hparams)
        self.__configure_from_hyperparams()

        self.train_df = train_df
        self.val_df = val_df
        self.tokenizer = tokenizer

        self.model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=3)
    
    def forward(self, ids, mask):
        output = self.model(ids, attention_mask=mask)
        return output.logits

    def training_step(self, batch, batch_nb):
        preds, loss, accuracy = self.__get_preds_loss(batch)
    
        self.log(
            'train/accuracy',
            accuracy,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.log('train/loss', loss)

        return loss


    def validation_step(self, batch, batch_nb):
        preds, loss, accuracy = self.__get_preds_loss(batch)
        
        self.log(
            'val/accuracy',
            accuracy,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.log('val/loss', loss)

        return loss

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-5)

    def train_dataloader(self):
        return get_dataloader(self.train_df, self.tokenizer, self.hparams.max_len, batch_size=self.hparams.batch_size)
      
    def val_dataloader(self):
        return get_dataloader(self.val_df, self.tokenizer, self.hparams.max_len, batch_size=self.hparams.batch_size, shuffle=False)

    def __configure_from_hyperparams(self):
        # Set optional hyperparameters to default values
        self.max_len = self.hparams.get("max_len", 100)
        self.batch_size =  self.hparams.get("batch_size", 32)

    def __get_preds_loss(self, batch):
        # Helper function to get predictions and loss
        ids = batch['ids']
        mask = batch['mask']
        labels = batch['labels']
        
        preds = self(ids, mask)
        loss = torch.nn.CrossEntropyLoss()(preds, labels)

        # Calculate accuracy
        _, predicted = torch.max(preds, 1)
        correct = (predicted == labels).sum().item()
        total = labels.size(0)
        accuracy = correct / total

        return preds, loss, accuracy

    
# Load data and return DataLoader
def get_dataloader(df, tokenizer, max_len, batch_size=32, shuffle=True):
    label_mapping = {'other': 0, 'question': 1, 'concern': 2}
    df['label'] = df['label'].map(label_mapping)
    dataset = ClassifierDataset(df, tokenizer, max_len)

    # Create DataLoader
    params = {'batch_size': batch_size, 'shuffle': shuffle, 'num_workers': 0}
    data_loader = DataLoader(dataset, **params)
    return data_loader

# Evaluate model
def evaluate_model(model, data_loader):
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    
    with torch.no_grad():
        for _, data in enumerate(data_loader, 0):
            ids = data['ids']
            mask = data['mask']
            labels = data['labels']
            
            outputs = model(ids, attention_mask=mask)
            _, predicted = torch.max(outputs.logits, 1)
            
            correct_predictions += (predicted == labels).sum().item()
            total_predictions += labels.size(0)
            
    accuracy = correct_predictions / total_predictions
    return accuracy

In [10]:
# WandB initialization 
wandb.login()

# Initialize BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

# Load training and validation data
train_df = pd.read_csv('data/train.csv')
val_df = pd.read_csv('data/val.csv')

# Config
run_config = {
    'epochs': 5,
    'max_len': 100,
    'batch_size': 32
}

# Initialize model
model = BERTClassifier(train_df, val_df, tokenizer, hparams=run_config)

# Initialize WandbLogger
wandb_logger = WandbLogger(entity='yvokeller', project='data-chatbot') # log_model='all'
wandb_logger.experiment.config.update(run_config)

# Initialize Trainer
trainer = pl.Trainer(
    max_epochs=run_config.get('epochs'), 
    logger=wandb_logger, 
    log_every_n_steps=1, 
    enable_progress_bar=True
)

# Train the model
trainer.fit(model)

# Close WandB logger
wandb.finish()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                          | Params
--------------------------------------------------------
0 | model | BertForSequenceClassification | 177 M 
--------------------------------------------------------
177 M     Trainable params
0         Non-trainable params
177 M     Total params
711.423   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.




0,1
epoch,▁▁▁▃▃▃▅▅▅▆▆▆███
train/accuracy,▁▁███
train/loss,█▇▃▃▁
trainer/global_step,▁▁▁▃▃▃▅▅▅▆▆▆███
val/accuracy,▁▁▁▁▁
val/loss,█▇▅▃▁

0,1
epoch,4.0
train/accuracy,0.66667
train/loss,0.90171
trainer/global_step,4.0
val/accuracy,0.33333
val/loss,1.06055


## Fine-Tune BERT Classifier

In [None]:
# BERT model initialization
model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=3)

# Optimizer and Loss function
optimizer = Adam(model.parameters(), lr=1e-5)
loss_function = torch.nn.CrossEntropyLoss()

# Fine-tuning loop
epochs = 3  # Replace with the number of epochs you want

for epoch in range(epochs):
    for _, data in enumerate(data_loader, 0):
        ids = data['ids']
        mask = data['mask']
        token_type_ids = data['token_type_ids']
        labels = data['labels']

        outputs = model(ids, attention_mask=mask, token_type_ids=token_type_ids, labels=labels)
        loss = outputs.loss
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()

print("Fine-tuning completed!")

## Evaluating the BERT Classifier

In [None]:
# Create DataLoader for test data
test_loader = get_dataloader('data/test.csv', tokenizer, MAX_LEN, shuffle=False)

# Evaluate the model on test data
accuracy = evaluate_model(model, test_loader)
print(f"Test Accuracy: {accuracy}")

## Using the BERT Classifier

In [None]:
def classify_user_prompt(text, model, tokenizer, label_mapping):
    # Prepare the text into tokenized tensor
    inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
    
    # Run the text through the model
    with torch.no_grad():
        outputs = model(**inputs)
        
    # Get the predicted label index
    _, predicted_idx = torch.max(outputs.logits, 1)
    
    # Convert the index to the corresponding label string
    predicted_label = None
    for label, idx in label_mapping.items():
        if idx == predicted_idx.item():
            predicted_label = label
            break
            
    return predicted_label

In [None]:
# Example usage
text = "Mir geht es schlecht, das Studium ist sehr anstrengend."
predicted_label = classify_user_prompt(text, model, tokenizer, label_mapping)
print(f"The predicted label for the text is: {predicted_label}")
