# CLIP Fine-tuning for VizWiz

This notebook adapts your previous ViT fine-tuning pipeline to use **CLIP** instead of ViT.

We keep the overall structure similar:
1. Load VizWiz annotations
2. Build a binary classification task: *answerable vs unanswerable*
3. Prepare a PyTorch `Dataset` + `DataLoader`
4. Fine-tune CLIPâ€™s **vision encoder + classification head**
5. Evaluate and save the model

> Note: We use HuggingFace `openai/clip-vit-base-patch32` and a custom PyTorch training loop (Trainer from HF is mainly for Encoder/Decoder or causal models; CLIP here is easier with a manual loop).

In [58]:
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
from transformers import TrainingArguments, Trainer
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, classification_report
import cv2
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cuda
GPU: NVIDIA GeForce RTX 5060 Laptop GPU


## 2. Data Loading and Analysis

We assume you have **VizWiz JSON annotations** in `data/annotations/{train,val}.json` and images in `data/train`, `data/val` (same as your ViT notebook).

In [59]:
# Load annotations
def load_vizwiz_annotations(split='train'):
    """Load VizWiz annotations from data/Annotations/{split}.json"""
    with open(f'data/Annotations/{split}.json', 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data

# Load all splits
train_data = load_vizwiz_annotations('train')
val_data = load_vizwiz_annotations('val')

# print("Train data keys:", train_data.keys())
# print(f"Train images: {len(train_data['images'])}")
# print(f"Train annotations: {len(train_data['annotations'])}")
# print(f"\nVal images: {len(val_data['images'])}")
# print(f"Val annotations: {len(val_data['annotations'])}")

## 3. Question Type Classifier (Heuristic-based)

We keep your heuristic **question type** function for later analysis.

In [60]:
def classify_question_type(question):
    """Classify question into coarse types using keywords."""
    question_lower = question.lower()
    
    # OCR-related keywords
    ocr_keywords = ['read', 'say', 'text', 'label', 'written', 'writing', 
                    'words', 'screen', 'display', 'says', 'does this say']
    if any(keyword in question_lower for keyword in ocr_keywords):
        return 'OCR_LIKE'
    
    # Color keywords
    color_keywords = ['color', 'colour', 'what color']
    if any(keyword in question_lower for keyword in color_keywords):
        return 'COLOR'
    
    # Count keywords
    count_keywords = ['how many', 'count', 'number of']
    if any(keyword in question_lower for keyword in count_keywords):
        return 'COUNT'
    
    # Direction keywords
    direction_keywords = ['left', 'right', 'top', 'bottom', 'front', 'back',
                         'above', 'below', 'which side']
    if any(keyword in question_lower for keyword in direction_keywords):
        return 'DIRECTION'
    
    # Time keywords
    time_keywords = ['time', 'clock', 'hour', 'minute']
    if any(keyword in question_lower for keyword in time_keywords):
        return 'TIME'
    
    return 'OTHER'

# Quick sanity check
test_questions = [
    "What does this label say?",
    "What color is this shirt?",
    "How many bottles are there?",
    "What is on the left?",
    "What time is it?",
    "What is this?"
]

for q in test_questions:
    print(f"{q:40s} -> {classify_question_type(q)}")

What does this label say?                -> OCR_LIKE
What color is this shirt?                -> COLOR
How many bottles are there?              -> COUNT
What is on the left?                     -> DIRECTION
What time is it?                         -> TIME
What is this?                            -> OTHER


## 4. Image Quality Detection (Optional Features)

We reuse your blur/darkness/contrast detection utilities (can be used later for analysis or as extra signals).

In [61]:
def detect_blur(image_path, threshold=100):
    """Detect if image is blurry using Laplacian variance."""
    img = cv2.imread(str(image_path))
    if img is None:
        return False
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    variance = cv2.Laplacian(gray, cv2.CV_64F).var()
    return variance < threshold

def detect_darkness(image_path, threshold=50):
    """Detect if image is too dark."""
    img = cv2.imread(str(image_path))
    if img is None:
        return False
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    mean_brightness = np.mean(gray)
    return mean_brightness < threshold

def detect_low_contrast(image_path, threshold=30):
    """Detect if image has low contrast."""
    img = cv2.imread(str(image_path))
    if img is None:
        return False
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    contrast = gray.std()
    return contrast < threshold

def get_image_quality_features(image_path):
    """Get all quality features for an image."""
    return {
        'is_blurry': detect_blur(image_path),
        'is_dark': detect_darkness(image_path),
        'is_low_contrast': detect_low_contrast(image_path)
    }

## 5. Dataset Preparation for Answerability Classification (CLIP Version)

We now create a **CLIP-based Dataset** for binary classification *(answerable vs unanswerable)*.

- CLIP Processor handles image resizing/normalization.
- Label = 1 if not rejected, 0 if `is_rejected == True`.

In [62]:
class VizWizAnswerabilityCLIPDataset(Dataset):
    def __init__(self, samples, images_dir, processor, max_samples=None):
        self.samples = samples
        if max_samples is not None:
            self.samples = self.samples[:max_samples]
        
        self.images_dir = Path(images_dir)
        self.processor = processor
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        item = self.samples[idx]
        
        # ---------- image ----------
        image_name = item.get("image")  # e.g. "VizWiz_val_00001234.jpg"
        image_path = self.images_dir / image_name
        
        try:
            image = Image.open(image_path).convert('RGB')
        except Exception:
            image = Image.new('RGB', (224, 224), color='gray')
        
        # ---------- CLIP processor ----------
        inputs = self.processor(images=image, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)
        
        # ---------- label ----------
        if "answerable" in item:
            label = int(item["answerable"])
        else:
            label = 0 if item.get("is_rejected", False) else 1
        
        return {
            "pixel_values": pixel_values,
            "labels": torch.tensor(label, dtype=torch.long),
            "question": item.get("question", ""),
            "image_path": str(image_path),
        }

# Initialize CLIP processor
clip_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(clip_name)
print("CLIP processor loaded")

CLIP processor loaded


## 6. CLIP Model Setup

We will:
- Load `CLIPModel`.
- Take its **vision encoder** output.
- Add a small **classification head (Linear)** on top.
- Fine-tune only the head (or optionally unfreeze some CLIP layers).

In [63]:
class CLIPAnswerabilityClassifier(nn.Module):
    def __init__(
        self,
        base_model_name="openai/clip-vit-base-patch32",
        num_labels=2,
        freeze_mode="partial",          # "head_only" | "partial" | "full"
        num_unfrozen_vision_layers=4    # only used when freeze_mode == "partial"
    ):
        super().__init__()
        self.clip = CLIPModel.from_pretrained(base_model_name)
        self.num_labels = num_labels

        # get_image_features output dimension (usually 512 for ViT-B/32)
        embed_dim = self.clip.visual_projection.out_features
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, num_labels)
        )

        # --------- Freezing strategy ---------
        # First freeze all CLIP parameters
        for p in self.clip.parameters():
            p.requires_grad = False

        if freeze_mode == "head_only":
            msg = "ðŸ‘‰ Freeze mode = head_only: only classifier head is trainable."

        elif freeze_mode == "partial":
            # Partially unfreeze: last N vision encoder layers + visual_projection
            vision_layers = list(self.clip.vision_model.encoder.layers)
            n_layers = len(vision_layers)
            start_idx = max(0, n_layers - num_unfrozen_vision_layers)

            for i in range(start_idx, n_layers):
                for p in vision_layers[i].parameters():
                    p.requires_grad = True
            for p in self.clip.visual_projection.parameters():
                p.requires_grad = True

            msg = (
                f"ðŸ‘‰ Freeze mode = partial: unfreeze vision encoder last "
                f"{num_unfrozen_vision_layers} layers (layer {start_idx}~{n_layers-1}) "
                f"+ visual_projection + classifier."
            )

        elif freeze_mode == "full":
            # Fully unfreeze all CLIP parameters
            for p in self.clip.parameters():
                p.requires_grad = True
            msg = "ðŸ‘‰ Freeze mode = full: CLIP backbone fully trainable."

        else:
            raise ValueError(f"Unknown freeze_mode: {freeze_mode}")

        # Classifier head should always be trainable
        for p in self.classifier.parameters():
            p.requires_grad = True

        print(msg)

        # Print parameter stats
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"Total params: {total_params:,}")
        print(f"Trainable params: {trainable_params:,}")

        self.loss_fct = nn.CrossEntropyLoss()

    def forward(self, pixel_values, labels=None):
        """
        pixel_values: [B, 3, H, W]
        returns: (loss, logits) if labels is not None
                 (None, logits) otherwise
        """
        # get_image_features â†’ [B, D], already normalized
        image_embeds = self.clip.get_image_features(pixel_values=pixel_values)  # [B, D]

        logits = self.classifier(image_embeds)  # [B, num_labels]

        loss = None
        if labels is not None:
            loss = self.loss_fct(logits, labels)

        return loss, logits
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CLIPAnswerabilityClassifier(
    base_model_name="openai/clip-vit-base-patch32",
    num_labels=2,
    freeze_mode="partial",         # "head_only" / "partial" / "full"
    num_unfrozen_vision_layers=8   # try 2 / 4 / 6 etc.
).to(device)


