# **Advance Topics in Machine Learning**
## **Assignment 0 - Introduction**
---

### **Task 5: Modality Gap in CLIP**

In [4]:
# Cell 1: Setup — install & imports
# If needed, uncomment the install lines.
# !pip -q install git+https://github.com/openai/CLIP.git
# !pip -q install umap-learn scikit-learn scipy torchvision matplotlib numpy

import os, random, math, time
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as T
import matplotlib.pyplot as plt
import clip
from sklearn.manifold import TSNE
import umap
from scipy.linalg import orthogonal_procrustes

seed = 42
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [5]:
# Cell 1.5: Optional faster STL-10 downloader (chunked, resumable) or quick CIFAR-10 fallback
# Set USE_CIFAR_FALLBACK = True to skip STL-10 download and use CIFAR-10 for a quick run.
USE_CIFAR_FALLBACK = False  # <--- set True if you want a fast small-dataset fallback

STL_URL = "https://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
DATA_ROOT = "./data"
LOCAL_TGZ = os.path.join(DATA_ROOT, "stl10_binary.tar.gz")

if USE_CIFAR_FALLBACK:
    print("CIFAR-10 fallback enabled. Skip STL-10 download and use CIFAR-10 for quick experiments.")
else:
    os.makedirs(DATA_ROOT, exist_ok=True)
    # Try requests with streaming first (resumable-ish)
    try:
        import requests
        from time import time
        print("Using requests to download with streaming and simple retry logic")

        def download_with_retries(url, out_path, max_retries=3, chunk_size=1<<20):
            attempt = 0
            while attempt < max_retries:
                try:
                    # support resume if partial file exists
                    headers = {}
                    pos = 0
                    if os.path.exists(out_path):
                        pos = os.path.getsize(out_path)
                        if pos > 0:
                            headers['Range'] = f'bytes={pos}-'
                    with requests.get(url, stream=True, headers=headers, timeout=30) as r:
                        r.raise_for_status()
                        total = None
                        if 'Content-Range' in r.headers:
                            # server supports resume
                            content_range = r.headers.get('Content-Range')
                            total = int(content_range.split('/')[-1])
                        elif 'Content-Length' in r.headers:
                            total = int(r.headers.get('Content-Length')) + pos

                        mode = 'ab' if pos > 0 else 'wb'
                        started = time()
                        with open(out_path, mode) as f:
                            downloaded = pos
                            for chunk in r.iter_content(chunk_size=chunk_size):
                                if not chunk:
                                    continue
                                f.write(chunk)
                                downloaded += len(chunk)
                                if total:
                                    pct = downloaded / total * 100
                                    elapsed = time() - started
                                    speed = downloaded / (1024*1024) / max(elapsed, 1e-6)
                                    print(f"Downloaded {downloaded}/{total} bytes ({pct:.1f}%), {speed:.2f} MB/s", end='\r')
                        print()  # newline after progress
                        # basic integrity check
                        if os.path.exists(out_path) and os.path.getsize(out_path) > 1024:
                            return out_path
                        else:
                            raise RuntimeError("Downloaded file too small or missing after download")
                except Exception as err:
                    attempt += 1
                    print(f"Download attempt {attempt} failed: {err}")
                    if attempt >= max_retries:
                        raise
                    print("Retrying...")
            raise RuntimeError("Failed to download after retries")

        try:
            print(f"Downloading STL-10 to {LOCAL_TGZ}")
            download_with_retries(STL_URL, LOCAL_TGZ, max_retries=3)
            print("Download finished")
        except Exception as e:
            print("Resumable download failed:", e)
            print("You can download manually from:")
            print(STL_URL)
            print(f"Save to: {LOCAL_TGZ}")
            print("Or set USE_CIFAR_FALLBACK=True and re-run the notebook for a fast fallback")
    except Exception as e:
        # fallback to urllib if requests not available
        print("requests not available or failed, falling back to urllib. This may be slower.")
        try:
            import urllib.request
            def reporthook(block_num, block_size, total_size):
                downloaded = block_num * block_size
                if total_size > 0:
                    pct = downloaded / total_size * 100
                    print(f"Downloaded {downloaded}/{total_size} bytes ({pct:.1f}%)", end='\r')
            print(f"Downloading STL-10 to {LOCAL_TGZ} using urllib")
            urllib.request.urlretrieve(STL_URL, LOCAL_TGZ, reporthook)
            print("\nDownload finished")
        except Exception as e2:
            print("urllib download also failed:", e2)
            print("Please download manually:")
            print(STL_URL)
            print(f"Save to: {LOCAL_TGZ}")
            print("Or set USE_CIFAR_FALLBACK=True and re-run the notebook for a fast fallback")

