In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4" 

# File Structure

```text
patchcamelyon_subset/
├── tumor/
│   ├── img00001.png
│   ├── img00002.png
│   └── ...
└── normal/
    ├── img00001.png
    ├── img00002.png
    └── ...
```

## Imports

In [None]:
# Other
import pandas as pd

# For set up
from datasets import load_dataset
from typing import Any

# For Loading Model
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

# For fine tuning
from peft import LoraConfig
from trl import SFTTrainer
from trl import SFTConfig


In [None]:
print("Number of GPUs visible:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device())
print("GPU name:", torch.cuda.get_device_name(torch.cuda.current_device()))

## Set Up

In [None]:

# ---------------------- Set Up ---------------------- #

train_size = 9000 
validation_size = 1000

# Downloaded and organized a subset of the patchcamelyon data set (first 10K images) into
# a folder called patchcamelyon_subset. Has sub folders "normal" and "tumor"
data = load_dataset("./patchcamelyon_subset", split="train")

```text
Dataset({
    features: ['image', 'label'],
    num_rows: 10000
})
data['image'][0]:  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=96x96 at 0x7FA624E71B40>
data['label'][0]:  0
<class 'datasets.arrow_dataset.Dataset'>
<class 'list'>
<class 'PIL.PngImagePlugin.PngImageFile'>
<class 'list'>
<class 'int'>
```

- data is a dataset.Dataset object. 
- data['image'] and data['label'] are lists
- data['image'][0] and data['label][0] is a PIL image and int respectively.
- N.B. data[0] **is** a dict object :)



In [None]:

data = data.train_test_split(
    train_size=train_size,
    test_size=validation_size,
    shuffle=True,
    seed=42,
)
# rename the 'test' set to 'validation'
data["validation"] = data.pop("test")


- data now has type dataset_dict.DatasetDict:

```text
DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 9000
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1000
    })
})
```

- we can no longer access data[0], since data now is a dictionary - whose two entries in turn are dataset.Dataset objects
- data['train'], for example, is:

```text
Dataset({
    features: ['image', 'label'],
    num_rows: 9000
})
```
- and data['train'][0] is a dict object, consisting of 'image' and 'label' keys plus their values

In [None]:

HISTOPATHOLOGY_CLASSES = [
    # One option for each class
    "A: no tumor present",
    "B: tumor present"
]

options = "\n".join(HISTOPATHOLOGY_CLASSES)
PROMPT = f"Is a tumor present in this histopathology image?\n{options}"

In [None]:

# 'example' is the name of the input here - input is a dict.
# The key for this dict is a str and the value can be of Any type
def format_data(example: dict[str, Any]) -> dict[str, Any]:
    # adds a new entry to the dict
    example["messages"] = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                },
                {
                    "type": "text",
                    "text": PROMPT,
                },
            ],
        },
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    # label of 0 will map to: (A: no tumor present), label of 1 will map to: (B: tumor present)
                    "text": HISTOPATHOLOGY_CLASSES[example["label"]],
                },
            ],
        },
    ]
    # Returns a dict with the same structure - but now {'image':blah, 'label':hmmm, 'message':blumph}
    return example

In [None]:
formatted_example = format_data(data['train'][0])

## Comparison showing the effect of format_data() function:
 
- format_data() adds a key (message) and value (the text of the message) to data['train'][0]:

```text
format_data(data['train'][0]):
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=96x96 at 0x7FA67FC65D50>, 'label': 0, 'messages': [{'role': 'user', 'content': [{'type': 'image'}, {'type': 'text', 'text': 'Is a tumor present in this histopathology image?\nA: no tumor present\nB: tumor present'}]}, {'role': 'assistant', 'content': [{'type': 'text', 'text': 'A: no tumor present'}]}]}

data['train'][1]:
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=96x96 at 0x7FA662E299F0>, 'label': 1}
```

## Now we format **all** the data:

In [None]:
data = data.map(format_data)

