In [1]:
!nvidia-smi 

Sun Dec  1 19:39:47 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off | 00000000:C1:00.0 Off |                    0 |
| N/A   33C    P0              66W / 500W |      4MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
import torch
import numpy as np
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
)
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model,TaskType
from sklearn.metrics import accuracy_score, f1_score
import wandb
import json
from tqdm import tqdm
from typing import Dict

# # Set device and random seeds for reproducibility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)

In [5]:
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

In [6]:
class MCDropoutLlama(AutoModelForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.dropout_rate = 0.1
        self.to(device)
        
        # Enable dropout by default for all dropout layers
        for module in self.modules():
            if isinstance(module, torch.nn.Dropout):
                module.p = self.dropout_rate
                module.train()
    
    def forward(self, *args, **kwargs):
        try:
            args = tuple(arg.to(device) if torch.is_tensor(arg) else arg for arg in args)
            kwargs = {k: v.to(device) if torch.is_tensor(v) else v for k, v in kwargs.items()}

            # Keep dropout layers in training mode
            for module in self.modules():
                if isinstance(module, torch.nn.Dropout):
                    module.train()

            outputs = super().forward(*args, **kwargs)
            torch.cuda.empty_cache()  # Clear cache after forward pass
            return outputs
        except Exception as e:
            print(f"Forward pass error: {str(e)}")
            raise

In [None]:
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        try:
            inputs = {k: v.to(model.device) if hasattr(v, 'to') else v 
             for k, v in inputs.items()}
            labels = inputs.pop("labels").to(model.device)
            outputs = model(**inputs)
            logits = outputs.logits[:, -1, :]  # Get the last token's logits

            # Create one-hot encoded labels with proper shape
            num_classes = logits.size(-1)
            one_hot_labels = torch.zeros(
                labels.size(0), 
                num_classes, 
                device=labels.device, 
                dtype=torch.float32
            )
            one_hot_labels.scatter_(1, labels.unsqueeze(1), 1)

            # Compute cross entropy loss
            loss = torch.nn.functional.cross_entropy(
                logits,
                labels,
                ignore_index=-100
            )

            with torch.no_grad():
                predictions = torch.argmax(logits, dim=-1)
                valid_mask = labels != -100
                accuracy = (predictions[valid_mask] == labels[valid_mask]).float().mean()

                if self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0:
                    wandb.log({
                        "train_loss": loss.item(),
                        "train_accuracy": accuracy.item(),
                        "train_step": self.state.global_step,
                        "train_epoch": self.state.epoch,
                    })

            # Clear unnecessary tensors
            del predictions, valid_mask
            torch.cuda.empty_cache()

            return (loss, outputs) if return_outputs else loss

        except Exception as e:
            print(f"Loss computation error: {str(e)}")
            raise

In [None]:
def prepare_and_tokenize_dataset(file_path, tokenizer):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = [json.loads(line) for line in f]

        formatted_data = []
        for item in tqdm(data, desc="Processing dataset"):
            question = f"""Question: {item['question']}
Available options:
A: {item['options']['A']}
B: {item['options']['B']}
C: {item['options']['C']}
D: {item['options']['D']}
E: {item['options']['E']}

Please select the correct answer."""
            
            tokenized = tokenizer(
                question,
                padding="max_length",
                truncation=True,
                max_length=512,
                return_tensors="pt"
            )
            
            if item['answer_idx'] in ['A', 'B', 'C', 'D', 'E']:
                label = ord(item['answer_idx']) - ord('A')
                tokenized["labels"] = torch.tensor(label, dtype=torch.long)
                formatted_data.append({k: v.squeeze() for k, v in tokenized.items()})
            
        return Dataset.from_list(formatted_data)
    except Exception as e:
        print(f"Dataset preparation error: {str(e)}")
        raise

In [9]:
class CustomDataCollator:
    def __init__(self, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, examples):
        batch = {}
        
        # Process input_ids and attention_mask
        batch_encoding = self.tokenizer.pad(
            [{"input_ids": ex["input_ids"], "attention_mask": ex["attention_mask"]} for ex in examples],
            padding=True,
            return_tensors="pt"
        )
        
        batch["input_ids"] = batch_encoding["input_ids"]
        batch["attention_mask"] = batch_encoding["attention_mask"]
        
        # Process labels
        if "labels" in examples[0]:
            batch["labels"] = torch.tensor([ex["labels"] for ex in examples], dtype=torch.long)
            
        return batch

In [10]:
def tokenize_function(examples: Dict, tokenizer, max_length: int = 512):
    """Tokenize the input data."""
    prompts = [f"{instruction}\n{response}" 
              for instruction, response in zip(examples["instruction"], examples["response"])]
    
    # Tokenize inputs
    tokenized = tokenizer(
        prompts,
        truncation=True,
        max_length=max_length,
        padding="max_length",
    )
    
    # Create labels (same as input_ids for causal language modeling)
    labels = tokenized["input_ids"].copy()
    
    # Convert to PyTorch tensors
    return {
        "input_ids": torch.tensor(tokenized["input_ids"]),
        "attention_mask": torch.tensor(tokenized["attention_mask"]),
        "labels": torch.tensor(labels)
    }

In [None]:
def prepare_model_and_tokenizer():
    """Initialize the model and tokenizer with LoRA configuration."""
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            "meta-llama/Llama-3.1-8b",
            use_fast=True,
            padding_side="right",
            trust_remote_code=True
        )
        
        model = MCDropoutLlama.from_pretrained(
            "meta-llama/Llama-3.1-8b",
            load_in_8bit=True,
            device_map="auto"
        )
        
        # Prepare model for k-bit training
        model = prepare_model_for_kbit_training(model)
        
        # Define LoRA Config
        lora_config = LoraConfig(
            r=16,  # rank
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type=TaskType.CAUSAL_LM
        )
        
        # Get PEFT model
        model = get_peft_model(model, lora_config)
        
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})
            model.resize_token_embeddings(len(tokenizer))
        
        return model, tokenizer
    
    except Exception as e:
        print(f"Error loading model/tokenizer: {str(e)}")
        raise

