In [None]:
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps "trl<0.9.0" xformers

In [2]:
import unsloth
from unsloth import FastLanguageModel
import torch

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [3]:
max_seq_length = 4096
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/DeepSeek-R1-Distill-Llama-8B-bnb-4bit",
    max_seq_length=max_seq_length,
    dtype=None,
    load_in_4bit=True,
)

==((====))==  Unsloth 2025.9.7: Fast Llama patching. Transformers: 4.55.4.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/5.70G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/236 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

In [4]:
model = FastLanguageModel.get_peft_model(
    model,
    r=32,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=32,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
)

Unsloth 2025.9.7 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [5]:
model.print_trainable_parameters()

trainable params: 83,886,080 || all params: 8,114,147,328 || trainable%: 1.0338


In [8]:
from datasets import load_dataset
dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", name='en', split="train[:3000]")

medical_o1_sft.json:   0%|          | 0.00/58.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/19704 [00:00<?, ? examples/s]

In [10]:
dataset

Dataset({
    features: ['Question', 'Complex_CoT', 'Response'],
    num_rows: 3000
})

In [12]:
dataset['Question'][100]

'A 25-year-old woman presents to the ED with a diffuse, erythematous rash, nausea, vomiting, and fever for 2 days. Physical examination reveals a soaked tampon in her vagina, and blood cultures are negative, suggesting toxic shock syndrome. Which specific molecule on T cells does the toxin most likely bind to?'

In [13]:
dataset['Complex_CoT'][100]

"Alright, here's a situation with a 25-year-old woman who showed up in the emergency department. She's got this widespread red rash, feeling nauseous, she's vomiting, and running a fever for two days. Something's not quite right here, and it all starts connecting to the idea of toxic shock syndrome. Oh, and there's a crucial detail: they found a soaked tampon during her exam.\n\nOkay, let's dig into what's happening in toxic shock syndrome. It's a bit of a nightmare because it's associated with these things called superantigens. These are basically like the rogue agents of the bacterial world, and they're mostly coming from bugs like Staphylococcus aureus.\n\nSuperantigens are tricky; they have a special way of activating T cells way beyond the norm. Usually, T cells need to see a specific antigen presented by MHC molecules, but these superantigens bypass that entire check and go straight to just supercharging T cells.\n\nSo, how do they pull this off? Well, these superantigens physica

In [14]:
dataset['Response'][100]

"The symptoms and scenario described align with toxic shock syndrome, commonly associated with superantigens produced by Staphylococcus aureus. These superantigens have a unique mechanism of action; they bypass normal antigen processing and directly stimulate T cells by binding to the T cell receptor (TCR). Specifically, the superantigens bind to the variable region of the beta chain of the T cell receptor's beta chain, known as TCR V beta. This interaction results in a massive, nonspecific activation of T cells, leading to the release of large amounts of cytokines and subsequent symptoms of toxic shock syndrome. Therefore, the specific molecule on T cells that the toxin most likely binds to is the TCR V beta region."

In [16]:
def format_medical_reasoning(examples):
    """Format examples for medical reasoning using actual dataset structure"""
    texts = []

    for i in range(len(examples["Question"])):
        question = examples["Question"][i]
        reasoning = examples["Complex_CoT"][i]
        response = examples["Response"][i]

        medical_prompt = f"""As a medical AI assistant, analyze the following medical question. Think through the problem step-by-step and provide a comprehensive medical response.

Medical Question:
{question}

Medical Analysis and Response:"""

        full_response = f"{reasoning}\n\n{response}" if reasoning else response

        messages = [
            {"role": "user", "content": medical_prompt},
            {"role": "assistant", "content": full_response}
        ]

        # Apply DeepSeek-R1 chat template
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
        )

        texts.append(text)

    return {"text": texts}

In [17]:
formatted_dataset = dataset.map(
    format_medical_reasoning,
    batched=True,
    remove_columns=dataset.column_names,
    desc="Formatting medical Q&A with reasoning"
)