# End of downloader / fallback cell


Using requests to download with streaming and simple retry logic
Downloading STL-10 to ./data\stl10_binary.tar.gz
Download attempt 1 failed: HTTPSConnectionPool(host='ai.stanford.edu', port=443): Max retries exceeded with url: /~acoates/stl10/stl10_binary.tar.gz (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x0000024E714F38D0>: Failed to establish a new connection: [WinError 10061] No connection could be made because the target machine actively refused it'))
Retrying...
Download attempt 1 failed: HTTPSConnectionPool(host='ai.stanford.edu', port=443): Max retries exceeded with url: /~acoates/stl10/stl10_binary.tar.gz (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x0000024E714F38D0>: Failed to establish a new connection: [WinError 10061] No connection could be made because the target machine actively refused it'))
Retrying...
Download attempt 2 failed: HTTPSConnectionPool(host='ai.stanford.edu', port=443): Max retries exceeded

In [7]:
# Cell 2: Load CLIP model & STL-10 (test set)
MODEL_NAME = "ViT-B/32"  # change to "RN50" or others if desired
# Try to load CLIP; on checksum/corrupt-cache failures, clear the cache and retry
try:
    model, preprocess = clip.load(MODEL_NAME, device=device)
except Exception as e:
    print("clip.load failed on first attempt:", e)
    import shutil, os
    cache_dir = os.path.expanduser("~/.cache/clip")
    if os.path.exists(cache_dir):
        print("Removing CLIP cache at:", cache_dir)
        shutil.rmtree(cache_dir, ignore_errors=True)
    print("Retrying clip.load() after clearing cache...")
    model, preprocess = clip.load(MODEL_NAME, device=device)

# STL-10 classes per torchvision docs
stl10_classes = ["airplane","bird","car","cat","deer","dog","horse","monkey","ship","truck"]

# Transforms: Use CLIP's preprocess for best results
test_tf = preprocess

# NOTE: On Windows Jupyter notebooks, set num_workers=0 to avoid worker spawn issues
data_root = "./data"
local_tgz = os.path.join(data_root, "stl10_binary.tar.gz")
try:
    test_set = datasets.STL10(root=data_root, split="test", download=True, transform=test_tf)
except Exception as e:
    print("STL-10 download failed:", e)
    import tarfile
    # If a local archive exists, attempt extraction; if it's corrupt, remove and try to recover via notebook downloader
    if os.path.exists(local_tgz):
        print(f"Found local archive {local_tgz}; attempting extraction...")
        try:
            with tarfile.open(local_tgz, "r:gz") as tar:
                tar.extractall(path=data_root)
            print("Extraction complete; loading dataset without download...")
            test_set = datasets.STL10(root=data_root, split="test", download=False, transform=test_tf)
        except (tarfile.ReadError, EOFError) as ex2:
            print("Archive appears corrupted or incomplete:", ex2)
            try:
                os.remove(local_tgz)
                print("Removed corrupt archive.")
            except Exception as rmex:
                print("Failed to remove corrupt archive:", rmex)
            # Try to use the downloader helper from Cell 1.5 if available
            if 'download_with_retries' in globals():
                try:
                    print("Attempting to re-download using the notebook downloader...")
                    download_with_retries(STL_URL, local_tgz, max_retries=3)
                    print("Re-download complete. Extracting...")
                    with tarfile.open(local_tgz, "r:gz") as tar:
                        tar.extractall(path=data_root)
                    test_set = datasets.STL10(root=data_root, split="test", download=False, transform=test_tf)
                except Exception as ex3:
                    print("Re-download or extraction failed:", ex3)
                    raise RuntimeError("Failed to recover STL-10 archive after re-download") from ex3
            else:
                print("No downloader function available in the notebook. Please re-download the archive manually and place it at:", local_tgz)
                print(STL_URL)
                raise RuntimeError("Corrupt local STL-10 archive; manual redownload required") from ex2
        except Exception as ex2:
            print("Extraction or load after extraction failed:", ex2)
            raise
    else:
        # No local archive: try downloader if present, otherwise instruct manual download
        if 'download_with_retries' in globals():
            try:
                print("No local archive found; attempting to download via notebook downloader...")
                download_with_retries(STL_URL, local_tgz, max_retries=3)
                print("Download finished; extracting...")
                with tarfile.open(local_tgz, "r:gz") as tar:
                    tar.extractall(path=data_root)
                test_set = datasets.STL10(root=data_root, split="test", download=False, transform=test_tf)
            except Exception as ex4:
                print("Automated download/extraction failed:", ex4)
                print("Please download manually or set USE_CIFAR_FALLBACK=True in Cell 1.5")
                raise
        else:
            print("Local archive not found. Please download STL-10 manually:")
            print("https://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz")
            print("Save it to:", local_tgz)
            print("PowerShell command (copy-paste into a PowerShell terminal):")
            print(f"Invoke-WebRequest -Uri https://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz -OutFile '{local_tgz}'")
            raise RuntimeError("STL-10 download failed and no local archive found. Place the archive at the path above and re-run this cell.")

# Finally create the loader
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=0)
len(test_set)


STL-10 download failed: <urlopen error [WinError 10061] No connection could be made because the target machine actively refused it>
Found local archive ./data\stl10_binary.tar.gz; attempting extraction...
Archive appears corrupted or incomplete: Compressed file ended before the end-of-stream marker was reached
Removed corrupt archive.
Attempting to re-download using the notebook downloader...
Download attempt 1 failed: HTTPSConnectionPool(host='ai.stanford.edu', port=443): Max retries exceeded with url: /~acoates/stl10/stl10_binary.tar.gz (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x0000024E715EA050>, 'Connection to ai.stanford.edu timed out. (connect timeout=30)'))
Retrying...
Download attempt 1 failed: HTTPSConnectionPool(host='ai.stanford.edu', port=443): Max retries exceeded with url: /~acoates/stl10/stl10_binary.tar.gz (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x0000024E715EA050>, 'Connection to ai.stanford.edu t

RuntimeError: Failed to recover STL-10 archive after re-download

In [None]:
# Cell 3: Prompting strategies (three+ variants)
def prompts_plain(labels):
    return labels  # ["cat", "dog", ...]

def prompts_template_a(labels):
    return [f"a photo of a {c}" for c in labels]

def prompts_template_b(labels):
    return [f"a high quality photo of a {c}" for c in labels]

def prompts_template_c(labels):
    return [f"a photo of a small {c} in the wild" for c in labels]

prompt_strategies = {
    "plain": prompts_plain,
    "photo": prompts_template_a,
    "hq_photo": prompts_template_b,
    "wild_small": prompts_template_c,  # optional 4th
}
list(prompt_strategies.keys())


In [None]:
# Cell 4: Zero-shot evaluation over the full STL-10 test set for multiple prompt strategies
@torch.no_grad()
def zero_shot_accuracy(model, loader, classnames, prompt_fn, device="cpu"):
    # Build prompt texts for classnames
    texts = prompt_fn(classnames)
    text_tokens = clip.tokenize(texts).to(device)
    text_features = model.encode_text(text_tokens)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    correct = 0
    total = 0
    for imgs, labels in loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        image_features = model.encode_image(imgs)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # cosine similarity via dot product on normalized vectors
        logits = 100.0 * image_features @ text_features.t()
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.numel()

    return correct / total

acc_results = {}
for name, fn in list(prompt_strategies.items())[:3]:  # compare 3 strategies
    acc = zero_shot_accuracy(model, test_loader, stl10_classes, fn, device)
    acc_results[name] = acc

acc_results


In [None]:
# Cell 5: Extract embeddings for a subset (50–100 samples) to explore modality gap
subset_n = 100
indices = np.random.RandomState(seed).choice(len(test_set), size=subset_n, replace=False)

# Prepare subset loader with deterministic order
subset_imgs = torch.stack([test_set[i][0] for i in indices])
subset_labels = torch.tensor([test_set[i][1] for i in indices])

@torch.no_grad()
def get_clip_embeddings(model, images, labels, classnames, text_prompt_fn, device="cpu"):
    model.eval()
    images = images.to(device)
    # image embeddings
    img_feats = model.encode_image(images)
    img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True)

    # per-sample text embedding for its GT class using provided prompt fn (e.g., "a photo of a X")
    texts = [text_prompt_fn(classnames)[y.item()] for y in labels]
    tok = clip.tokenize(texts).to(device)
    txt_feats = model.encode_text(tok)
    txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True)
    return img_feats.detach().cpu().numpy(), txt_feats.detach().cpu().numpy()

