In [117]:
!pip install torch torchvision ftfy regex tqdm matplotlib umap-learn scikit-learn scipy git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-zh99f9cx
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-zh99f9cx
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [118]:
#import github clip
!pip install open_clip_torch




In [119]:
import os, math, random, json
from typing import List, Dict, Tuple

In [120]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

In [121]:
import clip
from sklearn.manifold import TSNE

In [122]:
try:
    import umap
    HAVE_UMAP = True
except Exception:
    HAVE_UMAP = False

In [123]:
from scipy.linalg import orthogonal_procrustes

In [124]:
# setting Reproducability of code & device
def set_seed(seed=1337):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False
set_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [125]:
os.makedirs("artifacts", exist_ok=True)


In [126]:
device

'cuda'

In [127]:
model, clip_preprocess = clip.load("ViT-B/32", device=device)
model.eval()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [128]:
HAVE_UMAP

True

In [129]:
test = datasets.STL10(root="data", split="test", download=True, transform=clip_preprocess)
test_loader = DataLoader(test, batch_size=32, shuffle=False)


In [130]:
#classes used in STL-10
CLASSES = ["airplane","bird","car","cat","deer","dog","horse","monkey","ferry","truck"]

In [131]:
def prompt_sets() -> Dict[str, List[str]]:

    # plain labels
    plain = [f"{c}" for c in CLASSES]

    # short template
    short = [f"a photo of a {c}" for c in CLASSES]

    #  descriptive variants
    TEMPLATES = [
        "a photo of a {}.",
        "a close-up photo of a {}.",
        "a bright photo of a {}.",
        "a cropped photo of the {}.",
        "a photo of the small {}.",
        "a close-up of the {}.",
        "a low resolution photo of a {}.",
        "a high resolution photo of a {}.",
        "a picture of one {}.",
        "a photo of many {}.",
        "a photograph of a big {}.",
        "a JPEG photo of a {}."
    ]

    # We'll average text features across templates per class.
    variants = []
    for c in CLASSES:
        variants.append("|".join([t.format(c) for t in TEMPLATES]))

    return {
        "plain": plain,
        "short": short,
        "variants": variants,  # pipe-separated template strings per class
    }

    print(list(prompt_sets()))

prompt_sets()