Formatting medical Q&A with reasoning:   0%|          | 0/3000 [00:00<?, ? examples/s]

In [19]:
formatted_dataset['text'][100]

"<｜begin▁of▁sentence｜><｜User｜>As a medical AI assistant, analyze the following medical question. Think through the problem step-by-step and provide a comprehensive medical response.\n\nMedical Question:\nA 25-year-old woman presents to the ED with a diffuse, erythematous rash, nausea, vomiting, and fever for 2 days. Physical examination reveals a soaked tampon in her vagina, and blood cultures are negative, suggesting toxic shock syndrome. Which specific molecule on T cells does the toxin most likely bind to?\n\nMedical Analysis and Response:<｜Assistant｜>Alright, here's a situation with a 25-year-old woman who showed up in the emergency department. She's got this widespread red rash, feeling nauseous, she's vomiting, and running a fever for two days. Something's not quite right here, and it all starts connecting to the idea of toxic shock syndrome. Oh, and there's a crucial detail: they found a soaked tampon during her exam.\n\nOkay, let's dig into what's happening in toxic shock syndr

In [20]:
from transformers import TrainingArguments
training_args = TrainingArguments(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    warmup_steps=10,
    max_steps=150,
    learning_rate=1e-4,
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    logging_steps=5,
    optim="adamw_8bit",
    weight_decay=0.01,
    lr_scheduler_type="cosine",
    seed=3407,
    output_dir="deepseek_medical_doctor",
    save_strategy="no",
)

In [21]:
from trl import SFTTrainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=formatted_dataset,
    args=training_args,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    packing=False,
)

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

