In [32]:
!pip -q install -U "transformers>=4.44.0" "accelerate>=0.33.0" "huggingface_hub>=0.23.0" peft datasets
from huggingface_hub import login, whoami
login()
print("HF account:", whoami().get("name"))

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

HF account: CassandraMSull


In [33]:
!pip -q install -U "transformers>=4.44.0" "accelerate>=0.33.0" "huggingface_hub>=0.23.0" peft datasets

In [34]:
!pip install -q transformers peft datasets accelerate bitsandbytes

In [35]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
m = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tok = AutoTokenizer.from_pretrained(m, trust_remote_code=True)           # should work now (no 403)
mdl = AutoModelForCausalLM.from_pretrained(m, device_map="auto",
                                           torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                                           trust_remote_code=True)
print("Loaded")

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

Loaded


In [47]:
import argparse
import json
import logging
import os
from dataclasses import dataclass, field
from typing import Dict, List

from dataclasses import dataclass
from transformers.data.data_collator import default_data_collator
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

import torch
from datasets import Dataset
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    set_seed,
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [54]:
@dataclass
class ModelArguments:
    model_name_or_path: str = field(default="meta-llama/Llama-3.2-8B-Instruct")
    trust_remote_code: bool = field(default=True)

@dataclass
class DataArguments:
    step1_path: str = field(default="step1.json")
    step2_path: str = field(default="step1_medmcqa.json")
    step3_path: str = field(default="step1_medqa.json")
    max_seq_length: int = field(default=512)
    num_samples: int = field(default=50)

@dataclass
class LoraArguments:
    lora_rank: int = field(default=8)
    lora_alpha: int = field(default=16)
    lora_dropout: float = field(default=0.1)

@dataclass
class DataCollatorCastBool:
    tokenizer: PreTrainedTokenizerBase

    def __call__(self, features):
        batch = default_data_collator(features)
        if "attention_mask" in batch:
            batch["attention_mask"] = batch["attention_mask"].bool()
        # (optional) for safety with labels if you add them
        if "labels" in batch and hasattr(batch["labels"], "dtype") and batch["labels"].dtype != torch.long:
            batch["labels"] = batch["labels"].long()
        return batch