In [12]:
# Initialize model and tokenizer
model, tokenizer = prepare_model_and_tokenizer()

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [13]:
# Print trainable parameters info
model.print_trainable_parameters()

trainable params: 6,815,744 || all params: 8,037,085,184 || trainable%: 0.0848


In [14]:
# Prepare datasets
train_dataset = prepare_and_tokenize_dataset("MedQADataset/US/train.jsonl", tokenizer)
val_dataset = prepare_and_tokenize_dataset("MedQADataset/US/dev.jsonl", tokenizer)
print(f"Train samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

Processing dataset: 100%|██████████| 10178/10178 [00:05<00:00, 1816.72it/s]
Processing dataset: 100%|██████████| 1272/1272 [00:00<00:00, 1822.21it/s]


Train samples: 10178, Validation samples: 1272


In [None]:
training_args = TrainingArguments(
    output_dir="./llama-medical-qa-lora-v5",
    num_train_epochs=2,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    warmup_steps=100,
    weight_decay=0.01,
    logging_steps=10,
    save_strategy="steps",
    save_steps=100,
    learning_rate=2e-4,
    optim="paged_adamw_8bit",
    report_to="wandb",
    remove_unused_columns=False,
    save_total_limit=1,
    metric_for_best_model="accuracy",
    greater_is_better=True
)


In [17]:
# Initialize WandB
wandb.init(
    project="llama-mc-finetune",
    name="fine-tuning-train-val-only",
    config={
        "model": "Llama-3.1-8b",
        "epochs": 3,
        "batch_size": 8,
        "learning_rate": 2e-4
    }
)

# Initialize trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    #eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=CustomDataCollator(tokenizer),
    compute_metrics=compute_metrics
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mai_ritik[0m ([33mvisualaiblog[0m). Use [1m`wandb login --relogin`[0m to force relogin


Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [18]:
# Before training starts
import gc
torch.cuda.empty_cache()
gc.collect()

# Train
try:
    trainer.train()
    model.save_pretrained("./llama-medical-qa-lora-v5")
finally:
    torch.cuda.empty_cache()
    gc.collect()
    wandb.finish()

[2024-12-01 19:42:50,205] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)


You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss
10,10.3675
20,9.3044
30,6.1562
40,2.6153
50,1.8708
60,1.724
70,1.6961
80,1.6175
90,1.6589
100,1.6462


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


0,1
train/epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇██
train/global_step,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train/grad_norm,▄█▃▂▃▂▂▂▁▁▁▁▁▁▁▁▁▁▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
train/learning_rate,▂▃▄▄▅▇███▇▇▇▆▆▆▆▆▆▆▅▅▅▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▁▁▁
train/loss,█▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▃▁▃▁▅▅▆▃▂▃▁▆▆▃▇█▅▃█▅▇▇▅▅▇▇▅▇▆▆▇▅█▆█▇▅▇
train_epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▇▇▇▇▇▇▇▇▇█████
train_loss,█▆▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_step,▁▁▁▂▂▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇███

0,1
total_flos,4.695059062000189e+17
train/epoch,1.99843
train/global_step,636.0
train/grad_norm,6.19328
train/learning_rate,0.0
train/loss,1.1266
train_accuracy,0.875
train_epoch,1.97958
train_loss,1.76819
train_runtime,13619.5959
