In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab and Kaggle notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
    !pip install --no-deps cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
    !pip install --no-deps unsloth

In [None]:
from unsloth import FastLanguageModel
import torch

max_seq_length = 4096  # Choose any! We auto support RoPE Scaling internally!
dtype = (
    None  # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
)
load_in_4bit = True  # Use 4bit quantization to reduce memory usage. Can be False.

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/Meta-Llama-3.1-8B-bnb-4bit",  # Llama-3.1 15 trillion tokens model 2x faster!
    "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
    "unsloth/Meta-Llama-3.1-70B-bnb-4bit",
    "unsloth/Meta-Llama-3.1-405B-bnb-4bit",  # We also uploaded 4bit for 405b!
    "unsloth/Mistral-Nemo-Base-2407-bnb-4bit",  # New Mistral 12b 2x faster!
    "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit",
    "unsloth/mistral-7b-v0.3-bnb-4bit",  # Mistral v3 2x faster!
    "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
    "unsloth/Phi-3.5-mini-instruct",  # Phi-3.5 2x faster!
    "unsloth/Phi-3-medium-4k-instruct",
    "unsloth/gemma-2-9b-bnb-4bit",
    "unsloth/gemma-2-27b-bnb-4bit",  # Gemma 2x faster!
]  # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/gemma-2-9b",
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.2.15: Fast Gemma2 patching. Transformers: 4.48.3.
   \\   /|    GPU: Tesla T4. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/6.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 32,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

Unsloth 2025.2.15 patched 42 layers with 42 QKV layers, 42 O layers and 42 MLP layers.


In [None]:
import pandas as pd

def load_medical_datasets_from_csv(csv_path):
    # Load dataset from CSV
    df = pd.read_csv(csv_path)

    # Ensure required columns exist
    required_columns = {"Question", "Answer", "Reasoning"}
    if not required_columns.issubset(df.columns):
        raise ValueError(f"CSV file must contain {required_columns} columns.")

    # Rename columns
    df = df.rename(columns={"question": "Question", "answer": "Answer", "context": "Reasoning"})

    # Add chain-of-thought reasoning
    def generate_reasoning(row):
        return (
            f"First, I need to analyze the medical context: {row['Reasoning']}. "
            "I should consider relevant anatomical factors, physiological processes, "
            "and potential pathologies. Next, I'll evaluate the question's key elements "
            "and apply evidence-based medical knowledge. After differential consideration, "
            f"I conclude the most clinically appropriate answer is: {row['Answer']}."
        )

    # Apply reasoning generation
    df["Reasoning"] = df.apply(generate_reasoning, axis=1)

    return df

# Corrected dataset path
dataset_path = "/content/medical_case_based_qa_dataset_100k_unique.csv"

# Load dataset using pandas
df = load_medical_datasets_from_csv(dataset_path)

In [None]:
df = df.rename(columns={"Reasoning": "Thought"})

In [None]:
def format_medical_prompt(row):
    base_prompt = """<|begin_of_text|>
    <|start_header_id|>user<|end_header_id|>
    [Clinical Case Presentation]
    {question}

    Required:
    1. Differential diagnosis
    2. Pathophysiological rationale
    3. Diagnostic confirmation method<|eot_id|>

    <|start_header_id|>assistant<|end_header_id|>
    [Clinical Reasoning]
    {reasoning}

    [Final Diagnosis]
    {answer}<|eot_id|>
    """.format(
        question=row["Question"],
        reasoning=row["Thought"],
        answer=row["Answer"]
    )
    return base_prompt

df["Formatted_Prompt"] = df.apply(format_medical_prompt, axis=1)

In [None]:
print(df.columns)

Index(['Question', 'Answer', 'Thought', 'Formatted_Prompt'], dtype='object')


In [None]:
from trl import SFTTrainer
!pip install datasets
from datasets import Dataset
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported



In [None]:
# Convert Pandas DataFrame to Hugging Face Dataset
dataset = Dataset.from_pandas(df[["Formatted_Prompt"]])