### Map function gives nice output:
```text
Map: 100%|██████████| 9000/9000 [00:00<00:00, 15013.28 examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 6887.63 examples/s]
```
- data is still a dataset_dict.DatasetDict object
- data['train'] and data['validation] are still dataset.Dataset objects
- data['train'][0] and data['validation][0] are stil dict objects. And now they have three keys: 'image', 'label', 'messages'

```text
DatasetDict({
    train: Dataset({
        features: ['image', 'label', 'messages'],
        num_rows: 9000
    })
    validation: Dataset({
        features: ['image', 'label', 'messages'],
        num_rows: 1000
    })
})

```

## Load Model

In [None]:
# ---------------------- Loading Model ---------------------- #

model_id = "google/medgemma-4b-it"

# Check if GPU supports bfloat16
# major must be 8 to support bfloat16
if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
else: 
    print('GPU supports bfloat 16. You are good to go :)')

model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    # optimal device map when using one GPU.
    device_map="auto",
)

# Add a dictionary entry 'quantization_config' - sets the values of 5 parameters in BitsAndBytesConfig() 
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)
# model is assigned the pretrained model (google/medgemma-4b-it) with the specifications (model_kwargs)
# ** unpacks the dictionary values as arguments to the from_pretrained function
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)

# This is where .apply_chat_template looks back to
processor = AutoProcessor.from_pretrained(model_id)

# Use right padding to avoid issues during training
processor.tokenizer.padding_side = "right"

## Set Up For Fine Tuning

In [None]:
# ----------------------  Set Up for Fine Tuning ---------------------- #
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

- N.B. data['train'][0] is:
```text
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=96x96 at 0x7FA663C85EA0>, 'label': 0, 'messages': [{'content': [{'text': None, 'type': 'image'}, {'text': 'Is a tumor present in this histopathology image?\nA: no tumor present\nB: tumor present', 'type': 'text'}], 'role': 'user'}, {'content': [{'text': 'A: no tumor present', 'type': 'text'}], 'role': 'assistant'}]}
```
- That is, a dict with: {'image':value, 'label':value, 'messages':value}

In [None]:
# has type dict
special_tokens = processor.tokenizer.special_tokens_map

# Get the first key value pair from special_tokens dict
first_key = next(iter(special_tokens))
print(first_key, special_tokens[first_key])

- special_tokens is an object of type dict.
- the first key is bos_token, and the first value is bos

```text
bos_token <bos>
```

## Collate Function

In [None]:
# Step 1. Clone input_ids and assign to labels.
# Step 2. Mask unnecessary info
# Step 3. Add the now redacted info as a new entry in the batch called 'labels'

def collate_fn(examples: list[dict[str, Any]]):
    
    input_ids_list = []
    attention_mask_list = []
    pixel_values_list = []
    token_type_ids_list = []
    
    for example in examples:
        image = example["image"].convert("RGB")
        # Applies the chat template from messages and appends that to texts. 
        # Texts is a list of prompts with both A / B options, and the correct choice A or B.
        text = processor.apply_chat_template(
            example["messages"],
            add_generation_prompt=False,
            tokenize=False
        ).strip()

        processed = processor(text=text, images=image, return_tensors="pt", padding=True)

        # Add single processed example lists
        input_ids_list.append(processed["input_ids"][0])
        attention_mask_list.append(processed["attention_mask"][0])
        token_type_ids_list.append(processed['token_type_ids'][0])
        pixel_values_list.append(processed["pixel_values"][0])

    # Pad sequences - after having added all examples to the lists. Ensures all examples have same length values for given keys.
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids_list, batch_first=True, padding_value=processor.tokenizer.pad_token_id)
    attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask_list, batch_first=True, padding_value=0)
    token_type_ids = torch.nn.utils.rnn.pad_sequence(token_type_ids_list, batch_first=True, padding_value=0)
    pixel_values = torch.stack(pixel_values_list)

    # ------------------- Label / Masking Step ------------------- #

    # We want to predict the text output part of the input. We will later mask the image part.
    labels = input_ids.clone()

    # Mask special tokens
    special_tokens = processor.tokenizer.special_tokens_map
    boi_token_id, eoi_token_id = processor.tokenizer.convert_tokens_to_ids([
        special_tokens['boi_token'], special_tokens['eoi_token']
    ])

    # We don't want to predict image values. Any info with image token is masked since part of image.
    # Also masking padding tokens / other special tokens.
    ignore_token_ids = {
        processor.tokenizer.pad_token_id,
        boi_token_id,
        eoi_token_id,
        262144,  # Optional: image token
    }

    # **Tensor masking** operation for tokens not used in the loss computation.
    # 'labels' now contains how we want the model to behave: 'user: Heres an image - is it A or B?'  'model: it is A' All other info masked in labels section of batch.
    for token_id in ignore_token_ids:
        labels[labels == token_id] = -100

    # ------------------------------------------------------------- #

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "token_type_ids": token_type_ids,
        "pixel_values": pixel_values,
        "labels": labels,
    }

