In [1]:
%pip install -q transformers datasets pytorch-lightning optuna optuna-integration wandb

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from typing import Any, Dict
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

import optuna
from optuna.integration import PyTorchLightningPruningCallback

import wandb
from pytorch_lightning.loggers import WandbLogger


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Set random seeds for reproducibility
pl.seed_everything(42, workers=True)


Seed set to 42


42

In [4]:
BATCH_SIZE: int = 32
MAX_LENGTH: int = 512
MODEL_NAME: str = "bert-base-cased"

print(f"Batch_size: {BATCH_SIZE}"
      f"\nMax_length: {MAX_LENGTH}")

Batch_size: 32
Max_length: 512


In [5]:
class BoolQDataset(Dataset):
    def __init__(self, data: Dict[str, Any], tokenizer: AutoTokenizer, max_length: int = MAX_LENGTH):
        
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self) -> int:
        return len(self.data["question"])

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        # Get question and passage
        question = self.data["question"][idx]
        passage = self.data["passage"][idx]
        label = self.data["answer"][idx]

        # Tokenize
        encoded = self.tokenizer(
            question,
            passage,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        # Correctness tests for tokenization
        assert encoded["input_ids"].shape[-1] <= self.max_length, "Token length exceeds max_length!"
        assert encoded["input_ids"].shape == encoded["attention_mask"].shape, "Mismatch in token shapes!"

        return {
            "input_ids": encoded["input_ids"].squeeze(0),  # Remove batch dimension
            "attention_mask": encoded["attention_mask"].squeeze(0),  # Remove batch dimension
            "label": torch.tensor(label, dtype=torch.float),  # Float for binary classification
        }

In [6]:
class BoolQDataModule(pl.LightningDataModule):
    def __init__(self, tokenizer_name: str, batch_size: int = BATCH_SIZE, max_length: int = MAX_LENGTH):
        super().__init__()
        self.tokenizer_name = tokenizer_name
        self.batch_size = batch_size
        self.max_length = max_length

    def prepare_data(self) -> None:
        # Loading the dataset based on lecture slides
        self.train_data = load_dataset("google/boolq", split="train[:-1000]")
        self.validation_data = load_dataset("google/boolq", split="train[-1000:]")
        self.test_data = load_dataset("google/boolq", split="validation")

    def setup(self, stage: str = None) -> None:
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)

        # Create datasets
        self.train_dataset = BoolQDataset(self.train_data, self.tokenizer, self.max_length)
        self.val_dataset = BoolQDataset(self.validation_data, self.tokenizer, self.max_length)
        self.test_dataset = BoolQDataset(self.test_data, self.tokenizer, self.max_length)

        # Test dataset length
        assert len(self.train_dataset) == 8427, "Train dataset length is incorrect!"
        assert len(self.val_dataset) == 1000, "Validation dataset length is incorrect!"
        assert len(self.test_dataset) == 3270, "Test dataset length is incorrect!"

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

# Initialize DataModule
data_module = BoolQDataModule(tokenizer_name=MODEL_NAME, batch_size=BATCH_SIZE)

# Prepare and test data loading
data_module.prepare_data()
data_module.setup()

# Correctness test for DataLoader
for batch in data_module.train_dataloader():
    assert batch["input_ids"].shape[0] == BATCH_SIZE, "Batch size mismatch!"
    print(f"Batch loaded successfully with shape: {batch['input_ids'].shape}")
    break



Batch loaded successfully with shape: torch.Size([32, 512])


