In [5]:
from datasets import load_dataset

ds = load_dataset("hpe-ai/medical-cases-classification-tutorial")

Repo card metadata block was not found. Setting CardData to empty.


In [3]:
for split in ds:
    ds[split].save_to_disk(f"./medical_cases_{split}")

Saving the dataset (0/1 shards):   0%|          | 0/1724 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/370 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/370 [00:00<?, ? examples/s]

In [None]:
import pandas as pd
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import get_peft_model, LoraConfig, TaskType
import torch

# -------------------------------------------------------------------
# 1. LOAD & PREPROCESS THE DATA
# -------------------------------------------------------------------

# Load the CSV dataset
df = pd.read_csv(r"medical_cases_train\medical_cases_train.csv")

# Keep only the relevant columns and drop rows with missing values
df = df[["description", "transcription"]].dropna()

# Convert the Pandas dataframe into a Hugging Face Dataset object
dataset = Dataset.from_pandas(df)

# Format each row as a prompt-response dialogue (chat style for Gemma)
def format_prompt(example):
    return {
        "text": f"<start_of_turn>user\n{example['description']}\n<end_of_turn>\n<start_of_turn>model\n{example['transcription']}<end_of_turn>"
    }

# Apply the prompt formatting to each row in the dataset
dataset = dataset.map(format_prompt)

# -------------------------------------------------------------------
# 2. LOAD TOKENIZER AND BASE MODEL
# -------------------------------------------------------------------

# Set the name of the model checkpoint
model_name = "google/gemma-3-1b-it"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Ensure that padding token is set (Gemma uses EOS token for padding too)
tokenizer.pad_token = tokenizer.eos_token

# Load the model in float16 precision and place it on available GPU automatically
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,  # Use 16-bit precision (efficient on modern GPUs)
    device_map="auto"  # Automatically spreads model across available GPUs
)

# -------------------------------------------------------------------
# 3. APPLY LoRA FOR PARAMETER-EFFICIENT FINE-TUNING
# -------------------------------------------------------------------

# Define the LoRA configuration (Low-Rank Adaptation)
lora_config = LoraConfig(
    r=8,  # Rank of the update matrices (lower = less compute, higher = more expressive)
    lora_alpha=32,  # Scaling factor for LoRA
    target_modules=["q_proj", "v_proj"],  # These are the attention projection layers to inject LoRA into
    lora_dropout=0.05,  # Dropout for regularization
    bias="none",  # Do not adapt the biases
    task_type=TaskType.CAUSAL_LM  # Task is causal language modeling (predict next token)
)

# Wrap the model with LoRA
model = get_peft_model(model, lora_config)

# -------------------------------------------------------------------
# 4. TOKENIZE THE DATA
# -------------------------------------------------------------------

# Tokenization function that also prepares labels (for training)
def tokenize(example):
    tokens = tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=512  # You can adjust this based on GPU memory
    )
    tokens["labels"] = tokens["input_ids"].copy()  # Set labels for language modeling
    return tokens

# Apply tokenization
tokenized_dataset = dataset.map(tokenize, remove_columns=dataset.column_names)

# -------------------------------------------------------------------
# 5. SETUP TRAINING CONFIGURATION
# -------------------------------------------------------------------

training_args = TrainingArguments(
    output_dir="./gemma-lora-medical",  # Directory to save model checkpoints
    per_device_train_batch_size=1,  # Use batch size 1 (can be increased if GPU allows)
    gradient_accumulation_steps=4,  # Accumulate gradients to simulate larger batch size
    num_train_epochs=3,  # Train for 3 epochs
    learning_rate=2e-4,  # Learning rate (tuned for LoRA)
    fp16=True,  # Enable mixed precision training
    logging_dir="./logs",  # Where to write logs
    save_strategy="epoch",  # Save model every epoch
    save_total_limit=2  # Keep only the latest 2 checkpoints
)

# -------------------------------------------------------------------
# 6. INITIALIZE THE TRAINER
# -------------------------------------------------------------------

trainer = Trainer(
    model=model,  # LoRA-wrapped model
    args=training_args,
    train_dataset=tokenized_dataset,  # Tokenized dataset
    tokenizer=tokenizer,  # Needed for padding and decoding
    data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)  # MLM=False for causal LM
)

# -------------------------------------------------------------------
# 7. START TRAINING
# -------------------------------------------------------------------

trainer.train()

# -------------------------------------------------------------------
# 8. SAVE THE FINE-TUNED MODEL AND TOKENIZER
# -------------------------------------------------------------------

model.save_pretrained("./gemma-lora-medical")
tokenizer.save_pretrained("./gemma-lora-medical")


Map:   0%|          | 0/1724 [00:00<?, ? examples/s]

model.safetensors:  78%|#######7  | 1.55G/2.00G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