# Step 2: Train model using SFTTrainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    dataset_text_field="Formatted_Prompt",  # Use formatted text for training
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,  # Can make training 5x faster for short sequences.
    args=TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        warmup_ratio=0.1,
        num_train_epochs=3,  # Set this for 1 full training run.
        max_steps=500,
        learning_rate=5e-4,
        fp16=not torch.cuda.is_bf16_supported(),  # Use PyTorch's check
        bf16=torch.cuda.is_bf16_supported(),
        logging_steps=10,
        evaluation_strategy="no",
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="cosine",
        seed=3407,
        output_dir="medical_model"
    ),
)



Converting train dataset to ChatML (num_proc=2):   0%|          | 0/100000 [00:00<?, ? examples/s]

Applying chat template to train dataset (num_proc=2):   0%|          | 0/100000 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=2):   0%|          | 0/100000 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=2):   0%|          | 0/100000 [00:00<?, ? examples/s]

In [None]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 100,000 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 4
\        /    Total batch size = 8 | Total steps = 500
 "-____-"     Number of trainable parameters = 108,036,096


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mshrishtisonkar195[0m ([33mshrishtisonkar195-ucer[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Step,Training Loss
10,1.7093
20,0.1682
30,0.072
40,0.0666
50,0.0645
60,0.0641
70,0.0636
80,0.0627
90,0.0624
100,0.063


In [None]:
model = FastLanguageModel.for_inference(model)

def medical_chain_of_thought(question, max_length=128):
    prompt = f"""<|begin_of_text|>
    <|start_header_id|>user<|end_header_id|>
    [Clinical Case Presentation]
    {question}


    <|start_header_id|>assistant<|end_header_id|>
    [Clinical Reasoning]""".strip()

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to("cuda")

    outputs = model.generate(
        **inputs,
        max_new_tokens=max_length,
        temperature=0.3,
        top_p=0.85,
        repetition_penalty=1.15,
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id,
    )

    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Robust answer extraction
    diagnosis = "Unspecified diagnosis"
    if "[Final Diagnosis]" in full_response:
        diagnosis = full_response.split("[Final Diagnosis]")[-1].strip()
        diagnosis = diagnosis.split("<|eot_id|>")[0].strip()

    return full_response, diagnosis

In [None]:
medical_question = "A 86-year-old female presents with severe headache and dizziness. EEG shows seizure activity. What is the most likely diagnosis?"
reasoning, diagnosis = medical_chain_of_thought(medical_question)
print(f"Clinical Reasoning:\n{reasoning}\n\nFinal Diagnosis: {diagnosis}")

Clinical Reasoning:
<|begin_of_text|>
    <|start_header_id|>user<|end_header_id|>
    [Clinical Case Presentation]
    A 86-year-old female presents with severe headache and dizziness. EEG shows seizure activity. What is the most likely diagnosis?


    <|start_header_id|>assistant<|end_header_id|>
    [Clinical Reasoning]
    First, I need to analyze the medical context: The patient's severe headache and dizziness combined with EEG shows seizure activity strongly suggest Diabetes mellitus. This condition is commonly associated with Emergency Medicine and requires further medical evaluation.. I should consider relevant anatomical factors, physiological processes, and potential pathologies. Next, I'll evaluate the question's key elements and apply evidence-based medical knowledge. After differential consideration, I conclude the most clinically appropriate answer is: Diabetes mellitus.

    [Final Diagnosis]
    Diabetes mellitus<|eot_id|>
    [Clinical Reasoning]
    First, I need to 

In [None]:
medical_question = "A 37-year-old male presents with fever with night sweats. X-ray shows a lung mass. What is the most likely diagnosis?"
reasoning, diagnosis = medical_chain_of_thought(medical_question)
print(f"Clinical Reasoning:\n{reasoning}\n\nFinal Diagnosis: {diagnosis}")

Clinical Reasoning:
<|begin_of_text|>
    <|start_header_id|>user<|end_header_id|>
    [Clinical Case Presentation]
    A 37-year-old male presents with fever with night sweats. X-ray shows a lung mass. What is the most likely diagnosis?


    <|start_header_id|>assistant<|end_header_id|>
    [Clinical Reasoning]
    First, I need to analyze the medical context: The patient's fever with night sweats combined with X-ray shows a lung mass strongly suggest Tuberculosis. This condition is commonly associated with Rheumatology and requires further medical evaluation.. I should consider relevant anatomical factors, physiological processes, and potential pathologies. Next, I'll evaluate the question's key elements and apply evidence-based medical knowledge. After differential consideration, I conclude the most clinically appropriate answer is: Tuberculosis.

    [Final Diagnosis]
    Tuberculosis<|eot_id|>
    [Clinical Reasoning]
    First, I need to analyze

Final Diagnosis: Tuberculosis


In [None]:
# Install rouge_score using pip
!pip install rouge_score
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu
from sentence_transformers import SentenceTransformer, util
import requests

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24935 sha256=ccbd5c7b5ec48268821768a84d16c0105f83c48d69243350d3dab777f56d5b5e
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [None]:
# Simulated reference answer from medical experts
reference = {
    "reasoning": "The patient's severe headache and dizziness combined with EEG shows seizure activity strongly suggest Liver cirrhosis. This condition is commonly associated with Pediatrics and requires further medical evaluation. ",
    "diagnosis": "Liver cirrhosis"
}

# Example model output (simulated)
medical_question = " A 86-year-old female presents with severe headache and dizziness. EEG shows seizure activity. What is the most likely diagnosis?"
reasoning = "First, I need to analyze the medical context: The patient's severe headache and dizziness combined with EEG shows seizure activity strongly suggest Diabetes mellitus. This condition is commonly associated with Emergency Medicine and requires further medical evaluation.. I should consider relevant anatomical factors, physiological processes, and potential pathologies. Next, I'll evaluate the question's key elements and apply evidence-based medical knowledge. After differential consideration, I conclude the most clinically appropriate answer is: Diabetes mellitus."
diagnosis = "Diabetes mellitus"

# 1. Diagnosis Evaluation
def evaluate_diagnosis(predicted, reference):
    # Case-insensitive exact match
    em = 1 if predicted.lower().strip() == reference.lower().strip() else 0

    # For more complex matching (semantic similarity)
    model = SentenceTransformer('all-MiniLM-L6-v2')
    pred_embed = model.encode(predicted, convert_to_tensor=True)
    ref_embed = model.encode(reference, convert_to_tensor=True)
    similarity = util.pytorch_cos_sim(pred_embed, ref_embed).item()

    return {"semantic_similarity": similarity}

# 2. Reasoning Evaluation
def evaluate_reasoning(predicted, reference):
    # Text generation metrics
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    rouge = scorer.score(reference, predicted)['rougeL']

    # BLEU score
    bleu = sentence_bleu([reference.split()], predicted.split())

    # Factual consistency
    model = SentenceTransformer('all-MiniLM-L6-v2')
    pred_embed = model.encode(predicted, convert_to_tensor=True)
    ref_embed = model.encode(reference, convert_to_tensor=True)
    factual_score = util.pytorch_cos_sim(pred_embed, ref_embed).item()

    return {"rougeL": rouge, "bleu": bleu, "factual_consistency": factual_score}

# Run evaluations
diagnosis_scores = evaluate_diagnosis(diagnosis, reference["diagnosis"])
reasoning_scores = evaluate_reasoning(reasoning, reference["reasoning"])

# Print results
print("=== Diagnosis Evaluation ===")
print(f"Semantic Similarity: {diagnosis_scores['semantic_similarity']:.2f}")

print("\n=== Reasoning Evaluation ===")
print(f"ROUGE-L: F1={reasoning_scores['rougeL'].fmeasure:.2f}")
print(f"BLEU: {reasoning_scores['bleu']:.2f}")
print(f"Factual Consistency: {reasoning_scores['factual_consistency']:.2f}")

=== Diagnosis Evaluation ===
Semantic Similarity: 0.19

=== Reasoning Evaluation ===
ROUGE-L: F1=0.49
BLEU: 0.27
Factual Consistency: 0.65
