In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

# File Structure

```text
patchcamelyon_subset/
├── tumor/
│   ├── img00001.png
│   ├── img00002.png
│   └── ...
└── normal/
    ├── img00001.png
    ├── img00002.png
    └── ...
```

## Imports

In [3]:
# Other
import pandas as pd

# For set up
from datasets import load_dataset
from typing import Any

# For Loading Model
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

# For fine tuning
from peft import LoraConfig


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
print("Number of GPUs visible:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device())
print("GPU name:", torch.cuda.get_device_name(torch.cuda.current_device()))

Number of GPUs visible: 1
Current device: 0
GPU name: NVIDIA RTX A6000


## Set Up

In [None]:

# ---------------------- Set Up ---------------------- #

train_size = 9000 
validation_size = 1000 

# Downloaded and organized a subset of the patchcamelyon data set (first 10K images) into
# a folder called patchcamelyon_subset. Has sub folders "normal" and "tumor"
data = load_dataset("./patchcamelyon_subset", split="train")
data = data.train_test_split(
    train_size=train_size,
    test_size=validation_size,
    shuffle=True,
    seed=42,
)
# rename the 'test' set to 'validation'
data["validation"] = data.pop("test")

# ------------ Optional: display dataset details ------------ #
print(data) 
# This is actually a dictionary - it contains {'image':blah, 'label':hmmm}
print(f"data['train'][0]: {data['train'][0]}")
# First image in the training data
image = data['train'][0]['image']
# First label in the training data
label = data['train'][0]['label']
image.save("sample_image.png")
print("Image saved to sample_image.png")
print(label)
print(data['train'].features['label'])
# ----------------------------------------------------------- #

HISTOPATHOLOGY_CLASSES = [
    # One option for each class
    "A: no tumor present",
    "B: tumor present"
]

options = "\n".join(HISTOPATHOLOGY_CLASSES)
PROMPT = f"Is a tumor present in this histopathology image?\n{options}"

# 'example' is the name of the input here - input is a dict.
# The key for this dict is a str and the value can be of Any type
def format_data(example: dict[str, Any]) -> dict[str, Any]:
    # adds a new entry to the dict
    example["messages"] = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                },
                {
                    "type": "text",
                    "text": PROMPT,
                },
            ],
        },
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    # label of 0 will map to: (A: no tumor present), label of 1 will map to: (B: tumor present)
                    "text": HISTOPATHOLOGY_CLASSES[example["label"]],
                },
            ],
        },
    ]
    # Returns a dict with the same structure - but now {'image':blah, 'label':hmmm, 'message':blumph}
    return example

data = data.map(format_data)
print(data['train'][0])

```text
DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 9000
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1000
    })
})
data['train'][0]: {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=96x96 at 0x7F34302CA3E0>, 'label': 0}
Image saved to sample_image.png
0
ClassLabel(names=['normal', 'tumor'], id=None)
Map: 100%|██████████| 9000/9000 [00:00<00:00, 10200.15 examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 5657.86 examples/s]
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=96x96 at 0x7F34302CA980>, 'label': 0, 'messages': [{'content': [{'text': None, 'type': 'image'}, {'text': 'Is a tumor present in this histopathology image?\nA: no tumor present\nB: tumor present', 'type': 'text'}], 'role': 'user'}, {'content': [{'text': 'A: no tumor present', 'type': 'text'}], 'role': 'assistant'}]}

```

## Load Model

In [None]:
# ---------------------- Loading Model ---------------------- #

model_id = "google/medgemma-4b-it"

# Check if GPU supports bfloat16
if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
else: 
    print('GPU supports bfloat 16. You are good to go :)')

# A dictionary of model arguments - ie, 'attn_implementation' maps to 'eager'
model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Add a dictionary entry 'quantization_config' - sets the values of 5 parameters in BitsAndBytesConfig() 
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)

# This is where .apply_chat_template looks back to
processor = AutoProcessor.from_pretrained(model_id)

# Use right padding to avoid issues during training
processor.tokenizer.padding_side = "right"

## Set Up For Fine Tuning

In [None]:
# ----------------------  Set Up for Fine Tuning ---------------------- #

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

## Collate Function

In [117]:

# Step 1. Clone input_ids and assign to labels.
# Step 2. Mask unnecessary info
# Step 3. Add the now redacted info as a new entry in the batch called 'labels'