class FlexibleDataset:
    """Dataset class that preserves original format and creates appropriate prompts"""

    def __init__(self, data_path: str, tokenizer, max_length: int = 512, num_samples: int = 50):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.num_samples = num_samples
        self.data = self.load_data(data_path)
        self.dataset_type = self.detect_dataset_type(data_path)
        logger.info(f"Detected dataset type: {self.dataset_type}")

    def load_data(self, data_path: str) -> List[Dict]:
        """Load data preserving original structure"""
        data = []
        try:
            with open(data_path, 'r', encoding='utf-8') as f:
                # Try loading as JSON array first
                try:
                    loaded_data = json.load(f)
                    if isinstance(loaded_data, list):
                        data = loaded_data[:self.num_samples]
                    else:
                        data = [loaded_data][:self.num_samples]
                except json.JSONDecodeError:
                    # If that fails, try JSONL format
                    f.seek(0)
                    for idx, line in enumerate(f):
                        if idx >= self.num_samples:
                            break
                        if line.strip():
                            data.append(json.loads(line.strip()))

            logger.info(f"Loaded {len(data)} examples from {data_path}")

            if len(data) == 0:
                raise ValueError(f"No data loaded from {data_path}")

            # Log first example structure
            logger.info(f"Sample keys: {list(data[0].keys())}")
            logger.info(f"First example preview: {str(data[0])[:200]}...")

        except Exception as e:
            logger.error(f"Error loading {data_path}: {e}")
            raise
        return data

    def detect_dataset_type(self, data_path: str) -> str:
        """Detect which dataset type based on filename or structure"""
        filename = os.path.basename(data_path).lower()

        if 'pubmed' in filename or 'step1.json' == filename:
            return 'pubmedqa'
        elif 'medmcqa' in filename:
            return 'medmcqa'
        elif 'medqa' in filename:
            return 'medqa'

        # Try to detect from structure
        if self.data and len(self.data) > 0:
            example = self.data[0]

            # PubMedQA typically has: question, context, long_answer, final_decision
            if 'final_decision' in example or 'long_answer' in example:
                return 'pubmedqa'

            # MedMCQA typically has: question, opa, opb, opc, opd, cop, exp
            if all(key in example for key in ['opa', 'opb', 'opc', 'opd']):
                return 'medmcqa'

            # MedQA typically has: question, answer, options
            if 'options' in example and isinstance(example.get('options'), dict):
                return 'medqa'

        return 'generic'

    def format_pubmedqa(self, example: Dict) -> str:
        """Format PubMedQA example preserving original structure"""
        parts = []

        # Question
        if 'QUESTION' in example:
            parts.append(f"Question: {example['QUESTION']}")
        elif 'question' in example:
            parts.append(f"Question: {example['question']}")

        # Context
        if 'CONTEXTS' in example:
            contexts = example['CONTEXTS']
            if isinstance(contexts, list):
                context_text = " ".join(contexts)
            else:
                context_text = str(contexts)
            parts.append(f"\nContext: {context_text}")
        elif 'context' in example:
            parts.append(f"\nContext: {example['context']}")

        # Answer
        if 'final_decision' in example:
            parts.append(f"\nAnswer: {example['final_decision']}")
        elif 'FINAL_DECISION' in example:
            parts.append(f"\nAnswer: {example['FINAL_DECISION']}")
        elif 'long_answer' in example:
            parts.append(f"\nAnswer: {example['long_answer']}")
        elif 'LONG_ANSWER' in example:
            parts.append(f"\nAnswer: {example['LONG_ANSWER']}")

        return "\n".join(parts)

    def format_medmcqa(self, example: Dict) -> str:
        """Format MedMCQA example preserving original structure"""
        parts = []

        # Question
        if 'question' in example:
            parts.append(f"Question: {example['question']}")

        # Options
        options_text = []
        if 'opa' in example:
            options_text.append(f"A) {example['opa']}")
        if 'opb' in example:
            options_text.append(f"B) {example['opb']}")
        if 'opc' in example:
            options_text.append(f"C) {example['opc']}")
        if 'opd' in example:
            options_text.append(f"D) {example['opd']}")

        if options_text:
            parts.append("\nOptions:\n" + "\n".join(options_text))

        # Correct answer
        if 'cop' in example:
            cop = example['cop']
            # Map cop to letter
            cop_map = {1: 'A', 2: 'B', 3: 'C', 4: 'D'}
            if isinstance(cop, int) and cop in cop_map:
                answer_letter = cop_map[cop]
                # Get the actual option text
                option_key = f"op{answer_letter.lower()}"
                if option_key in example:
                    parts.append(f"\nCorrect Answer: {answer_letter}) {example[option_key]}")
                else:
                    parts.append(f"\nCorrect Answer: {answer_letter}")
            else:
                parts.append(f"\nCorrect Answer: {cop}")

        # Explanation if available
        if 'exp' in example and example['exp']:
            parts.append(f"\nExplanation: {example['exp']}")

        return "\n".join(parts)

    def format_medqa(self, example: Dict) -> str:
        """Format MedQA example preserving original structure"""
        parts = []

        # Question
        if 'question' in example:
            parts.append(f"Question: {example['question']}")

        # Options
        if 'options' in example:
            options = example['options']
            if isinstance(options, dict):
                options_text = [f"{key}) {value}" for key, value in sorted(options.items())]
                parts.append("\nOptions:\n" + "\n".join(options_text))
            elif isinstance(options, list):
                options_text = [f"{chr(65+i)}) {opt}" for i, opt in enumerate(options)]
                parts.append("\nOptions:\n" + "\n".join(options_text))

        # Answer
        if 'answer' in example:
            parts.append(f"\nCorrect Answer: {example['answer']}")
        elif 'answer_idx' in example and 'options' in example:
            idx = example['answer_idx']
            if isinstance(example['options'], dict):
                answer_key = list(example['options'].keys())[idx] if idx < len(example['options']) else None
                if answer_key:
                    parts.append(f"\nCorrect Answer: {answer_key}) {example['options'][answer_key]}")
            else:
                parts.append(f"\nCorrect Answer: {chr(65+idx)}) {example['options'][idx]}")

        # Explanation if available
        if 'explanation' in example and example['explanation']:
            parts.append(f"\nExplanation: {example['explanation']}")

        return "\n".join(parts)

    def format_generic(self, example: Dict) -> str:
        """Generic formatter for unknown formats - just concatenate all text fields"""
        parts = []

        # Common field names to look for
        question_keys = ['question', 'QUESTION', 'query', 'input', 'prompt']
        answer_keys = ['answer', 'ANSWER', 'response', 'output', 'final_decision', 'FINAL_DECISION']
        context_keys = ['context', 'CONTEXTS', 'passage', 'background']

        # Add question
        for key in question_keys:
            if key in example and example[key]:
                parts.append(f"Question: {example[key]}")
                break

        # Add context if available
        for key in context_keys:
            if key in example and example[key]:
                context_val = example[key]
                if isinstance(context_val, list):
                    context_val = " ".join(str(c) for c in context_val)
                parts.append(f"\nContext: {context_val}")
                break

        # Add answer
        for key in answer_keys:
            if key in example and example[key]:
                parts.append(f"\nAnswer: {example[key]}")
                break

        # If nothing found, just stringify the whole example
        if not parts:
            parts.append(str(example))

        return "\n".join(parts)

    def format_example(self, example: Dict) -> str:
        """Format example based on detected dataset type"""
        if self.dataset_type == 'pubmedqa':
            return self.format_pubmedqa(example)
        elif self.dataset_type == 'medmcqa':
            return self.format_medmcqa(example)
        elif self.dataset_type == 'medqa':
            return self.format_medqa(example)
        else:
            return self.format_generic(example)

    def tokenize_function(self, examples):
        """Tokenize examples preserving original format"""
        formatted_texts = [self.format_example(ex) for ex in examples]

        model_inputs = self.tokenizer(
            formatted_texts,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors=None,
            add_special_tokens=True,
        )

        # Create labels (same as input_ids for causal LM)
        model_inputs["labels"] = model_inputs["input_ids"].copy()
        return model_inputs

    def get_dataset(self) -> Dataset:
        """Convert to HuggingFace dataset and tokenize"""
        dataset = Dataset.from_list(self.data)

        tokenized_dataset = dataset.map(
            lambda examples: self.tokenize_function([examples]),
            batched=False,
            remove_columns=dataset.column_names,
            desc="Tokenizing dataset",
        )
        return tokenized_dataset

