In [None]:
import pickle
import random

import torch
from transformers import Blip2Processor

from lib.easy_vqa.easyvqa_classification import EasyVQAClassification
from lib.models.feature_visualizer import FeatureVisualizer
from lib.trainers.classification_trainer import ClassificationTrainer
from lib.types import (
    DatasetTypes,
    HFRepos,
    ModelTypes,
    Suffix,
    TrainingParameters,
    VQAParameters,
)
from lib.utils import EXPERIMENT
from lib.visualization import (
    display_class_specific_images,
    show_image,
)

%load_ext autoreload
%autoreload 2

# Load dependencies
processor = Blip2Processor.from_pretrained(HFRepos.BLIP2_OPT)

DIR = "analysis/easyvqa/"
EXPERIMENT.set_seed(2024).apply_seed()

# Results for EasyVQA

In [None]:
args = VQAParameters(Suffix.Train)  # using combined dataset
args.processor = processor
dataset = EasyVQAClassification(args)

In [None]:
split = "train"

best_path = f"data/models/easy_vqa/classifier/1780639714/features_{split}.pkl"  # classifier outputs
data = pickle.load(open(best_path, "rb"))
features = data["features"]
labels = data["labels"]


feature_visualizer = FeatureVisualizer(
    id_to_answer=dataset.id_to_answer, dataset_name="easyvqa"
)
feature_visualizer.set_features(features, labels, split)
feature_visualizer.visualize_features_with_umap(
    save_path=f"{DIR}/5.easyvqa_{split}_1780639714_features"
)

In [None]:
args = VQAParameters(Suffix.Val)  # using combined dataset
args.processor = processor
dataset = EasyVQAClassification(args)

split = "val"
best_path = f"data/models/easy_vqa/classifier/1780639714/features_{split}.pkl"  # classifier outputs
data = pickle.load(open(best_path, "rb"))
features = data["features"]
labels = data["labels"]


feature_visualizer = FeatureVisualizer(
    id_to_answer=dataset.id_to_answer, dataset_name="easyvqa"
)
feature_visualizer.set_features(features, labels, split)
feature_visualizer.visualize_features_with_umap(
    save_path=f"{DIR}/5.easyvqa_{split}_features"
)

# Comparing classes side by side


In [None]:
class_types = ["no", "yes"]  # Example class types
display_class_specific_images(
    dataset.raw_dataset,
    "EasyVQA",
    f"{DIR}/7.easyvqa_class_specific_samples.pdf",
    class_types,
    font_size=24,
)

# Live evaluation

This section allows to randomly choose images from the dataset and predict the answer.


In [None]:
EXPERIMENT.set_seed(2024).apply_seed()

test_args = VQAParameters(split="test", is_testing=True, use_proportional_split=True)
test_args.processor = processor

dataset = EasyVQAClassification(test_args)

parameters = TrainingParameters(
    dataset_name=DatasetTypes.EASY_VQA,
    resume_checkpoint=True,
    model_name=ModelTypes.BLIP2Classifier,
    is_trainable=False,
    train_args=None,
    val_args=None,
    test_args=test_args,
    resume_state=False,
    is_testing=True,
    use_wandb=False,
)

module = ClassificationTrainer(parameters)
model = module.model
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
sample = random.randint(0, len(dataset))
data = dataset[sample]

pixel_values = data["pixel_values"].unsqueeze(0).to(device)
input_ids = data["input_ids"].unsqueeze(0).to(device)
attention_mask = data["attention_mask"].unsqueeze(0).to(device)
labels = data["labels"].unsqueeze(0).to(device)

outputs = model(
    pixel_values=pixel_values,
    input_ids=input_ids,
    attention_mask=attention_mask,
    labels=labels,
    log=False,
)
_, preds = torch.max(outputs.logits, 1)
_, target_pred = torch.max(labels, 1)

predicted = dataset.id_to_answer[preds.item()]
target = dataset.id_to_answer[target_pred.item()]

print(predicted, target)
show_image(dataset.raw_dataset[sample], predicted, target)