Map:   0%|          | 0/1724 [00:00<?, ? examples/s]

  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mabhayvg-2904[0m ([33mabhayvg-2904-indian-institute-of-technology-gandhinagar[0m). Use [1m`wandb login --relogin`[0m to force relogin


It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.


Step,Training Loss


In [None]:
import pandas as pd
from datasets import Dataset
from unsloth import FastLanguageModel
from transformers import TrainingArguments, Trainer, EvalPrediction
import torch
from peft import LoraConfig
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# -------------------------------------------------------------------
# 1. LOAD DATA
# -------------------------------------------------------------------
train_df = pd.read_csv("medical_cases_train/medical_cases_train.csv")[["description", "transcription"]].dropna()
val_df = pd.read_csv("medical_cases_validation/medical_cases_validation.csv")[["description", "transcription"]].dropna()
test_df = pd.read_csv("medical_cases_test/medical_cases_test.csv")[["description", "transcription"]].dropna()

train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

# -------------------------------------------------------------------
# 2. FORMAT PROMPTS
# -------------------------------------------------------------------
def format_prompt(example):
    return {
        "text": f"<start_of_turn>user\n{example['description']}\n<end_of_turn>\n<start_of_turn>model\n{example['transcription']}<end_of_turn>"
    }

train_dataset = train_dataset.map(format_prompt)
val_dataset = val_dataset.map(format_prompt)
test_dataset = test_dataset.map(format_prompt)

# -------------------------------------------------------------------
# 3. LOAD MODEL
# -------------------------------------------------------------------
model_name = "unsloth/gemma-3-1b-it-unsloth-bnb-4bit"
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=512,
    dtype=None,
    load_in_4bit=True
)
tokenizer.pad_token = tokenizer.eos_token

# -------------------------------------------------------------------
# 4. APPLY LoRA
# -------------------------------------------------------------------
FastLanguageModel.for_training(model)

lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model.add_adapter(lora_config)

# -------------------------------------------------------------------
# 5. TOKENIZATION
# -------------------------------------------------------------------
def tokenize(example):
    tokens = tokenizer(example["text"], padding="max_length", truncation=True, max_length=512)
    tokens["labels"] = tokens["input_ids"].copy()
    return tokens

train_dataset = train_dataset.map(tokenize, remove_columns=train_dataset.column_names)
val_dataset = val_dataset.map(tokenize, remove_columns=val_dataset.column_names)
test_dataset = test_dataset.map(tokenize, remove_columns=test_dataset.column_names)

# -------------------------------------------------------------------
# 6. TRAINING ARGUMENTS
# -------------------------------------------------------------------
training_args = TrainingArguments(
    output_dir="./gemma-lora-medical",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    warmup_steps=10,
    num_train_epochs=6,
    learning_rate=2e-4,
    fp16=True,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="epoch",
    save_total_limit=2,
    report_to="none"
)


# -------------------------------------------------------------------
# 7. METRICS FUNCTION
# -------------------------------------------------------------------
def compute_metrics(eval_pred: EvalPrediction):
    preds = eval_pred.predictions.argmax(-1)
    labels = eval_pred.label_ids

    # Flatten and ignore padded tokens
    true_labels = []
    pred_labels = []
    for pred, label in zip(preds, labels):
        for p, l in zip(pred, label):
            if l != -100:
                true_labels.append(l)
                pred_labels.append(p)

    return {
        "accuracy": accuracy_score(true_labels, pred_labels),
        "precision": precision_score(true_labels, pred_labels, average='macro', zero_division=0),
        "recall": recall_score(true_labels, pred_labels, average='macro', zero_division=0),
        "f1": f1_score(true_labels, pred_labels, average='macro', zero_division=0),
    }

# -------------------------------------------------------------------
# 8. TRAINER
# -------------------------------------------------------------------
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()

# -------------------------------------------------------------------
# 9. FINAL TEST EVALUATION
# -------------------------------------------------------------------
print("\n=== Final Evaluation on Test Set ===")
test_results = trainer.evaluate(eval_dataset=test_dataset)
for key, value in test_results.items():
    print(f"{key}: {value:.4f}")

# -------------------------------------------------------------------
# 10. SAVE MODEL
# -------------------------------------------------------------------
model.save_pretrained("./gemma-lora-medical")
tokenizer.save_pretrained("./gemma-lora-medical")


ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
Unsloth: Failed to patch Gemma3ForConditionalGeneration.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


Map:   0%|          | 0/1724 [00:00<?, ? examples/s]

Map:   0%|          | 0/370 [00:00<?, ? examples/s]

Map:   0%|          | 0/370 [00:00<?, ? examples/s]

  GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f"cuda:{i}") for i in range(n_gpus)])


==((====))==  Unsloth 2025.3.19: Fast Gemma3 patching. Transformers: 4.51.1.
   \\   /|    NVIDIA GeForce RTX 3050 Ti Laptop GPU. Num GPUs = 1. Max memory: 4.0 GB. Platform: Windows.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Map:   0%|          | 0/1724 [00:00<?, ? examples/s]

Map:   0%|          | 0/370 [00:00<?, ? examples/s]

Map:   0%|          | 0/370 [00:00<?, ? examples/s]

  trainer = Trainer(
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,724 | Num Epochs = 6 | Total steps = 1,290
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 745,472/1,000,000,000 (0.07% trained)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
10,3.9599
20,3.0152
30,2.6102
40,2.5049
50,2.4916
60,2.3871
70,2.4221
80,2.3398
90,2.2101
100,2.2779
