In [26]:
import os
import json
import torch
import wandb
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig,
    DataCollatorWithPadding
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training
)
from huggingface_hub import notebook_login
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

In [12]:
wandb.init(project="medical-berta", 
          name="medmcqa-finetuning",
          config={
              "base_model": "microsoft/deberta-v3-small",  # Small but powerful model
              "dataset": "openlifescienceai/medmcqa",
              "learning_rate": 3e-4,
              "batch_size": 8,
              "num_epochs": 3
          })

VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [19]:
class MedicalDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

def load_and_preprocess_data(tokenizer):
    """Load and preprocess the MedMCQA dataset."""
    dataset = load_dataset("openlifescienceai/medmcqa")
    
    def format_example(example):
        return f"""Question: {example['question']}
A) {example['opa']}
B) {example['opb']}
C) {example['opc']}
D) {example['opd']}"""
    
    # Prepare train dataset
    train_texts = [format_example(ex) for ex in dataset['train']]
    train_labels = [ex['cop'] for ex in dataset['train']]
    
    # Prepare validation dataset
    val_texts = [format_example(ex) for ex in dataset['validation']]
    val_labels = [ex['cop'] for ex in dataset['validation']]
    
    # Tokenize datasets
    train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=512)
    val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=512)
    
    # Create dataset objects
    train_dataset = MedicalDataset(train_encodings, train_labels)
    val_dataset = MedicalDataset(val_encodings, val_labels)
    
    return train_dataset, val_dataset

def evaluate_medical_performance(model, tokenizer, dataset):
    """Evaluate model performance on medical queries."""
    model.eval()
    dataloader = DataLoader(dataset, batch_size=8, shuffle=False)
    
    correct = 0
    total = 0
    predictions = []
    
    with torch.no_grad():
        for batch in dataloader:
            inputs = {k: v.cuda() for k, v in batch.items() if k != 'labels'}
            labels = batch['labels'].cuda()
            
            outputs = model(**inputs)
            preds = outputs.logits.argmax(-1)
            predictions.extend(preds.cpu().numpy())
            
            correct += (preds == labels).sum().item()
            total += len(labels)
    
    accuracy = correct / total
    return accuracy, predictions

In [27]:
def main():
    # 1. Load base model and tokenizer
    model_name = wandb.config.base_model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Configure quantization
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    
    # Load model with quantization config
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=4,
        quantization_config=bnb_config,
    )
    
    # 2. Log initial model size
    def get_model_size(model):
        param_size = 0
        for param in model.parameters():
            param_size += param.nelement() * param.element_size()
        buffer_size = 0
        for buffer in model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()
        size_all_mb = (param_size + buffer_size) / 1024**2
        return size_all_mb
    
    initial_size = get_model_size(model)
    wandb.log({"model_size_before_quantization": initial_size})
    
    # 3. Prepare model for k-bit training
    model = prepare_model_for_kbit_training(model)
    
    # 4. Configure LoRA
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="SEQ_CLS"
    )
    
    model = get_peft_model(model, lora_config)
    
    # Log quantized model size
    quantized_size = get_model_size(model)
    wandb.log({"model_size_after_quantization": quantized_size})
    
    # 5. Load and preprocess data
    train_dataset, val_dataset = load_and_preprocess_data(tokenizer)
    
    # 6. Evaluate initial performance on medical queries
    initial_accuracy, _ = evaluate_medical_performance(model, tokenizer, val_dataset)
    wandb.log({"initial_medical_accuracy": initial_accuracy})
    
    # 7. Configure training arguments
    training_args = TrainingArguments(
        output_dir="./results",
        learning_rate=wandb.config.learning_rate,
        per_device_train_batch_size=wandb.config.batch_size,
        per_device_eval_batch_size=wandb.config.batch_size,
        num_train_epochs=wandb.config.num_epochs,
        weight_decay=0.01,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        push_to_hub=False,
        report_to="wandb"
    )
    
    # 8. Create Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=DataCollatorWithPadding(tokenizer),
    )
    
    # 9. Train model
    trainer.train()
    
    # 10. Evaluate final performance on medical queries
    final_accuracy, predictions = evaluate_medical_performance(model, tokenizer, val_dataset)
    wandb.log({
        "final_medical_accuracy": final_accuracy,
        "accuracy_improvement": final_accuracy - initial_accuracy
    })
    
    # 11. Log confusion matrix
    wandb.log({
        "confusion_matrix": wandb.plot.confusion_matrix(
            probs=None,
            y_true=[example['labels'].item() for example in val_dataset],
            preds=predictions,
            class_names=["A", "B", "C", "D"]
        )
    })
    
    # 12. Save final model
    model.save_pretrained("./final_model")
    tokenizer.save_pretrained("./final_model")
    
    # Close wandb run
    wandb.finish()

In [28]:
if __name__ == "__main__":
    main()

`low_cpu_mem_usage` was None, now set to True since model is quantized.
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-small and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
  return fn(*args, **kwargs)


Epoch,Training Loss,Validation Loss
1,1.37,1.371592
2,1.368,1.360901
3,1.356,1.362785


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


VBox(children=(Label(value='0.032 MB of 0.032 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy_improvement,▁
eval/loss,█▁▂
eval/runtime,▂█▁
eval/samples_per_second,▇▁█
eval/steps_per_second,▇▁█
final_medical_accuracy,▁
initial_medical_accuracy,█▁▁
model_size_after_quantization,▁▁▁▁
model_size_before_quantization,▁▁▁▁
train/epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇██████

0,1
accuracy_improvement,0.11236
eval/loss,1.36279
eval/runtime,13.649
eval/samples_per_second,306.469
eval/steps_per_second,38.318
final_medical_accuracy,0.33349
initial_medical_accuracy,0.22113
model_size_after_quantization,398.71976
model_size_before_quantization,209.05909
total_flos,5.615687522610288e+16
