In [1]:
!nvidia-smi

Sat Dec  7 15:13:16 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off | 00000000:41:00.0 Off |                    0 |
| N/A   32C    P0              67W / 500W |      4MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
import gc
import torch
import numpy as np
from datasets import Dataset
from transformers import (
    AutoModelForMultipleChoice,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    DistilBertTokenizer
)
from sklearn.metrics import accuracy_score, f1_score
import wandb
import json
from tqdm import tqdm
from typing import Dict


In [4]:
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

In [5]:
class MCDropoutBert2Bert(AutoModelForMultipleChoice):
    def __init__(self, config):
        super().__init__(config)
        # Set dropout rate
        self.dropout_rate = 0.1
        
        # Move model to device (e.g., GPU)
        self.to(device)
    
    def _set_dropout_rate(self, module):
        """Helper function to set dropout rate for all Dropout layers."""
        for submodule in module.modules():
            if isinstance(submodule, torch.nn.Dropout):
                submodule.p = self.dropout_rate
                submodule.train()  # Make sure dropout is active in training


In [6]:
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=False):
        try:
            inputs = {k: v.to(model.device) if hasattr(v, 'to') else v 
             for k, v in inputs.items()}
            labels = inputs.pop("labels").to(model.device)
            outputs = model(**inputs)
            logits = outputs.logits  # Get the last token's logits
            # Compute cross entropy loss
            loss = torch.nn.functional.cross_entropy(
                logits,
                labels,
                ignore_index=-100
            )

            with torch.no_grad():
                predictions = torch.argmax(logits, dim=-1)
                valid_mask = labels != -100
                accuracy = (predictions[valid_mask] == labels[valid_mask]).float().mean()

                if self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0:
                    wandb.log({
                        "train_loss": loss.item(),
                        "train_accuracy": accuracy.item(),
                        "train_step": self.state.global_step,
                        "train_epoch": self.state.epoch,
                    })

            # Clear unnecessary tensors
            del predictions, valid_mask
            torch.cuda.empty_cache()

            return (loss, outputs) if return_outputs else loss

        except Exception as e:
            print(f"Loss computation error: {str(e)}")
            raise
        

In [18]:
def prepare_and_tokenize_dataset(file_path, tokenizer):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = [json.loads(line) for line in f]

        formatted_data = []
        for item in tqdm(data, desc="Processing dataset"):
            question = item["question"]
            options = [item["options"][key] for key in ["A", "B", "C", "D", "E"]]

            # Tokenize question with each option
            tokenized_options = tokenizer(
                [f"{question} {opt}" for opt in options],
                padding="max_length",
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )

            label = ord(item["answer_idx"]) - ord('A')
            formatted_data.append({
                "input_ids": tokenized_options["input_ids"],
                "attention_mask": tokenized_options["attention_mask"],
                "labels": torch.tensor(label, dtype=torch.long),
            })

        return Dataset.from_list(formatted_data)
    except Exception as e:
        print(f"Dataset preparation error: {str(e)}")
        raise


In [19]:
class CustomDataCollator:
    def __init__(self, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, examples):
        batch = {}
        
        # Process input_ids and attention_mask
        batch_encoding = self.tokenizer.pad(
            [{"input_ids": ex["input_ids"], "attention_mask": ex["attention_mask"]} for ex in examples],
            padding=True,
            return_tensors="pt"
        )
        
        batch["input_ids"] = batch_encoding["input_ids"]
        batch["attention_mask"] = batch_encoding["attention_mask"]
        
        # Process labels
        if "labels" in examples[0]:
            batch["labels"] = torch.tensor([ex["labels"] for ex in examples], dtype=torch.long)
            
        return batch


In [20]:
def tokenize_function(examples: Dict, tokenizer, max_length: int = 512):
    """Tokenize the input data."""
    prompts = [f"{instruction}\n{response}" 
              for instruction, response in zip(examples["instruction"], examples["response"])]
    
    # Tokenize inputs
    tokenized = tokenizer(
        prompts,
        truncation=True,
        max_length=max_length,
        padding="max_length",
    )
    
    # Create labels (same as input_ids for causal language modeling)
    labels = tokenized["input_ids"].copy()
    
    # Convert to PyTorch tensors
    return {
        "input_ids": torch.tensor(tokenized["input_ids"]),
        "attention_mask": torch.tensor(tokenized["attention_mask"]),
        "labels": torch.tensor(labels)
    }



