In [None]:
import torch
print(f"GPU Available: {torch.cuda.is_available()}")
print(f"GPU Name: {torch.cuda.get_device_name(0)}")

In [None]:
!nvidia-smi  # Check GPU info
!nvcc --version  # Check CUDA version
!pip list | grep torch  # Check PyTorch version

In [None]:
!pip install trl

In [None]:
# First, ensure compatible package versions are installed
!pip install -q torch==2.1.2 transformers==4.36.2 accelerate==0.25.0 peft==0.7.1 bitsandbytes==0.41.3 datasets==2.16.1 trl==0.7.10

import json
import torch
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer

# --------------------------------------------------
# 1. Data Preparation (Simplified and more robust)
# --------------------------------------------------
try:
    with open('/kaggle/input/medical-qa-json/medical_qna_complete_50.json', 'r') as f:
        data = json.load(f)
except Exception as e:
    raise FileNotFoundError(f"Error loading dataset: {e}")

# Simplified formatting that's more universally compatible
formatted_data = []
for item in data:
    if not all(key in item for key in ["Instruction", "Response"]):
        continue  # Skip malformed entries
        
    instruction = item["Instruction"]
    response = "\n".join([f"{k}: {v}" for k, v in item["Response"].items()])
    formatted_data.append({
        "text": f"### Instruction:\n{instruction}\n\n### Response:\n{response}"
    })

dataset = Dataset.from_list(formatted_data)

# --------------------------------------------------
# 2. Model Setup (More reliable configuration)
# --------------------------------------------------
model_name = "microsoft/phi-2"

# More conservative quantization settings
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# Tokenizer with fallback options
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
except Exception as e:
    raise RuntimeError(f"Tokenizer loading failed: {e}")

# Model loading with fallback options
try:
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        trust_remote_code=True,
        device_map="auto",
        torch_dtype=torch.float16,
    )
except Exception as e:
    raise RuntimeError(f"Model loading failed: {e}")

# --------------------------------------------------
# 3. PEFT Configuration (More stable settings)
# --------------------------------------------------
try:
    model = prepare_model_for_kbit_training(model)
    
    peft_config = LoraConfig(
        r=4,  # Further reduced for stability
        lora_alpha=8,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "dense"]  # More universal targets
    )
    
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
except Exception as e:
    raise RuntimeError(f"PEFT setup failed: {e}")

# --------------------------------------------------
# 4. Training Setup (More conservative parameters)
# --------------------------------------------------
training_args = TrainingArguments(
    output_dir="./phi2-medical-qa",
    per_device_train_batch_size=1,  # Reduced to 1 for maximum stability
    gradient_accumulation_steps=8,
    optim="adamw_torch",  # More stable than paged_adamw_32bit
    learning_rate=5e-5,  # Lower learning rate
    lr_scheduler_type="linear",
    warmup_ratio=0.1,
    num_train_epochs=3,  # Reduced epochs
    save_strategy="epoch",
    logging_steps=5,
    fp16=True,
    max_grad_norm=0.3,
    report_to="none",
    gradient_checkpointing=True,
    ddp_find_unused_parameters=False,
    remove_unused_columns=True  # Saves memory
)

# Simplified trainer setup
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=256,  # Reduced sequence length
    tokenizer=tokenizer,
    args=training_args,
    packing=False  # Disabled for stability
)

# --------------------------------------------------
# 5. Training Execution (With error handling)
# --------------------------------------------------
try:
    print("Starting training...")
    trainer.train()
    print("Training completed successfully!")
except Exception as e:
    print(f"Training failed: {e}")
    raise

# Save model components separately
try:
    trainer.model.save_pretrained("phi2-medical-qa-lora")
    tokenizer.save_pretrained("phi2-medical-qa-lora")
    print("Model saved successfully!")
except Exception as e:
    print(f"Model saving failed: {e}")
    raise

# --------------------------------------------------
# 6. Inference Test (More reliable generation)
# --------------------------------------------------
def generate_response(question):
    try:
        prompt = f"### Instruction:\n{question}\n\n### Response:\n"
        inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=True).to("cuda")
        
        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.7,
            do_sample=True,
            top_k=50,
            pad_token_id=tokenizer.eos_token_id
        )
        
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        return f"Error generating response: {e}"

# Test with error handling
try:
    print("\nTest generation:")
    print(generate_response("What is anemia?"))
except Exception as e:
    print(f"Test generation failed: {e}")