In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import clip  # pip install git+https://github.com/openai/CLIP.git


In [2]:
# ── Hyperparameters ─────────────────────────────────────────────────────────────
batch_size   = 128       # training batch size
epochs       = 10        # number of epochs to train
lr           = 2e-5      # learning rate
weight_decay = 1e-4      # optimizer weight decay
noise_prob   = 0.5       # probability [0–1] to corrupt each label
# ────────────────────────────────────────────────────────────────────────────────

In [3]:
def corrupt_labels(labels: torch.Tensor, num_classes: int, noise_prob: float) -> torch.Tensor:
    """
    With probability noise_prob, replace each label with a random class in [0, num_classes-1].
    """
    if noise_prob <= 0:
        return labels
    mask = torch.rand(labels.shape, device=labels.device) < noise_prob
    random_labels = torch.randint(num_classes, labels.shape, device=labels.device)
    return torch.where(mask, random_labels, labels)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CLIP and get its vision encoder
model, _ = clip.load("ViT-B/32", device=device)
model = model.float().to(device)
vision   = model.visual
embed_dim  = vision.output_dim   # typically 512
num_classes = 10

# New classification head
classifier = nn.Linear(embed_dim, num_classes).to(device)

# Data transforms (resize to 224×224 + normalize with CLIP stats)
normalize = transforms.Normalize(
    mean=(0.48145466, 0.4578275, 0.40821073),
    std=(0.26862954, 0.26130258, 0.27577711)
)
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize,
])

# Datasets & loaders
train_ds = torchvision.datasets.CIFAR10("data", train=True,  download=True, transform=train_transform)
test_ds  = torchvision.datasets.CIFAR10("data", train=False, download=True, transform=test_transform)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# Optimizer & loss
params    = list(vision.parameters()) + list(classifier.parameters())
optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()

for epoch in range(1, epochs + 1):
    # — Training —
    vision.train(); classifier.train()
    running_loss = 0.0
    correct = total = 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)

        # inject label noise
        noisy_labels = corrupt_labels(labels, num_classes, noise_prob)

        optimizer.zero_grad()
        feats  = vision(imgs)
        logits = classifier(feats)
        loss   = criterion(logits, noisy_labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(dim=1)
        # measure accuracy against the true labels
        correct += (preds == labels).sum().item()
        total   += labels.size(0)

    train_loss = running_loss / total
    train_acc  = correct / total

    # — Evaluation —
    vision.eval(); classifier.eval()
    correct = total = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            feats  = vision(imgs)
            logits = classifier(feats)
            preds  = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total   += labels.size(0)

    test_acc = correct / total

    print(f"Epoch {epoch:2d} | "
            f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.3f} | "
            f"Test Acc: {test_acc:.3f}")

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same