img_embeds, txt_embeds = get_clip_embeddings(
    model, subset_imgs, subset_labels, stl10_classes, prompts_template_a, device
)
img_embeds.shape, txt_embeds.shape


In [None]:
# Cell 6: 2D projection (t-SNE and UMAP) and visualization of image vs text embeddings
def plot_embeddings_2d(img_2d, txt_2d, labels, title="Embeddings 2D", save_path=None):
    plt.figure(figsize=(7,6))
    # image points
    plt.scatter(img_2d[:,0], img_2d[:,1], c=labels, cmap="tab10", s=20, alpha=0.7, label="image")
    # text points (slightly larger markers)
    plt.scatter(txt_2d[:,0], txt_2d[:,1], c=labels, cmap="tab10", s=60, marker="x", label="text")
    plt.title(title); plt.xlabel("dim 1"); plt.ylabel("dim 2")
    plt.legend()
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.show()
    plt.close()

labels_np = subset_labels.numpy()

# t-SNE
tsne = TSNE(n_components=2, random_state=seed, init="pca", perplexity=30, n_iter=1000)
img_2d_tsne = tsne.fit_transform(img_embeds)
txt_2d_tsne = tsne.fit_transform(txt_embeds)
plot_embeddings_2d(img_2d_tsne, txt_2d_tsne, labels_np, title="t-SNE: CLIP Image vs Text (subset)", save_path="clip_outputs/tsne_before.png")

