In [23]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1" 

# File Structure

```text
patchcamelyon_subset/
├── tumor/
│   ├── img00001.png
│   ├── img00002.png
│   └── ...
└── normal/
    ├── img00001.png
    ├── img00002.png
    └── ...
```

## Imports

In [24]:
# Other
import pandas as pd

# For set up
from datasets import load_dataset
from typing import Any

# For Loading Model
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

# For fine tuning
from peft import LoraConfig
from trl import SFTTrainer
from trl import SFTConfig


In [25]:
print("Number of GPUs visible:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device())
print("GPU name:", torch.cuda.get_device_name(torch.cuda.current_device()))

Number of GPUs visible: 1
Current device: 0
GPU name: NVIDIA RTX A6000


## Set Up

In [26]:

# ---------------------- Set Up ---------------------- #

train_size = 9000 
validation_size = 1000 

# Downloaded and organized a subset of the patchcamelyon data set (first 10K images) into
# a folder called patchcamelyon_subset. Has sub folders "normal" and "tumor"
data = load_dataset("./patchcamelyon_subset", split="train")
data = data.train_test_split(
    train_size=train_size,
    test_size=validation_size,
    shuffle=True,
    seed=42,
)
# rename the 'test' set to 'validation'
data["validation"] = data.pop("test")

# ------------ Optional: display dataset details ------------ #
print(data) 
# This is actually a dictionary - it contains {'image':blah, 'label':hmmm}
print(f"data['train'][0]: {data['train'][0]}")
# First image in the training data
image = data['train'][0]['image']
# First label in the training data
label = data['train'][0]['label']
image.save("sample_image.png")
print("Image saved to sample_image.png")
print(label)
print(data['train'].features['label'])
# ----------------------------------------------------------- #

HISTOPATHOLOGY_CLASSES = [
    # One option for each class
    "A: no tumor present",
    "B: tumor present"
]

options = "\n".join(HISTOPATHOLOGY_CLASSES)
PROMPT = f"Is a tumor present in this histopathology image?\n{options}"

# 'example' is the name of the input here - input is a dict.
# The key for this dict is a str and the value can be of Any type
def format_data(example: dict[str, Any]) -> dict[str, Any]:
    # adds a new entry to the dict
    example["messages"] = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                },
                {
                    "type": "text",
                    "text": PROMPT,
                },
            ],
        },
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    # label of 0 will map to: (A: no tumor present), label of 1 will map to: (B: tumor present)
                    "text": HISTOPATHOLOGY_CLASSES[example["label"]],
                },
            ],
        },
    ]
    # Returns a dict with the same structure - but now {'image':blah, 'label':hmmm, 'message':blumph}
    return example

data = data.map(format_data)
print(data['train'][0])

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 9000
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1000
    })
})
data['train'][0]: {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=96x96 at 0x7F75A685ECE0>, 'label': 0}
Image saved to sample_image.png
0
ClassLabel(names=['normal', 'tumor'], id=None)
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=96x96 at 0x7F75A685C490>, 'label': 0, 'messages': [{'content': [{'text': None, 'type': 'image'}, {'text': 'Is a tumor present in this histopathology image?\nA: no tumor present\nB: tumor present', 'type': 'text'}], 'role': 'user'}, {'content': [{'text': 'A: no tumor present', 'type': 'text'}], 'role': 'assistant'}]}


```text
DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 9000
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1000
    })
})
data['train'][0]: {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=96x96 at 0x7F34302CA3E0>, 'label': 0}
Image saved to sample_image.png
0
ClassLabel(names=['normal', 'tumor'], id=None)
Map: 100%|██████████| 9000/9000 [00:00<00:00, 10200.15 examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 5657.86 examples/s]
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=96x96 at 0x7F34302CA980>, 'label': 0, 'messages': [{'content': [{'text': None, 'type': 'image'}, {'text': 'Is a tumor present in this histopathology image?\nA: no tumor present\nB: tumor present', 'type': 'text'}], 'role': 'user'}, {'content': [{'text': 'A: no tumor present', 'type': 'text'}], 'role': 'assistant'}]}

```

## Load Model

In [27]:
# ---------------------- Loading Model ---------------------- #

model_id = "google/medgemma-4b-it"

# Check if GPU supports bfloat16
if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
else: 
    print('GPU supports bfloat 16. You are good to go :)')

# A dictionary of model arguments - ie, 'attn_implementation' maps to 'eager'
model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Add a dictionary entry 'quantization_config' - sets the values of 5 parameters in BitsAndBytesConfig() 
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)

