This notebook fine-tunes MedGemma using a 1260 row dataset of Q&A typical of Tactical Combat Causualty Care (TCCC) used by the United States Armed Forces. The dataset is synthetic, but intended to utilize real scenarios and language experienced by combat medics. <br>
For this case, I found that 3 epochs worked relatively well. The second cell is testing the format, and the last cell is the full 3 epoch training.

In [None]:
!pip install -q transformers accelerate bitsandbytes peft datasets trl huggingface_hub

from huggingface_hub import login, whoami
login()
print(whoami())

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m59.1/59.1 MB[0m [31m19.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m532.9/532.9 kB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[?25h

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶

{'type': 'user', 'id': '689398ab26be2ec6a00c9dc5', 'name': 'CharlieKingOfTheRats', 'fullname': 'Charles Donnelly', 'isPro': False, 'avatarUrl': '/avatars/0526a0ff75e4f0569e3431f54d68c811.svg', 'orgs': [{'type': 'org', 'id': '6919d4b99510e3441e1448b2', 'name': 'basement-labs', 'fullname': 'Basement Labs', 'email': 'charles.donnelly497@gmail.com', 'canPay': False, 'billingMode': 'postpaid', 'periodEnd': None, 'avatarUrl': 'https://cdn-avatars.huggingface.co/v1/production/uploads/689398ab26be2ec6a00c9dc5/GhO13PwR2Pvk_8WJiiWOE.png', 'roleInOrg': 'admin', 'isEnterprise': False}], 'auth': {'type': 'access_token', 'accessToken': {'displayName': 'Mac mini M1 deploy', 'role': 'fineGrained', 'createdAt': '2025-11-21T23:58:28.152Z', 'fineGrained': {'canReadGatedRepos': True, 'global': [], 'scoped': [{'entity': {'_id': '6910e1ebd674cb6b8bddf27a', 'type': 'dataset', 'name': 'CharlieKingOfTheRats/Mechanical-engineering'}, 'permissions': ['repo.content.read']}, {'entity': {'_id': '6919d4b99510e3441e1

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from datasets import Dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from huggingface_hub import HfApi

# ----------------------------
# resume here, Config
# ----------------------------
BASE_MODEL = "google/medgemma-1.5-4b-it"
LORA_REPO = "CharlieKingOfTheRats/medgemma-1.5-4b-tccc-lora"

# ----------------------------
# 4-bit QLoRA config (float32 to avoid AMP issues)
# ----------------------------
bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float32,  # safe on T4 / M1
)

# ----------------------------
# Tokenizer
# ----------------------------
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# ----------------------------
# Model
# ----------------------------
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb,
    device_map="auto",
    trust_remote_code=True,
    dtype=torch.float32,
)

model.gradient_checkpointing_enable()
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)

# ----------------------------
# Dataset
# ----------------------------
data = [
    {
        "question": "Hawk-9 medic, casualty GSW right shoulder, heavy bleeding.",
        "answer": "Apply tourniquet proximal to wound, assess distal pulse, prepare evac."
    }
]

def format_example(x):
    return {
        "text": f"<user>{x['question']}</user>\n<assistant>{x['answer']}</assistant>{tokenizer.eos_token}"
    }

dataset = Dataset.from_list([format_example(x) for x in data])

# ----------------------------
# LoRA
# ----------------------------
lora = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora)
model.print_trainable_parameters()

# ----------------------------
# Training arguments
# ----------------------------
training_args = TrainingArguments(
    output_dir="./medgemma_tccc_lora",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-4,
    fp16=False,    # disable AMP
    bf16=False,    # disable BF16
    optim="adamw_torch",
    logging_steps=1,
    save_strategy="epoch",
    report_to="none",
)

# ----------------------------
# Trainer without AMP
# ----------------------------
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    processing_class=tokenizer,
)

# Disable AMP explicitly using autocast
with torch.autocast("cuda", enabled=False):
    trainer.train()

# ----------------------------
# Push LoRA adapter
# ----------------------------
api = HfApi()
api.create_repo(LORA_REPO, exist_ok=True)

model.save_pretrained("lora_adapter")
tokenizer.save_pretrained("lora_adapter")

model.push_to_hub(LORA_REPO)
tokenizer.push_to_hub(LORA_REPO)

print(f"LoRA adapter pushed: https://huggingface.co/{LORA_REPO}")

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.55k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/115 [00:00<?, ?B/s]

trainable params: 11,898,880 || all params: 4,311,978,352 || trainable%: 0.2759




Adding EOS to train dataset:   0%|          | 0/1 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/1 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/1 [00:00<?, ? examples/s]

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.
  return fn(*args, **kwargs)


Step,Training Loss
1,4.6904
2,3.8563
3,3.3223


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


README.md: 0.00B [00:00, ?B/s]

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...adapter_model.safetensors:   0%|          | 43.9kB / 23.9MB            

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...pqcrj0r03/tokenizer.model: 100%|##########| 4.69MB / 4.69MB            

  ...mpqcrj0r03/tokenizer.json: 100%|##########| 33.4MB / 33.4MB            

No files have been modified since last commit. Skipping to prevent empty commit.


LoRA adapter pushed: https://huggingface.co/CharlieKingOfTheRats/medgemma-1.5-4b-tccc-lora


In [None]:
# full training based off medgemma_tccc dataset

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from huggingface_hub import HfApi

# ----------------------------
# Config
# ----------------------------
BASE_MODEL = "google/medgemma-1.5-4b-it"
LORA_REPO = "CharlieKingOfTheRats/medgemma-1.5-4b-tccc-lora"
DATASET_NAME = "CharlieKingOfTheRats/medgemma_tccc"

# ----------------------------
# 4-bit QLoRA config (safe for Colab T4)
# ----------------------------
bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float32,  # stable on T4
)

# ----------------------------
# Tokenizer
# ----------------------------
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# ----------------------------
# Model
# ----------------------------
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.float32,
)

model.gradient_checkpointing_enable()
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)

# ----------------------------
# Dataset
# ----------------------------
raw_dataset = load_dataset(DATASET_NAME, split="train")

def format_conversation(example):
    """
    Each example is either:
    {
      "messages": [
        {"role":"user","content":"..."},
        {"role":"assistant","content":"..."}
      ]
    }
    or directly a list of role/content dicts.
    """
    messages = example["messages"] if "messages" in example else example

    user_msg = None
    assistant_msg = None

    for m in messages:
        if m["role"] == "user":
            user_msg = m["content"]
        elif m["role"] == "assistant":
            assistant_msg = m["content"]

    if user_msg is None or assistant_msg is None:
        return None

    text = (
        f"<user>{user_msg}</user>\n"
        f"<assistant>{assistant_msg}</assistant>{tokenizer.eos_token}"
    )
    return {"text": text}

dataset = raw_dataset.map(
    format_conversation,
    remove_columns=raw_dataset.column_names
)

dataset = dataset.filter(lambda x: x["text"] is not None)

print("Sample formatted training example:\n")
print(dataset[0]["text"][:600])

# ----------------------------
# LoRA Config
# ----------------------------
lora = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora)
model.print_trainable_parameters()

# ----------------------------
# Training Arguments
# ----------------------------
training_args = TrainingArguments(
    output_dir="./medgemma_tccc_lora",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-4,
    fp16=False,       # Disable AMP
    bf16=False,
    optim="adamw_torch",
    logging_steps=1,
    save_strategy="epoch",
    report_to="none",
    save_total_limit=2,
)

# ----------------------------
# Trainer (No AMP)
# ----------------------------
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    processing_class=tokenizer,
)

