## Install dependencies

In [None]:
! pip install -q webdataset matplotlib open_clip_torch img2dataset peft transformers

## Download and reconstruct the GAIA dataset

In [None]:
# https://github.com/Orion-AI-Lab/GAIA
# https://huggingface.co/datasets/azavras/GAIA

!huggingface-cli download azavras/GAIA --repo-type dataset --local-dir GAIA

In [None]:
!img2dataset --url_list "GAIA_test_data.json" \
             --input_format "json" \
             --url_col "image_src" \
             --caption_col "image_alt" \
             --output_format "webdataset" \
             --save_additional_columns "['id','captions']" \
             --output_folder "test/" \
             --processes_count 4 \
             --thread_count 4 \
             --retries=5 \
             --image_size 512 \
             --encode_format "png" \
             --encode_quality 9 \
             --resize_mode "keep_ratio" \
             --number_sample_per_shard 512 \
             --disallowed_header_directives '[]'

## Image Classification

In [None]:
import matplotlib.pyplot as plt
import open_clip
import torch
import torch.nn.functional as F
import webdataset as wds


# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, _, preprocess = open_clip.create_model_and_transforms(
    "ViT-B-32-quickgelu",
    pretrained="openai",
    device=device,
)
model.eval()
tokenizer = open_clip.get_tokenizer("ViT-B-32-quickgelu")

# Load first image from test dataset
dataset = (
    wds.WebDataset(
        "/mnt/nvme1/azavras/GAIA/wds/test/{00000..00015}.tar", shardshuffle=False
    )
    .decode("pil")
    .to_tuple("png", "txt", "json")
)

# Get first image
image, text, metadata = next(iter(dataset))
print(text)

# Preprocess image for model
image_input = preprocess(image).unsqueeze(0).to(device)

# Define some common classes to classify against
classes = [
    "a satellite image of Europe",
    "a satellite image of Africa",
    "a satellite image of Asia",
    "a satellite image of South America",
    "a satellite image of North America",
    "a satellite image of Australia",
    "a satellite image of Antarctica",
]

# Encode image
with torch.no_grad():
    image_features = F.normalize(model.encode_image(image_input), dim=-1)

    # Encode text classes
    text_tokens = tokenizer(classes).to(device)
    text_features = F.normalize(model.encode_text(text_tokens), dim=-1)

    # Compute similarity scores
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    values, indices = similarity[0].topk(5)

# Display image and top predictions
plt.figure(figsize=(10, 5))

# Show image
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Input Image")
plt.axis("off")

# Show predictions
plt.subplot(1, 2, 2)
y_pos = range(len(values))
plt.barh(y_pos, values.cpu().numpy())
plt.yticks(y_pos, [classes[i] for i in indices])
plt.xlabel("Confidence")
plt.title("Top 5 Predictions")

plt.tight_layout()
plt.show()

## Image Retrieval

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

import open_clip
import torch
import torch.nn.functional as F
import webdataset as wds


# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, _, preprocess = open_clip.create_model_and_transforms(
    "ViT-B-32-quickgelu",
    pretrained="openai",
    device=device,
)
model.eval()
tokenizer = open_clip.get_tokenizer("ViT-B-32-quickgelu")

# Load and process test dataset
dataset = (
    wds.WebDataset("/mnt/nvme1/azavras/GAIA/wds/test/{00000..00015}.tar")
    .decode("pil")
    .to_tuple("png", "txt", "json")
)

# Process all images
all_images = []
all_raw_images = []  # Store raw PIL images for display
for images, _, _ in dataset:
    # Store raw image before preprocessing
    all_raw_images.append(images)
    # Preprocess for model
    processed = preprocess(images).unsqueeze(0)
    all_images.append(processed)

# Encode all images
images = torch.cat(all_images).to(device)
with torch.no_grad():
    image_embeds = F.normalize(model.encode_image(images), dim=-1)

# Example text query
text_query = "a satellite image of a river"

# Encode text query
text_tokens = tokenizer([text_query]).to(device)
with torch.no_grad():
    text_embed = F.normalize(model.encode_text(text_tokens), dim=-1)

# Compute similarities and get top 5
scores = text_embed @ image_embeds.T
top_scores, top_indices = scores[0].topk(5)

# Display results
plt.figure(figsize=(15, 3))
for i, (score, idx) in enumerate(zip(top_scores, top_indices), 1):
    plt.subplot(1, 5, i)
    plt.imshow(all_raw_images[idx])
    plt.title(f"Score: {score:.3f}")
    plt.axis("off")
plt.suptitle(f"Top 5 images for query: '{text_query}'")
plt.tight_layout()
plt.show()

## Image Captioning

In [None]:
import matplotlib.pyplot as plt
import torch
import webdataset as wds
from peft import PeftModel
from transformers import Blip2ForConditionalGeneration, Blip2Processor


# Initialize BLIP-2 model and processor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load processor from original model
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")

# Load base model
base_model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
)
base_model.to(device)

# Load fine-tuned adapter
checkpoint_path = "..."
model = PeftModel.from_pretrained(base_model, checkpoint_path)
model.eval()

# Load first image from test dataset
dataset = (
    wds.WebDataset("/mnt/nvme1/azavras/GAIA/wds/test/{00000..00015}.tar")
    .decode("pil")
    .to_tuple("png", "txt", "json")
)

# Get first image
image, text, metadata = next(iter(dataset))

# Process image for BLIP-2
inputs = processor(image, return_tensors="pt").to(device, torch.float16)

# Generate caption
with torch.no_grad():
    generated_ids = model.generate(
        **inputs,
        max_length=120,
        num_beams=5,
    )

# Decode caption
caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()

# Display image with caption
plt.figure(figsize=(10, 8))
plt.imshow(image)
plt.title(
    f"Ground truth caption: {metadata['captions'][0]}\n\nGenerated Caption: {caption}",
    wrap=True,
)
plt.axis("off")
plt.tight_layout()
plt.show()