In [21]:
def prepare_model_and_tokenizer():
    """Initialize the BERT2BERT model and tokenizer."""
    try:
        # Load the tokenizer for BERT
        tokenizer = AutoTokenizer.from_pretrained(
            "bert-base-uncased",
            use_fast=False,
            padding_side="right",
            trust_remote_code=True
        )

        # Load the BERT2BERT encoder-decoder model
        model = AutoModelForMultipleChoice.from_pretrained(
            "bert-base-uncased", 
        )

        # Add pad_token if necessary (BERT models often use [PAD] token)
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})
            model.resize_token_embeddings(len(tokenizer))  # Adjust model's embedding layer

        return model, tokenizer

    except Exception as e:
        print(f"Error loading model/tokenizer: {str(e)}")
        raise



In [22]:
# Initialize model and tokenizer
model, tokenizer = prepare_model_and_tokenizer()

Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
def print_trainable_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")
    print(f"Non-trainable parameters: {total_params - trainable_params}")

# Example usage:
print_trainable_parameters(model)


Total parameters: 109483009
Trainable parameters: 109483009
Non-trainable parameters: 0


In [None]:
# Prepare datasets
train_dataset = prepare_and_tokenize_dataset("MedQADataset/US/train.jsonl", tokenizer)

Processing dataset: 100%|██████████| 10178/10178 [01:49<00:00, 93.05it/s] 


In [25]:
training_args = TrainingArguments(
    output_dir="./bert-medical-qa-",
    num_train_epochs=2,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    warmup_steps=100,
    weight_decay=0.01,
    logging_steps=10,
    save_strategy="steps",
    save_steps=100,
    learning_rate=2e-4,
    optim="paged_adamw_8bit",
    report_to="wandb",
    remove_unused_columns=False,
    save_total_limit=1,
    metric_for_best_model="accuracy",
    greater_is_better=True
)

In [26]:
# Initialize Weights and Biases (wandb)
wandb.init(
    project="bert-medical-qa-finetune",
    name="fine-tuning-train-val-only",
    config={
        "model": "bert-base-uncased",  # Change model name for BERT2BERT
        "epochs": 2,
        "batch_size": 8,
        "learning_rate": 2e-4
    }
)

In [27]:
# Initialize Trainer with CustomTrainer
trainer = CustomTrainer(
    model=model,  # BERT2BERT model
    args=training_args,  # Training arguments defined earlier
    train_dataset=train_dataset,  # Training dataset
    tokenizer=tokenizer,  # BERT tokenizer
    data_collator=CustomDataCollator(tokenizer),  # Custom data collator for BERT2BERT
)

  trainer = CustomTrainer(
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [28]:
# Train the model
trainer.train()

Step,Training Loss
10,1.6297
20,1.6237
30,1.6178
40,1.6034
50,1.6212
60,1.6105
70,1.6202
80,1.6114
90,1.6072
100,1.6191


TrainOutput(global_step=636, training_loss=1.61302997631097, metrics={'train_runtime': 1840.053, 'train_samples_per_second': 11.063, 'train_steps_per_second': 0.346, 'total_flos': 2.676604733807616e+16, 'train_loss': 1.61302997631097, 'epoch': 1.9984289080911233})

In [29]:
model.save_pretrained("./bert-medical-final")
tokenizer.save_pretrained("./bert-medical-final")

('./bert-medical-final/tokenizer_config.json',
 './bert-medical-final/special_tokens_map.json',
 './bert-medical-final/vocab.txt',
 './bert-medical-final/added_tokens.json')

In [30]:
wandb.finish()

0,1
train/epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇█████
train/grad_norm,█▅▃▃▃▂▂▂▂▄▂▂▂▃▃▄▃▂▃▃▁▁▁▂▃▁▄▁▁▂▁▁▁▂▂▁▁▁▂▁
train/learning_rate,▂▃▄▄▅▇▇███▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▃▃▃▃▃▃▂▂▂▂▂▁▁
train/loss,█▅▁▆▃▃▂▅▄▆▄▄▂▄▃▃▃▃▃▄▃▃▃▃▄▄▃▄▅▂▃▂▂▅▃▂▄▃▄▄
train_accuracy,▄▂▄▅▅▂▄▅▂▄▇█▅▂▄▅▄▄▂▂▄▂▄▄▂▄▄▁▄▅▄▂▂▁▂▂▄▄▄▂
train_epoch,▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▇▇▇▇▇▇▇▇██
train_loss,▁█▆▆▆▅▃▃▅▃▂▆▁▇▅▃▄▃█▅▇▅▃▆▅▅▅▃▅▆▃▆▆▅▇▄▅▇▄▃
train_step,▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇█

0,1
total_flos,2.676604733807616e+16
train/epoch,1.99843
train/global_step,636.0
train/grad_norm,3.67541
train/learning_rate,0.0
train/loss,1.6125
train_accuracy,0.375
train_epoch,1.97958
train_loss,1.61303
train_runtime,1840.053