with torch.autocast("cuda", enabled=False):
    trainer.train()

# ----------------------------
# Push LoRA Adapter to HF
# ----------------------------
api = HfApi()
api.create_repo(LORA_REPO, exist_ok=True)

model.save_pretrained("lora_adapter")
tokenizer.save_pretrained("lora_adapter")

model.push_to_hub(LORA_REPO)
tokenizer.push_to_hub(LORA_REPO)

print(f"\nLoRA adapter pushed to:")
print(f"https://huggingface.co/{LORA_REPO}")


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

README.md:   0%|          | 0.00/638 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


medgemma_dataset.jsonl: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/1260 [00:00<?, ? examples/s]

Map:   0%|          | 0/1260 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1260 [00:00<?, ? examples/s]

Sample formatted training example:

<user>AI, this is Echo-26 medic in TFC, casualty with GSW to right thigh. CAT applied, bleeding controlled. Airway intact, patient alert. SpO‚ÇÇ 94%, respirations 26. HR 142, BP 98 systolic. Confirm TXA and reassessment.</user>
<assistant>Echo-26, proceed with the following:

1. **TXA Administration**: Administer TXA (Tranexamic Acid) 1 gram IV over 10 minutes if not already given. Ensure it is within 3 hours of injury.

2. **Reassess Vital Signs**: After TXA administration, continue to monitor:
   - Heart Rate
   - Blood Pressure
   - Respiratory Rate
   - SpO‚ÇÇ levels

3. **Fluid Resuscitation*
trainable params: 11,898,880 || all params: 4,311,978,352 || trainable%: 0.2759




Adding EOS to train dataset:   0%|          | 0/1260 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/1260 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/1260 [00:00<?, ? examples/s]

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.
  return fn(*args, **kwargs)


Step,Training Loss
1,2.0047
2,1.9336
3,1.7211
4,1.8101
5,1.7269
6,1.4853
7,1.4875
8,1.4346
9,1.3232
10,1.4074


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...adapter_model.safetensors:   0%|          | 43.9kB / 23.9MB            

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...pln3k84wi/tokenizer.model: 100%|##########| 4.69MB / 4.69MB            

  ...mpln3k84wi/tokenizer.json:  75%|#######5  | 25.2MB / 33.4MB            

No files have been modified since last commit. Skipping to prevent empty commit.



LoRA adapter pushed to:
https://huggingface.co/CharlieKingOfTheRats/medgemma-1.5-4b-tccc-lora