def get_target_modules(model_name: str) -> List[str]:
    """Get appropriate target modules for LoRA based on model architecture"""
    model_name_lower = model_name.lower()

    if "llama" in model_name_lower:
        return ["q_proj", "k_proj", "v_proj", "o_proj"]
    elif "mistral" in model_name_lower:
        return ["q_proj", "k_proj", "v_proj", "o_proj"]
    else:
        return ["q_proj", "v_proj"]

def setup_lora_config(lora_args: LoraArguments, model_name: str):
    """Setup LoRA configuration"""
    target_modules = get_target_modules(model_name)
    logger.info(f"Using target modules: {target_modules}")

    return LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=lora_args.lora_rank,
        lora_alpha=lora_args.lora_alpha,
        lora_dropout=lora_args.lora_dropout,
        target_modules=target_modules,
        bias="none",
    )

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

def load_model_and_tokenizer(
    model_args: ModelArguments,
    lora_config: LoraConfig,
    checkpoint_path: str = None,
):
    """Load model and tokenizer, optionally from a checkpoint"""

    # choose dtype based on GPU availability
    _dtype = torch.float16 if torch.cuda.is_available() else torch.float32

    if checkpoint_path and os.path.exists(checkpoint_path):
        logger.info(f"Loading model from checkpoint: {checkpoint_path}")

        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            checkpoint_path,
            trust_remote_code=model_args.trust_remote_code,
            padding_side="right",
            use_fast=True,
        )

        # Load base model (add attn_implementation="eager")
        base_model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            trust_remote_code=model_args.trust_remote_code,
            torch_dtype=_dtype,
            device_map="auto",
            attn_implementation="eager",   # <<< key fix
        )

        # Load LoRA weights
        model = PeftModel.from_pretrained(base_model, checkpoint_path)
        logger.info("Loaded model with previous LoRA weights")

    else:
        logger.info(f"Loading fresh model: {model_args.model_name_or_path}")

        # Tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            trust_remote_code=model_args.trust_remote_code,
            padding_side="right",
            use_fast=True,
        )

        # Ensure pad token exists
        added_pad = False
        if tokenizer.pad_token is None:
            if tokenizer.eos_token is not None:
                tokenizer.pad_token = tokenizer.eos_token
            else:
                tokenizer.add_special_tokens({'pad_token': '[PAD]'})
                added_pad = True

        # Model (add attn_implementation="eager")
        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            trust_remote_code=model_args.trust_remote_code,
            torch_dtype=_dtype,
            device_map="auto",
            attn_implementation="eager",   # <<< key fix
        )

        # Resize embeddings only if we just added a brand-new pad token
        if added_pad:
            model.resize_token_embeddings(len(tokenizer))

        # Apply LoRA
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()

    return model, tokenizer

