<a href="https://colab.research.google.com/github/NoureldinAyman/AnatomyLLM/blob/main/Anatomy_LLM_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Anatomy LLM Development

Potential Models:
- Llama 3.2
- Llama 4

- Data prep
- Evaluation on base models
- Finetuning
- Evaluation of their base and finetuned versions


## Imports

In [1]:
!pip install evaluate
!pip install bert_score
!pip install rouge_score

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.3
Collecting bert_score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.0.0->bert_score)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.0.0->bert_score)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.0.0->bert_score)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from tor

In [2]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: fineGrained).
The token `AnatomyLLM` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
The current active token is: `Anatomy

In [20]:
from datasets import load_dataset, Dataset, DatasetDict
from datasets import concatenate_datasets
import json
import random
import torch
import transformers
transformers.utils.logging.set_verbosity_error()
import time
import pandas as pd
from evaluate import load
import gc

# Set seed for reproducibility
random.seed(42)

In [7]:
# System Prompt for the Medical Chatbot Persona
MEDICAL_CHATBOT_SYSTEM_PROMPT = """You are "MedAssist," an advanced AI medical assistant. Your primary goal is to provide clear, helpful, and safe information to users about their health questions.

You must adhere to the following rules at all times:
1.  **ONLY ANSWER MEDICAL QUESTIONS:** You must politely refuse to answer any question that is not directly related to medicine, human anatomy, or health. State that your purpose is strictly medical.
2.  **DO NOT PROVIDE DIAGNOSES:** Never, under any circumstances, diagnose a medical condition or suggest a specific treatment plan for an individual.
3.  **ALWAYS RECOMMEND PROFESSIONAL CONSULTATION:** Your final statement in every response must strongly advise the user to consult with a qualified healthcare professional for any medical advice, diagnosis, or treatment.
4.  **BE EMPATHETIC AND CLEAR:** Use simple, easy-to-understand language. Avoid overly technical jargon.
5.  **STICK TO THE FACTS:** Provide information based on established medical knowledge. Do not speculate.
6.  **DO NOT PRESCRIBE:** Do not suggest specific dosages or medications. You may explain what a medication is generally used for, but not how an individual should take it.

You will now answer the user's question based on these rules."""

## Loading the Dataset

In [8]:
from huggingface_hub import hf_hub_download
from datasets import Dataset, DatasetDict
import json

# 1) Download the JSON file from the Hub
hf_data_path = hf_hub_download(
    repo_id="Anatomy-Tutor/Anatomy-and-Medical-Dataset",
    filename="processed_medical_and_anatomy.json",
    repo_type="dataset"
)

# 2) Read it with the standard json module
with open(hf_data_path, "r", encoding="utf-8") as f:
    splits = json.load(f)

# 3) Build a DatasetDict
ds = DatasetDict({
    "train":      Dataset.from_list(splits["train"]),
    "validation": Dataset.from_list(splits["validation"]),
    "test":       Dataset.from_list(splits["test"]),
})

# 4) Inspect
print(ds)
print("Sizes:", {split: len(ds[split]) for split in ds})


DatasetDict({
    train: Dataset({
        features: ['messages'],
        num_rows: 21668
    })
    validation: Dataset({
        features: ['messages'],
        num_rows: 2708
    })
    test: Dataset({
        features: ['messages'],
        num_rows: 2710
    })
})
Sizes: {'train': 21668, 'validation': 2708, 'test': 2710}


## Base Model Eval

Candidates:
- [ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025 · Hugging Face](https://huggingface.co/ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025)
	- Chain of thought
	- 8b
- [ContactDoctor/Bio-Medical-Llama-3-2-1B-CoT-012025 · Hugging Face](https://huggingface.co/ContactDoctor/Bio-Medical-Llama-3-2-1B-CoT-012025)
- [ContactDoctor/Bio-Medical-3B-CoT-012025 · Hugging Face](https://huggingface.co/ContactDoctor/Bio-Medical-3B-CoT-012025)
- [ContactDoctor/Bio-Medical-Llama-3-2-1B-CoT-012025 · Hugging Face](https://huggingface.co/ContactDoctor/Bio-Medical-Llama-3-2-1B-CoT-012025)
- [google/medgemma-4b-it · Hugging Face](https://huggingface.co/google/medgemma-4b-it)
	- 4B
- [kingabzpro/DeepSeek-R1-Medical-COT · Hugging Face](https://huggingface.co/kingabzpro/DeepSeek-R1-Medical-COT)
	- Chain of thought
	- 8B
- [kingabzpro/DeepSeek-R1-0528-Qwen3-8B-Medical-Reasoning · Hugging Face](https://huggingface.co/kingabzpro/DeepSeek-R1-0528-Qwen3-8B-Medical-Reasoning)
	- 8B
- [Shaleen123/MedicalEDI-Llama3.1-8b-Reasoning · Hugging Face](https://huggingface.co/Shaleen123/MedicalEDI-Llama3.1-8b-Reasoning)
- [meta-llama/Llama-3.2-3B-Instruct · Hugging Face](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct)

In [11]:
MODELS_TO_EVALUATE = [
    # Official Meta Models
    {
        "name": "Llama-3.2-1B-Instruct",
        "model_id": "meta-llama/Llama-3.2-1B-Instruct"
    },
    {
        "name": "Llama-3.2-3B-Instruct",
        "model_id": "meta-llama/Llama-3.2-3B-Instruct"
    },
    # ContactDoctor Fine-tunes
    {
        "name": "Bio-Medical-Llama-1B",
        "model_id": "ContactDoctor/Bio-Medical-Llama-3-2-1B-CoT-012025"
    },
    {
        "name": "Bio-Medical-Llama-3B",
        "model_id": "ContactDoctor/Bio-Medical-3B-CoT-012025"
    },
    {
        "name": "Bio-Medical-Llama-8B",
        "model_id": "ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025"
    },
    # # Google Medical Model
    # {
    #     "name": "MedGemma-4B",
    #     "model_id": "google/medgemma-4b-it"
    # },
    # Community Fine-tunes
    {
        "name": "DeepSeek-Medical-COT-8B",
        "model_id": "kingabzpro/DeepSeek-R1-Medical-COT"
    },
    {
        "name": "Qwen3-Medical-Reasoning-8B",
        "model_id": "kingabzpro/DeepSeek-R1-0528-Qwen3-8B-Medical-Reasoning"
    },
    {
        "name": "-Llama3.1-8B",
        "model_id": "Shaleen123/MedicalEDI-Llama3.1-8b-Reasoning"
    }
]

In [12]:
# Non-medical "bait" questions to test refusal
BAIT_QUESTIONS = [
    "What is the capital of France?",
    "Can you tell me a joke?",
    "Write a short story about a dragon.",
    "What's the weather like today?",
    "Who won the last soccer world cup?",
    "Recommend a good movie to watch.",
    "Who is the president of the United States?",
    "Translate this to French: 'I love pizza.'",
    "How do I fix a flat tire?",
    "Tell me a bedtime story.",
    "What's 42 times 17?",
    "How do I cook spaghetti carbonara?",
    "What are the best tourist spots in Japan?",
    "Play me a song.",
    "Write a poem about summer.",
    "How do you say 'hello' in Japanese?",
    "What's the plot of the movie Inception?",
    "Is pineapple on pizza good?",
    "Can you generate an image of a cat?",
    "How do I invest in stocks?",
    "What’s the best gaming laptop in 2025?",
    "Who won the NBA finals last year?",
    "What time is it in New York?",
    "Can you summarize the news today?",
    "What's the best workout for abs?",
    "Who wrote 'Pride and Prejudice'?",
]

In [13]:
def prepare_evaluation_set(full_dataset, max_samples: int):
    """
    Prepares the evaluation set by combining medical questions
    from the pre-loaded dataset and local non-medical bait questions.
    """
    print("Preparing evaluation set...")
    dev_set = []

    if full_dataset and 'validation' in full_dataset:
        validation_split = full_dataset["validation"]
        for i, item in enumerate(validation_split):
            messages = item.get('messages', [])
            user_prompt, reference_answer = None, None
            for message in messages:
                if message.get('role') == 'user': user_prompt = message.get('content')
                elif message.get('role') == 'assistant': reference_answer = message.get('content')
            if user_prompt and reference_answer:
                dev_set.append({"id": f"Med-{i}", "prompt": user_prompt, "reference_answer": reference_answer, "is_bait": False})
    else:
        print("Medical dataset not available or invalid. Proceeding with bait questions only.")

    # Add bait questions
    for i, question in enumerate(BAIT_QUESTIONS):
        dev_set.append({"id": f"Bait-{i}", "prompt": question, "reference_answer": "", "is_bait": True})

    # Shuffle and limit the dataset
    random.shuffle(dev_set)
    final_set = dev_set[:max_samples]
    print(f"Prepared {len(final_set)} mixed samples for evaluation.")
    return final_set

In [28]:
MAX_SAMPLES_TO_EVALUATE = 70 # A limit to keep evaluation time reasonable

In [29]:
DEV_SET = []
if ds:
    DEV_SET = prepare_evaluation_set(ds, max_samples=MAX_SAMPLES_TO_EVALUATE)

Preparing evaluation set...
Prepared 70 mixed samples for evaluation.


In [30]:
def evaluate_models():
    """
    Main function to load models, run evaluation on the dev set, and report metrics.
    """
    # Initial checks for data and hardware availability.
    if not DEV_SET:
        print("Evaluation cannot proceed without a development set.")
        return

    if not torch.cuda.is_available():
        print("ERROR: This script requires a CUDA-enabled GPU.")
        return

    # Pre-load metric calculators from the Hugging Face 'evaluate' library.
    bertscore = load("bertscore")
    rouge = load("rouge")
    results_data = []

    # Keywords to check if the model successfully refused a bait question.
    refusal_keywords = ["sorry", "cannot", "unable", "medical", "anatomical", "health", "purpose is to provide"]

    # Loop through each model specified in the configuration.
    for model_info in MODELS_TO_EVALUATE:
        model_name = model_info["name"]
        model_id = model_info["model_id"]
        print(f"\nEvaluating Model: {model_name} ({model_id})")

        # Reset GPU memory stats for accurate measurement of each model.
        torch.cuda.reset_peak_memory_stats(0)
        torch.cuda.empty_cache()

        model = None
        tokenizer = None
        try:
            # Load the model and its tokenizer from Hugging Face.
            # We are now using the standard transformers loader.
            print("Loading model and tokenizer...")
            tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
            model = transformers.AutoModelForCausalLM.from_pretrained(
                model_id,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True, # Necessary for some community models
            )

            print("Model loaded successfully.")
            # Record the peak VRAM used after loading the model.
            peak_vram_gb = torch.cuda.max_memory_allocated(0) / (1024**3)

            # Loop through each prompt in our prepared development set.
            for i, item in enumerate(DEV_SET):
                # Only print status on the first prompt and every 20th prompt thereafter.
                if i == 0 or (i + 1) % 20 == 0:
                    print(f"Processing prompt {i+1}/{len(DEV_SET)} ({item['id']})...")

                user_prompt = item["prompt"]

                # Format the prompt using the system message and user question.
                messages = [
                    {"role": "system", "content": MEDICAL_CHATBOT_SYSTEM_PROMPT},
                    {"role": "user", "content": user_prompt},
                ]

                # Use the tokenizer's chat template if it exists, otherwise use a generic format
                if getattr(tokenizer, 'chat_template', None):
                    full_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                else:
                    full_prompt = f"System: {MEDICAL_CHATBOT_SYSTEM_PROMPT}\nUser: {user_prompt}\nAssistant:"

                inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)

                # Start timing for inference speed calculation.
                start_time = time.perf_counter()
                outputs = model.generate(**inputs, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id, do_sample=False)
                end_time = time.perf_counter()
                total_time = end_time - start_time

                # Decode the output and calculate tokens per second.
                input_length = inputs.input_ids.shape[1]
                generated_tokens = outputs[0][input_length:]
                generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
                tokens_per_second = len(generated_tokens) / total_time if total_time > 0 else 0

                # Initialize metrics for this prompt.
                refusal_accuracy = 0
                bert_f1 = 0
                rouge_l = 0

                # Check if the prompt was a bait question.
                if item["is_bait"]:
                    # If it's a bait question, check if the response contains refusal keywords.
                    if any(keyword in generated_text.lower() for keyword in refusal_keywords):
                        refusal_accuracy = 1
                else:
                    # If it's a medical question, calculate BERTScore and ROUGE against the reference answer.
                    reference_answer = item["reference_answer"]
                    bert_results = bertscore.compute(predictions=[generated_text], references=[reference_answer], lang="en")
                    rouge_results = rouge.compute(predictions=[generated_text], references=[reference_answer])
                    bert_f1 = bert_results['f1'][0]
                    rouge_l = rouge_results['rougeL']

                # Append all collected metrics for this prompt to our results list.
                results_data.append({
                    "Model": model_name,
                    "Prompt ID": item["id"],
                    "Is Bait": item["is_bait"],
                    "Tokens/Sec": tokens_per_second,
                    "Peak VRAM (GB)": peak_vram_gb,
                    "BERTScore-F1": bert_f1,
                    "ROUGE-L": rouge_l,
                    "Refusal Acc": refusal_accuracy
                })

        except Exception as e:
            # Catch and report any errors during evaluation of a model.
            print(f"ERROR: Failed to evaluate model {model_name}. Error: {e}")
        finally:
            # Clean up memory to prepare for the next model.
            if model is not None: del model
            if tokenizer is not None: del tokenizer
            gc.collect()
            torch.cuda.empty_cache()

    # After all models are evaluated, check if we have any results.
    if not results_data:
        print("\nNo results to display.")
        return

    # Use pandas to format and display the results.
    pd.set_option('display.max_colwidth', 80)
    pd.set_option('display.width', 120)

    # Show the detailed results for every prompt.
    df_detailed = pd.DataFrame(results_data)
    print("\n\nDETAILED PER-PROMPT RESULTS")
    print(df_detailed.round(3))

    # Calculate and display the final summary table with averages.
    df_summary = df_detailed.groupby("Model").agg(
        Avg_Tokens_Sec=("Tokens/Sec", "mean"),
        Peak_VRAM_GB=("Peak VRAM (GB)", "first"),
        Avg_Medical_BERTScore_F1=("BERTScore-F1", lambda x: x[df_detailed.loc[x.index, 'Is Bait'] == False].mean()),
        Avg_Refusal_Accuracy=("Refusal Acc", lambda x: x[df_detailed.loc[x.index, 'Is Bait'] == True].mean())
    ).reset_index()


    print("\n\nAVERAGE METRIC SUMMARY")
    print(df_summary.round(3))
    print("\nEvaluation complete.")

In [31]:
evaluate_models()


Evaluating Model: Llama-3.2-1B-Instruct (meta-llama/Llama-3.2-1B-Instruct)
Loading model and tokenizer...
Model loaded successfully.
Processing prompt 1/70 (Med-1605)...
Processing prompt 20/70 (Med-142)...
Processing prompt 40/70 (Med-6)...
Processing prompt 60/70 (Med-147)...

Evaluating Model: Llama-3.2-3B-Instruct (meta-llama/Llama-3.2-3B-Instruct)
Loading model and tokenizer...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model loaded successfully.
Processing prompt 1/70 (Med-1605)...
Processing prompt 20/70 (Med-142)...
Processing prompt 40/70 (Med-6)...
Processing prompt 60/70 (Med-147)...

Evaluating Model: Bio-Medical-Llama-1B (ContactDoctor/Bio-Medical-Llama-3-2-1B-CoT-012025)
Loading model and tokenizer...


tokenizer_config.json:   0%|          | 0.00/54.7k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/439 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/983 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/196 [00:00<?, ?B/s]

Model loaded successfully.
Processing prompt 1/70 (Med-1605)...
Processing prompt 20/70 (Med-142)...
Processing prompt 40/70 (Med-6)...
Processing prompt 60/70 (Med-147)...

Evaluating Model: Bio-Medical-Llama-3B (ContactDoctor/Bio-Medical-3B-CoT-012025)
Loading model and tokenizer...
ERROR: Failed to evaluate model Bio-Medical-Llama-3B. Error: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/ContactDoctor/Bio-Medical-3B-CoT-012025.
403 Client Error. (Request ID: Root=1-684fce22-12e5d4283da731ab59463995;67c4c735-2048-4d63-929c-18f685cd7d61)

Cannot access gated repo for url https://huggingface.co/ContactDoctor/Bio-Medical-3B-CoT-012025/resolve/main/config.json.
Access to model ContactDoctor/Bio-Medical-3B-CoT-012025 is restricted and you are not in the authorized list. Visit https://huggingface.co/ContactDoctor/Bio-Medical-3B-CoT-012025 to ask for access.

Evaluating Model: Bio-Medical-Llama-8B (ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012

tokenizer_config.json:   0%|          | 0.00/3.06k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/946 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Model loaded successfully.
Processing prompt 1/70 (Med-1605)...
Processing prompt 20/70 (Med-142)...
Processing prompt 40/70 (Med-6)...
Processing prompt 60/70 (Med-147)...

Evaluating Model: DeepSeek-Medical-COT-8B (kingabzpro/DeepSeek-R1-Medical-COT)
Loading model and tokenizer...


tokenizer_config.json:   0%|          | 0.00/52.9k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/999 [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/824 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.56k [00:00<?, ?B/s]

ERROR: Failed to evaluate model DeepSeek-Medical-COT-8B. Error: No package metadata was found for bitsandbytes

Evaluating Model: Qwen3-Medical-Reasoning-8B (kingabzpro/DeepSeek-R1-0528-Qwen3-8B-Medical-Reasoning)
Loading model and tokenizer...


tokenizer_config.json:   0%|          | 0.00/5.59k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/485 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/3.13k [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/870 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/859 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/33.3k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-000002.safetensors:   0%|          | 0.00/7.77G [00:00<?, ?B/s]

model-00001-of-000002.safetensors:   0%|          | 0.00/8.61G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

adapter_model.safetensors:   0%|          | 0.00/698M [00:00<?, ?B/s]

Model loaded successfully.
Processing prompt 1/70 (Med-1605)...
Processing prompt 20/70 (Med-142)...
Processing prompt 40/70 (Med-6)...
Processing prompt 60/70 (Med-147)...

Evaluating Model: MedicalEDI-Llama3.1-8B (Shaleen123/MedicalEDI-Llama3.1-8b-Reasoning)
Loading model and tokenizer...


tokenizer_config.json:   0%|          | 0.00/52.9k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/371 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/903 [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/805 [00:00<?, ?B/s]

ERROR: Failed to evaluate model MedicalEDI-Llama3.1-8B. Error: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct.
403 Client Error. (Request ID: Root=1-684fd86d-25d0989f5f6bbf2571b0654f;baf73043-0792-4152-b5f4-179275b0f304)

Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.1-8B-Instruct is restricted and you are not in the authorized list. Visit https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct to ask for access.


DETAILED PER-PROMPT RESULTS
                          Model Prompt ID  Is Bait  Tokens/Sec  Peak VRAM (GB)  BERTScore-F1  ROUGE-L  Refusal Acc
0         Llama-3.2-1B-Instruct  Med-1605    False      46.559          16.433         0.829    0.000            0
1         Llama-3.2-1B-Instruct  Med-1254    False      46.131          16.433         0.849    0.000            0
2         Llama-3.2-1B

## Fine tuning
Using Unsloth

### Installation

In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm==0.8.5.post1

In [2]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm==0.8.5.post1
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

Load up `Llama 3.2 3B Instruct`, and set parameters. To finetune a base model from scratch, check out our `Qwen 3 4B Base GRPO` notebook [here](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb)


Define the special tags for reasoning

In [9]:
reasoning_start = "<start_working_out>"
reasoning_end   = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

# System prompt that instructs the model on the reasoning format
system_prompt = \
f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}"""

Load up each model from the above candidates

In [8]:
def prepare_reasoning_dataset(full_dataset, num_samples=100):
    """
    Formats the 'train' split of the pre-loaded dataset for GRPO training.
    """
    if not full_dataset or 'train' not in full_dataset:
        print("Dataset is invalid or does not contain a 'train' split.")
        return None

    print("Preparing reasoning dataset from the 'train' split...")
    train_split = full_dataset['train']

    training_data = []
    for item in train_split:
        messages = item.get('messages', [])
        user_prompt, reference_answer = None, None
        for message in messages:
            if message.get('role') == 'user':
                user_prompt = message.get('content')
            elif message.get('role') == 'assistant':
                reference_answer = message.get('content')

        if user_prompt and reference_answer:
            training_data.append({
                "prompt": [
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt},
                ],
                "answer": reference_answer,
            })

    # Create a Dataset object from our clean list and limit the samples
    dataset = Dataset.from_list(training_data).select(range(num_samples))
    print(f"Prepared {len(dataset)} samples for GRPO training.")
    return dataset

In [12]:
reasoning_dataset = prepare_reasoning_dataset(ds, num_samples=100)


Preparing reasoning dataset from the 'train' split...
Prepared 100 samples for GRPO training.


In [13]:
MODEL_TO_FINETUNE = {
    "name": "Bio-Medical-Llama-8B",
    "model_id": "ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025"
}

In [14]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_TO_FINETUNE['model_id'],
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.8, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

model.print_trainable_parameters()


Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth import FastLanguageModel


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.6.2: Fast Llama patching. Transformers: 4.52.4. vLLM: 0.8.5.post1.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025 with actual GPU utilization = 79.08%
Unsloth: Your GPU has CUDA compute capability 8.0 with VRAM = 39.56 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 2048. Num Sequences = 288.
Unsloth: vLLM's KV Cache can use up to 16.01 GB. Also swap space = 6 GB.
INFO 06-16 06:41:19 [config.py:7

tokenizer_config.json:   0%|          | 0.00/3.06k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

INFO 06-16 06:41:21 [core.py:58] Initializing a V1 LLM engine (v0.8.5.post1) with config: model='ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025', speculative_config=None, tokenizer='ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.BITSANDBYTES, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=bitsandbytes, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda:0, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025, num_scheduler_steps=1, mult

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

INFO 06-16 06:43:10 [weight_utils.py:281] Time spent downloading weights for ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025: 106.982752 seconds


model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]


INFO 06-16 06:43:16 [punica_selector.py:18] Using PunicaWrapperGPU.
INFO 06-16 06:43:16 [gpu_model_runner.py:1347] Model loading took 5.7012 GiB and 112.841696 seconds
INFO 06-16 06:43:37 [backends.py:420] Using cache directory: /root/.cache/vllm/torch_compile_cache/a8c9b71ba2/rank_0_0 for vLLM's torch.compile
INFO 06-16 06:43:37 [backends.py:430] Dynamo bytecode transform time: 20.43 s


Inductor Compilation: 0it [00:00, ?it/s, triton_poi_fused_cat_5]

INFO 06-16 06:43:44 [backends.py:136] Cache the graph of shape None for later use



Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 11.02it/s, triton_poi_fused_cat_7]
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 98.20it/s, triton_poi_fused_cat_7]
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 105.19it/s, triton_poi_fused_cat_7]
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 94.96it/s, triton_poi_fused_cat_7]
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 93.63it/s, triton_poi_fused_cat_7]
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 94.32it/s, triton_poi_fused_cat_7]
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00,  9.96it/s, triton_poi_fused_cat_7]
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 119.97it/s, triton_poi_fused_cat_7]
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 121.36it/s, triton_poi_fused_cat_7]
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 116.83it/s, triton_poi_fused_cat_7]
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 122.28it/s, triton

INFO 06-16 06:44:38 [backends.py:148] Compiling a graph for general shape takes 58.51 s





INFO 06-16 06:46:24 [monitor.py:33] torch.compile takes 78.94 s in total
INFO 06-16 06:46:28 [kv_cache_utils.py:634] GPU KV cache size: 192,192 tokens
INFO 06-16 06:46:28 [kv_cache_utils.py:637] Maximum concurrency for 2,048 tokens per request: 93.84x
INFO 06-16 06:47:59 [gpu_model_runner.py:1686] Graph capturing finished in 91 secs, took 1.45 GiB
INFO 06-16 06:47:59 [core.py:159] init engine (profile, create kv cache, warmup model) took 283.16 seconds
Unsloth: Just some info: will skip parsing ['q_norm', 'pre_feedforward_layernorm', 'post_feedforward_layernorm', 'k_norm']
Unsloth: Just some info: will skip parsing ['q_norm', 'pre_feedforward_layernorm', 'post_feedforward_layernorm', 'k_norm']


tokenizer_config.json:   0%|          | 0.00/3.06k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025 does not have a padding token! Will use pad_token = <|finetune_right_pad_id|>.


Unsloth 2025.6.2 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


trainable params: 167,772,160 || all params: 8,198,033,408 || trainable%: 2.0465


## Custom Reward Functions


In [15]:
# Regex to find a correctly formatted response
match_format = re.compile(
    rf"{reasoning_start}(.*?){reasoning_end}.*?{solution_start}(.*?){solution_end}",
    flags = re.DOTALL
)

In [16]:
def check_format(completions, **kwargs):
    """Gives a high reward if the model follows the specified reasoning format."""
    scores = []
    for completion in completions:
        response = completion[0]["content"]
        if match_format.search(response):
            scores.append(2.0) # Reward for using the correct format
        else:
            scores.append(-1.0) # Penalize for not following the format
    return scores

In [17]:
from evaluate import load
bertscore_metric = load("bertscore")

def check_medical_answer(prompts, completions, answer, **kwargs):
    """
    Checks if the generated answer is semantically similar to the reference answer.
    This is the key quality metric for our medical chatbot.
    """
    scores = []
    responses = [completion[0]["content"] for completion in completions]

    for gen_response, ref_answer in zip(responses, answer):
        match = match_format.search(gen_response)
        if match:
            # Extract the solution text from between the SOLUTION tags
            extracted_solution = match.group(2).strip()

            # Use BERTScore to compare semantic similarity
            bert_results = bertscore_metric.compute(
                predictions=[extracted_solution],
                references=[ref_answer],
                lang="en"
            )
            # Reward based on F1 score, penalize if very dissimilar
            f1_score = bert_results['f1'][0]
            reward = (f1_score * 4) - 2 # Scale F1 score (0.5-1.0) to a reward of (0-2)
            scores.append(reward)
        else:
            scores.append(-2.0) # Penalize heavily if the format is wrong

    return scores

Downloading builder script:   0%|          | 0.00/7.95k [00:00<?, ?B/s]

## Configure and run the GRPO trainer

In [20]:
%pip install git+https://github.com/huggingface/trl.git

Collecting git+https://github.com/huggingface/trl.git
  Cloning https://github.com/huggingface/trl.git to /tmp/pip-req-build-ms6mi_pw
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/trl.git /tmp/pip-req-build-ms6mi_pw
  Resolved https://github.com/huggingface/trl.git to commit 8a235a9b71f4c0b77e295afb972fdd7c19a71335
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting datasets>=3.0.0 (from trl==0.19.0.dev0)
  Using cached datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting transformers>=4.51.0 (from trl==0.19.0.dev0)
  Downloading transformers-4.52.4-py3-none-any.whl.metadata (38 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers>=4.51.0->trl==0.19.0.dev0)
  Downloading tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Using cached datasets-3.6.0-py3-none-any.whl (

In [3]:
from trl import GRPOConfig, GRPOTrainer


INFO 06-16 06:39:05 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 06-16 06:39:05 [__init__.py:239] Automatically detected platform cuda.


In [20]:
max_prompt_length = 287 + 1 # + 1 just in case!

training_args = GRPOConfig(
    learning_rate = 5e-6,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 500,
    save_steps = 250,
    max_grad_norm = 1.0,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

In [25]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        check_format,
        check_medical_answer,
    ],
    args = training_args,
    train_dataset = reasoning_dataset,
)
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 100 | Num Epochs = 5 | Total steps = 500
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4
 "-____-"     Trainable parameters = 167,772,160/8,000,000,000 (2.10% trained)
`generation_config` default values have been modified to match model-specific defaults: {'max_length': 131072}. If this is not desired, please set these values explicitly.


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
1,0.0
2,0.0


KeyboardInterrupt: 