# UMAP
reducer = umap.UMAP(n_components=2, random_state=seed, n_neighbors=15, min_dist=0.1, metric="cosine")
img_2d_umap = reducer.fit_transform(img_embeds)
txt_2d_umap = reducer.fit_transform(txt_embeds)
plot_embeddings_2d(img_2d_umap, txt_2d_umap, labels_np, title="UMAP: CLIP Image vs Text (subset)", save_path="clip_outputs/umap_before.png")


In [None]:
# Cell 7: Orthogonal Procrustes alignment (learn rotation R between image & text embeddings)
# Solve min_R || X R - Y ||_F  s.t. R^T R = I
# We'll learn R on the subset pairs (img_embeds, txt_embeds) and then reuse R later.

# Optional: ensure both have zero-mean before Procrustes (common practice).
X = img_embeds - img_embeds.mean(axis=0, keepdims=True)
Y = txt_embeds - txt_embeds.mean(axis=0, keepdims=True)

R, _ = orthogonal_procrustes(X, Y)  # R: d x d
R.shape


In [None]:
# Cell 8: Apply rotation to subset and visualize aligned embeddings
X_aligned = (img_embeds @ R)  # rotate image embeddings into text space

# t-SNE after alignment
tsne = TSNE(n_components=2, random_state=seed, init="pca", perplexity=30, n_iter=1000)
img_rot_2d_tsne = tsne.fit_transform(X_aligned)
txt_2d_tsne = tsne.fit_transform(txt_embeds)  # re-embed for a fair compare (or reuse previous)
plot_embeddings_2d(img_rot_2d_tsne, txt_2d_tsne, labels_np, title="t-SNE: After Procrustes (subset)", save_path="clip_outputs/tsne_after.png")

