In [5]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import pandas as pd
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
from torch.optim import Adam
import torch
from torch.utils.data import Dataset

## Data Loader & Tokenizer

In [6]:
class ClassifierDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.text = dataframe.text
        self.labels = dataframe.label
        self.max_len = max_len

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        text = str(self.text[index])
        text = " ".join(text.split()) # Removes any extra whitespace

        # https://huggingface.co/docs/transformers/v4.34.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.__call__
        inputs = self.tokenizer(
            text,
            None,
            add_special_tokens=True, # Add '[CLS]' and '[SEP]', default True
            max_length=self.max_len, # Maximum length to use by one of the truncation/padding parameters
            padding='max_length', # Pad to a maximum length specified with the argument max_length
            truncation=True, # Truncate to a maximum length specified with the argument max_length
        )
        ids = inputs['input_ids'] # Indices of input sequence tokens in the vocabulary
        mask = inputs['attention_mask'] # Mask to avoid performing attention on padding token indices

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'labels': torch.tensor(self.labels[index], dtype=torch.long)
        }

class TextClassifier(pl.LightningModule):
    def __init__(self, train_df, val_df, tokenizer, max_len=100):
        super(TextClassifier, self).__init__()
        self.train_df = train_df
        self.val_df = val_df
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=3)
    
    def forward(self, ids, mask):
        output = self.model(ids, attention_mask=mask)
        return output.logits

    def training_step(self, batch, batch_nb):
        preds, loss = self._get_preds_loss(batch)
        
        self.log('train_loss', loss)
        return {'loss': loss}

    def validation_step(self, batch, batch_nb):
        preds, loss = self._get_preds_loss(batch)
        
        self.log('val_loss', loss)
        return {'val_loss': loss}

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-5)

    def train_dataloader(self):
        return load_data(self.train_df, self.tokenizer, self.max_len)
      
    def val_dataloader(self):
        return load_data(self.val_df, self.tokenizer, self.max_len, shuffle=False)

    def _get_preds_loss(self, batch):
        # helper function to get predictions and loss
        ids = batch['ids']
        mask = batch['mask']
        labels = batch['labels']
        
        preds = self(ids, mask)
        loss = torch.nn.CrossEntropyLoss()(preds, labels)
        return preds, loss

# Load data and return DataLoader
def load_data(df, tokenizer, max_len, batch_size=32, shuffle=True):
    label_mapping = {'other': 0, 'question': 1, 'concern': 2}
    df['label'] = df['label'].map(label_mapping)
    dataset = ClassifierDataset(df, tokenizer, max_len)

    # Create DataLoader
    params = {'batch_size': batch_size, 'shuffle': shuffle, 'num_workers': 0}
    data_loader = DataLoader(dataset, **params)
    return data_loader

# Evaluate model
def evaluate_model(model, data_loader):
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    
    with torch.no_grad():
        for _, data in enumerate(data_loader, 0):
            ids = data['ids']
            mask = data['mask']
            labels = data['labels']
            
            outputs = model(ids, attention_mask=mask)
            _, predicted = torch.max(outputs.logits, 1)
            
            correct_predictions += (predicted == labels).sum().item()
            total_predictions += labels.size(0)
            
    accuracy = correct_predictions / total_predictions
    return accuracy

In [7]:
# Initialize BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

# Load training and validation data
train_df = pd.read_csv('data/train.csv')
val_df = pd.read_csv('data/val.csv')

# Initialize model
model = TextClassifier(train_df, val_df, tokenizer)

# Initialize WandbLogger
wandb_logger = WandbLogger(entity='yvokeller', project='data-chatbot', log_model='all')

# Initialize Trainer
trainer = pl.Trainer(
    max_epochs=5, 
    logger=wandb_logger, 
    log_every_n_steps=1, 
    enable_progress_bar=True
)

# Train the model
trainer.fit(model)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  rank_zero_warn(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                          | Params
--------------------------------------------------------
0 | model | BertForSequenceClassification | 177 M 
--------------------------------------------------------
177 M     Trainable params
0         Non-trainable params
177 M     Total params
711.423   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


## Fine-Tune BERT Classifier

In [43]:
# BERT model initialization
model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=3)

# Optimizer and Loss function
optimizer = Adam(model.parameters(), lr=1e-5)
loss_function = torch.nn.CrossEntropyLoss()

# Fine-tuning loop
epochs = 3  # Replace with the number of epochs you want

for epoch in range(epochs):
    for _, data in enumerate(data_loader, 0):
        ids = data['ids']
        mask = data['mask']
        token_type_ids = data['token_type_ids']
        labels = data['labels']

        outputs = model(ids, attention_mask=mask, token_type_ids=token_type_ids, labels=labels)
        loss = outputs.loss
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()

print("Fine-tuning completed!")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fine-tuning completed!


## Evaluating the BERT Classifier

In [44]:
# Create DataLoader for test data
test_loader = load_data('data/test.csv', tokenizer, MAX_LEN, shuffle=False)

# Evaluate the model on test data
accuracy = evaluate_model(model, test_loader)
print(f"Test Accuracy: {accuracy}")

Test Accuracy: 0.3333333333333333


## Using the BERT Classifier

In [36]:
def classify_user_prompt(text, model, tokenizer, label_mapping):
    # Prepare the text into tokenized tensor
    inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
    
    # Run the text through the model
    with torch.no_grad():
        outputs = model(**inputs)
        
    # Get the predicted label index
    _, predicted_idx = torch.max(outputs.logits, 1)
    
    # Convert the index to the corresponding label string
    predicted_label = None
    for label, idx in label_mapping.items():
        if idx == predicted_idx.item():
            predicted_label = label
            break
            
    return predicted_label

In [37]:
# Example usage
text = "Mir geht es schlecht, das Studium ist sehr anstrengend."
predicted_label = classify_user_prompt(text, model, tokenizer, label_mapping)
print(f"The predicted label for the text is: {predicted_label}")


The predicted label for the text is: concern