def train_stage(
    stage_name: str,
    data_path: str,
    model,
    tokenizer,
    training_args: TrainingArguments,
    data_args: DataArguments
):
    """Train a single stage"""
    logger.info(f"\n{'='*60}")
    logger.info(f"Starting training stage: {stage_name}")
    logger.info(f"Data path: {data_path}")
    logger.info(f"Output directory: {training_args.output_dir}")
    logger.info(f"{'='*60}\n")

    # Load dataset with flexible formatting
    try:
        flexible_dataset = FlexibleDataset(
            data_path=data_path,
            tokenizer=tokenizer,
            max_length=data_args.max_seq_length,
            num_samples=data_args.num_samples,
        )
        train_dataset = flexible_dataset.get_dataset()
        logger.info(f"Dataset loaded with {len(train_dataset)} examples")

        # Show a formatted example
        if len(flexible_dataset.data) > 0:
            logger.info(f"\nExample formatted text:\n{flexible_dataset.format_example(flexible_dataset.data[0])[:500]}...")

    except Exception as e:
        logger.error(f"Error loading dataset: {e}")
        raise

    # Data collator
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
    )

    # Trainer
    trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=None,
    tokenizer=tokenizer,
    data_collator=DataCollatorCastBool(tokenizer),
    )

    # Train
    logger.info("Starting training...")
    try:
        trainer.train()
        logger.info("Training completed successfully")
    except Exception as e:
        logger.error(f"Training failed: {e}")
        raise

    # Save model
    try:
        trainer.save_model()
        trainer.save_state()
        tokenizer.save_pretrained(training_args.output_dir)
        logger.info(f"Model saved to {training_args.output_dir}")
    except Exception as e:
        logger.error(f"Error saving model: {e}")
        raise

    return trainer.model