- 'lonely_batch is the result of applying collating function to one single input.
- In 'lonely_batch' (of type dict), we can see that there is a new input: 'labels'
- 'labels' is identical to 'input_ids' but with certain tokens masked.

In [None]:
example = data['train'][0]
lonely_batch = collate_fn([example])
print(type(lonely_batch))
print(lonely_batch)

```text
<class 'dict'>
{'input_ids': tensor([[     2,      2,    105,   2364,    109, 255999, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 256000,    108,   4602,    496,  17491,   1861,    528,    672,
           2441, 118234,   2471, 236881,    107, 236776, 236787,    951,  17491,
           1861,    107, 236799, 236787,  17491,   1861,    106,    107,    105,
           4368,    107, 236776, 236787,    951,  17491,   1861,    106]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]]), 'pixel_values': tensor([[[[ 0.8745,  0.8745,  0.8745,  ...,  0.3255,  0.3255,  0.3255],
          [ 0.8745,  0.8745,  0.8745,  ...,  0.3255,  0.3255,  0.3255],
          [ 0.8745,  0.8745,  0.8745,  ...,  0.3255,  0.3255,  0.3255],
          ...,
          [ 0.6863,  0.6863,  0.6863,  ..., -0.0980, -0.0980, -0.0980],
          [ 0.6863,  0.6863,  0.6863,  ..., -0.0980, -0.0980, -0.0980],
          [ 0.6863,  0.6863,  0.6863,  ..., -0.0980, -0.0980, -0.0980]],

         [[ 0.5451,  0.5451,  0.5451,  ..., -0.1216, -0.1216, -0.1216],
          [ 0.5451,  0.5451,  0.5451,  ..., -0.1216, -0.1216, -0.1216],
          [ 0.5451,  0.5451,  0.5451,  ..., -0.1216, -0.1216, -0.1216],
          ...,
          [ 0.2392,  0.2392,  0.2392,  ..., -0.3725, -0.3725, -0.3725],
          [ 0.2392,  0.2392,  0.2392,  ..., -0.3725, -0.3725, -0.3725],
          [ 0.2392,  0.2392,  0.2392,  ..., -0.3725, -0.3725, -0.3725]],

         [[ 0.6549,  0.6549,  0.6549,  ...,  0.2000,  0.2000,  0.2000],
          [ 0.6549,  0.6549,  0.6549,  ...,  0.2000,  0.2000,  0.2000],
          [ 0.6549,  0.6549,  0.6549,  ...,  0.2000,  0.2000,  0.2000],
          ...,
          [ 0.3882,  0.3882,  0.3882,  ...,  0.1608,  0.1608,  0.1608],
          [ 0.3882,  0.3882,  0.3882,  ...,  0.1608,  0.1608,  0.1608],
          [ 0.3882,  0.3882,  0.3882,  ...,  0.1608,  0.1608,  0.1608]]]]), 'labels': tensor([[     2,      2,    105,   2364,    109,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,    108,   4602,    496,  17491,   1861,    528,    672,
           2441, 118234,   2471, 236881,    107, 236776, 236787,    951,  17491,
           1861,    107, 236799, 236787,  17491,   1861,    106,    107,    105,
           4368,    107, 236776, 236787,    951,  17491,   1861,    106]])}
```

In [None]:
num_train_epochs = 1  # @param {type: "number"}
learning_rate = 2e-4  # @param {type: "number"}

args = SFTConfig(
    output_dir="medgemma-4b-it-sft-lora-PatchCamelyon",            # Directory and Hub repository id to save the model to
    num_train_epochs=num_train_epochs,                       # Number of training epochs
    per_device_train_batch_size=4,                           # Batch size per device during training
    per_device_eval_batch_size=4,                            # Batch size per device during evaluation
    gradient_accumulation_steps=4,                           # Number of steps before performing a backward/update pass
    gradient_checkpointing=True,                             # Enable gradient checkpointing to reduce memory usage
    optim="adamw_torch_fused",                               # Use fused AdamW optimizer for better performance
    logging_steps=50,                                        # Number of steps between logs
    save_strategy="epoch",                                   # Save checkpoint every epoch
    eval_strategy="steps",                                   # Evaluate every `eval_steps`
    eval_steps=50,                                           # Number of steps between evaluations
    learning_rate=learning_rate,                             # Learning rate based on QLoRA paper
    bf16=True,                                               # Use bfloat16 precision
    max_grad_norm=0.3,                                       # Max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                                       # Warmup ratio based on QLoRA paper
    lr_scheduler_type="linear",                              # Use linear learning rate scheduler
    push_to_hub=False,                                        # Push model to Hub
    report_to="tensorboard",                                 # Report metrics to tensorboard
    gradient_checkpointing_kwargs={"use_reentrant": False},  # Set gradient checkpointing to non-reentrant to avoid issues
    dataset_kwargs={"skip_prepare_dataset": True},           # Skip default dataset preparation to preprocess manually
    remove_unused_columns = False,                           # Columns are unused for training but needed for data collator
    label_names=["labels"],                                  # Input keys that correspond to the labels
)

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=data["train"],
    eval_dataset=data["validation"].shuffle().select(range(200)),  # Use subset of validation set for faster run
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