In [7]:
class BoolQClassifier(pl.LightningModule):
    def __init__(
            self,
            model_name: str,
            learning_rate: float = 1e-5,
            hidden_dim: int = 256,
            dropout_rate: float = 0.
    ):
        super(BoolQClassifier, self).__init__()
        self.save_hyperparameters()
        self.bert = AutoModel.from_pretrained(model_name)
        self.classifier = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        self.loss_fn = nn.BCELoss()

        # Storage for test metrics
        self.val_preds = []
        self.val_labels = []
        self.test_preds = []
        self.test_labels = []

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]  # CLS token
        logits = self.classifier(cls_output)
        return logits.squeeze(-1)

    def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        logits = self(batch['input_ids'], batch['attention_mask'])
        loss = self.loss_fn(logits, batch['label'])
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
        logits = self(batch['input_ids'], batch['attention_mask'])
        loss = self.loss_fn(logits, batch['label'])
        preds = (logits > 0.5).float()
    
        # Store predictions and labels for confusion matrix
        self.val_preds.extend(preds.cpu().numpy())
        self.val_labels.extend(batch['label'].cpu().numpy())
    
        acc = (preds == batch['label']).float().mean()
    
        # Log validation metrics
        self.log('val_loss', loss, prog_bar=True, on_epoch=True)
        self.log('val_acc', acc, prog_bar=True, on_epoch=True)
    
        return {'val_loss': loss, 'val_acc': acc}
    
    def on_validation_epoch_end(self) -> None:
        # Compute confusion matrix
        cm = confusion_matrix(self.val_labels, self.val_preds)
        
        # Plot confusion matrix
        fig, ax = plt.subplots(figsize=(5, 5))
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[0, 1])
        disp.plot(cmap=plt.cm.Blues, ax=ax)
        plt.title("Validation Confusion Matrix")

        # Log confusion matrix to WandB
        wandb.log({"val_confusion_matrix": wandb.Image(fig)})

        # Clear storage
        self.val_preds.clear()
        self.val_labels.clear()

    def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
        logits = self(batch['input_ids'], batch['attention_mask'])
        loss = self.loss_fn(logits, batch['label'])
        preds = (logits > 0.5).float()

        # Store predictions and labels for confusion matrix
        self.test_preds.extend(preds.cpu().numpy())
        self.test_labels.extend(batch['label'].cpu().numpy())

        acc = (preds == batch['label']).float().mean()

        # Log test metrics
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)

        return {'test_loss': loss, 'test_acc': acc}

    def on_test_epoch_end(self) -> None:
        # Compute confusion matrix
        cm = confusion_matrix(self.test_labels, self.test_preds)
        
        # Plot confusion matrix
        fig, ax = plt.subplots(figsize=(5, 5))
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[0, 1])
        disp.plot(cmap=plt.cm.Blues, ax=ax)
        plt.title("Test Confusion Matrix")

        # Log confusion matrix to WandB
        wandb.log({"test_confusion_matrix": wandb.Image(fig)})

        # Clear storage
        self.test_preds.clear()
        self.test_labels.clear()

    

    def configure_optimizers(self) -> torch.optim.Optimizer:
        # Separate parameter groups
        transformer_params = list(self.bert.parameters())
        classifier_params = list(self.classifier.parameters())
    
        # Define learning rates
        transformer_lr = self.hparams.learning_rate  # Base learning rate
        classifier_lr = self.hparams.learning_rate * 10  # Higher learning rate for classifier
    
        # Create parameter groups
        optimizer = torch.optim.AdamW([
            {'params': transformer_params, 'lr': transformer_lr},
            {'params': classifier_params, 'lr': classifier_lr}
        ])
    
        return optimizer




In [8]:
# WandB Logger initialization
def get_wandb_logger(run_name: str, group_name: str, hyperparameters: dict):
    wandb.finish()
    return WandbLogger(
        project="nlp-p4-pretrained_transformers",
        name=run_name,
        group=group_name,
        log_model=True,
        reinit=True,
    )

# Custom WandB Callback for Optuna Integration
class CustomWandbLoggingCallback(pl.Callback):
    def __init__(self, log_interval: int = 10):
        self.log_interval = log_interval

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if (batch_idx + 1) % self.log_interval == 0:
            metrics = trainer.callback_metrics
            wandb.log({
                "train_loss": metrics.get("train_loss", None),
                "train_acc": metrics.get("train_acc", None),
            })

    def on_validation_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        wandb.log({
            "val_loss": metrics.get("val_loss_epoch", None),
            "val_acc": metrics.get("val_acc_epoch", None),
        })

# Helper Function to Format Run Name
def format_run_name(hyperparams: dict) -> str:
    return "_".join([f"{key[:2]}_{val}" for key, val in hyperparams.items()])

# Manual Training
def train_manual():
    # Hyperparameters
    hyperparameters = {
        "learning_rate": 2e-5,
        "hidden_dim": 256,
        "dropout_rate": 0.3,
        "batch_size": BATCH_SIZE
    }

    # Run Name
    run_name = format_run_name(hyperparameters)

    # WandB Logger
    wandb_logger = get_wandb_logger(run_name, "manual_testing", hyperparameters)

    # Initialize DataModule
    data_module = BoolQDataModule(
        tokenizer_name=MODEL_NAME, 
        batch_size=hyperparameters["batch_size"]
    )

    # Initialize Model
    model = BoolQClassifier(
        model_name=MODEL_NAME,
        learning_rate=hyperparameters["learning_rate"],
        hidden_dim=hyperparameters["hidden_dim"],
        dropout_rate=hyperparameters["dropout_rate"]
    )

    # Callbacks
    early_stopping = EarlyStopping(monitor='val_loss', patience=10, mode='min')
    checkpoint = ModelCheckpoint(monitor='val_acc', mode='max', save_top_k=1, filename=run_name)

    # Trainer
    trainer = Trainer(
        max_epochs=100,
        callbacks=[early_stopping, checkpoint, CustomWandbLoggingCallback()],
        accelerator='auto',
        devices=1,
        logger=wandb_logger
    )

    # Train
    trainer.fit(model, datamodule=data_module)
    
    # Finish WandB run
    wandb.finish()

