## Task 1 - Zero-Shot Evaluation

In [1]:
#@title GPU / Python / Torch sanity
import os, sys, subprocess, json, platform, torch
print("Python :", sys.version)
print("CUDA   :", torch.version.cuda)
print("Torch  :", torch.__version__)
print("Device :", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
!nvidia-smi || true

Python : 3.10.19 | packaged by conda-forge | (main, Oct 22 2025, 22:29:10) [GCC 14.3.0]
CUDA   : 12.1
Torch  : 2.3.1+cu121
Device : NVIDIA GeForce RTX 4090
Sun Nov  9 16:28:36 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.144.03             Driver Version: 550.144.03     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:01:00.0 Off |                  Off |
| 33%   67C    P2            291W /  450W |   10014MiB /  24564MiB |     58%      Default |
|                                         |                        |        

In [2]:
# some imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from transformers import CLIPProcessor, CLIPModel, CLIPVisionModel, logging
import clip
from peft import LoraConfig, get_peft_model, TaskType
from torchinfo import summary
from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import json
import warnings

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# some settings
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "openai/clip-vit-large-patch14" # pre-trained CLIP model (ViT-L/14)
BATCH_SIZE = 256 # adjust based on your GPU memory
gradient_accumulation_steps = 1 # adjust based on your GPU memory
# For Linear Probe & LoRA
NUM_EPOCHS = 200
print(f"Using device: {DEVICE}")

DATA_FOLDER = "./data"  # folder to store datasets
os.makedirs(DATA_FOLDER, exist_ok=True)

Using device: cuda


In [4]:
# CLIP settings
# --- Load CLIP Processor ---
processor = CLIPProcessor.from_pretrained(MODEL_ID)
# --- Define a transform to process images for CLIP ---
class CLIPTransform:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, image):
        # The processor expects a PIL image or list of images
        # It returns a dict, we extract 'pixel_values'
        # .squeeze(0) removes the batch dimension the processor adds
        return self.processor(images=image, return_tensors="pt")["pixel_values"].squeeze(0)

clip_transform = CLIPTransform(processor)



In [5]:
# dataset related imports
from torchvision.datasets import Flowers102 
from datasets import load_dataset

# --- Flowers102 ---
# prepare Flowers102 dataset
flowers102_test_dts = Flowers102(root=DATA_FOLDER, split="test", transform=object, download=True) # evaluation on this set
print(f"Total test samples: {len(flowers102_test_dts)}") # should be 6149

# prepare class names for Flowers102
with open("./data/cat_to_name.json", "r") as f:
    flowers102_class_names = json.load(f)

# --- CUB-200-2011 ---
birds_200 = load_dataset("bentrevett/caltech-ucsd-birds-200-2011", cache_dir=DATA_FOLDER, download_mode="reuse_dataset_if_exists")
cub_bird_test_dts = birds_200["test"]
print(f"Total test samples: {len(cub_bird_test_dts)}") # should be 5794

# prepare class names for CUB-200-2011
cub_bird_class_names = cub_bird_test_dts.features["label"].names

# === Create DataLoaders ===
flowers102_test_loader = DataLoader(
    flowers102_test_dts, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True
)
cub_bird_test_loader = DataLoader(
    cub_bird_test_dts, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True
)


Total test samples: 6149
Total test samples: 5794


In [None]:
print("--- Starting Method 1: Zero-Shot Classification ---")

# === 1. Load the full CLIP model ===
model = CLIPModel.from_pretrained(MODEL_ID).to(DEVICE)
model.eval()

# === 2. Create and encode text prompts ===
# handcrafted prompts and custom prompts
prompt_templates = [
    "a photo of a {}.",
    "a photo of {}.",
    "photo of {}.",
    "a image of a {}.",
    "a image of {}.",
    "image of {}."
]

@torch.no_grad()
def encode_text_prompts(class_names):
    all_features = []
    for cname in class_names:
        texts = [t.format(cname) for t in prompt_templates]
        tokens = clip.tokenize(texts).to(DEVICE)
        feats = model.encode_text(tokens)
        feats /= feats.norm(dim=-1, keepdim=True)
        mean_feat = feats.mean(dim=0)
        mean_feat /= mean_feat.norm()
        all_features.append(mean_feat)
    return torch.stack(all_features, dim=0)

flowers_text_features = encode_text_prompts(flowers102_class_names)
cub_text_features = encode_text_prompts(cub_bird_class_names)


# === 3. Evaluate on the test set ===

@torch.no_grad()
def zeroshot_eval(dataloader, text_features):
    correct, total = 0, 0
    for batch in tqdm(dataloader, desc="Zero-Shot Evaluation"):
        if isinstance(batch, dict):
            images, labels = batch["pixel_values"], batch["labels"]
        else:
            images, labels = batch
        if isinstance(images, list):
            images = torch.stack(images, dim=0)

        images, labels = images.to(DEVICE), labels.to(DEVICE)
        image_features = model.encode_image(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        logits = 100.0 * image_features @ text_features.T
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.numel()
    return correct / total




--- Starting Method 1: Zero-Shot Classification ---


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
with torch.no_grad():
    
    flowers102_accuracy = zeroshot_eval(flowers102_test_loader, flowers_text_features)

In [None]:
with torch.no_grad():
    cub_bird_accuracy = zeroshot_eval(cub_bird_test_loader, cub_text_features)
    

In [None]:
# === 4. Result Analysis ===

print(f"\nZero-Shot Test Accuracy: {flowers102_accuracy * 100:.2f}%")

print(f"\nZero-Shot Test Accuracy: {cub_bird_accuracy * 100:.2f}%")

# also can do the "classification_report" and "confusion_matrix" here



In [None]:
# === 5. Visualization ===
# use plt to visualize some predictions
 
def visualize_predictions(dataloader, text_features, class_names, num_images=6):
    model.eval()
    images_list, labels_list, preds_list = [], [], []
    with torch.no_grad():
        for batch in dataloader:
            if isinstance(batch, dict):
                images, labels = batch["pixel_values"], batch["labels"]
            else:
                images, labels = batch
            if isinstance(images, list):
                images = torch.stack(images, dim=0)
            images, labels = images.to(device), labels.to(device)
            feats = model.encode_image(images)
            feats /= feats.norm(dim=-1, keepdim=True)
            logits = 100.0 * feats @ text_features.T
            preds = logits.argmax(dim=-1)
            images_list.append(images.cpu())
            labels_list.append(labels.cpu())
            preds_list.append(preds.cpu())
            if len(images_list) > 1:
                break
    images = torch.cat(images_list)[:num_images]
    labels = torch.cat(labels_list)[:num_images]
    preds = torch.cat(preds_list)[:num_images]

    plt.figure(figsize=(12, 5))
    for i in range(num_images):
        plt.subplot(2, (num_images + 1) // 2, i + 1)
        img = images[i].permute(1, 2, 0).numpy()
        img = (img - img.min()) / (img.max() - img.min() + 1e-6)
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"GT: {class_names[labels[i]]}\nPred: {class_names[preds[i]]}", fontsize=8)
    plt.tight_layout()
    plt.show()

# visualize some flowers predictions
visualize_predictions(flowers102_test_loader, flowers_text_features, flowers102_class_names)