ðŸ‘‰ Freeze mode = partial: unfreeze vision encoder last 8 layers (layer 4~11) + visual_projection + classifier.
Total params: 151,804,675
Trainable params: 57,623,554


## 7. Create Datasets & Dataloaders

Update the `TRAIN_IMAGES_DIR` and `VAL_IMAGES_DIR` if needed.

In [64]:
TRAIN_IMAGES_DIR = "data/train"
VAL_IMAGES_DIR = "data/val"

train_dataset = VizWizAnswerabilityCLIPDataset(
    train_data,
    TRAIN_IMAGES_DIR,
    processor,
    max_samples=3000   # trial; increase when things work
)

val_dataset = VizWizAnswerabilityCLIPDataset(
    val_data,
    VAL_IMAGES_DIR,
    processor,
    max_samples=1000
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size: {len(val_dataset)}")

sample = train_dataset[0]
print(f"\nSample pixel_values shape: {sample['pixel_values'].shape}")
print(f"Sample label: {sample['labels']}")


def collate_fn(batch):
    pixel_values = torch.stack([b['pixel_values'] for b in batch])
    labels = torch.stack([b['labels'] for b in batch])
    return {
        'pixel_values': pixel_values,
        'labels': labels
    }

train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0, 
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0, 
    collate_fn=collate_fn
)

