In [None]:
# Filename: cs771_mp2_lwp_pipeline.py
# Put this in a notebook (.ipynb) or run as a script after adjusting paths.

import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms, datasets, models
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import accuracy_score
import joblib
import pickle

# -------------------------
# Config / Hyperparameters
# -------------------------
DATA_ROOT = "data"   # where you put downloaded D1..D20 and heldout folders
NUM_CLASSES = 10
BATCH_SIZE = 128
FEAT_BACKBONE = "resnet18"   # options: resnet18, resnet50
USE_PCA = True               # reduce feature dim (stabilizes prototype distances)
PCA_DIM = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Task1 (same distribution) hyperparams (more aggressive pseudo-labeling)
TASK1_CONF_THRESH = 0.7
TASK1_ALPHA = 0.25   # EMA update weight for prototypes
# Task2 (different distributions) hyperparams (conservative)
TASK2_CONF_THRESH = 0.9
TASK2_ALPHA = 0.08
TASK2_DISTANCE_FACTOR = 1.5  # allows update only if mean distance <= factor * class_std

# -------------------------
# Utilities: image loader
# -------------------------
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std  = [0.229, 0.224, 0.225]
transform = transforms.Compose([
    transforms.Resize(224),            # backbone expects larger images than CIFAR
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

def list_image_files(folder):
    """Return full-path list of image files (non-recursive)."""
    exts = {".png", ".jpg", ".jpeg", ".bmp"}
    files = [os.path.join(folder, f) for f in os.listdir(folder)
             if os.path.splitext(f.lower())[1] in exts]
    files.sort()
    return files

class GenericImageDataset(Dataset):
    """Handles both labeled (class-subfolders) and unlabeled (all images) cases."""
    def __init__(self, root_folder, transform=None, labeled=True):
        self.root = root_folder
        self.transform = transform
        self.labeled = labeled
        # detect labeled structure (class subfolders) automatically
        if labeled:
            # Use torchvision.ImageFolder semantics if there are subfolders
            # We'll detect whether the folder contains subfolders with images
            subfolders = [os.path.join(root_folder,d) for d in os.listdir(root_folder)
                          if os.path.isdir(os.path.join(root_folder,d))]
            has_classes = False
            for sf in subfolders:
                if list_image_files(sf):
                    has_classes = True
                    break
            if has_classes:
                # Build a flat list with class indices from folder names
                classes = sorted([d for d in os.listdir(root_folder) if os.path.isdir(os.path.join(root_folder,d))])
                self.class_to_idx = {c:i for i,c in enumerate(classes)}
                self.samples = []
                for c in classes:
                    for f in list_image_files(os.path.join(root_folder,c)):
                        self.samples.append((f, self.class_to_idx[c]))
            else:
                # no labeled subfolders -> treat as unlabeled
                self.labeled = False
                self.samples = [(f, -1) for f in list_image_files(root_folder)]
        else:
            self.samples = [(f, -1) for f in list_image_files(root_folder)]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        fpath, lbl = self.samples[idx]
        img = Image.open(fpath).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, lbl, fpath

# -------------------------
# Feature extractor
# -------------------------
def build_backbone(name="resnet18", device="cpu"):
    if name == "resnet18":
        model = models.resnet18(pretrained=True)
        feat_dim = model.fc.in_features
    elif name == "resnet50":
        model = models.resnet50(pretrained=True)
        feat_dim = model.fc.in_features
    else:
        raise ValueError("Unknown backbone")
    # remove final fc
    modules = list(model.children())[:-1]
    feat_net = nn.Sequential(*modules)
    feat_net.to(device).eval()
    return feat_net, feat_dim

def extract_features_for_folder(folder, backbone, batch_size=128, transform=transform, labeled=True, device="cpu"):
    ds = GenericImageDataset(folder, transform=transform, labeled=labeled)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    feats = []
    labels = []
    paths = []
    with torch.no_grad():
        for imgs, lbls, fpaths in tqdm(dl, desc=f"Feat extract {os.path.basename(folder)}"):
            imgs = imgs.to(device)
            out = backbone(imgs)
            out = out.view(out.size(0), -1).cpu().numpy()
            feats.append(out)
            labels.extend([int(x) for x in lbls])
            paths.extend(fpaths)
    feats = np.vstack(feats)
    labels = np.array(labels)   # -1 for unlabeled
    return feats, labels, paths

# -------------------------
# LwP classifier
# -------------------------
def softmax_rows(x):
    x = x - x.max(axis=1, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=1, keepdims=True)

class LwPClassifier:
    """
    Prototype-based classifier.
    prototypes: numpy array (C, D)
    counts: pseudo-counts used for Bayesian mean update (keeps model param fixed size)
    """
    def __init__(self, num_classes, feat_dim, use_prob=False):
        self.C = num_classes
        self.D = feat_dim
        self.use_prob = use_prob
        self.prototypes = np.zeros((self.C, self.D), dtype=np.float32)
        self.counts = np.zeros((self.C,), dtype=np.float32) + 1e-6
        # For probabilistic variant keep diagonal variances
        if self.use_prob:
            self.var = np.ones((self.C, self.D), dtype=np.float32)

    def fit(self, feats, labels):
        """Fit prototypes from labeled data (feats: N x D, labels: N)."""
        for c in range(self.C):
            idx = np.where(labels == c)[0]
            if len(idx) > 0:
                self.prototypes[c] = feats[idx].mean(axis=0)
                self.counts[c] = len(idx)

    def predict_proba(self, feats, temperature=1.0):
        """Return class probabilities via negative Euclidean distances + softmax."""
        # feats: N x D
        # compute squared euclid dist: (a-b)^2 = a^2 + b^2 - 2ab
        a2 = np.sum(feats * feats, axis=1, keepdims=True)   # N x 1
        b2 = np.sum(self.prototypes * self.prototypes, axis=1)   # C
        ab = feats.dot(self.prototypes.T)   # N x C
        d2 = a2 + b2 - 2*ab   # N x C
        scores = -np.sqrt(np.maximum(d2, 0.0)) / max(1e-8, temperature)  # negative distance
        probs = softmax_rows(scores)
        return probs, d2  # returns distances also for optional gating

    def predict(self, feats):
        probs, d2 = self.predict_proba(feats)
        preds = probs.argmax(axis=1)
        confs = probs.max(axis=1)
        return preds, confs, d2

    def update_with_pseudo(self, feats, preds, confs, distances=None,
                           conf_thresh=0.7, alpha=0.2, conservative_mask=None,
                           distance_factor=1.5):
        """
        Updates prototypes using pseudo-labeled samples.
        - Only samples with conf >= conf_thresh are used.
        - alpha controls EMA (prototype <- (1-alpha)*proto + alpha*mean_new )
        - conservative_mask (optional): boolean per-class specifying whether to be extra conservative
        - distance_factor: for conservative mode, only update if mean distance <= distance_factor * class_std
        """
        # compute per-class updates
        for c in range(self.C):
            idx = np.where(preds == c)[0]
            if len(idx) == 0:
                continue
            # select confident subset
            idx_conf = idx[confs[idx] >= conf_thresh]
            if len(idx_conf) == 0:
                continue
            # compute mean of confident features
            new_mean = feats[idx_conf].mean(axis=0)
            # optional conservative gating by distance: require new_mean close to prototype
            if distances is not None:
                # distances provided as d2 (N x C)
                mean_dist = np.sqrt(np.mean(distances[idx_conf, c]))
                # compute per-class std from current prototypes via distances to proto - approximate
                # approximate class std: sqrt(mean of squared dist for samples assigned to class using current proto)
                # we use counts to estimate stability; fallback to small value if counts low
                class_std = 0.0
                if self.counts[c] > 1:
                    # approximate by computing prot-dist for prior pseudo-samples - we don't have them; use a heuristic
                    class_std = max(1.0, np.sqrt(np.mean(distances[:, c])))  # fallback heuristic
                else:
                    class_std = max(1.0, mean_dist)
                if mean_dist > distance_factor * class_std:
                    # skip update as cluster is too far -> likely wrong pseudo-labels
                    continue
            # update via EMA while preserving proto shape (keeps parameter count same)
            self.prototypes[c] = (1.0 - alpha) * self.prototypes[c] + alpha * new_mean
            # bump counts (we do not access original labeled data; counts are part of model params)
            self.counts[c] += len(idx_conf)

    def save(self, path):
        with open(path, "wb") as f:
            pickle.dump({"prototypes": self.prototypes, "counts": self.counts}, f)

    def load(self, path):
        with open(path, "rb") as f:
            d = pickle.load(f)
        self.prototypes = d["prototypes"]
        self.counts = d["counts"]

# -------------------------
# High-level training loops for Task1 and Task2
# -------------------------
def evaluate_on_heldouts(model, heldout_feats_dict, heldout_labels_dict, upto_index):
    """
    Evaluate model on D_hat1..D_hat_{upto_index} (1-based indexing).
    heldout_feats_dict: {"D_hat1": feats, ...}
    returns list of accuracies for datasets 1..upto_index
    """
    accs = []
    for i in range(1, upto_index+1):
        key = f"D_hat{i}"
        feats = heldout_feats_dict[key]
        labels = heldout_labels_dict[key]
        preds, confs, _ = model.predict(feats)
        acc = accuracy_score(labels, preds)
        accs.append(acc)
    return accs

def run_task1(backbone, feat_dim, data_root=DATA_ROOT,
              use_pca=USE_PCA, pca_dim=PCA_DIM):
    """
    Implements Task 1 (D1..D10), returns accuracy matrix (10x10)
    """
    # 1) Precompute features for D1..D10 and heldouts D_hat1..D_hat10
    train_feats = {}
    train_labels = {}
    for i in range(1, 11):
        folder = os.path.join(data_root, f"D{i}")
        feats, labels, paths = extract_features_for_folder(folder, backbone, batch_size=BATCH_SIZE, transform=transform, labeled=True, device=DEVICE)
        train_feats[f"D{i}"] = feats
        train_labels[f"D{i}"] = labels

    held_feats = {}
    held_labels = {}
    for i in range(1, 11):
        folder = os.path.join(data_root, "heldout", f"D_hat{i}")
        feats, labels, paths = extract_features_for_folder(folder, backbone, batch_size=BATCH_SIZE, transform=transform, labeled=True, device=DEVICE)
        held_feats[f"D_hat{i}"] = feats
        held_labels[f"D_hat{i}"] = labels

    # Optional PCA fit on D1 features (or all D1..D10 labeled if available)
    if use_pca:
        pca = TruncatedSVD(n_components=pca_dim, random_state=0)
        # fit PCA on D1 labeled features only (per assignment D1 is the only labeled among first 10)
        pca.fit(train_feats["D1"])
        for k in list(train_feats.keys()):
            train_feats[k] = pca.transform(train_feats[k])
        for k in list(held_feats.keys()):
            held_feats[k] = pca.transform(held_feats[k])
        feat_dim = pca_dim
    else:
        pca = None

    # Initialize model f1 from D1
    model = LwPClassifier(NUM_CLASSES, feat_dim)
    model.fit(train_feats["D1"], train_labels["D1"])

    # We'll store prototype snapshot after each step
    models_snapshots = {}
    accuracy_matrix = np.zeros((10,10), dtype=float)

    # sequential update for D1..D10
    for i in range(1, 11):
        # model is f_i
        # Evaluate on heldouts 1..i
        accs = evaluate_on_heldouts(model, held_feats, held_labels, upto_index=i)
        accuracy_matrix[i-1, :i] = accs
        print(f"f{i} accuracies on D_hat1..D_hat{i} = {np.round(accs,4)}")
        models_snapshots[f"f{i}"] = {"prototypes": model.prototypes.copy(), "counts": model.counts.copy()}
        # if last, break (no D_{i+1} to update from)
        if i == 10:
            break
        # predict on next dataset D_{i+1} (unlabeled) and update
        feats_next = train_feats[f"D{i+1}"]
        preds, confs, d2 = model.predict(feats_next)
        # update with pseudo labels using TASK1 aggressive hyperparams
        model.update_with_pseudo(feats_next, preds, confs, distances=d2,
                                 conf_thresh=TASK1_CONF_THRESH, alpha=TASK1_ALPHA)
        print(f"Updated f{i} -> f{i+1} using D{i+1} pseudo-labels.")
    return model, models_snapshots, accuracy_matrix, pca

def run_task2(backbone, feat_dim, starting_model_snapshot, data_root=DATA_ROOT,
              use_pca=USE_PCA):
    """
    Implements Task 2 (starting from f10) over D11..D20.
    starting_model_snapshot: a dict with 'prototypes' and 'counts' from f10
    returns accuracy matrix (10 x 20) for f11..f20 on D_hat1..D_hat20
    """
    # Precompute D11..D20 features and heldouts D_hat1..D_hat20 (we need heldouts 1..20)
    train_feats = {}
    for i in range(11, 21):
        folder = os.path.join(data_root, f"D{i}")
        feats, labels, paths = extract_features_for_folder(folder, backbone, batch_size=BATCH_SIZE, transform=transform, labeled=False, device=DEVICE)
        train_feats[f"D{i}"] = feats

    held_feats = {}
    held_labels = {}
    for i in range(1, 21):
        folder = os.path.join(data_root, "heldout", f"D_hat{i}")
        feats, labels, paths = extract_features_for_folder(folder, backbone, batch_size=BATCH_SIZE, transform=transform, labeled=True, device=DEVICE)
        held_feats[f"D_hat{i}"] = feats
        held_labels[f"D_hat{i}"] = labels

    # If PCA was used in Task1, you must transform the new features using same PCA (passed as param).
    # (Assume caller will apply PCA transform to these feats if needed)

    # Initialize model f10 from snapshot
    model = LwPClassifier(NUM_CLASSES, feat_dim)
    model.prototypes = starting_model_snapshot["prototypes"].copy()
    model.counts = starting_model_snapshot["counts"].copy()

    accuracy_matrix = np.zeros((10, 20), dtype=float)
    models_snapshots = {}

    # sequential updates D11..D20 producing f11..f20
    for idx, i in enumerate(range(11, 21), start=1):
        # update model using current dataset D_i to produce f_i
        feats_next = train_feats[f"D{i}"]
        preds, confs, d2 = model.predict(feats_next)

        # Adaptive conservative scheme:
        # - compute class-wise mean confidence; only update classes with high mean confidence
        # - compute mean distance to prototype and require it not to be too large (gating)
        mean_confidence_per_class = []
        for c in range(NUM_CLASSES):
            idx_c = np.where(preds == c)[0]
            if len(idx_c) == 0:
                mean_confidence_per_class.append(0.0)
            else:
                mean_confidence_per_class.append(float(confs[idx_c].mean()))
        mean_confidence_per_class = np.array(mean_confidence_per_class)

        # Build conservative mask: only classes with mean_conf >= TASK2_CONF_THRESH
        conservative_mask = mean_confidence_per_class >= TASK2_CONF_THRESH

        # Use distance gating as well: compute mean sqrt distance for confident predicted samples per class
        # We'll pass distances into update method to implement gating
        model.update_with_pseudo(feats_next, preds, confs, distances=d2,
                                 conf_thresh=TASK2_CONF_THRESH, alpha=TASK2_ALPHA,
                                 conservative_mask=conservative_mask,
                                 distance_factor=TASK2_DISTANCE_FACTOR)
        # After update we have f_i (i between 11..20)
        model_idx = idx  # 1..10
        # Evaluate f_i on D_hat1..D_hat{i} (all previous heldouts)
        upto = i  # D_hat1..D_hati (but accuracy matrix expects columns as D_hat1..D_hat20)
        # compute accuracies on D_hat1..D_hat{i}
        accs = []
        for j in range(1, i+1):
            feats = held_feats[f"D_hat{j}"]
            labels = held_labels[f"D_hat{j}"]
            preds_h, confs_h, _ = model.predict(feats)
            accs.append(accuracy_score(labels, preds_h))
        # Fill into accuracy matrix: row model_idx-1, columns 0..i-1
        accuracy_matrix[model_idx-1, :i] = accs
        models_snapshots[f"f{i}"] = {"prototypes": model.prototypes.copy(), "counts": model.counts.copy()}
        print(f"f{i} updated and evaluated on heldouts 1..{i}: {np.round(accs,4)}")

    return model, models_snapshots, accuracy_matrix

# -------------------------
# Example main driver
# -------------------------
if __name__ == "__main__":
    # Build backbone
    backbone, feat_dim = build_backbone(FEAT_BACKBONE, device=DEVICE)
    # Run Task1
    final_model_f10, snaps_task1, acc_mat_task1, pca = run_task1(backbone, feat_dim, data_root=DATA_ROOT)
    print("Task1 accuracy matrix (10x10):")
    print(np.round(acc_mat_task1, 4))
    # Take f10 snapshot for Task2:
    f10_snapshot = snaps_task1["f10"]
    # If PCA was used in Task1, apply same PCA to D11..D20 and heldouts in Task2
    if USE_PCA and pca is not None:
        # transform D11..D20 and heldout features using pca inside run_task2 or pre-transform files
        # For simplicity this script currently assumes run_task2 will read raw feats and you should apply the same `pca.transform`
        pass
    # Run Task2 (user must ensure D11..D20/heldout features are PCA-transformed if PCA used)
    final_model_f20, snaps_task2, acc_mat_task2 = run_task2(backbone, feat_dim, f10_snapshot, data_root=DATA_ROOT, use_pca=USE_PCA)
    print("Task2 accuracy matrix (10x20) for f11..f20 rows and D_hat1..D_hat20 columns")
    print(np.round(acc_mat_task2, 4))
    # Save matrices
    np.save("acc_mat_task1.npy", acc_mat_task1)
    np.save("acc_mat_task2.npy", acc_mat_task2)
    # Save prototype snapshots
    joblib.dump(snaps_task1, "snaps_task1.pkl")
    joblib.dump(snaps_task2, "snaps_task2.pkl")
    print("Saved results: acc matrices and snapshots.")