print("Batch test:", next(iter(trainer.get_train_dataloader())))

## Tensor Masking

In [None]:
my_tensor = torch.tensor([2, 3, 4])
print(my_tensor)

# --- Tensor Masking --- #
hide_this =  3
my_tensor[my_tensor == hide_this] = 0
print(my_tensor)

In [None]:
# Take first three examples from data['train']
data_subset = data['train'].select(range(1))
batch = collate_fn(data_subset)
#print(f'BATCH: ', batch)

# --------------------- Decoding some tokens ----------------------- #

# Cannot decode -100 - keep in mind
# This is some of the info that remains un-masked mask
print(processor.tokenizer.decode([108,   4602,    496,  17491,   1861,    528,    672,
           2441, 118234,   2471, 236881,    107, 236776, 236787,    951,  17491,
           1861,    107, 236799, 236787,  17491,   1861,    106,    107,    105,
           4368,    107, 236799, 236787,  17491,   1861,    106]))
print('\n')



## Token Flags check if we need to mask more tokens

In [None]:
# -------------------- Need to mask more tokens! ------------------- #
special_tokens = processor.tokenizer.special_tokens_map

boi_token = special_tokens['boi_token']
eoi_token = special_tokens['eoi_token']

boi_token_id, eoi_token_id = processor.tokenizer.convert_tokens_to_ids([boi_token, eoi_token])

# consider just input_ids
input_ids = batch["input_ids"]

token_flags = {
    'EOI': (input_ids == eoi_token_id).any().item(),
    'BOI': (input_ids == boi_token_id).any().item()
}

for name, found in token_flags.items():
    print(f'{name} token found in input ids' if found else f'{name} token not found in input ids')

```text
EOI token found in input ids
BOI token found in input ids
```

## Image token represenation

In [None]:
input_ids = batch["input_ids"][0].tolist()
num_image_tokens = input_ids.count(262144)
print(f"Number of <image_soft_token> tokens: {num_image_tokens}")


```text
Number of <image_soft_token> tokens: 256
```


- Each image: 96 × 96 pixels
- Number of image tokens: 256  
- Therefore:  

  $\frac{96 \times 96}{256} = 36 \text{ pixels per patch} \Rightarrow \sqrt{36} = 6 \times 6 \text{ pixels per patch}$
  
- The image is divided into 256 patches, each of size 6×6 pixels.  
- These patches are flattened, encoded, and each gets represented by one token (`<image_soft_token>`, token ID 262144).  
- So:  

  $96 \times 96 \text{ image} \quad \rightarrow \quad 256 \text{ tokens} \quad (\text{each representing a 6×6 pixel patch})$

## Common methods for dataset.Dataset Object

In [None]:
def my_func(input):
    return input

# Access rows
print(data_subset[0])

# Select subset of rows
data_subset_small = data_subset.select(range(2))

# Apply transformation to all rows
data_subset_mapped = data_subset.map(my_func)

# Split into train/val
split_data = data_subset.train_test_split(test_size=0.2)

# Shuffle rows
shuffled = data_subset.shuffle(seed=42)

## Inspecting Batch

In [None]:
# Length of input_ids[0] is the same as attention_mask[0] within batch
if len(batch['input_ids'][0]) == len(batch['attention_mask'][0]):
    print('same length')
else: print('different length')

# Length of each batch['token_type_ids'][i] is 296
batch_ttis = batch['token_type_ids']
length = len(batch['token_type_ids'])
for i in range(length):
    print(len(batch_ttis[i]))

```text
same length
```

This means that the input_ids and the attention_mask have the same length