# This is where .apply_chat_template looks back to
processor = AutoProcessor.from_pretrained(model_id)

# Use right padding to avoid issues during training
processor.tokenizer.padding_side = "right"

GPU supports bfloat 16. You are good to go :)


Loading checkpoint shards: 100%|██████████| 2/2 [00:08<00:00,  4.42s/it]


## Set Up For Fine Tuning

In [28]:
# ----------------------  Set Up for Fine Tuning ---------------------- #

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

## Collate Function

In [29]:

# Step 1. Clone input_ids and assign to labels.
# Step 2. Mask unnecessary info
# Step 3. Add the now redacted info as a new entry in the batch called 'labels'

def collate_fn(examples: list[dict[str, Any]]):
    texts = []
    images = []
    for example in examples:
        images.append([example["image"].convert("RGB")])
        # Applies the chat template from messages and appends that to texts. 
        # Texts is a list of prompts with both A / B options, and the correct choice A or B.
        texts.append(processor.apply_chat_template(
            example["messages"],
            add_generation_prompt=False,
            tokenize=False
        ).strip())

    # Tokenize the texts and process the images
    # Contains 'input_ids'
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # These labels are tokenized version of the input
    labels = batch["input_ids"].clone()

    # Masking step begins here
    special_tokens = processor.tokenizer.special_tokens_map

    boi_token = special_tokens['boi_token']
    eoi_token = special_tokens['eoi_token']

    boi_token_id, eoi_token_id = processor.tokenizer.convert_tokens_to_ids([boi_token, eoi_token])

    # We don't want to predict image values. Any info with image token is masked since part of image.
    # Also masking padding tokens / other special tokens.
    ignore_token_ids = {
        processor.tokenizer.pad_token_id,
        boi_token_id,
        eoi_token_id,
        262144
    }

    # **Tensor masking** operation for tokens not used in the loss computation.
    for token_id in ignore_token_ids:
        labels[labels == token_id] = -100

    # 'labels' now contains how we want the model to behave: 'user: Heres an image - is it A or B?'  'model: it is A' All other info masked in labels section of batch.
    batch["labels"] = labels

    return batch

In [30]:
num_train_epochs = 1  # @param {type: "number"}
learning_rate = 2e-4  # @param {type: "number"}

args = SFTConfig(
    output_dir="medgemma-4b-it-sft-lora-PatchCamelyon",            # Directory and Hub repository id to save the model to
    num_train_epochs=num_train_epochs,                       # Number of training epochs
    per_device_train_batch_size=4,                           # Batch size per device during training
    per_device_eval_batch_size=4,                            # Batch size per device during evaluation
    gradient_accumulation_steps=4,                           # Number of steps before performing a backward/update pass
    gradient_checkpointing=True,                             # Enable gradient checkpointing to reduce memory usage
    optim="adamw_torch_fused",                               # Use fused AdamW optimizer for better performance
    logging_steps=50,                                        # Number of steps between logs
    save_strategy="epoch",                                   # Save checkpoint every epoch
    eval_strategy="steps",                                   # Evaluate every `eval_steps`
    eval_steps=50,                                           # Number of steps between evaluations
    learning_rate=learning_rate,                             # Learning rate based on QLoRA paper
    bf16=True,                                               # Use bfloat16 precision
    max_grad_norm=0.3,                                       # Max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                                       # Warmup ratio based on QLoRA paper
    lr_scheduler_type="linear",                              # Use linear learning rate scheduler
    push_to_hub=False,                                        # Push model to Hub
    report_to="tensorboard",                                 # Report metrics to tensorboard
    gradient_checkpointing_kwargs={"use_reentrant": False},  # Set gradient checkpointing to non-reentrant to avoid issues
    dataset_kwargs={"skip_prepare_dataset": True},           # Skip default dataset preparation to preprocess manually
    remove_unused_columns = False,                           # Columns are unused for training but needed for data collator
    label_names=["labels"],                                  # Input keys that correspond to the labels
)

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=data["train"],
    eval_dataset=data["validation"].shuffle().select(range(200)),  # Use subset of validation set for faster run
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

print("Batch test:", next(iter(trainer.get_train_dataloader())))

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.50 GiB. GPU 0 has a total capacity of 47.53 GiB of which 196.12 MiB is free. Process 3743819 has 37.04 GiB memory in use. Including non-PyTorch memory, this process has 10.28 GiB memory in use. Of the allocated memory 9.89 GiB is allocated by PyTorch, and 93.68 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

