In [7]:
!pip uninstall -y bitsandbytes

[0m

In [8]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install transformers==4.36.0
!pip install peft==0.7.1
!pip install datasets==2.14.0
!pip install accelerate==0.25.0

Looking in indexes: https://download.pytorch.org/whl/cu118


In [9]:
script_content = '''
import argparse
import json
import logging
import os
from dataclasses import dataclass, field
from typing import Dict, List

import torch
from datasets import Dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    set_seed,
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class ModelArguments:
    model_name_or_path: str = field(default="microsoft/DialoGPT-medium")
    trust_remote_code: bool = field(default=True)

@dataclass
class DataArguments:
    pubmedqa_path: str = field(default="pubmedqa_train.jsonl")
    medmcqa_path: str = field(default="medmcqa_train.jsonl")
    medqa_path: str = field(default="medqa_train.jsonl")
    max_seq_length: int = field(default=512)

@dataclass
class LoraArguments:
    lora_rank: int = field(default=8)
    lora_alpha: int = field(default=16)
    lora_dropout: float = field(default=0.1)
    target_modules: List[str] = field(default_factory=list)

class InstructionDataset:
    def __init__(self, data_path: str, tokenizer, max_length: int = 512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(data_path)

    def load_data(self, data_path: str) -> List[Dict]:
        data = []
        try:
            with open(data_path, 'r', encoding='utf-8') as f:
                for line in f:
                    data.append(json.loads(line.strip()))
            logger.info(f"Loaded {len(data)} examples from {data_path}")
        except Exception as e:
            logger.error(f"Error loading {data_path}: {e}")
            raise
        return data

    def format_instruction(self, example: Dict) -> str:
        instruction = example.get("instruction", "")
        input_text = example.get("input", "")
        output = example.get("output", "")


        text = f"{instruction} {input_text} {output}".strip()
        return text

    def tokenize_function(self, examples):
        formatted_texts = [self.format_instruction(ex) for ex in examples]


        model_inputs = self.tokenizer(
            formatted_texts,
            truncation=True,
            padding=True,  # Enable padding
            max_length=self.max_length,
            return_tensors=None,
            add_special_tokens=True,
        )


        model_inputs["labels"] = model_inputs["input_ids"].copy()
        return model_inputs

    def get_dataset(self) -> Dataset:
        dataset = Dataset.from_list(self.data)
        tokenized_dataset = dataset.map(
            lambda examples: self.tokenize_function([examples]),
            batched=False,
            remove_columns=dataset.column_names,
            desc="Tokenizing dataset",
        )
        return tokenized_dataset

def get_target_modules(model_name: str) -> List[str]:
    model_name_lower = model_name.lower()

    if "llama" in model_name_lower:
        return ["q_proj", "v_proj"]
    elif "dialogpt" in model_name_lower or "gpt" in model_name_lower:
        return ["c_attn"]
    elif "mistral" in model_name_lower:
        return ["q_proj", "v_proj"]
    else:
        return ["q_proj", "v_proj"]  # Safe default

def setup_lora_config(lora_args: LoraArguments, model_name: str):
    target_modules = get_target_modules(model_name)
    print(f"Using target modules: {target_modules}")

    return LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=lora_args.lora_rank,
        lora_alpha=lora_args.lora_alpha,
        lora_dropout=lora_args.lora_dropout,
        target_modules=target_modules,
        bias="none",
    )

def load_model_and_tokenizer(model_args: ModelArguments, lora_config: LoraConfig):
    print(f"Loading model: {model_args.model_name_or_path}")

    # Tokenizer.
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            trust_remote_code=model_args.trust_remote_code,
            padding_side="right",  # Important for causal LM
            use_fast=True,
        )
        print("Tokenizer loaded.")
    except Exception as e:
        print(f"Error loading tokenizer: {e}")
        raise


    if tokenizer.pad_token is None:
        if tokenizer.eos_token is not None:
            tokenizer.pad_token = tokenizer.eos_token
        else:
            tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        print("Set pad token.")

    # Model with error handling.
    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            trust_remote_code=model_args.trust_remote_code,
            torch_dtype=torch.float32,
            device_map=None,  # Load on CPU first, then move to GPU
            low_cpu_mem_usage=True,
        )
        print("Base model loaded on CPU.")


        if tokenizer.pad_token == '[PAD]':
            model.resize_token_embeddings(len(tokenizer))
            print("Resized token embeddings.")


        if torch.cuda.is_available():
            model = model.cuda()
            print("Model to GPU.")

    except Exception as e:
        print(f"Error loading model: {e}")
        raise

    # LoRA with error handling.
    try:
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()
        print("LoRA applied.")
    except Exception as e:
        print(f"Error applying LoRA: {e}")
        print("Available attention modules:")
        for name, _ in model.named_modules():
            if any(target in name.lower() for target in ["attn", "proj", "query", "key", "value"]):
                print(f"  - {name}")
        raise

    return model, tokenizer

class RobustDataCollator(DataCollatorForLanguageModeling):


    def __call__(self, features):

        try:
            batch = super().__call__(features)


            if "input_ids" in batch:
                input_ids = batch["input_ids"]
                if len(input_ids.shape) != 2:
                    print(f"  Fixing input_ids shape: {input_ids.shape}")
                    batch["input_ids"] = input_ids.view(-1, input_ids.shape[-1])

            if "labels" in batch:
                labels = batch["labels"]
                if len(labels.shape) != 2:
                    print(f"  Fixing labels shape: {labels.shape}")
                    batch["labels"] = labels.view(-1, labels.shape[-1])

            return batch

        except Exception as e:
            print(f" Error in data collator: {e}")
            print(f"Features: {[type(f) for f in features]}")
            raise

def train_stage(stage_name: str, data_path: str, model, tokenizer, training_args: TrainingArguments, data_args: DataArguments):
    print(f"\\n Starting training stage: {stage_name}")
    print(f" Data path: {data_path}")
    print(f" Output directory: {training_args.output_dir}")

    # Dataset.
    try:
        instruction_dataset = InstructionDataset(
            data_path=data_path,
            tokenizer=tokenizer,
            max_length=data_args.max_seq_length,
        )
        train_dataset = instruction_dataset.get_dataset()
        print(f" Dataset loaded with {len(train_dataset)} examples.")


        if len(train_dataset) > 0:
            sample = train_dataset[0]
            print(f" Sample input_ids shape: {len(sample['input_ids'])}")
            print(f" Sample labels shape: {len(sample['labels'])}")

    except Exception as e:
        print(f" Error loading dataset: {e}")
        raise


    data_collator = RobustDataCollator(
        tokenizer=tokenizer,
        mlm=False,
        pad_to_multiple_of=None,
    )

    # Trainer.
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
    )

    print("Training.")

    # Training with error handling.
    try:
        trainer.train()
        print("Training done.")
    except Exception as e:
        print(f" Training failed: {e}")


        try:
            sample_batch = next(iter(trainer.get_train_dataloader()))
            print(f" Batch info:")
            for key, value in sample_batch.items():
                if hasattr(value, 'shape'):
                    print(f"  {key}: {value.shape}")
        except:
            print("Could not inspect batch")
        raise


    try:
        trainer.save_model()
        trainer.save_state()
        tokenizer.save_pretrained(training_args.output_dir)
        print(f" Model saved to {training_args.output_dir}")
    except Exception as e:
        print(f" Error saving model: {e}")
        raise

    return trainer

def main():
    parser = argparse.ArgumentParser(description="Medical QA fine-tuning")

    parser.add_argument("--model_name_or_path", default="microsoft/DialoGPT-medium")
    parser.add_argument("--pubmedqa_path", default="pubmedqa_train.jsonl")
    parser.add_argument("--medmcqa_path", default="medmcqa_train.jsonl")
    parser.add_argument("--medqa_path", default="medqa_train.jsonl")
    parser.add_argument("--output_dir", default="./checkpoints")
    parser.add_argument("--max_seq_length", type=int, default=128)  # Reduced further
    parser.add_argument("--num_epochs", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=1)  # Start with 1
    parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
    parser.add_argument("--learning_rate", type=float, default=5e-5)
    parser.add_argument("--lora_rank", type=int, default=4)  # Reduced
    parser.add_argument("--lora_alpha", type=int, default=8)  # Reduced
    parser.add_argument("--stage", type=str, choices=["all", "1", "2", "3"], default="1")
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()

    print("-"*60)
    print("Medical QA Fine-Tuning.")
    print("-"*60)
    print(f"Model: {args.model_name_or_path}")
    print(f"Stage: {args.stage}")
    print(f"Batch size: {args.batch_size}")
    print(f"Max sequence length: {args.max_seq_length}")
    print(f"LoRA rank: {args.lora_rank}")
    print("="*60)

    set_seed(args.seed)

    model_args = ModelArguments(
        model_name_or_path=args.model_name_or_path,
    )

    data_args = DataArguments(
        pubmedqa_path=args.pubmedqa_path,
        medmcqa_path=args.medmcqa_path,
        medqa_path=args.medqa_path,
        max_seq_length=args.max_seq_length,
    )

    lora_args = LoraArguments(
        lora_rank=args.lora_rank,
        lora_alpha=args.lora_alpha,
    )

    lora_config = setup_lora_config(lora_args, args.model_name_or_path)


    stages = [
        {"name": "pubmedqa", "data_path": args.pubmedqa_path, "output_dir": os.path.join(args.output_dir, "pubmedqa")},
        {"name": "medmcqa", "data_path": args.medmcqa_path, "output_dir": os.path.join(args.output_dir, "medmcqa")},
        {"name": "medqa", "data_path": args.medqa_path, "output_dir": os.path.join(args.output_dir, "medqa")},
    ]

    if args.stage == "all":
        stages_to_run = [0, 1, 2]
    else:
        stages_to_run = [int(args.stage) - 1]


    model, tokenizer = load_model_and_tokenizer(model_args, lora_config)

    for stage_idx in stages_to_run:
        stage = stages[stage_idx]
        stage_name = stage["name"]

        print(f"\\n{'='*50}")
        print(f"STAGE {stage_idx + 1}: {stage_name.upper()}")
        print(f"{'='*50}")


        os.makedirs(stage["output_dir"], exist_ok=True)


        training_args = TrainingArguments(
            output_dir=stage["output_dir"],
            num_train_epochs=args.num_epochs,
            per_device_train_batch_size=args.batch_size,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            logging_steps=5,
            save_steps=100,
            learning_rate=args.learning_rate,
            weight_decay=0.01,
            dataloader_num_workers=0,
            remove_unused_columns=False,
            report_to=None,
            save_total_limit=1,
            load_best_model_at_end=False,
            max_steps=50,  # Limit steps for testing
            logging_first_step=True,
            dataloader_drop_last=True,  # Drop incomplete batches
        )


        trainer = train_stage(
            stage_name=stage_name,
            data_path=stage["data_path"],
            model=model,
            tokenizer=tokenizer,
            training_args=training_args,
            data_args=data_args,
        )

        print(f" Stage {stage_idx + 1} ({stage_name}) completed!")
        model = trainer.model

    print("\\n Training done.")

if __name__ == "__main__":
    main()
'''


with open('medical_qa_robust.py', 'w') as f:
    f.write(script_content)

print("Created medical_qa_robust.py")

Created medical_qa_robust.py


In [13]:
!python medical_qa_robust.py \
    --model_name_or_path "microsoft/DialoGPT-medium" \
    --num_epochs 1 \
    --batch_size 1 \
    --gradient_accumulation_steps 4 \
    --max_seq_length 256 \
    --lora_rank 4 \
    --lora_alpha 8 \
    --stage "1"

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
2025-09-25 05:25:21.969027: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1758777921.990866    7270 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758777921.997421    7270 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1758777922.014228    7270 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1758777922.014274    7270 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:

In [14]:
!python medical_qa_robust.py \
    --model_name_or_path "NousResearch/Llama-2-7b-chat-hf" \
    --num_epochs 1 \
    --batch_size 1 \
    --gradient_accumulation_steps 4 \
    --max_seq_length 128 \
    --lora_rank 4 \
    --lora_alpha 8 \
    --stage "1"

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
2025-09-25 05:27:00.043761: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1758778020.066639    7785 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758778020.073530    7785 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1758778020.091104    7785 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1758778020.091148    7785 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:

In [15]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# Loading the fine-tuned model.
tokenizer = AutoTokenizer.from_pretrained("./checkpoints/pubmedqa")
base_model = AutoModelForCausalLM.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
model = PeftModel.from_pretrained(base_model, "./checkpoints/pubmedqa")


prompt = "Question: What is the normal body temperature? Answer:"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=50)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Question: What is the normal body temperature? Answer: Normal body temperature varies from person to person, but it is generally around 98.6 degrees Fahrenheit (37 degrees Celsius). However, body temperature can fluctuate depending on various factors such as age, sex,
