In [None]:

!pip install -q unsloth[colab-new] datasets transformers accelerate bitsandbytes wandb
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

import unsloth
from unsloth import FastLanguageModel, is_bfloat16_supported
import os
import torch
import pandas as pd
from datasets import Dataset
from transformers import TrainingArguments
from trl import SFTTrainer
import gc

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

class MedicalConfig:
    """Configuration tailored for medical question-answering fine-tuning"""

    model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"

    max_seq_length = 2048  

    lora_r = 16            
    lora_alpha = 16        
    lora_dropout = 0       

    batch_size = 2                 
    gradient_accumulation = 4      

    epochs = 3
    learning_rate = 2e-4
    weight_decay = 0.01
    warmup_ratio = 0.1

    fp16 = True
    bf16 = is_bfloat16_supported()
    seed = 42  # reproducibility

config = MedicalConfig()

def check_system_readiness():
    """Check if your Colab setup is ready for medical AI training"""

    print("Medical AI Training System Check")
    print("=" * 50)

    if torch.cuda.is_available():
        gpu = torch.cuda.get_device_properties(0)
        memory_gb = gpu.total_memory / 1e9
        print(f"GPU Ready: {gpu.name}")
        print(f"GPU Memory: {memory_gb:.1f} GB")
        if memory_gb < 12:
            print(" Recommendation: Reduce batch_size or max_seq_length if OOM")
        return True
    else:
        print("No GPU found! Enable GPU: Runtime → Change runtime type → T4 GPU")
        return False

if not check_system_readiness():
    raise SystemError("Please enable GPU before continuing")

def load_medical_dataset(filename="medDataset_processed.csv"):
    """Load your medical dataset with care and attention to quality"""
    df = pd.read_csv(filename)
    print(f"Loaded {len(df)} medical Q&A pairs")
    records = []
    for _, row in df.iterrows():
        q, a = str(row["Question"]).strip(), str(row["Answer"]).strip()
        if len(q) < 10 or len(a) < 20:
            continue
        records.append({"question": q, "answer": a, "qtype": row.get("qtype", "general")})
    print(f"Prepared {len(records)} high-quality medical Q&A pairs")
    return Dataset.from_list(records)

def create_medical_prompt(question, answer, tokenizer):
    """Generate ChatML-style conversation for Llama 3.2 Instruct"""
    conv = [
        {"role": "system", "content": "You are a knowledgeable medical assistant providing helpful information."},
        {"role": "user",   "content": question},
        {"role": "assistant", "content": answer}
    ]
    return tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)

def prepare_training_data(dataset, tokenizer):
    def _format(batch):
        texts = []
        for q, a in zip(batch["question"], batch["answer"]):
            texts.append(create_medical_prompt(q, a, tokenizer))
        return {"text": texts}

    processed = dataset.map(_format, batched=True, remove_columns=dataset.column_names)
    print(f"Formatted {len(processed)} examples for medical training")
    return processed

def initialize_medical_model():
    """Load and configure Llama 3.2 1B for medical fine-tuning"""
    print("Loading Llama 3.2 1B for medical training...")
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=config.model_name,
        max_seq_length=config.max_seq_length,
        dtype=None,
        load_in_4bit=True,
        device_map="auto"
    )
    model = FastLanguageModel.get_peft_model(
        model,
        r=config.lora_r,
        target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=config.seed,
    )
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    total = model.num_parameters()
    trainable = model.num_parameters(only_trainable=True)
    print(f"Total params: {total:,}, Trainable: {trainable:,} ({100*trainable/total:.1f}%)")
    return model, tokenizer

def setup_medical_training():
    return TrainingArguments(
        output_dir="./medical-ai-results",
        num_train_epochs=config.epochs,
        per_device_train_batch_size=config.batch_size,
        gradient_accumulation_steps=config.gradient_accumulation,
        learning_rate=config.learning_rate,
        weight_decay=config.weight_decay,
        lr_scheduler_type="cosine",
        warmup_ratio=config.warmup_ratio,
        fp16=config.fp16,
        bf16=config.bf16,
        gradient_checkpointing=True,
        dataloader_num_workers=0,
        logging_steps=25,
        save_steps=100,
        eval_steps=100,
        save_total_limit=2,
        remove_unused_columns=False,
        group_by_length=True,
        seed=config.seed,
    )

def test_medical_ai(model, tokenizer, questions=None):
    if questions is None:
        questions = [
            "What are the early warning signs of heart disease?",
            "How is diabetes typically diagnosed?",
            "What should someone do if they suspect they have pneumonia?",
            "What are the common side effects of blood pressure medications?",
            "How can someone prevent the spread of infectious diseases?"
        ]
    FastLanguageModel.for_inference(model)
    for i, q in enumerate(questions, 1):
        print(f"\nTest {i}: {q}")
        conv = [
            {"role":"system","content":"You are a knowledgeable medical assistant providing helpful information."},
            {"role":"user","content":q},
        ]
        prompt = tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            out = model.generate(**inputs, max_new_tokens=200, temperature=0.7,
                                 do_sample=True, pad_token_id=tokenizer.eos_token_id,
                                 repetition_penalty=1.1)
        resp = tokenizer.decode(out[0], skip_special_tokens=True)
        print(resp.split("\n")[-1].strip())
        print("-" * 50)

def train_medical_ai():
    print("Starting Medical AI Training Pipeline")
    ds = load_medical_dataset("medDataset_processed.csv")
    model, tokenizer = initialize_medical_model()
    processed = prepare_training_data(ds, tokenizer)
    split = processed.train_test_split(test_size=0.1, seed=config.seed)
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=split["train"],
        eval_dataset=split["test"],
        dataset_text_field="text",
        max_seq_length=config.max_seq_length,
        args=setup_medical_training(),
        packing=False
    )
    print("Training...")
    torch.cuda.empty_cache(); gc.collect()
    try:
        trainer.train()
        print("Training completed!")
    except RuntimeError as e:
        print("Training error:", e)
        return None, None
    model.save_pretrained("medical-ai-lora")
    tokenizer.save_pretrained("medical-ai-lora")
    print("Saved model to medical-ai-lora")
    print("Running tests:")
    test_medical_ai(model, tokenizer)
    return model, tokenizer

def clear_memory():
    torch.cuda.empty_cache()
    gc.collect()
    if torch.cuda.is_available():
        used = torch.cuda.memory_allocated() / 1e9
        total = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"GPU Memory: {used:.1f}GB used of {total:.1f}GB")

if __name__ == "__main__":
    print("Welcome to Medical AI Training!")
    clear_memory()
    model, tokenizer = train_medical_ai()
    if model:
        print("SUCCESS! Your Medical AI is ready.")
    else:
        print("Training failed. Check errors above.")