{'plain': ['airplane',
  'bird',
  'car',
  'cat',
  'deer',
  'dog',
  'horse',
  'monkey',
  'ferry',
  'truck'],
 'short': ['a photo of a airplane',
  'a photo of a bird',
  'a photo of a car',
  'a photo of a cat',
  'a photo of a deer',
  'a photo of a dog',
  'a photo of a horse',
  'a photo of a monkey',
  'a photo of a ferry',
  'a photo of a truck'],
 'variants': ['a photo of a airplane.|a close-up photo of a airplane.|a bright photo of a airplane.|a cropped photo of the airplane.|a photo of the small airplane.|a close-up of the airplane.|a low resolution photo of a airplane.|a high resolution photo of a airplane.|a picture of one airplane.|a photo of many airplane.|a photograph of a big airplane.|a JPEG photo of a airplane.',
  'a photo of a bird.|a close-up photo of a bird.|a bright photo of a bird.|a cropped photo of the bird.|a photo of the small bird.|a close-up of the bird.|a low resolution photo of a bird.|a high resolution photo of a bird.|a picture of one bird.|a phot

In [132]:
# Encode text prompts into class prototypes
@torch.no_grad()
def build_text_prototypes(model, prompt_sets: Dict[str, List[str]]) -> Dict[str, torch.Tensor]:
    out = {}
    for key, prompts in prompt_sets.items():
        class_feats = []
        for entry in prompts:
            # entry either a single prompt (plain/short) or a pipe-joined ensemble string (variants)
            if "|" in entry:
                texts = entry.split("|")
                tokens = clip.tokenize(texts).to(device)
                txt_feat = model.encode_text(tokens)  # [n_templates, d]
                txt_feat = F.normalize(txt_feat, dim=-1)
                proto = txt_feat.mean(dim=0, keepdim=True)  # average ensemble
            else:
                tokens = clip.tokenize([entry]).to(device)
                proto = F.normalize(model.encode_text(tokens), dim=-1)
            class_feats.append(proto)
        # shape: [num_classes, d]
        out[key] = torch.cat(class_feats, dim=0)
    return out  # dict of [C, D]

In [133]:
text_prototypes=build_text_prototypes(model, prompt_sets=prompt_sets())
text_prototypes

{'plain': tensor([[ 0.0059, -0.0175, -0.0041,  ...,  0.0177, -0.0110, -0.0175],
         [ 0.0213,  0.0039,  0.0057,  ..., -0.0332, -0.0021, -0.0254],
         [ 0.0060,  0.0097,  0.0036,  ..., -0.0225, -0.0248,  0.0104],
         ...,
         [ 0.0056,  0.0012, -0.0085,  ..., -0.0209,  0.0038,  0.0008],
         [ 0.0350, -0.0259, -0.0195,  ..., -0.0206,  0.0156,  0.0260],
         [ 0.0222,  0.0051, -0.0071,  ..., -0.0005, -0.0157,  0.0013]],
        device='cuda:0', dtype=torch.float16),
 'short': tensor([[ 0.0128,  0.0265,  0.0104,  ..., -0.0315,  0.0071, -0.0219],
         [ 0.0229,  0.0381, -0.0079,  ..., -0.0594,  0.0048, -0.0292],
         [ 0.0071,  0.0141, -0.0105,  ..., -0.0375, -0.0331,  0.0005],
         ...,
         [ 0.0102,  0.0345, -0.0019,  ..., -0.0366,  0.0031, -0.0028],
         [ 0.0314, -0.0218, -0.0071,  ..., -0.0060,  0.0069,  0.0253],
         [ 0.0123,  0.0341, -0.0019,  ..., -0.0508, -0.0196, -0.0007]],
        device='cuda:0', dtype=torch.float16),
 'vari

In [134]:
#check size of tensors 10 for classes and 512 for clip embeddings of each class
print(len(text_prototypes))
print(text_prototypes["plain"].shape)
print(text_prototypes["short"].shape)
print(text_prototypes["variants"].shape)

3
torch.Size([10, 512])
torch.Size([10, 512])
torch.Size([10, 512])


In [135]:
# encoding images
@torch.no_grad()
def encode_images(model, loader) -> torch.Tensor:
    image_feats = []
    for imgs, _ in tqdm(loader, desc="Encoding images"):
        imgs = imgs.to(device)
        feats = model.encode_image(imgs)
        feats = F.normalize(feats, dim=-1)
        image_feats.append(feats.float().cpu())
    return torch.cat(image_feats, dim=0)  # [N, D]

In [136]:
encoded_images = encode_images(model, test_loader)
encoded_images

Encoding images: 100%|██████████| 250/250 [00:39<00:00,  6.31it/s]


tensor([[-4.4327e-03,  2.2202e-02,  3.0212e-03,  ...,  7.8674e-02,
         -8.2550e-03,  2.9709e-02],
        [-8.5602e-03,  2.7130e-02, -8.6606e-05,  ...,  5.6213e-02,
          1.1093e-02,  1.8326e-02],
        [-1.3489e-02, -8.3971e-04, -3.1853e-03,  ...,  5.8868e-02,
         -2.5574e-02, -1.8509e-02],
        ...,
        [ 3.0746e-02, -6.2065e-03, -2.8885e-02,  ...,  5.6946e-02,
          2.1915e-03, -4.6425e-03],
        [-1.0956e-02, -9.5320e-04, -5.4230e-02,  ...,  6.7566e-02,
         -2.1545e-02,  1.4524e-03],
        [-9.6664e-03, -2.8427e-02,  4.9019e-04,  ...,  4.5563e-02,
          4.0512e-03, -2.1408e-02]])

In [137]:
print(len(encoded_images))
print(encoded_images.shape)

8000
torch.Size([8000, 512])


In [138]:
# Zero-shot accuracy
@torch.no_grad()
def zero_shot_eval(image_feats: torch.Tensor, text_protos: torch.Tensor, labels: np.ndarray) -> float:
    # cosine similarity via dot after L2 norm above
    image_feats=image_feats.to(device).half() # Convert image_feats to half precision
    text_protos=text_protos.to(device)
    logits = image_feats @ text_protos.T  # [N, C]
    print(image_feats.dtype, text_protos.dtype)
    preds = torch.argmax(logits, dim=1).cpu().numpy()
    acc = (preds == labels).mean() * 100.0
    return float(acc), preds

In [139]:
zero_shot_evaluation_plain=zero_shot_eval(encoded_images, text_prototypes['plain'], test.labels)

torch.float16 torch.float16


In [140]:
accuracy_plain , _ = zero_shot_evaluation_plain
print(f"Zero-shot accuracy: {accuracy_plain:.2f}")

Zero-shot accuracy: 95.38


In [141]:
zero_shot_evaluation_short=zero_shot_eval(encoded_images, text_prototypes['short'], test.labels)

torch.float16 torch.float16


In [142]:
accuracy_short , _ = zero_shot_evaluation_short
print(f"Zero-shot accuracy: {accuracy_short:.2f}")

Zero-shot accuracy: 96.73


In [143]:
zero_shot_evaluation_variants=zero_shot_eval(encoded_images, text_prototypes['variants'], test.labels)

torch.float16 torch.float16


In [144]:
accuracy_variants , _ = zero_shot_evaluation_variants
print(f"Zero-shot accuracy: {accuracy_variants:.2f}")

Zero-shot accuracy: 96.61


In [145]:
def evaluate_prompts(model, test_loader, test_ds, text_proto_dict):
    # Encode all test images once
    image_feats = encode_images(model, test_loader)  # [N,D]
    labels = np.array(test_ds.labels)  # STL10 stores test labels in .labels

    results = {}
    for name, protos in text_proto_dict.items():
        acc, preds = zero_shot_eval(image_feats, protos.to(image_feats.device), labels)
        results[name] = {"acc": acc}
        np.save(f"artifacts/preds_{name}.npy", preds)
    np.save("artifacts/image_feats.npy", image_feats.numpy())
    np.save("artifacts/labels.npy", labels)
    with open("artifacts/zero_shot_results.json", "w") as f:
        json.dump(results, f, indent=2)
    print("Zero-shot results:", results)
    return image_feats, labels, results

In [146]:
evaluated_prompts = evaluate_prompts(model, test_loader, test, text_prototypes)

Encoding images: 100%|██████████| 250/250 [00:27<00:00,  9.22it/s]

torch.float16 torch.float16
torch.float16 torch.float16
torch.float16 torch.float16
Zero-shot results: {'plain': {'acc': 95.375}, 'short': {'acc': 96.72500000000001}, 'variants': {'acc': 96.6125}}





In [147]:
@torch.no_grad()
def paired_embeddings_for_subset(image_feats_all: torch.Tensor,
                                 labels: np.ndarray,
                                 text_protos_for_analysis: torch.Tensor,
                                 n_samples: int = 100,
                                 seed: int = 0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    idxs = rng.choice(len(labels), size=min(n_samples, len(labels)), replace=False)
    X = image_feats_all[idxs]                         # [k,D]
    Y = text_protos_for_analysis[labels[idxs].astype(int)]        # ground-truth text proto for each sample
    return idxs, X.cpu().numpy(), Y.detach().cpu().numpy()

In [148]:
paired_subset_embeddings=paired_embeddings_for_subset(encoded_images, test.labels, text_prototypes['variants'], n_samples=100, seed=0)

In [149]:
idxs, X, Y = paired_subset_embeddings
print(f"Size of indices: {idxs.size}")
print(f"Shape of image features subset: {X.shape}")
print(f"Shape of text prototypes subset: {Y.shape}")

Size of indices: 100
Shape of image features subset: (100, 512)
Shape of text prototypes subset: (100, 512)


In [150]:
# 2D projections (t-SNE / UMAP)
def project_2d(X: np.ndarray, Y: np.ndarray, method="tsne", seed=0):
    Z = np.concatenate([X, Y], axis=0)
    if method == "umap" and HAVE_UMAP:
        reducer = umap.UMAP(random_state=seed, n_neighbors=15, min_dist=0.1)
        Z2 = reducer.fit_transform(Z)
    else:
        Z2 = TSNE(n_components=2, random_state=seed, init="pca", perplexity=min(30, max(5, len(Z)//10))).fit_transform(Z)
    X2 = Z2[:len(X)]
    Y2 = Z2[len(X):]
    return X2, Y2

In [151]:
project_2d_tsne=project_2d(X, Y, method="tsne", seed=0)
project_2d_umap=project_2d(X, Y, method="umap", seed=0)

  warn(


In [152]:
print(f"Shape of image features subset: {project_2d_tsne[0].shape}")
print(f"Shape of text prototypes subset: {project_2d_tsne[1].shape}")
print(f"Shape of image features subset: {project_2d_umap[0].shape}")
print(f"Shape of text prototypes subset: {project_2d_umap[1].shape}")

Shape of image features subset: (100, 2)
Shape of text prototypes subset: (100, 2)
Shape of image features subset: (100, 2)
Shape of text prototypes subset: (100, 2)


In [153]:
def plot_modalities(X2, Y2, title, fname):
    plt.figure(figsize=(6,6))
    plt.scatter(X2[:,0], X2[:,1], s=16, alpha=0.7, label="image feats")
    plt.scatter(Y2[:,0], Y2[:,1], s=16, alpha=0.7, marker="x", label="text feats")
    plt.legend(); plt.title(title); plt.tight_layout()
    plt.savefig(fname, dpi=150); plt.close()

In [154]:
plot_modalities_tsne = plot_modalities(project_2d_tsne[0], project_2d_tsne[1], "Pre-alignment (t-SNE)", "artifacts/modality_gap_pre_tsne.png")
plot_modalities_umap = plot_modalities(project_2d_tsne[0], project_2d_tsne[1], "Pre-alignment (UMAP)", "artifacts/modality_gap_pre_umap.png")

In [155]:
# Procrustes alignment (orthogonal)
def orthogonal_align(X: np.ndarray, Y: np.ndarray, l2norm_before=True) -> Tuple[np.ndarray, np.ndarray]:
    Xc = X.copy(); Yc = Y.copy()
    if l2norm_before:
        Xc = Xc / (np.linalg.norm(Xc, axis=1, keepdims=True) + 1e-9)
        Yc = Yc / (np.linalg.norm(Yc, axis=1, keepdims=True) + 1e-9)
    R, _ = orthogonal_procrustes(Xc, Yc)  # finds R minimizing ||X R - Y||_F  (R is orthogonal)
    X_aligned = X @ R
    return X_aligned, R

In [156]:
X_aligned, R = orthogonal_align(X, Y, l2norm_before=True)

In [157]:
print(f"Shape of image features subset: {X_aligned.shape}")
print(f"Shape of text prototypes subset: {R.shape}")

Shape of image features subset: (100, 512)
Shape of text prototypes subset: (512, 512)


In [158]:
@torch.no_grad()
def apply_rotation_to_all(image_feats_all: torch.Tensor, R: np.ndarray) -> torch.Tensor:
    A = torch.from_numpy(R).to(image_feats_all.device).float()  # [D,D]
    return (image_feats_all @ A.T).contiguous()

In [159]:
rotated_images = apply_rotation_to_all(encoded_images, R)

In [160]:
print(f"Shape of image features subset: {rotated_images.shape}")

Shape of image features subset: torch.Size([8000, 512])


In [161]:
model.eval()

# (1) Build prompt sets & text prototypes
prompt_sets = prompt_sets()
text_protos = build_text_prototypes(model, prompt_sets)  # dict name -> [C,D]
# Add also a reference text proto bank we’ll use for pairing/visualization (take the strongest = variants)
text_proto_for_analysis = text_protos["variants"]  # [C,D]

image_feats_all, labels, results0 = evaluate_prompts(model, test_loader, test, text_protos)

idxs, X_im, Y_txt = paired_embeddings_for_subset(encoded_images, labels, text_proto_for_analysis, n_samples=100, seed=0)
for meth in ["tsne", "umap"]:
  X2, Y2 = project_2d(X_im, Y_txt, method=meth, seed=0)
  plot_modalities(X2, Y2, f"Pre-alignment ({meth.upper()})", f"artifacts/modality_gap_pre_{meth}.png")
X_aligned_norm, R_norm = orthogonal_align(X_im, Y_txt, l2norm_before=True)
X2a, Y2a = project_2d(X_aligned_norm, Y_txt, method="tsne", seed=0)
plot_modalities(X2a, Y2a, "Post-alignment (t-SNE, L2 before Procrustes)", "artifacts/modality_gap_post_tsne_norm.png")

X_aligned_raw, R_raw = orthogonal_align(X_im, Y_txt, l2norm_before=False)
X2b, Y2b = project_2d(X_aligned_raw, Y_txt, method="tsne", seed=0)
plot_modalities(X2b, Y2b, "Post-alignment (t-SNE, NO L2 before Procrustes)", "artifacts/modality_gap_post_tsne_noL2.png")

# (3f) Recompute zero-shot accuracy with aligned image features (use same text prototypes)
# We apply R learned on subset to ALL image features, then evaluate
img_feats_rot_norm = apply_rotation_to_all(image_feats_all, R_norm)
img_feats_rot_raw  = apply_rotation_to_all(image_feats_all, R_raw)

aligned_results = {}
for name, protos in text_protos.items():
    acc_norm, _ = zero_shot_eval(img_feats_rot_norm, protos.to(img_feats_rot_norm.device), labels)
    acc_raw,  _  = zero_shot_eval(img_feats_rot_raw,  protos.to(img_feats_rot_raw.device),  labels)
    aligned_results[name] = {"acc_procrustes_L2": acc_norm, "acc_procrustes_noL2": acc_raw}

with open("artifacts/zero_shot_aligned_results.json", "w") as f:
    json.dump(aligned_results, f, indent=2)

print("\n=== Baseline zero-shot (%) ===")
for k,v in results0.items(): print(f"{k:10s}: {v['acc']:.2f}")

print("\n=== After Procrustes (train on subset, apply to all) ===")
for k,v in aligned_results.items():
    print(f"{k:10s}: L2 {v['acc_procrustes_L2']:.2f} | noL2 {v['acc_procrustes_noL2']:.2f}")

# Save report
report = {
    "baseline": results0,
    "aligned": aligned_results,
    "notes": [
        "CLIP features are L2-normalized; cosine similarity is dot product.",
        "Prompt ensembling (variants) typically > short > plain.",
        "Procrustes is orthogonal (rotation/reflection); preserves norms and pairwise distances.",
        "L2-normalizing features before Procrustes often yields better alignment (consistent scale)."
    ]
}
with open("artifacts/report_summary.json", "w") as f:
    json.dump(report, f, indent=2)

Encoding images: 100%|██████████| 250/250 [00:26<00:00,  9.61it/s]


torch.float16 torch.float16
torch.float16 torch.float16
torch.float16 torch.float16
Zero-shot results: {'plain': {'acc': 95.375}, 'short': {'acc': 96.72500000000001}, 'variants': {'acc': 96.6125}}


  warn(


torch.float16 torch.float16
torch.float16 torch.float16
torch.float16 torch.float16
torch.float16 torch.float16
torch.float16 torch.float16
torch.float16 torch.float16

=== Baseline zero-shot (%) ===
plain     : 95.38
short     : 96.73
variants  : 96.61

=== After Procrustes (train on subset, apply to all) ===
plain     : L2 59.70 | noL2 64.45
short     : L2 80.99 | noL2 68.19
variants  : L2 80.44 | noL2 70.89
