## Task 2 - CLIP Fine-Tuning on the Visual Encoder

In [None]:
#@title GPU / Python / Torch sanity
import os, sys, subprocess, json, platform, torch
print("Python :", sys.version)
print("CUDA   :", torch.version.cuda)
print("Torch  :", torch.__version__)
print("Device :", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
!nvidia-smi || true

In [None]:
# some imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, CLIPVisionModel, logging
import clip
from peft import LoraConfig, get_peft_model, TaskType
from torchinfo import summary
from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import json
import warnings
import torch.nn.functional as F
from PIL import Image
import os
from datasets import load_dataset
import io
from torchvision.datasets import Flowers102
from torch.utils.data import DataLoader

In [None]:
# 降噪：避免 tokenizers 在多工情境下噴警告
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# some settings
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "openai/clip-vit-large-patch14" # pre-trained CLIP model (ViT-L/14)
BATCH_SIZE = 64 # adjust based on your GPU memory
NUM_WORKERS = 2
gradient_accumulation_steps = 1 # adjust based on your GPU memory
# For Linear Probe & LoRA
NUM_EPOCHS = 1
print(f"Using device: {DEVICE}")


processor = CLIPProcessor.from_pretrained(MODEL_ID)
model     = CLIPModel.from_pretrained(MODEL_ID).to(DEVICE)
model.eval()




In [None]:
# ==== Flowers102 (torchvision) ====


# 把 PIL 影像轉成 (3,224,224) 的 CLIP 規格 tensor
def clip_image_transform(image):
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)
    image = image.convert("RGB")
    px = processor(images=image, return_tensors="pt")["pixel_values"][0]  # (3,224,224)
    return px

flowers102_test_dts = Flowers102(
    root="./data",
    split="test",
    transform=clip_image_transform,
    download=True
)
print(f"Total test samples (Flowers102): {len(flowers102_test_dts)}")  # 6149