In [55]:
def main():
    parser = argparse.ArgumentParser(description="Sequential Medical QA fine-tuning for Llama 8B - Flexible Format")

    # Model arguments
    parser.add_argument("--model_name_or_path", default="meta-llama/Llama-3.2-8B-Instruct",
                       help="Base model to use")

    # Data arguments
    parser.add_argument("--step1_path", default="/content/step1.json",
                       help="Path to step 1 (PubMedQA) dataset")
    parser.add_argument("--step2_path", default="/content/step1_medmcqa.json",
                       help="Path to step 2 (MedMCQA) dataset")
    parser.add_argument("--step3_path", default="/content/step1_medqa.json",
                       help="Path to step 3 (MedQA) dataset")
    parser.add_argument("--num_samples", type=int, default=50,
                       help="Number of samples to use from each dataset")

    # Output arguments
    parser.add_argument("--output_dir", default="./llama8b_checkpoints",
                       help="Base output directory for checkpoints")

    # Training arguments
    parser.add_argument("--max_seq_length", type=int, default=512)
    parser.add_argument("--num_epochs", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
    parser.add_argument("--learning_rate", type=float, default=2e-4)

    # LoRA arguments
    parser.add_argument("--lora_rank", type=int, default=8)
    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--lora_dropout", type=float, default=0.1)

    # Execution control
    parser.add_argument("--start_step", type=int, default=1, choices=[1, 2, 3],
                       help="Which step to start from (1, 2, or 3)")
    parser.add_argument("--end_step", type=int, default=3, choices=[1, 2, 3],
                       help="Which step to end at (1, 2, or 3)")
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()

    # Validate step arguments
    if args.start_step > args.end_step:
        raise ValueError("start_step must be <= end_step")

    print("="*70)
    print("SEQUENTIAL MEDICAL QA FINE-TUNING - FLEXIBLE FORMAT")
    print("="*70)
    print(f"Base Model: {args.model_name_or_path}")
    print(f"Training Steps: {args.start_step} to {args.end_step}")
    print(f"Samples per dataset: {args.num_samples}")
    print(f"Preserving original dataset formats")
    print("="*70)

    set_seed(args.seed)

    # Setup arguments
    model_args = ModelArguments(model_name_or_path=args.model_name_or_path)
    data_args = DataArguments(
        step1_path=args.step1_path,
        step2_path=args.step2_path,
        step3_path=args.step3_path,
        max_seq_length=args.max_seq_length,
        num_samples=args.num_samples,
    )
    lora_args = LoraArguments(
        lora_rank=args.lora_rank,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
    )

    # Setup LoRA config
    lora_config = setup_lora_config(lora_args, args.model_name_or_path)

    # Define stages
    stages = [
        {
            "step": 1,
            "name": "PubMedQA",
            "data_path": args.step1_path,
            "output_dir": os.path.join(args.output_dir, "step1_pubmedqa")
        },
        {
            "step": 2,
            "name": "MedMCQA",
            "data_path": args.step2_path,
            "output_dir": os.path.join(args.output_dir, "step2_medmcqa")
        },
        {
            "step": 3,
            "name": "MedQA",
            "data_path": args.step3_path,
            "output_dir": os.path.join(args.output_dir, "step3_medqa")
        },
    ]

    # Filter stages based on start and end step
    stages_to_run = [s for s in stages if args.start_step <= s["step"] <= args.end_step]

    # Initialize model
    model = None
    tokenizer = None
    previous_checkpoint = None

    # Run each stage sequentially
    for stage in stages_to_run:
        stage_num = stage["step"]
        stage_name = stage["name"]

        print(f"\n{'#'*70}")
        print(f"# STEP {stage_num}: {stage_name.upper()}")
        print(f"{'#'*70}\n")

        # Create output directory
        os.makedirs(stage["output_dir"], exist_ok=True)

        # Load model (fresh for step 1, from checkpoint for subsequent steps)
        if stage_num == 1 or model is None:
            model, tokenizer = load_model_and_tokenizer(
                model_args,
                lora_config,
                checkpoint_path=previous_checkpoint
            )
        else:
            logger.info(f"Continuing with model from previous step")

        # Setup training arguments for this stage
        training_args = TrainingArguments(
            output_dir=stage["output_dir"],
            num_train_epochs=args.num_epochs,
            per_device_train_batch_size=args.batch_size,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            learning_rate=args.learning_rate,
            weight_decay=0.01,
            logging_steps=10,
            save_steps=100,
            save_total_limit=2,
            fp16=True,
            dataloader_num_workers=2,
            remove_unused_columns=False,
            report_to="none",
            load_best_model_at_end=False,
            warmup_steps=10,
        )

        # Train this stage
        model = train_stage(
            stage_name=stage_name,
            data_path=stage["data_path"],
            model=model,
            tokenizer=tokenizer,
            training_args=training_args,
            data_args=data_args,
        )

        # Update checkpoint path for next stage
        previous_checkpoint = stage["output_dir"]

        print(f"\n✓ Step {stage_num} ({stage_name}) completed!")
        print(f"  Model saved to: {stage['output_dir']}")

    print("\n" + "="*70)
    print("ALL TRAINING STAGES COMPLETED SUCCESSFULLY!")
    print("="*70)
    print(f"\nFinal model location: {stages_to_run[-1]['output_dir']}")

In [56]:
# Check if data files exist
import os

print("Checking data files...")
data_files = [
    "/content/step1.json",
    "/content/step1_medmcqa.json",
    "/content/step1_medqa.json"
]

for file in data_files:
    if os.path.exists(file):
        print(f"✓ Found: {file}")
    else:
        print(f"✗ Missing: {file}")

# Check GPU
import torch
if torch.cuda.is_available():
    print(f"\n✓ GPU Available: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("\n✗ No GPU available - training will be very slow!")

Checking data files...
✓ Found: /content/step1.json
✓ Found: /content/step1_medmcqa.json
✓ Found: /content/step1_medqa.json

✓ GPU Available: NVIDIA A100-SXM4-80GB
  Memory: 85.17 GB


In [57]:
import os, torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"  # or "meta-llama/Meta-Llama-3-8B-Instruct"

print("CUDA available:", torch.cuda.is_available())
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print("dtype:", dtype)

try:
    tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    print("✅ Tokenizer loaded. vocab_size:", tok.vocab_size)
except Exception as e:
    print("❌ Tokenizer failed:", type(e).__name__, e)

try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        trust_remote_code=True,
        torch_dtype=dtype,
        device_map="auto",           # will put weights on GPU if available
        # DO NOT set load_in_4bit on Py3.12 unless you know bitsandbytes works in your runtime
    )
    first_param = next(model.parameters())
    print("✅ Model loaded. Device:", first_param.device, "dtype:", first_param.dtype)
except Exception as e:
    print("❌ Model failed:", type(e).__name__, e)

CUDA available: True
dtype: torch.float16
✅ Tokenizer loaded. vocab_size: 128000


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

✅ Model loaded. Device: cuda:0 dtype: torch.float16


In [58]:
import os
import argparse
import torch

def main(argv=None):
    parser = argparse.ArgumentParser(description="Sequential Medical QA fine-tuning for Llama 8B - Flexible Format")

    # Model arguments
    parser.add_argument("--model_name_or_path", default="meta-llama/Llama-3.2-8B-Instruct",
                        help="Base model to use")

    # Data arguments
    parser.add_argument("--step1_path", default="/content/step1.json",
                        help="Path to step 1 (PubMedQA) dataset")
    parser.add_argument("--step2_path", default="/content/step1_medmcqa.json",
                        help="Path to step 2 (MedMCQA) dataset")
    parser.add_argument("--step3_path", default="/content/step1_medqa.json",
                        help="Path to step 3 (MedQA) dataset")
    parser.add_argument("--num_samples", type=int, default=50,
                        help="Number of samples to use from each dataset")

    # Output arguments
    parser.add_argument("--output_dir", default="./llama8b_checkpoints",
                        help="Base output directory for checkpoints")

    # Training arguments
    parser.add_argument("--max_seq_length", type=int, default=512)
    parser.add_argument("--num_epochs", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
    parser.add_argument("--learning_rate", type=float, default=2e-4)

    # LoRA arguments
    parser.add_argument("--lora_rank", type=int, default=8)
    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--lora_dropout", type=float, default=0.1)

    # Execution control
    parser.add_argument("--start_step", type=int, default=1, choices=[1, 2, 3],
                        help="Which step to start from (1, 2, or 3)")
    parser.add_argument("--end_step", type=int, default=3, choices=[1, 2, 3],
                        help="Which step to end at (1, 2, or 3)")
    parser.add_argument("--seed", type=int, default=42)

    # NOTE: parse provided argv if given; otherwise use real CLI
    if argv is None:
        args = parser.parse_args()
    else:
        args = parser.parse_args(argv)

    # Quick sanity checks help a ton in notebooks
    for p in [args.step1_path, args.step2_path, args.step3_path]:
        if not os.path.exists(p):
            raise FileNotFoundError(f"Dataset not found: {p}")

    # ---- rest of your function stays the same, with two small tweaks below ----

    set_seed(args.seed)
    model_args = ModelArguments(model_name_or_path=args.model_name_or_path)
    data_args = DataArguments(
        step1_path=args.step1_path,
        step2_path=args.step2_path,
        step3_path=args.step3_path,
        max_seq_length=args.max_seq_length,
        num_samples=args.num_samples,
    )
    lora_args = LoraArguments(
        lora_rank=args.lora_rank,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
    )
    lora_config = setup_lora_config(lora_args, args.model_name_or_path)

    stages = [
        {"step": 1, "name": "PubMedQA", "data_path": args.step1_path,
         "output_dir": os.path.join(args.output_dir, "step1_pubmedqa")},
        {"step": 2, "name": "MedMCQA", "data_path": args.step2_path,
         "output_dir": os.path.join(args.output_dir, "step2_medmcqa")},
        {"step": 3, "name": "MedQA", "data_path": args.step3_path,
         "output_dir": os.path.join(args.output_dir, "step3_medqa")},
    ]
    stages_to_run = [s for s in stages if args.start_step <= s["step"] <= args.end_step]

    model = None
    tokenizer = None
    previous_checkpoint = None

    for stage in stages_to_run:
        stage_num = stage["step"]
        stage_name = stage["name"]

        os.makedirs(stage["output_dir"], exist_ok=True)

        if stage_num == 1 or model is None:
            model, tokenizer = load_model_and_tokenizer(
                model_args, lora_config, checkpoint_path=previous_checkpoint
            )

        # 🔧 notebook/gpu-safe tweaks:
        use_fp16 = torch.cuda.is_available()
        training_args = TrainingArguments(
            output_dir=stage["output_dir"],
            num_train_epochs=args.num_epochs,
            per_device_train_batch_size=args.batch_size,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            learning_rate=args.learning_rate,
            weight_decay=0.01,
            logging_steps=10,
            save_steps=100,
            save_total_limit=2,
            fp16=use_fp16,               # only use fp16 if GPU exists
            dataloader_num_workers=0,    # avoid multiprocessing issues in notebooks/Windows
            remove_unused_columns=False,
            report_to="none",
            load_best_model_at_end=False,
            warmup_steps=10,
        )

        model = train_stage(
            stage_name=stage_name,
            data_path=stage["data_path"],
            model=model,
            tokenizer=tokenizer,
            training_args=training_args,
            data_args=data_args,
        )

        previous_checkpoint = stage["output_dir"]

    print(f"\nFinal model location: {stages_to_run[-1]['output_dir']}")

In [52]:
MODEL_IDS = [
    "meta-llama/Meta-Llama-3.1-8B-Instruct",  # most common
    "meta-llama/Meta-Llama-3-8B-Instruct",    # older 3.0 line
]
from huggingface_hub import HfApi
api = HfApi()
for mid in MODEL_IDS:
    try:
        info = api.model_info(mid)  # uses your login
        print("OK:", mid, "→", info.sha[:8])
    except Exception as e:
        print("NO ACCESS / NOT FOUND:", mid, "→", e)

OK: meta-llama/Meta-Llama-3.1-8B-Instruct → 0e9e39f2
OK: meta-llama/Meta-Llama-3-8B-Instruct → 8afb486c


In [59]:
main([
    '--model_name_or_path', 'meta-llama/Meta-Llama-3.1-8B-Instruct',
    '--step1_path', '/content/step1.json',
    '--step2_path', '/content/step1_medmcqa.json',
    '--step3_path', '/content/step1_medqa.json',
    '--output_dir', './llama8b_checkpoints',
    '--num_samples', '50',
    '--num_epochs', '1',
    '--batch_size', '1',
    '--gradient_accumulation_steps', '8',
    '--max_seq_length', '256',
    '--learning_rate', '2e-4',
    '--start_step', '1',
    '--end_step', '3',
])

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

trainable params: 6,815,744 || all params: 8,037,076,992 || trainable%: 0.0848


Tokenizing dataset:   0%|          | 0/1 [00:00<?, ? examples/s]

  trainer = Trainer(
The model is already on multiple devices. Skipping the move to device specified in `args`.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128009, 'pad_token_id': 128009}.
ERROR:__main__:Training failed: too many values to unpack (expected 4)


ValueError: too many values to unpack (expected 4)

# testing the model


In [8]:
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

In [44]:
# Complete Testing Code for Colab
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from peft import PeftModel

def load_model_and_tokenizer(model_args, lora_config, checkpoint_path=None, hf_token=None):
    model_id = model_args.model_name_or_path

    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "right"

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        trust_remote_code=True,
        device_map="auto",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        attn_implementation="eager",   # <-- add this
    )

    # (apply LoRA etc.)
    return model, tokenizer

def generate_response(model, tokenizer, prompt: str, max_new_tokens: int = 150):
    """Generate a response for the given prompt"""
    formatted_prompt = f"### Instruction:\n{prompt}\n\n### Response:\n"

    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    if "### Response:" in response:
        response = response.split("### Response:")[-1].strip()

    return response

def test_sample_questions(model_path):
    """Test with predefined questions"""
    model, tokenizer = load_model(model_path)

    test_questions = [
        "What is the normal body temperature in Celsius?",
        "What are the main symptoms of diabetes?",
        "How is hypertension typically diagnosed?",
        "What is the purpose of a complete blood count (CBC) test?",
        "What are the risk factors for cardiovascular disease?",
    ]

    print("="*70)
    print("TESTING MODEL WITH SAMPLE QUESTIONS")
    print("="*70)

    for i, question in enumerate(test_questions, 1):
        print(f"\n{'='*70}")
        print(f"Question {i}: {question}")
        print("-"*70)

        response = generate_response(model, tokenizer, question, max_new_tokens=150)
        print(f"Response: {response}")

    print(f"\n{'='*70}")
    print("Testing complete!")

    return model, tokenizer

def test_interactive(model_path):
    """Interactive testing mode"""
    model, tokenizer = load_model(model_path)

    print("="*70)
    print("INTERACTIVE MODE - Type 'quit' to exit")
    print("="*70)

    while True:
        print("\n" + "-"*70)
        question = input("Enter your medical question: ").strip()

        if question.lower() in ['quit', 'exit', 'q']:
            print("Exiting...")
            break

        if not question:
            continue

        print("\nGenerating response...\n")
        response = generate_response(model, tokenizer, question, max_new_tokens=150)
        print(f"Response: {response}")

    return model, tokenizer

def test_single_question(model_path, question, max_new_tokens=150):
    """Test with a single question"""
    model, tokenizer = load_model(model_path)

    print(f"Question: {question}")
    print("-"*70)
    response = generate_response(model, tokenizer, question, max_new_tokens)
    print(f"Response: {response}")

    return model, tokenizer

In [12]:
MODEL_PATH = "./llama8b_checkpoints/step3_medqa"
model, tokenizer = test_sample_questions(MODEL_PATH)

Loading tokenizer from ./llama8b_checkpoints/step3_medqa...


HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': './llama8b_checkpoints/step3_medqa'. Use `repo_type` argument if needed.