In [22]:
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 3,000 | Num Epochs = 1 | Total steps = 150
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 8 x 1) = 8
 "-____-"     Trainable parameters = 83,886,080 of 8,114,147,328 (1.03% trained)
  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mabeshith[0m ([33mabeshith-dr-m-g-r-educational-and-research-institute[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
5,2.2576
10,2.0783
15,1.7925
20,1.6522


Step,Training Loss
5,2.2576
10,2.0783
15,1.7925
20,1.6522
25,1.5578
30,1.5686
35,1.5647
40,1.5029
45,1.5168
50,1.5451


TrainOutput(global_step=150, training_loss=1.5506600125630696, metrics={'train_runtime': 2711.0046, 'train_samples_per_second': 0.443, 'train_steps_per_second': 0.055, 'total_flos': 3.548130253951795e+16, 'train_loss': 1.5506600125630696, 'epoch': 0.4})

In [23]:
model.save_pretrained("deepseek_medical_ai")
tokenizer.save_pretrained("deepseek_medical_ai")

('deepseek_medical_ai/tokenizer_config.json',
 'deepseek_medical_ai/special_tokens_map.json',
 'deepseek_medical_ai/chat_template.jinja',
 'deepseek_medical_ai/tokenizer.json')

In [24]:
FastLanguageModel.for_inference(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 4096, padding_idx=128004)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=32, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=32, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lor

In [25]:
def medical_diagnosis(medical_question, max_tokens=700):
    """Analyze medical questions with step-by-step reasoning"""

    medical_prompt = f"""As a medical AI assistant, analyze the following medical question. Think through the problem step-by-step and provide a comprehensive medical response.

Medical Question:
{medical_question}

Medical Analysis and Response:"""

    messages = [
        {"role": "user", "content": medical_prompt}
    ]

    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer([text], return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=0.3,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )

    generated_tokens = outputs[0][len(inputs['input_ids'][0]):]
    response = tokenizer.decode(generated_tokens, skip_special_tokens=True)

    return response.strip()

In [26]:
diagnostic_questions = [
    "A 45-year-old patient presents with chest pain, shortness of breath, and swelling in the legs. What are the most likely diagnoses and what diagnostic tests should be ordered?",

    "An 8-year-old child has a fever of 102°F, sore throat, and red rash on the body. The parents mention the child was exposed to strep throat at school. What is your assessment?",

    "A 65-year-old diabetic patient reports frequent urination, excessive thirst, and unexplained weight loss over the past month. Blood glucose is 400 mg/dL. What is happening?",

    "A 30-year-old woman presents with joint pain in her hands and knees, morning stiffness lasting over an hour, and fatigue. What conditions should be considered?"
]

print("🩺 DIAGNOSTIC REASONING:")
for i, question in enumerate(diagnostic_questions, 1):
    print(f"\n❓ Question {i}: {question}")
    analysis = medical_diagnosis(question, 600)
    print(f"🔬 Medical Analysis: {analysis}")
    print("-" * 60)

🩺 DIAGNOSTIC REASONING:

❓ Question 1: A 45-year-old patient presents with chest pain, shortness of breath, and swelling in the legs. What are the most likely diagnoses and what diagnostic tests should be ordered?
🔬 Medical Analysis: The patient is experiencing chest pain, shortness of breath, and swelling in the legs. These symptoms are suggestive of a cardiac issue. The chest pain could be due to ischemia, which is often associated with coronary artery disease. The shortness of breath and leg swelling might be due to heart failure. 

First, I should consider imaging to look for any structural issues in the heart. An echocardiogram would be a good starting point to assess the heart's structure and function. This test can help identify if there's a problem with the heart's chambers or valves.

Next, I should think about the blood supply to the heart. A coronary angiogram would be useful to check for any blockages or narrowing in the coronary arteries. This is crucial if there's suspici

In [27]:
treatment_questions = [
    "What are the treatment options for a patient newly diagnosed with Type 2 diabetes with HbA1c of 8.5%?",

    "How should we manage a patient with moderate depression who has not responded to first-line SSRI therapy?",

    "What is the appropriate treatment approach for a patient with newly diagnosed hypertension reading 160/100?",

    "A patient with chronic asthma is having increased symptoms despite current inhaler use. What adjustments are needed?"
]

print("\n💊 TREATMENT PLANNING:")
for i, question in enumerate(treatment_questions, 1):
    print(f"\n💊 Treatment Question {i}: {question}")
    treatment = medical_diagnosis(question, 500)
    print(f"🎯 Treatment Response: {treatment}")
    print("-" * 60)


💊 TREATMENT PLANNING:

💊 Treatment Question 1: What are the treatment options for a patient newly diagnosed with Type 2 diabetes with HbA1c of 8.5%?
🎯 Treatment Response: The patient has been newly diagnosed with Type 2 diabetes and has an HbA1c of 8.5%. This indicates that their blood sugar levels are not well controlled, and they need to start treatment immediately to prevent complications. The first step is to initiate a medication regimen that helps control blood sugar levels effectively. Metformin is often the first medication prescribed for Type 2 diabetes because it is well-tolerated and effective.

In addition to medication, it's crucial to focus on lifestyle changes. This includes adopting a healthy diet with a focus on whole grains, lean proteins, and vegetables, along with regular physical activity. These changes can significantly help in managing blood sugar levels and improving overall health.

Regular monitoring of blood sugar levels is essential to ensure that the treat

In [28]:
emergency_questions = [
    "A patient presents to the ER unconscious with pinpoint pupils and slow breathing. What is the most likely cause and immediate treatment?",

    "A 25-year-old presents with sudden severe headache described as 'worst headache of my life' with neck stiffness. What is your concern?",

    "A patient has crushing chest pain radiating to the left arm, shortness of breath, and is diaphoretic. What is the immediate action plan?",

    "An elderly patient fell and now has severe hip pain and inability to bear weight. The leg appears shortened and externally rotated. What is likely?"
]

print("\n🚨 EMERGENCY MEDICINE:")
for i, question in enumerate(emergency_questions, 1):
    print(f"\n🚨 Emergency {i}: {question}")
    emergency = medical_diagnosis(question, 400)
    print(f"⚡ Emergency Response: {emergency}")
    print("-" * 60)


🚨 EMERGENCY MEDICINE:

🚨 Emergency 1: A patient presents to the ER unconscious with pinpoint pupils and slow breathing. What is the most likely cause and immediate treatment?
⚡ Emergency Response: The patient is unconscious with pinpoint pupils and slow breathing, which are concerning signs. Pinpoint pupils suggest a decrease in pupil size, often due to a lack of parasympathetic tone. Slow breathing could be due to decreased respiratory rate or increased respiratory depression. These symptoms are indicative of a possible decrease in brainstem activity, such as in a brainstem infarction. A brainstem infarction can cause a loss of consciousness, pinpoint pupils, and respiratory depression. 

Given the symptoms, the most likely cause is a brainstem infarction. This condition can result from various causes, such as ischemia, hemorrhage, or trauma. The immediate treatment for a brainstem infarction should focus on maintaining vital functions and preventing further complications. 

To addre

In [29]:
prevention_questions = [
    "What preventive care screenings should be recommended for a healthy 50-year-old woman?",

    "A patient asks about vaccine recommendations for travel to Southeast Asia. What vaccines are needed?",

    "What lifestyle modifications should be recommended for a patient with prediabetes?",

    "A pregnant patient in her first trimester asks about prenatal care recommendations. What should be advised?"
]

print("\n🛡️ PREVENTIVE MEDICINE:")
for i, question in enumerate(prevention_questions, 1):
    print(f"\n🛡️ Prevention {i}: {question}")
    prevention = medical_diagnosis(question, 400)
    print(f"🔒 Prevention Response: {prevention}")
    print("-" * 60)


🛡️ PREVENTIVE MEDICINE:

🛡️ Prevention 1: What preventive care screenings should be recommended for a healthy 50-year-old woman?
🔒 Prevention Response: A 50-year-old woman should consider several preventive care screenings to maintain her health and detect potential issues early. First, mammograms are crucial for breast cancer screening. Typically, women should start at age 50, and the frequency depends on guidelines, often every two years for those with average risk. Next, pap smear tests are important for cervical cancer screening, usually every three years. Blood tests for breast, ovarian, and cervical cancer are also recommended. Blood pressure checks are essential to monitor cardiovascular health. Cholesterol screening helps in managing heart disease risks. Vision and hearing tests can catch potential issues early. A colonoscopy is recommended for colorectal cancer screening, typically starting at age 50. A bone density test is important for assessing osteoporosis risk. Finally, 

In [30]:
pediatric_questions = [
    "A 6-month-old infant has a fever of 103°F and is not feeding well. The parents are very worried. What should be done?",

    "A 3-year-old child has been having difficulty walking and complains of leg pain. There is no history of trauma. What are the possibilities?",

    "A 12-year-old student has been having trouble concentrating in school and seems hyperactive. Teachers report disruptive behavior. What should be evaluated?"
]

print("\n👶 PEDIATRIC MEDICINE:")
for i, question in enumerate(pediatric_questions, 1):
    print(f"\n👶 Pediatric {i}: {question}")
    pediatric = medical_diagnosis(question, 450)
    print(f"🍼 Pediatric Response: {pediatric}")
    print("-" * 60)


👶 PEDIATRIC MEDICINE:

👶 Pediatric 1: A 6-month-old infant has a fever of 103°F and is not feeding well. The parents are very worried. What should be done?
🍼 Pediatric Response: The infant is 6 months old and has a fever of 103°F, which is quite high for a baby this young. This fever could indicate an infection, especially since the baby isn't feeding well, which might suggest an underlying issue. The parents are clearly concerned, which is understandable given the situation.

First, it's important to assess the baby's overall condition. We need to check for signs of dehydration, which can be a serious problem in young infants. Signs of dehydration include dry mucous membranes, sunken eyes, and a lack of tears. If these are present, we should consider intravenous fluids to help rehydrate the baby.

Next, we should consider the possibility of an infection. A fever of this severity in a baby is concerning and could be due to a bacterial or viral infection. It's crucial to check for any 