In [None]:
import os
import math
import random
import pickle

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm.auto import tqdm
import clip
from sklearn.model_selection import train_test_split

In [None]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# ── 1. CIFAR‑10 batch loading ────────────────────────────────────────────────
CIFAR_CLASSES = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

def unpickle_file(path):
    with open(path, "rb") as f:
        return pickle.load(f, encoding="bytes")

def load_cifar(root):
    xs, ys = [], []
    for batch_name in [f"data_batch_{i}" for i in range(1, 6)] + ["test_batch"]:
        data = unpickle_file(os.path.join(root, batch_name))
        xs.append(data[b"data"])
        ys.extend(data[b"labels"])
    X = np.vstack(xs).astype(np.uint8)  # (60000, 3072)
    y = np.array(ys, dtype=np.int64)   # (60000,)
    return X, y

In [None]:
# ── 2. Tokenizer for CLIP text ───────────────────────────────────────────────
_tokenizer = clip.simple_tokenizer.SimpleTokenizer()
def tokenize(texts, context_length: int = 77) -> torch.LongTensor:
    if isinstance(texts, str):
        texts = [texts]
    sot = _tokenizer.encoder["<|startoftext|>"]
    eot = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot] + _tokenizer.encode(t) + [eot] for t in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
    for i, tokens in enumerate(all_tokens):
        length = min(len(tokens), context_length)
        result[i, :length] = torch.tensor(tokens[:length])
        if len(tokens) > context_length:
            result[i, -1] = tokens[-1]
    return result

# ── 3. Dataset with optional label corruption ────────────────────────────────
class CIFARDataset(Dataset):
    def __init__(self, X, y, indices, transform, prob=0.0):
        self.X = X
        self.y = y
        self.indices = indices
        self.transform = transform
        self.prob = prob

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        k = self.indices[idx]
        img_arr = self.X[k].reshape(3, 32, 32).transpose(1, 2, 0)
        img = Image.fromarray(img_arr, mode="RGB")

        # label corruption with probability prob
        if random.random() < self.prob:
            label = self.y[random.choice(self.indices)]
        else:
            label = self.y[k]

        img = self.transform(img)
        text_prompt = f"a photo of a {CIFAR_CLASSES[label]}"
        text = tokenize(text_prompt)[0]
        return img, text

In [None]:
cifar_root = "/home/hice1/asubramanian91/scratch/cifar-10-batches-py"
X, y = load_cifar(cifar_root)

train_idx, test_idx = train_test_split(
    np.arange(len(X)), test_size=0.2, random_state=SEED, stratify=y
)

# load CLIP and its preprocessing
model, preprocess_clip = clip.load("ViT-B/32", device=device, jit=False)
model = model.to(device).float() 

# add data augmentations before CLIP's preprocess
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
    preprocess_clip
])

batch_size = 256
prob = 0.05  # label corruption probability

train_ds = CIFARDataset(X, y, train_idx, transform, prob=prob)
test_ds  = CIFARDataset(X, y, test_idx,  transform, prob=0.0)

dltrain = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
dltest  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=4)


In [None]:
# ── 5. Optimizer & Scheduler & Loss ──────────────────────────────────────────
head_names = {"visual_projection", "text_projection", "logit_scale"}
head_params, backbone_params = [], []
for name, p in model.named_parameters():
    if name.split(".")[0] in head_names:
        head_params.append(p)
    else:
        backbone_params.append(p)

optim = torch.optim.AdamW(
    [
        {"params": head_params,     "lr": 1e-4, "weight_decay": 0.2},
        {"params": backbone_params, "lr": 3e-5, "weight_decay": 0.2},
    ],
    betas=(0.9, 0.98), eps=1e-8
)

n_epochs = 20
warmup_steps = 500
total_steps = n_epochs * len(dltrain)

def lr_lambda(step):
    if step < warmup_steps:
        return (step + 1) / warmup_steps
    progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    return 0.5 * (1 + math.cos(math.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda)
criterion = nn.CrossEntropyLoss()

# initialise logit_scale safely
with torch.no_grad():
    model.logit_scale.fill_(math.log(1 / 0.07))
    model.logit_scale.clamp_(0, math.log(100))

model = model.to(device)

In [None]:
EPS = 1e-6
LOGIT_CLAMP = math.log(100)

class RollingMean:
    def __init__(self): self.count = 0; self.avg = 0.0
    def update(self, v):
        self.avg = (self.avg * self.count + v) / (self.count + 1)
        self.count += 1
    def __call__(self): return self.avg

loss_history = []

for epoch in range(1, n_epochs + 1):
    # Training
    model.train()
    rm = RollingMean()
    with tqdm(dltrain, desc=f"Epoch {epoch}/{n_epochs}") as bar:
        for images, texts in bar:
            images, texts = images.to(device), texts.to(device)
            
            I_e = model.encode_image(images)
            T_e = model.encode_text(texts)

            if torch.isnan(I_e).any() or torch.isnan(T_e).any():
                print("↯ NaN in embedding! I_e:", I_e[torch.isnan(I_e)], 
                      "T_e:", T_e[torch.isnan(T_e)])
                break
            
            I_e = F.normalize(I_e, dim=1, eps=EPS)
            T_e = F.normalize(T_e, dim=1, eps=EPS)

            # debug: after normalization?
            if torch.isinf(I_e).any() or torch.isinf(T_e).any():
                print("↯ Inf after normalize! I_e:", I_e[torch.isinf(I_e)],
                      "T_e:", T_e[torch.isinf(T_e)])
                break

            with torch.no_grad():
                model.logit_scale.data.clamp_(0, LOGIT_CLAMP)
            logit_scale = model.logit_scale.exp()

            # debug: logit_scale okay?
            if not torch.isfinite(logit_scale):
                print("↯ Bad logit_scale:", model.logit_scale.data, "exp:", logit_scale)
                break

            logits_i = I_e @ T_e.T * logit_scale
            logits_t = T_e @ I_e.T * logit_scale

            # debug: logits
            if torch.isnan(logits_i).any() or torch.isinf(logits_i).any():
                print("↯ Bad logits_i:", logits_i[~torch.isfinite(logits_i)])
                break
            labels = torch.arange(images.size(0), device=device)

            loss = 0.5 * (criterion(logits_i, labels) + criterion(logits_t, labels))

            optim.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optim.step()
            scheduler.step()

            rm.update(loss.item())
            bar.set_postfix(train_loss=f"{rm():.4f}")

    # Validation
    model.eval()
    val_losses = []
    with torch.no_grad(), tqdm(dltest, desc="Val  ", leave=False) as bar:
        for images, texts in bar:
            images, texts = images.to(device), texts.to(device)
            
            I_e = model.encode_image(images)
            T_e = model.encode_text(texts)
            I_e = F.normalize(I_e, dim=1, eps=EPS)
            T_e = F.normalize(T_e, dim=1, eps=EPS)

            model.logit_scale.data.clamp_(0, LOGIT_CLAMP)
            logit_scale = model.logit_scale.exp()

            logits_i = I_e @ T_e.T * logit_scale
            logits_t = T_e @ I_e.T * logit_scale
            labels = torch.arange(images.size(0), device=device)

            val_loss = 0.5 * (criterion(logits_i, labels) + criterion(logits_t, labels))
            val_losses.append(val_loss.item())

    mean_val = float(np.mean(val_losses))
    loss_history.append(mean_val)
    print(f"   ✦ Validation loss: {mean_val:.4f}")

print("Done! Validation losses:", loss_history)