# Evaluating Fine Tuned Model

In [23]:

# Will be evaluating the finetuned model here

from utils import load_model_and_processor
from datasets import load_dataset


import evaluate_model

# model, processor = load_model_and_processor()
# model.eval()

raw = load_dataset("./patchcamelyon_test")
test_data = raw["train"]
test_data = test_data.shuffle(seed=42).select(range(1000))


HISTOPATHOLOGY_CLASSES = [
    # One option for each class
    "A: no tumor present",
    "B: tumor present"
]

options = "\n".join(HISTOPATHOLOGY_CLASSES)
PROMPT = f"Is a tumor present in this histopathology image?\n{options}"

def format_test_data(example: dict[str, Any]) -> dict[str, Any]:
    example["messages"] = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                },
                {
                    "type": "text",
                    "text": PROMPT,
                },
            ],
        },
    ]
    return example

test_data = test_data.map(format_test_data)

accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

REFERENCES = test_data["label"]

print(test_data['label'])



[1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 

In [None]:
def compute_metrics(predictions: list[int]) -> dict[str, float]:
    metrics = {}
    metrics.update(accuracy_metric.compute(
        predictions=predictions,
        references=REFERENCES,
    ))
    metrics.update(f1_metric.compute(
        predictions=predictions,
        references=REFERENCES,
        average="weighted",
    ))
    return metrics

In [None]:
from datasets import ClassLabel

# Rename the class names to the tissue classes, `X: tissue type`
test_data = test_data.cast_column(
    "label",
    ClassLabel(names=HISTOPATHOLOGY_CLASSES)
)


In [None]:
print(test_data['label'])
LABEL_FEATURE = test_data.features["label"]

# Mapping to alternative label format, `(X) tissue type`
ALT_LABELS = dict([
    (label, f"({label.replace(': ', ') ')}") for label in HISTOPATHOLOGY_CLASSES
])


def postprocess(prediction: list[dict[str, str]], do_full_match: bool=False) -> int:
    response_text = prediction[0]["generated_text"]
    if do_full_match:
        return LABEL_FEATURE.str2int(response_text)
    for label in HISTOPATHOLOGY_CLASSES:
        # Search for `X: tissue type` or `(X) tissue type` in the response
        if label in response_text or ALT_LABELS[label] in response_text:
            return LABEL_FEATURE.str2int(label)
    return -1

In [None]:
from transformers import pipeline

pt_pipe = pipeline(
    "image-text-to-text",
    model=model_id,
    torch_dtype=torch.bfloat16,
)

# Set `do_sample = False` for deterministic responses
pt_pipe.model.generation_config.do_sample = False
pt_pipe.model.generation_config.pad_token_id = processor.tokenizer.eos_token_id

In [None]:


pt_outputs = pt_pipe(
    text=test_data["messages"],
    images=test_data["image"],
    max_new_tokens=40,
    batch_size=64,
    return_full_text=False,
)

pt_predictions = [postprocess(out) for out in pt_outputs]


pt_metrics = compute_metrics(pt_predictions)
print(f"Baseline metrics: {pt_metrics}")
     

In [20]:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, pipeline
from peft import PeftModel

base_model_id = model_id
lora_check_point_path = './medgemma-4b-it-sft-lora-PatchCamelyon/checkpoint-252'

base_model = AutoModelForImageTextToText.from_pretrained(
    base_model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model = PeftModel.from_pretrained(base_model, lora_check_point_path)
model = model.merge_and_unload()  # Applies the LoRA weights to the original model
model.eval()

processor = AutoProcessor.from_pretrained(base_model_id)

ft_pipe = pipeline(
    "image-text-to-text",
    model=model,  
    processor=processor,
    torch_dtype=torch.bfloat16,
)

# Optional inference tweaks
ft_pipe.model.generation_config.do_sample = False
ft_pipe.model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
processor.tokenizer.padding_side = "left"

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.30s/it]
Device set to use cuda:0


In [21]:

ft_outputs = ft_pipe(
    text=test_data["messages"],
    images=test_data["image"],
    max_new_tokens=20,
    batch_size=64,
    return_full_text=False,
)

ft_predictions = [postprocess(out, do_full_match=True) for out in ft_outputs]

In [24]:

ft_metrics = compute_metrics(ft_predictions)
print(f"Fine-tuned metrics: {ft_metrics}")

Fine-tuned metrics: {'accuracy': 0.869, 'f1': 0.86852804562022}
