In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from sklearn.manifold import TSNE

from contrastive_pretrain_resnet50 import SimCLR, get_encoder_resnet50_cifar


# ============================================================
# 1. Frozen Encoder Wrapper
# ============================================================

class FrozenEncoder(nn.Module):
    def __init__(self, ckpt_path, device):
        super().__init__()
        # Load full SimCLR model (encoder + projection)
        self.model = SimCLR()
        state = torch.load(ckpt_path, map_location=device)
        self.model.load_state_dict(state)

        # Keep ONLY the encoder, not the projector
        self.encoder = self.model.encoder.to(device)
        self.encoder.eval()

    @torch.no_grad()
    def forward(self, x):
        h = self.encoder(x)          # (B, 2048, 1, 1)
        h = torch.flatten(h, 1)      # (B, 2048)
        return h

In [None]:
# ============================================================
# 2. CIFAR-10 Transform (test-time)
# ============================================================

test_t = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2470, 0.2435, 0.2616)
    )
])


# ============================================================
# 3. Load CIFAR-10 Train/Test for Downstream
# ============================================================

def load_cifar(data_dir="./data"):
    train = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=test_t)
    test  = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=test_t)

    train_loader = DataLoader(train, batch_size=256, shuffle=False)
    test_loader  = DataLoader(test, batch_size=256, shuffle=False)

    return train_loader, test_loader


# ============================================================
# 4. Extract Embeddings
# ============================================================

def extract_embeddings(encoder, loader, device):
    features = []
    labels = []

    encoder.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            h = encoder(x)
            features.append(h.cpu())
            labels.append(y)

    return torch.cat(features), torch.cat(labels)


# ============================================================
# 5. Classifier Head (MLP)
# ============================================================

class Classifier(nn.Module):
    def __init__(self, in_dim=2048, num_classes=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        return self.net(x)


# ============================================================
# 6. Train Classifier
# ============================================================

def train_classifier(x_train, y_train, x_test, y_test, device, epochs=1000):
    clf = Classifier(in_dim=x_train.size(1)).to(device)

    opt = torch.optim.Adam(clf.parameters(), lr=1e-3, weight_decay=1e-4)
    loss_fn = nn.CrossEntropyLoss()

    x_train, y_train = x_train.to(device), y_train.to(device)
    x_test, y_test = x_test.to(device), y_test.to(device)

    # ----------------------------------------
    # STORE METRICS
    # ----------------------------------------
    history = {
        "train_loss": [],
        "test_loss": [],
        "train_acc": [],
        "test_acc": [],
    }

    for ep in range(epochs):
        # -----------------------
        # TRAIN FOR ONE EPOCH
        # -----------------------
        clf.train()
        logits = clf(x_train)
        loss = loss_fn(logits, y_train)

        opt.zero_grad()
        loss.backward()
        opt.step()

        train_preds = logits.argmax(dim=1)
        train_acc = accuracy_score(y_train.cpu(), train_preds.cpu())

        # -----------------------
        # TEST EVAL THIS EPOCH
        # -----------------------
        clf.eval()
        with torch.no_grad():
            test_logits = clf(x_test)
            test_loss = loss_fn(test_logits, y_test)
            test_preds = test_logits.argmax(dim=1)
            test_acc = accuracy_score(y_test.cpu(), test_preds.cpu())

        # ----------------------------------------
        # STORE IN HISTORY
        # ----------------------------------------
        history["train_loss"].append(loss.item())
        history["test_loss"].append(test_loss.item())
        history["train_acc"].append(train_acc)
        history["test_acc"].append(test_acc)

        if (ep + 1) % 10 == 0:
            print(
                f"[Epoch {ep+1}/{epochs}] "
                f"TrainLoss={loss.item():.4f} TestLoss={test_loss.item():.4f} "
                f"TrainAcc={train_acc:.4f} TestAcc={test_acc:.4f}"
            )

    return clf, history

In [None]:
# ============================================================
# 7. Evaluation: Accuracy, Precision, Recall, F1, AUC
# ============================================================

def evaluate_classifier(clf, x_test, y_test, device):
    clf.eval()
    with torch.no_grad():
        logits = clf(x_test.to(device))
        probs = torch.softmax(logits, dim=1).cpu().numpy()
        preds = probs.argmax(axis=1)

    y_true = y_test.numpy()

    acc = accuracy_score(y_true, preds)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, preds, average="macro")

    try:
        auc = roc_auc_score(y_true, probs, multi_class="ovr")
    except:
        auc = None

    print("\n===== DOWNSTREAM RESULTS =====")
    print(f"Accuracy:  {acc:.4f}")
    print(f"Precision: {prec:.4f}")
    print(f"Recall:    {rec:.4f}")
    print(f"F1-score:  {f1:.4f}")
    print(f"AUC:       {auc:.4f}" if auc is not None else "AUC: N/A")

    return acc, prec, rec, f1, auc

In [None]:
# ============================================================
# 8. Embedding Visualization (optional)
# ============================================================

def visualize_embeddings(x, y, out_path="tsne_embeddings.png", show=True):
    x = x.numpy()
    y = y.numpy()

    tsne = TSNE(n_components=2, init='pca', learning_rate='auto')
    z = tsne.fit_transform(x)

    plt.figure(figsize=(8, 6))
    sc = plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10', s=4)
    plt.colorbar(sc, ticks=range(10))
    plt.title("t-SNE of Frozen Encoder Embeddings")

    # Save the figure
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)

    if show:
        plt.show()

    plt.close()

    print(f"[t-SNE] Saved embedding visualization to {out_path}")

In [None]:
# ============================================================
# 9. plotting curves
# ============================================================
def plot_curves(history, out_path="downstream_curves.png"):
    plt.figure(figsize=(12, 5))

    # --- LOSS ---
    plt.subplot(1, 2, 1)
    plt.plot(history["train_loss"], label="Train Loss")
    plt.plot(history["test_loss"], label="Test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss Curves")
    plt.legend()

    # --- ACCURACY ---
    plt.subplot(1, 2, 2)
    plt.plot(history["train_acc"], label="Train Accuracy")
    plt.plot(history["test_acc"], label="Test Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Accuracy Curves")
    plt.legend()

    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()

    print(f"[Plot] Saved learning curves to: {out_path}")

In [None]:
# ============================================================
# 10. Main Pipeline
# ============================================================

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    ckpt = "simclr_resnet50_multicrop.pt"   # pretrained checkpoint

    print("Loading frozen encoder...")
    encoder = FrozenEncoder(ckpt, device)

    print("Loading CIFAR-10 for downstream...")
    train_loader, test_loader = load_cifar()

    print("Extracting embeddings...")
    x_train, y_train = extract_embeddings(encoder, train_loader, device)
    x_test, y_test = extract_embeddings(encoder, test_loader, device)

    print("Training classifier...")
    clf, history = train_classifier(x_train, y_train, x_test, y_test, device)
    plot_curves(history)


    print("Evaluating classifier...")
    evaluate_classifier(clf, x_test, y_test, device)

    # Optional visualization
    visualize_embeddings(x_test, y_test, out_path="tsne_test_embeddings1.png")


if __name__ == "__main__":
    main()