: 

## Tensor Masking

In [None]:
my_tensor = torch.tensor([2, 3, 4])
print(my_tensor)

# --- Tensor Masking --- #
hide_this =  3
my_tensor[my_tensor == hide_this] = 0
print(my_tensor)

tensor([2, 3, 4])
tensor([2, 0, 4])


In [None]:
# Take first three examples from data['train']
data_subset = data['train'].select(range(3))
batch = collate_fn(data_subset)
#print(f'BATCH: ', batch)

# --------------------- Decoding some tokens ----------------------- #

# Cannot decode -100 - keep in mind
# This is some of the info that remains un-masked mask
print(processor.tokenizer.decode([108,   4602,    496,  17491,   1861,    528,    672,
           2441, 118234,   2471, 236881,    107, 236776, 236787,    951,  17491,
           1861,    107, 236799, 236787,  17491,   1861,    106,    107,    105,
           4368,    107, 236799, 236787,  17491,   1861,    106]))
print('\n')




Is a tumor present in this histopathology image?
A: no tumor present
B: tumor present<end_of_turn>
<start_of_turn>model
B: tumor present<end_of_turn>




## Token Flags check if we need to mask more tokens

In [None]:
# -------------------- Need to mask more tokens! ------------------- #
special_tokens = processor.tokenizer.special_tokens_map

boi_token = special_tokens['boi_token']
eoi_token = special_tokens['eoi_token']

boi_token_id, eoi_token_id = processor.tokenizer.convert_tokens_to_ids([boi_token, eoi_token])

# consider just input_ids
input_ids = batch["input_ids"]

token_flags = {
    'EOI': (input_ids == eoi_token_id).any().item(),
    'BOI': (input_ids == boi_token_id).any().item()
}

for name, found in token_flags.items():
    print(f'{name} token found in input ids' if found else f'{name} token not found in input ids')

EOI token found in input ids
BOI token found in input ids


```text
EOI token found in input ids
BOI token found in input ids
```

## Image token represenation

In [None]:
input_ids = batch["input_ids"][0].tolist()
num_image_tokens = input_ids.count(262144)
print(f"Number of <image_soft_token> tokens: {num_image_tokens}")


Number of <image_soft_token> tokens: 256


```text
Number of <image_soft_token> tokens: 256
```


- Each image: 96 × 96 pixels
- Number of image tokens: 256  
- Therefore:  

  $\frac{96 \times 96}{256} = 36 \text{ pixels per patch} \Rightarrow \sqrt{36} = 6 \times 6 \text{ pixels per patch}$
  
- The image is divided into 256 patches, each of size 6×6 pixels.  
- These patches are flattened, encoded, and each gets represented by one token (`<image_soft_token>`, token ID 262144).  
- So:  

  $96 \times 96 \text{ image} \quad \rightarrow \quad 256 \text{ tokens} \quad (\text{each representing a 6×6 pixel patch})$

## Common methods for dataset.Dataset Object

In [None]:
def my_func(input):
    return input

# Access rows
print(data_subset[0])

# Select subset of rows
data_subset_small = data_subset.select(range(2))

# Apply transformation to all rows
data_subset_mapped = data_subset.map(my_func)

# Split into train/val
split_data = data_subset.train_test_split(test_size=0.2)

# Shuffle rows
shuffled = data_subset.shuffle(seed=42)

{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=96x96 at 0x7F777CEF4040>, 'label': 0, 'messages': [{'content': [{'text': None, 'type': 'image'}, {'text': 'Is a tumor present in this histopathology image?\nA: no tumor present\nB: tumor present', 'type': 'text'}], 'role': 'user'}, {'content': [{'text': 'A: no tumor present', 'type': 'text'}], 'role': 'assistant'}]}


Map: 100%|██████████| 3/3 [00:00<00:00, 30.55 examples/s]


## Inspecting Batch

In [None]:
# Length of input_ids[0] is the same as attention_mask[0] within batch
if len(batch['input_ids'][0]) == len(batch['attention_mask'][0]):
    print('same length')
else: print('different length')

# Length of each batch['token_type_ids'][i] is 296
batch_ttis = batch['token_type_ids']
length = len(batch['token_type_ids'])
for i in range(length):
    print(len(batch_ttis[i]))

same length
296
296
296


```text
same length
```

This means that the input_ids and the attention_mask have the same length

In [None]:






# ---------------- scratch work ------------- #
# tensor_a = torch.rand(5)
# print(tensor_a)
# print(len(tensor_a))