def collate_fn(examples: list[dict[str, Any]]):
    texts = []
    images = []
    for example in examples:
        images.append([example["image"].convert("RGB")])
        # Applies the chat template from messages and appends that to texts. 
        # Texts is a list of prompts with both A / B options, and the correct choice A or B.
        texts.append(processor.apply_chat_template(
            example["messages"],
            add_generation_prompt=False,
            tokenize=False
        ).strip())

    # Tokenize the texts and process the images
    # Contains 'input_ids'
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    

    # These labels are tokenized version of the input
    labels = batch["input_ids"].clone()

    # Mask boi_token (255999). (B)eggining (O)f (I)mage token
    boi_token_id = processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )

    # **Tensor masking** operation for tokens not used in the loss computation.
    # For instance, we don't want to predict image values. Any info with image token is masked since part of image.
    # Also masking padding tokens / other special tokens.
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == boi_token_id] = -100
    labels[labels == 262144] = -100

    # 'labels' contains how we want the model to behave: "Heres an image - is it A or B? A" for example. All other info masked.
    batch["labels"] = labels

    return batch


## Tensor Masking

In [126]:
my_tensor = torch.tensor([2, 3, 4])
print(my_tensor)

# --- Tensor Masking --- #
hide_this =  3
my_tensor[my_tensor == hide_this] = 0
print(my_tensor)

tensor([2, 3, 4])
tensor([2, 0, 4])


In [120]:
# Take first three examples from data['train']
data_subset = data['train'].select(range(3))
batch = collate_fn(data_subset)
#print(f'BATCH: ', batch)

# --------------------- Decoding some tokens ----------------------- #

# Cannot decode -100 - keep in mind
# This is some of the info that remains un-masked mask
print(processor.tokenizer.decode([108,   4602,    496,  17491,   1861,    528,    672,
           2441, 118234,   2471, 236881,    107, 236776, 236787,    951,  17491,
           1861,    107, 236799, 236787,  17491,   1861,    106,    107,    105,
           4368,    107, 236799, 236787,  17491,   1861,    106]))
print('\n')
# -------------------- Need to mask more tokens? ------------------- #

# consider just input_ids
input_ids = batch["input_ids"]
has_eoi = (input_ids == eoi_token_id).any().item()

if has_eoi:
    print("<eoi> token is present in input_ids.")
else:
    print("<eoi> token is NOT present in input_ids.")




Is a tumor present in this histopathology image?
A: no tumor present
B: tumor present<end_of_turn>
<start_of_turn>model
B: tumor present<end_of_turn>


<eoi> token is present in input_ids.


```text
<class 'datasets.arrow_dataset.Dataset'>
<bos><bos><start_of_turn>user


<start_of_image><image_soft_token><image_soft_token>
```

In [None]:
input_ids = batch["input_ids"][0].tolist()
num_image_tokens = input_ids.count(262144)
print(f"Number of <image_soft_token> tokens: {num_image_tokens}")


```text
Number of <image_soft_token> tokens: 256
```


- Each image: 96 × 96 pixels
- Number of image tokens: 256  
- Therefore:  

  $\frac{96 \times 96}{256} = 36 \text{ pixels per patch} \Rightarrow \sqrt{36} = 6 \times 6 \text{ pixels per patch}$
  
- The image is divided into 256 patches, each of size 6×6 pixels.  
- These patches are flattened, encoded, and each gets represented by one token (`<image_soft_token>`, token ID 262144).  
- So:  

  $96 \times 96 \text{ image} \quad \rightarrow \quad 256 \text{ tokens} \quad (\text{each representing a 6×6 pixel patch})$

In [None]:

# Finding the length of each batch['token_type_ids'][i] 
batch_ttis = batch['token_type_ids']
length = len(batch['token_type_ids'])
for i in range(length):
    print(len(batch_ttis[i]))



In [111]:
eoi_token_id = processor.tokenizer.convert_tokens_to_ids(
    processor.tokenizer.special_tokens_map["eoi_token"]
)
print("EOI token ID:", eoi_token_id)

EOI token ID: 256000


In [110]:
print(processor.tokenizer.special_tokens_map)
print(processor.tokenizer.special_tokens_map['boi_token'])
print(processor.tokenizer.convert_tokens_to_ids('<start_of_image>'))

{'bos_token': '<bos>', 'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'boi_token': '<start_of_image>', 'eoi_token': '<end_of_image>', 'image_token': '<image_soft_token>'}
<start_of_image>
255999


## Common methods for dataset.Dataset Object

In [None]:
def my_func(input):
    return input

# Access rows
print(data_subset[0])

# Select subset of rows
data_subset_small = data_subset.select(range(2))

# Apply transformation to all rows
data_subset_mapped = data_subset.map(my_func)

# Split into train/val
split_data = data_subset.train_test_split(test_size=0.2)

# Shuffle rows
shuffled = data_subset.shuffle(seed=42)

In [None]:
if len(batch['input_ids'][0]) == len(batch['attention_mask'][0]):
    print('same length')
else: print('different length')

print(len(batch['input_ids'][0]))

print(len(batch['token_type_ids'][0]))

```text
same length
```

This means that the input_ids and the attention_mask have the same length

In [None]:






# ---------------- scratch work ------------- #
# tensor_a = torch.rand(5)
# print(tensor_a)
# print(len(tensor_a))