# Optuna Objective with WandB Logging
def objective(trial: optuna.Trial) -> float:
    # Suggest hyperparameters
    hyperparameters = {
        "learning_rate": trial.suggest_loguniform('learning_rate', 1e-6, 1e-4),
        "hidden_dim": trial.suggest_int('hidden_dim', 128, 512, step=64),
        "dropout_rate": trial.suggest_uniform('dropout_rate', 0.1, 0.5),
        "batch_size": BATCH_SIZE
    }

    # Run Name
    run_name = format_run_name(hyperparameters)

    # WandB Logger
    wandb_logger = get_wandb_logger(run_name, "optuna_testing", hyperparameters)

    # Initialize DataModule
    data_module = BoolQDataModule(
        tokenizer_name=MODEL_NAME, 
        batch_size=hyperparameters["batch_size"]
    )

    # Initialize Model
    model = BoolQClassifier(
        model_name=MODEL_NAME,
        learning_rate=hyperparameters["learning_rate"],
        hidden_dim=hyperparameters["hidden_dim"],
        dropout_rate=hyperparameters["dropout_rate"]
    )

    # Callbacks
    early_stopping = EarlyStopping(monitor='val_loss', patience=10, mode='min')
    checkpoint = ModelCheckpoint(monitor='val_acc', mode='max', save_top_k=1, filename=run_name)
    pruning_callback = PyTorchLightningPruningCallback(trial, monitor='val_loss')

    # Trainer
    trainer = Trainer(
        max_epochs=50,
        callbacks=[early_stopping, checkpoint, pruning_callback, CustomWandbLoggingCallback()],
        accelerator='auto',
        devices=1,
        logger=wandb_logger
    )

    # Train
    trainer.fit(model, datamodule=data_module)
    
    # Validate to fetch the latest metrics
    val_metrics = trainer.validate(model, datamodule=data_module, verbose=False)

    # Finish WandB run
    wandb.finish()
    
    # Retrieve best score
    return val_metrics[0]['val_acc']

# Optuna Study
def run_optuna():
    study = optuna.create_study(direction='maximize')
    study.optimize(objective, n_trials=5)

    # Best Hyperparameters
    best_params = study.best_params
    print(f"Best hyperparameters: {best_params}")

# Example Execution
# Uncomment one of the following to run
# train_manual()
run_optuna()


[I 2024-11-25 15:43:40,969] A new study created in memory with name: no-name-350ce0ae-4542-4e0c-9d09-52c78c10c198
  "learning_rate": trial.suggest_loguniform('learning_rate', 1e-6, 1e-4),
  "dropout_rate": trial.suggest_uniform('dropout_rate', 0.1, 0.5),
[W 2024-11-25 15:43:42,519] Trial 0 failed with parameters: {'learning_rate': 1.2177914779583088e-06, 'hidden_dim': 512, 'dropout_rate': 0.1928871365969061} because of the following error: ValueError('Expected a parent').
Traceback (most recent call last):
  File "C:\Users\Pascal\miniconda3\envs\nlp\Lib\site-packages\optuna\study\_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "C:\Users\Pascal\AppData\Local\Temp\ipykernel_18252\3816517564.py", line 121, in objective
    trainer = Trainer(
              ^^^^^^^^
  File "C:\Users\Pascal\miniconda3\envs\nlp\Lib\site-packages\pytorch_lightning\utilities\argparse.py", line 70, in insert_env_defaults
    return fn(self, **kwar

ValueError: Expected a parent

In [None]:
"""
# Define paths and load model from checkpoint
base_path = Path("nlp-p4-pretrained_transformers/nlfg3sfr/checkpoints")
run_name = "best_model"
file_path = base_path / (run_name + ".ckpt")

# Initialize WandB logger for evaluation
wandb_logger = WandbLogger(project="nlp-p4-pretrained_transformers", name=run_name, group="evaluation")

# Load the model from the checkpoint
model = BoolQClassifier.load_from_checkpoint(file_path)

# Initialize DataModule for testing
data_module = BoolQDataModule(tokenizer_name="bert-base-cased", batch_size=BATCH_SIZE)
data_module.prepare_data()
data_module.setup()

# Initialize the trainer for testing
trainer = pl.Trainer(logger=wandb_logger)

# Run testing on the test set
trainer.test(model, dataloaders=data_module.test_dataloader())

# Finish WandB session
wandb.finish()
"""