# UMAP after alignment
reducer = umap.UMAP(n_components=2, random_state=seed, n_neighbors=15, min_dist=0.1, metric="cosine")
img_rot_2d_umap = reducer.fit_transform(X_aligned)
txt_2d_umap = reducer.fit_transform(txt_embeds)
plot_embeddings_2d(img_rot_2d_umap, txt_2d_umap, labels_np, title="UMAP: After Procrustes (subset)", save_path="clip_outputs/umap_after.png")


In [None]:
# Cell 9: Recompute zero-shot accuracy USING ROTATED IMAGE EMBEDDINGS
# We will recompute accuracies for the same 3 prompt strategies, but with image features rotated by R.

@torch.no_grad()
def zero_shot_accuracy_with_rotation(model, loader, classnames, prompt_fn, R, device="cpu"):
    texts = prompt_fn(classnames)
    text_tokens = clip.tokenize(texts).to(device)
    text_features = model.encode_text(text_tokens)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    # Center text features using subset means used for R? For simplicity, use no centering here; R already learned on centered.
    # (If you choose to center, be consistent with how R was learned.)

    correct = 0
    total = 0
    for imgs, labels in loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        image_features = model.encode_image(imgs)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # Rotate features
        image_features = image_features @ torch.from_numpy(R).to(device).float()

        logits = 100.0 * image_features @ text_features.t()
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.numel()

    return correct / total

acc_results_rot = {}
for name, fn in list(prompt_strategies.items())[:3]:
    acc = zero_shot_accuracy_with_rotation(model, test_loader, stl10_classes, fn, R, device)
    acc_results_rot[name] = acc

acc_results, acc_results_rot


In [None]:
# Cell 10: Quick sanity visualization of a few test images + top-1 predictions (optional)
# Uses the "photo" prompt strategy by default.
from torchvision.utils import make_grid, save_image
from IPython.display import display

@torch.no_grad()
def visualize_predictions(model, dataset, classnames, prompt_fn, n=8):
    idxs = np.random.RandomState(seed).choice(len(dataset), size=n, replace=False)
    imgs = torch.stack([dataset[i][0] for i in idxs]).to(device)
    labels = torch.tensor([dataset[i][1] for i in idxs])

    texts = prompt_fn(classnames)
    tok = clip.tokenize(texts).to(device)
    text_features = model.encode_text(tok)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    image_features = model.encode_image(imgs)
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)

    logits = 100.0 * image_features @ text_features.t()
    preds = logits.argmax(dim=-1).cpu().numpy()

    grid = make_grid(imgs.cpu(), nrow=n, normalize=True, value_range=(0,1))
    display(T.ToPILImage()(grid))
    print("GT:", [classnames[y] for y in labels.numpy()])
    print("Pred:", [classnames[p] for p in preds])

visualize_predictions(model, test_set, stl10_classes, prompts_template_a, n=8)


In [None]:
# Cell 11: Save results to a small report dict for later reference
results = {
    "zero_shot_acc_plain": acc_results["plain"],
    "zero_shot_acc_photo": acc_results["photo"],
    "zero_shot_acc_hq_photo": acc_results["hq_photo"],
    "aligned_zero_shot_acc_plain": acc_results_rot["plain"],
    "aligned_zero_shot_acc_photo": acc_results_rot["photo"],
    "aligned_zero_shot_acc_hq_photo": acc_results_rot["hq_photo"],
    "subset_indices": indices.tolist(),
    "model": MODEL_NAME,
}
os.makedirs("clip_outputs", exist_ok=True)
np.save("clip_outputs/results.npy", results)
results
