<a href="https://colab.research.google.com/github/aee4/MedGemma/blob/main/scripts/Medgemma_eczema.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

emmanueleyramagbetor_dataset_eczema_path = kagglehub.dataset_download('emmanueleyramagbetor/dataset-eczema')

print('Data source import complete.')


In [None]:
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"


In [None]:
import torch
from datasets import load_dataset
from huggingface_hub import login


In [None]:
login()


In [None]:
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainingArguments
from peft import LoraConfig
from datasets import load_dataset


In [None]:
data_files = {
    "train": "/kaggle/input/dataset-eczema/train.csv",
    "validation": "/kaggle/input/dataset-eczema/validation.csv",
    "test": "/kaggle/input/dataset-eczema/test.csv",
}

dataset = load_dataset("csv", data_files=data_files)
print(dataset)


In [None]:
def format_example(example):
    text = example["text"]
    label = str(example["label"])     # ensure "0" or "1"

    example["messages"] = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": f"Description: {text}\n\nAnswer the correct class (0 or 1)."}
            ]
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": label}]
        }
    ]
    return example

formatted_data = dataset.map(format_example)
formatted_data


In [None]:
model_id = "google/medgemma-4b-it"

model, processor = FastLanguageModel.from_pretrained(
    model_name = model_id,
    max_seq_length = 2048,
    dtype = None,
    load_in_4bit = True,
)

processor.tokenizer.padding_side = "right"
print("Loaded MedGemma 4B IT with Unsloth.")


In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    lora_alpha = 16,
    lora_dropout = 0.05,
    bias = "none",
    target_modules = "all-linear",
    use_gradient_checkpointing = True,
)


In [None]:
def collate_fn(batch):
    texts = [
        processor.apply_chat_template(
            ex["messages"],
            add_generation_prompt=False,
            tokenize=False
        )
        for ex in batch
    ]

    batch_inputs = processor(
        text=texts,
        return_tensors="pt",
        padding=True
    )

    labels = batch_inputs["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    batch_inputs["labels"] = labels

    return batch_inputs


In [None]:
def to_text(example):
    return {
        "text": processor.apply_chat_template(
            example["messages"],
            tokenize=False,
            add_generation_prompt=False
        )
    }

final_dataset = formatted_data.map(to_text, remove_columns=["messages"])
print("Example:\n", final_dataset["train"][0]["text"])


In [None]:
from trl import SFTConfig

args = SFTConfig(
    output_dir="medgemma-binary-classification",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    warmup_steps=5,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    seed=42,
    report_to="none",
)

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=final_dataset["train"],
    eval_dataset=final_dataset["validation"],
    dataset_text_field="text",
    max_seq_length=1024,
    dataset_num_proc=1,
)


In [None]:
trainer.train()


In [None]:
eval_results = trainer.evaluate()
print(eval_results)


In [None]:
!pip install huggingface_hub


In [None]:
import os

# Folder to save your model
save_path = "/kaggle/working/my_model"
os.makedirs(save_path, exist_ok=True)

# Save the LoRA-finetuned model (instance method)
model.save_pretrained(save_path)

# Save the tokenizer
processor.tokenizer.save_pretrained(save_path)

print(f"Model and tokenizer saved to {save_path}")


In [None]:
from huggingface_hub import HfApi

repo_id = "aee4/medgemma-eczema"

api = HfApi()

api.create_repo(repo_id=repo_id, private=False)

api.upload_folder(
    folder_path="/kaggle/working/my_model",
    repo_id=repo_id,
    commit_message="Upload fine-tuned model"
)