flowers102_test_loader = DataLoader(
    flowers102_test_dts,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

# 類別名稱（使用你上傳的 cat_to_name.json）
with open("./data/cat_to_name.json", "r") as f:
    cat_to_name = json.load(f)
flowers102_class_names = [cat_to_name[str(i)] for i in range(1, 103)]


In [None]:
# ==== CUB-200-2011 (HF datasets) ====

birds_200 = load_dataset("bentrevett/caltech-ucsd-birds-200-2011", cache_dir="./data")
cub_bird_test_raw = birds_200["test"]
print(f"Total test samples (CUB): {len(cub_bird_test_raw)}")  # 5794

def _to_pil_image_safe(img):
    """Robust convert HF dataset image (PIL/ndarray/list/bytes) -> PIL.Image."""
    if isinstance(img, Image.Image):
        return img
    if isinstance(img, np.ndarray):
        return Image.fromarray(img)
    # list may be list-of-ints (bytes) or nested lists (HWC)
    if isinstance(img, (list, tuple)):
        # list of ints -> bytes
        if all(isinstance(x, (int, np.integer)) for x in img):
            return Image.open(io.BytesIO(bytes(img)))
        # else try to convert to array (may raise ValueError for ragged)
        try:
            arr = np.asarray(img, dtype=np.uint8)
            return Image.fromarray(arr)
        except Exception as ex:
            raise TypeError(f"Cannot convert list image to PIL (ragged?): {ex}") from ex
    if isinstance(img, (bytes, bytearray)):
        return Image.open(io.BytesIO(img))
    raise TypeError(f"Unsupported image type for conversion to PIL: {type(img)}")

def cub_transform(example):
    """
    Accept both single example and batch-dict (where example['image'] is a list).
    Return pixel_values as numpy arrays (C,H,W) to keep HF formatting/collate stable.
    """
    def proc_one(img):
        pil = _to_pil_image_safe(img).convert("RGB")
        px = processor(images=pil, return_tensors="pt")["pixel_values"][0]  # tensor (3,224,224)
        return px.numpy()

    # batched call from datasets.formatting may pass lists
    if isinstance(example, dict) and isinstance(example.get("image"), (list, tuple)):
        imgs = [proc_one(im) for im in example["image"]]
        labs = list(example["label"])
        return {"pixel_values": imgs, "label": labs}
    # single example
    img = example["image"]
    px_arr = proc_one(img)
    return {"pixel_values": px_arr, "label": example["label"]}

cub_bird_test_dts = cub_bird_test_raw.with_transform(cub_transform)

# 強韌的 collate：把所有樣本疊成 (B,3,224,224)，label 成 (B,)
def _to_chw224(x: torch.Tensor) -> torch.Tensor:
    x = torch.as_tensor(x)
    if not torch.is_floating_point(x):
        x = x.float() / 255.0
    if x.ndim == 3:
        # HWC -> CHW
        if x.shape[-1] == 3 and x.shape[0] != 3:
            x = x.permute(2, 0, 1)
        # 灰階擴通道
        if x.shape[0] == 1:
            x = x.repeat(3, 1, 1)
        elif x.shape[0] != 3:
            raise ValueError(f"Unexpected channel dim: {x.shape}")
    elif x.ndim == 2:
        x = x.unsqueeze(0).repeat(3, 1, 1)
    else:
        raise ValueError(f"Unexpected ndim {x.ndim} for image with shape {tuple(x.shape)}")
    if x.shape[1:] != (224, 224):
        x = F.interpolate(x.unsqueeze(0).float(), size=(224, 224),
                          mode="bilinear", align_corners=False).squeeze(0)
    return x.float()

def hf_collate_fn(batch):
    imgs = torch.stack([_to_chw224(b["pixel_values"]) for b in batch], dim=0)  # (B,3,224,224)
    labs = torch.tensor([int(b["label"]) for b in batch], dtype=torch.long)    # (B,)
    return {"pixel_values": imgs, "label": labs}

from torch.utils.data import DataLoader
cub_bird_test_loader = DataLoader(
    cub_bird_test_dts,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=hf_collate_fn
)

# 類別名稱由 HF features 提供
cub_bird_class_names = cub_bird_test_raw.features["label"].names

import re
def clean_cub_name(name: str) -> str:
    name = re.sub(r'^\d+\.', '', name)   # remove leading numeric prefix
    name = name.replace('_', ' ')
    return name.strip()

cub_bird_class_names = [clean_cub_name(n) for n in cub_bird_class_names]
for x in cub_bird_class_names[:10]:
    print(x)

Start Linear Probing

In [None]:
print("--- Starting Method: Linear Probing ---")

# === 1. Load CLIP Vision Model (no text part) ===
model = ...

# === 2. Freeze backbone ===
for p in vision_model.parameters():
    p.requires_grad = False
for p in visual_projection.parameters():
    p.requires_grad = False

# === 3. Classifier head ===
head = # ...
    
# === 4. Training setup ===
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(object, lr=lr)
train_losses, val_losses, val_accuracies = [], [], []

# === 5. Training Loop ===
for epoch in range(NUM_EPOCHS):
    head.train()
    """
    ...
    """
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]"):
        pass
        """
        ...
        """
    
    print(f"Epoch {epoch+1} - Train Loss: {train_losses[-1]:.4f} | "
          f"Val Loss: {val_losses[-1]:.4f} | Val Acc: {val_accuracies[-1]*100:.2f}% | "
          f"Time: {epoch_end - epoch_start:.2f} sec")

# === 6. Plot curves ===


# === 7. Test ===


# === 8. Visualization ===


In [None]:
print("--- Starting Method: LoRA Fine-Tuning ---")

# === 1. Load CLIP Vision Model (no text part) ===
model = ...

# === 2. LoRA config (Q/V projections) ===
lora_config = LoraConfig(
    "..."
)

# === 3. Wrap with PEFT ===
vision_model_lora = get_peft_model(vision_model, lora_config)
print("LoRA Model - Trainable Parameters:")
vision_model_lora.print_trainable_parameters()

# === 4. Freeze projection ===
for p in visual_projection.parameters():
    p.requires_grad = False

# === 5. Training setup ===
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(object, lr=lr)
train_losses, val_losses, val_accuracies = [], [], []

# === 5. Training Loop ===
for epoch in range(NUM_EPOCHS):
    head.train()
    """
    ...
    """
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]"):
        pass
        """
        ...
        """
    
    print(f"Epoch {epoch+1} - Train Loss: {train_losses[-1]:.4f} | "
          f"Val Loss: {val_losses[-1]:.4f} | Val Acc: {val_accuracies[-1]*100:.2f}% | "
          f"Time: {epoch_end - epoch_start:.2f} sec")

# === 6. Plot curves ===


# === 7. Test ===


# === 8. Visualization ===