Train dataset size: 3000
Val dataset size: 1000

Sample pixel_values shape: torch.Size([3, 224, 224])
Sample label: 1


In [65]:
backbone_params = []
head_params = []

for name, p in model.named_parameters():
    if not p.requires_grad:
        continue
    if name.startswith("classifier."):
        head_params.append(p)
    else:
        backbone_params.append(p)

optimizer = torch.optim.AdamW(
    [
        {"params": backbone_params, "lr": 3e-5},  # smaller lr for backbone
        {"params": head_params, "lr": 5e-5},      # larger lr for classifier head
    ],
    weight_decay=0.01
)

print("Trainable backbone params:", sum(p.numel() for p in backbone_params))
print("Trainable head params:", sum(p.numel() for p in head_params))

def evaluate(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            loss, logits = model(pixel_values=pixel_values, labels=None)
            preds = torch.argmax(logits, dim=-1)

            all_preds.extend(preds.cpu().numpy().tolist())
            all_labels.extend(labels.cpu().numpy().tolist())

    acc = accuracy_score(all_labels, all_preds)
    return acc, np.array(all_labels), np.array(all_preds)


Trainable backbone params: 57096192
Trainable head params: 527362


## 8. Training Loop

We implement a standard PyTorch training loop (similar spirit to your Trainer setup), and compute accuracy on the validation set.

In [None]:
num_epochs = 13
best_val_acc = 0.0

for epoch in range(1, num_epochs + 1):
    model.train()
    total_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")

    for batch in pbar:
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        loss, logits = model(pixel_values=pixel_values, labels=labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    avg_loss = total_loss / len(train_loader)
    val_acc, _, _ = evaluate(model, val_loader, device)

    print(f"Epoch {epoch}: avg train loss = {avg_loss:.4f}, val acc = {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "clip_vizwiz_answerability_best_partial.pth")
        print(" New best model saved.")


Epoch 1/5: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 188/188 [01:34<00:00,  1.99it/s, loss=0.4188]
                                                           

Epoch 1: avg train loss = 0.4488, val acc = 0.6280
 New best model saved.


Epoch 2/5: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 188/188 [01:10<00:00,  2.68it/s, loss=0.1299]
                                                           

Epoch 2: avg train loss = 0.2876, val acc = 0.6340
 New best model saved.


Epoch 3/5: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 188/188 [01:12<00:00,  2.59it/s, loss=0.4338]
                                                           

Epoch 3: avg train loss = 0.1966, val acc = 0.6260


Epoch 4/5: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 188/188 [01:25<00:00,  2.19it/s, loss=0.0338]
                                                           

Epoch 4: avg train loss = 0.0963, val acc = 0.6380
 New best model saved.


Epoch 5/5: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 188/188 [01:13<00:00,  2.56it/s, loss=0.0258]
                                                           

Epoch 5: avg train loss = 0.0344, val acc = 0.6420
 New best model saved.


## 9. Evaluation & Report

After training, we run a final evaluation and print a detailed classification report.

In [44]:
# Final evaluation
val_acc, true_labels, pred_labels = evaluate(model, val_loader, device)
print(f"\nFinal Validation Accuracy: {val_acc:.4f}\n")

print("="*60)
print("CLASSIFICATION REPORT")
print("="*60)
print(classification_report(
    true_labels, 
    pred_labels, 
    target_names=['Unanswerable', 'Answerable']
))

                                                           


Final Validation Accuracy: 0.5900

CLASSIFICATION REPORT
              precision    recall  f1-score   support

Unanswerable       0.55      0.57      0.56       231
  Answerable       0.62      0.61      0.61       269

    accuracy                           0.59       500
   macro avg       0.59      0.59      0.59       500
weighted avg       0.59      0.59      0.59       500





## 10. Save Model & Processor

We save the fine-tuned classifier weights together with the CLIP processor configuration.

In [None]:
save_dir = Path("./models/clip_vizwiz_answerability")
save_dir.mkdir(parents=True, exist_ok=True)

torch.save(model.state_dict(), save_dir / "pytorch_model.bin")
processor.save_pretrained(save_dir)

print(f"Model + processor saved to {save_dir}")