In [None]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Will be evaluating the finetuned model here

import torch
import utils
from datasets import load_dataset
from datasets import ClassLabel
from transformers import AutoModelForImageTextToText, AutoProcessor, PaliGemmaForConditionalGeneration, pipeline
from peft import PeftModel
import evaluate
from typing import Any
from tqdm import tqdm


## Test data set loading

In [None]:

# raw has type: <class 'datasets.dataset_dict.DatasetDict'> and has only one entry - raw['train']. (check: len(raw))
raw = load_dataset("./patchcamelyon_test")

# raw['train'] is of type <class 'datasets.arrow_dataset.Dataset'>, and has 2000 entries (check: len(raw['train']))
test_data = raw["train"]

# take the first 10 entries for test dataa
test_data = test_data.shuffle(seed=42).select(range(10))


In [None]:

# <class 'list'>
HISTOPATHOLOGY_CLASSES = [
    # One option for each class
    "A: no tumor present",
    "B: tumor present"
]

# options has type <class 'str'>
options = "\n".join(HISTOPATHOLOGY_CLASSES)
# <class 'str'>
PROMPT = f"Is a tumor present in this histopathology image?\n{options}"


Need to think about why this is the structure of the evaluation file.

In [None]:

# add another list to test_data. test_data['messages'] has type <class: 'list'>
# elements of test_data['messages'] are lists of len = 1
def format_test_data(example: dict[str, Any]) -> dict[str, Any]:
    example["messages"] = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                },
                {
                    "type": "text",
                    "text": PROMPT,
                },
            ],
        },
    ]
    return example


In [None]:

# <class 'datasets.arrow_dataset.Dataset'>
test_data = test_data.map(format_test_data)

# importing accuracy and f1 metrics from evaluate
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

# Ground truth labels
# <class 'list'>
# REFERENCES[0] has type int
REFERENCES = test_data["label"]


In [None]:
# cast the label column to new names:
test_data = test_data.cast_column('label', ClassLabel(names=['pos', 'neg']))

#print again to check:
print(test_data['label'][0])

# <class 'datasets.features.features.ClassLabel'> has an int2str method.
# test_data.features['label'] is the ClassLabel object which has the labels of the integer classes. 
# test_data['label'][0] is an integer. And int2str converts that integer into its class label.
print(test_data.features['label'].int2str(test_data['label'][0]))

# <class 'datasets.features.features.Features'>
test_data.features
# <class 'datasets.features.image.Image'>
test_data.features['image']
# <class 'datasets.features.features.ClassLabel'>
test_data.features['label']
# <class 'list'>
test_data.features['messages']

# some dataset.Dataset methods
# print(test_data.num_columns)
# print(test_data.num_rows)
# print(test_data.column_names)
# print(test_data.shape)


# Computing Metrics

In [None]:
# Making a new dict
metrics_dict = {}

# updating dict with new data
example_data = {'new_data': 42}
metrics_dict.update(example_data)
print('Updated dict:', metrics_dict)

# computing metrics for p against r
p = [1, 2, 4]
r = [1, 2, 3]
metrics_dict.update(accuracy_metric.compute(
    predictions = p,
    references = r
))
print("Metrics dict with accuracy metric:", metrics_dict)

metrics_dict.update(f1_metric.compute(
    predictions = p,
    references = r,
    average = 'weighted'
))
print("Metrics dict with accuracy and f1 metrics:", metrics_dict)


In [None]:
def compute_metrics(predictions: list[int]) -> dict[str, float]:
    # <class 'dict'>
    metrics = {}
    metrics.update(accuracy_metric.compute(
        # takes ft_predictions which is list[int]
        predictions=predictions, # <class: list> (see the function definition)
        # REFEFENCES has type list[int]
        references=REFERENCES, # <class: list> (definition of REFERENCES)
    ))
    metrics.update(f1_metric.compute(
        predictions=predictions,
        references=REFERENCES,
        average="weighted",
    ))
    return metrics

## Post Processing

In [None]:

# Rename the class names to the tissue classes, `X: tissue type`
test_data = test_data.cast_column(
    "label",
    ClassLabel(names=HISTOPATHOLOGY_CLASSES)
)

# datasets.features.features.ClassLabel
# ground truth labels
LABEL_FEATURE = test_data.features["label"]
LABEL_FEATURE.str2int('A: no tumor present')

In [None]:
for label in HISTOPATHOLOGY_CLASSES:
    print(label)
    print(f"({label.replace(': ', ') ')}")

In [None]:

# Mapping to alternative label format, `(X) tissue type`
ALT_LABELS = dict([
    (label, f"({label.replace(': ', ') ')}") for label in HISTOPATHOLOGY_CLASSES
])


In [None]:

# do_full_match is set to True
def postprocess(prediction: list[dict[str, str]], do_full_match: bool=False) -> int:
    response_text = prediction[0]["generated_text"]
    if do_full_match:
        # eg - if response_text = 'A: no tumor present', then will return 0.
        return LABEL_FEATURE.str2int(response_text)
    for label in HISTOPATHOLOGY_CLASSES:
        # Search for `X: tissue type` or `(X) tissue type` in the response
        if label in response_text or ALT_LABELS[label] in response_text:
            return LABEL_FEATURE.str2int(label)
    return -1


In [None]:
# ----------- Loading Model from Checkpoint ----------- #

base_model, processor = utils.load_model_and_processor()
lora_check_point_path = './medgemma-4b-it-sft-lora-PatchCamelyon/checkpoint-252'

model = PeftModel.from_pretrained(base_model, lora_check_point_path)
model = model.merge_and_unload()  # Applies the LoRA weights to the original model
model.eval()

In [None]:
# -------- Evaluation Pipeline -------- #

ft_pipe = pipeline(
    "image-text-to-text",
    model=model,  
    processor=processor,
    torch_dtype=torch.bfloat16,
)

# Optional inference tweaks
ft_pipe.model.generation_config.do_sample = False
ft_pipe.model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
processor.tokenizer.padding_side = "left"

# ft_outputs[0] is <class: list> of len = 1
# ft_outputs[0][0] is <class 'dict'>
# ft_outputs[0][0].keys() >> gives dict_keys(['input_text', 'generated_text'])

# <class 'list'>
ft_outputs = ft_pipe(
    text=test_data["messages"],
    images=test_data["image"],
    max_new_tokens=20,
    batch_size=4,
    return_full_text=False,
)

# Each one of these ft_outputs[i] gets passed to postprocess(). 
# The first arguement of ft_outputs[i][0]['generated_text'] is like A: no tumor present
for i in range(1):
    print(ft_outputs[i][0]['generated_text'])


In [None]:

ft_predictions = [postprocess(out, do_full_match=True) for out in ft_outputs]

ft_metrics = compute_metrics(ft_predictions)
print(f"Fine-tuned metrics: {ft_metrics}")