<a href="https://colab.research.google.com/github/Sofa3xpert/CW2_ML/blob/main/CW2_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
########################################
# Module: env_setup.py
########################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import time
from collections import defaultdict, Counter
import os
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors

# Utility classes for logging
class AverageMeter:
    def __init__(self):
        self.sum = 0.0
        self.count = 0
    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
    @property
    def avg(self):
        return self.sum / self.count if self.count != 0 else 0.0

class Metric:
    def __init__(self):
        self.correct = 0
        self.total = 0
    def update_prediction(self, preds, targets):
        _, predicted = torch.max(preds, 1)
        self.correct += (predicted == targets).sum().item()
        self.total += targets.size(0)
    def calc_accuracy(self):
        return self.correct / self.total if self.total != 0 else 0.0

def mapping_func(name):
    # Example convex mapping: M(x) = x/(2-x) for x in [0,1]
    if name == "convex":
        return lambda beta: beta/(2-beta) if beta < 2 else 1.0
    elif name == "linear":
        return lambda beta: beta
    else:
        return lambda beta: beta

def evaluate_step(model, dataloader, device):
    model.eval()
    metric = Metric()
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs = imgs.to(device)
            outputs = model(imgs)
            metric.update_prediction(outputs, labels.to(device))
    return metric.calc_accuracy()


class EMA:
    def __init__(self, model, decay, device):
        self.shadow = {}
        self.decay = decay
        self.device = device
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()
    def update(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = self.decay * self.shadow[name] + (1.0 - self.decay) * param.data
    def apply_shadow(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                param.data = self.shadow[name].clone()

# Experiment arguments
class ExperimentArgs:
    def __init__(self):
        # Network and dataset parameters:
        self.network = "resnet18"           # or "resnet50" (change for experimentation purposes only)
        self.num_classes = 10               # CIFAR-10
        self.data = "CIFAR10"
        self.num_X = 0                      # Start with an empty labeled set (L₀ = ∅)
        self.include_x_in_u = True
        self.augs = None                    # Augmentations are defined in the data module
        self.batch_size = 64
        self.mu = 7

        # Semi-supervised training parameters:
        self.lr = 0.03                    # Learning rate for SSL training
        self.momentum = 0.9
        self.nesterov = True
        self.weight_decay = 5e-4
        self.iterations = 1048576         # Total training iterations (~2^20)

        ### WARNING: num of epochs reduced for limited session time in Colab ####

        self.epochs_semi = 10             # Number of epochs per active learning round

        # Supervised training parameters (for linear evaluation):


            ### WARNING: num of epochs reduced for limited session time in Colab ####
        self.epochs_supervised = 10

        # Other training fields:
        self.wandb = False
        self.mode = "train"
        self.load_path = "ckpt.pth"       # For SSL training checkpoint saving/loading
        self.ema_decay = 0.999
        self.amp = False                  # Set True if using Automatic Mixed Precision

        # Thresholding parameters:
        self.mapping = "convex"
        self.threshold = 0.95
        self.lu_weight = 1.0
        self.save_path = "./checkpoints"

        # Pre-trained representation checkpoint:
        # For ResNet18, checkpoint should be for SimCLR with ResNet18.
        # For ResNet50, update accordingly.
        self.simclr_checkpoint = "/content/simclr_cifar-10.pth.tar"
        self.use_typiclust_initial = True

        # Active Learning parameters:
        self.initial_size = 0             # L₀ is empty.
        self.budget = 10                  # Query 10 samples per round.
        self.active_rounds = 10           # Total active learning rounds.

        # For PPL or learning_status storage:
        self.learning_status = None

    def __repr__(self):
        return str(self.__dict__)

# Create instance of ExperimentArgs
exp_args = ExperimentArgs()

# Create checkpoint directory if it doesn't exist.
os.makedirs(exp_args.save_path, exist_ok=True)

In [25]:
########################################
# Module: representation.py
########################################
# This module implements SimCLR-based representation learning functions
# for both ResNet18 and ResNet50.

# MLP projection head: we reuse the one from env_setup
class MLPProjection(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(MLPProjection, self).__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(hidden_dim, out_dim)
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# SimCLR model wrapper
class SimCLR(nn.Module):
    def __init__(self, base_model, projector):
        super(SimCLR, self).__init__()
        self.encoder = base_model      # Produces features (dim depends on model)
        self.projector = projector     # Maps features to 128-dim projection
    def forward(self, x, return_features=False):
        features = self.encoder(x)
        if return_features:
            return F.normalize(features, dim=1)
        projection = self.projector(features)
        return F.normalize(projection, dim=1)

def get_simclr_model(network_type, checkpoint_path):
    """
    Returns a pre-trained SimCLR model.
    For 'resnet18': use a modified ResNet18 with 512-dim output.
    For 'resnet50': use a modified ResNet50 with 2048-dim output.
    """
    if network_type.lower() == "resnet18":
        base_model = models.resnet18(weights=None)
        base_model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        base_model.maxpool = nn.Identity()
        base_model.fc = nn.Identity()  # Output: 512-dim features.
        in_dim = 512
    elif network_type.lower() == "resnet50":
        base_model = models.resnet50(weights=None)
        base_model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        base_model.maxpool = nn.Identity()
        base_model.fc = nn.Identity()  # Output: 2048-dim features.
        in_dim = 2048
    else:
        raise ValueError("Unsupported network type")

    projector = MLPProjection(in_dim=in_dim, hidden_dim=in_dim, out_dim=128)
    model = SimCLR(base_model, projector)
    state_dict = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    return model

def extract_embeddings(model, dataloader, device):
    """
    Extracts L2-normalized embeddings from the encoder (penultimate layer).
    """
    embeddings, labels = [], []
    with torch.no_grad():
        for sample in dataloader:
            # Check if the sample tuple has 3 elements.
            if len(sample) == 3:
                imgs, lbls, _ = sample
            else:
                imgs, lbls = sample
            imgs = imgs.to(device)
            feats = model(imgs, return_features=True)
            embeddings.append(feats.cpu().numpy())
            labels.extend(lbls.numpy())
    embeddings = np.concatenate(embeddings, axis=0)
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    embeddings = embeddings / (norms + 1e-10)
    return embeddings, np.array(labels)

In [26]:
#%% #### Data Preparation and Indexed Dataset for CIFAR-10
from torchvision import datasets

# Define transforms.
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2470, 0.2435, 0.2616)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2470, 0.2435, 0.2616)),
])

# Standard CIFAR-10 for labeled set.
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
# Custom dataset for unlabeled pool that returns (image, label, index).
class IndexedCIFAR10(datasets.CIFAR10):
    def __getitem__(self, index):
        img, target = super(IndexedCIFAR10, self).__getitem__(index)
        return img, target, index

unlabeled_dataset = IndexedCIFAR10(root='./data', train=True, download=False, transform=transform_train)

def get_dataloaders(data, num_X, include_x_in_u, batch_size, mu):
    U_loader = DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    T_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    return None, U_loader, T_loader

In [27]:
########################################
# Module: active_selection.py
########################################
def original_typiclust_selection(embeddings, unlabeled_indices, budget):
    """
    Original TypiClust selection:
    - Uses KMeans with n_clusters = budget,
    - Selects the sample closest to each cluster centroid.
    """
    # Cluster the embeddings into 'budget' clusters
    kmeans = KMeans(n_clusters=budget, random_state=42)
    cluster_labels = kmeans.fit_predict(embeddings)

    selected_local = []  # To store indices (relative to embeddings) selected from each cluster

    # Define a default number of neighbors for typicality calculation.
    # This can be tuned based on the dataset.
    default_k = 5

    # Process each cluster
    for cluster in range(budget):
        # Get indices of points belonging to the current cluster
        indices = np.where(cluster_labels == cluster)[0]
        cluster_points = embeddings[indices]

        # Edge case: If a cluster has no points (should not happen in KMeans)
        if len(indices) == 0:
            continue

        # If there is only one point in the cluster, select it automatically.
        if len(indices) == 1:
            selected_local.append(indices[0])
        else:
            # Determine the number of neighbors: we use (default_k + 1) because the point itself
            # is always the nearest neighbor (with distance 0). If the cluster is smaller than
            # (default_k + 1), use the available number.
            n_neighbors = min(default_k + 1, len(indices))

            # Compute nearest neighbors within the cluster.
            # Note: metric 'euclidean' is used to compute the Euclidean distances.
            nbrs = NearestNeighbors(n_neighbors=n_neighbors, metric='euclidean').fit(cluster_points)
            distances, neighbors = nbrs.kneighbors(cluster_points)
            # distances is a 2D array of shape (n_points_in_cluster, n_neighbors)

            # Calculate typicality for each point.
            # Exclude the first neighbor (the point itself, with distance 0) if there is more than one neighbor.
            if n_neighbors > 1:
                # Compute the average of the inverse distances for the K nearest neighbors (excluding itself)
                typicality_scores = np.mean(1.0 / distances[:, 1:], axis=1)
            else:
                typicality_scores = np.zeros(len(indices))

            # Choose the point with the highest typicality score.
            best_idx_local = np.argmax(typicality_scores)
            selected_local.append(indices[best_idx_local])

    # Map the selected local indices back to the global indices from the original unlabeled pool.
    selected_local = np.array(selected_local)
    selected_global = np.array(unlabeled_indices)[selected_local]

    return selected_global

def modified_active_selection(embeddings, unlabeled_indices, labeled_count, query_budget, max_clusters=500, neighbor_k=20, min_cluster_size=5):
    """
    Modified active selection:
    - Set K = min(|L_{i-1}| + B, max_clusters)
    - For each cluster with at least min_cluster_size samples,
      compute typicality as the inverse of the average Euclidean distance to
      its min{neighbor_k, cluster_size} nearest neighbors.
    - Select the top query_budget samples across all clusters based on typicality.
    """
    K = min(labeled_count + query_budget, max_clusters)
    from sklearn.cluster import KMeans
    kmeans = KMeans(n_clusters=K, random_state=42)
    cluster_labels = kmeans.fit_predict(embeddings)

    typicality_scores = []
    corresponding_indices = []
    for cluster in range(K):
        cluster_idxs = np.where(cluster_labels == cluster)[0]
        if len(cluster_idxs) < min_cluster_size:
            continue  # Skip small clusters.
        cluster_points = embeddings[cluster_idxs]
        k_neighbors = min(neighbor_k, len(cluster_points))
        for i, idx in enumerate(cluster_idxs):
            point = cluster_points[i]
            distances = np.linalg.norm(cluster_points - point, axis=1)
            distances = np.delete(distances, i)
            avg_dist = np.mean(np.sort(distances)[:k_neighbors]) if len(distances) > 0 else np.inf
            typicality = 1.0 / (avg_dist + 1e-8)
            typicality_scores.append(typicality)
            corresponding_indices.append(unlabeled_indices[idx])

    if len(typicality_scores) < query_budget:
        selected_global = np.array(corresponding_indices)
    else:
        typicality_scores = np.array(typicality_scores)
        selected_idx_local = np.argsort(-typicality_scores)[:query_budget]
        selected_global = np.array(corresponding_indices)[selected_idx_local]
    return selected_global

In [28]:
def active_learning_experiment(args, selection_method="original_typiclust"):
    """
    Active Learning experiment that integrates the selection method.

    Parameters:
      args: ExperimentArgs instance.
      selection_method: String; one of "modified_typiclust", "original_typiclust", or "random".

    Returns:
      test_accuracies: List of test accuracies for each active learning round.
      cumulative_budgets: List of cumulative labeled sample counts after each round.
      labeled_indices: List of global indices selected for labeling.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Get dataloaders for unlabeled pool and test set.
    _, U_loader, T_loader = get_dataloaders(args.data, args.num_X, args.include_x_in_u, args.batch_size, args.mu)
    total_unlabeled = len(train_dataset)
    labeled_mask = np.zeros(total_unlabeled, dtype=bool)
    labeled_indices = []
    cumulative_budgets = []
    test_accuracies = []

    # Load pre-trained SimCLR model for feature extraction.
    simclr_model = get_simclr_model(args.network, args.simclr_checkpoint)
    simclr_model.to(device)

    for round in range(args.active_rounds):
        print(f"\nActive Learning Round {round+1}/{args.active_rounds}")
        unlabeled_indices = np.where(~labeled_mask)[0]
        U_subset = Subset(unlabeled_dataset, unlabeled_indices)
        U_subset_loader = DataLoader(U_subset, batch_size=args.batch_size, shuffle=False, num_workers=2)

        # Extract embeddings from the current unlabeled pool.
        embeddings, _ = extract_embeddings(simclr_model, U_subset_loader, device)

        if selection_method == "modified_typiclust":
            current_label_count = len(labeled_indices)
            selected_global = modified_active_selection(embeddings, unlabeled_indices, current_label_count, args.budget)
            print("Modified TypiClust selected indices:", selected_global)
        elif selection_method == "original_typiclust":
            selected_global = original_typiclust_selection(embeddings, unlabeled_indices, args.budget)
            print("Original TypiClust selected indices:", selected_global)
        elif selection_method == "random":
            selected_global = np.random.choice(unlabeled_indices, size=args.budget, replace=False)
            print("Random selected indices:", selected_global)
        else:
            raise ValueError("Unknown selection method")

        # Update labeled set.
        labeled_mask[selected_global] = True
        labeled_indices.extend(selected_global.tolist())
        print("Total labeled samples so far:", len(labeled_indices))
        cumulative_budgets.append(len(labeled_indices))

        # Build labeled dataloader from train_dataset.
        X_subset = Subset(train_dataset, labeled_indices)
        X_loader_current = DataLoader(X_subset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        # Run full semi-supervised training for this round.
        print("Starting full semi-supervised training for this round...")
        test_acc = train_semi_supervised_full(args, X_loader_current, U_loader, T_loader, device, simclr_model, network_type=args.network)
        test_accuracies.append(test_acc)
        print(f"After round {round+1}: Test Accuracy = {test_acc*100:.2f}%")

    return test_accuracies, cumulative_budgets, labeled_indices

In [29]:
#######################################
# Module: supervised_training.py
########################################
def get_network(network_type, num_classes):
    # Builds a modified ResNet for CIFAR-10.
    if network_type.lower() == "resnet18":
        net = models.resnet18(weights=None, num_classes=num_classes)
        net.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        net.maxpool = nn.Identity()
        return net
    elif network_type.lower() == "resnet50":
        net = models.resnet50(weights=None, num_classes=num_classes)
        net.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        net.maxpool = nn.Identity()
        return net
    else:
        raise ValueError("Unsupported network type")

def train_supervised(args, device, network_type):
    model = get_network(network_type, args.num_classes)
    model.to(device)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)

    optimizer = torch.optim.SGD(model.parameters(), lr=0.025, momentum=0.9,
                                nesterov=True, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.iterations)
    criterion = torch.nn.CrossEntropyLoss()

    epoch_losses = []
    epoch_acc = []
    for epoch in range(args.epochs_supervised):
        model.train()
        epoch_loss = 0.0
        for imgs, labels in train_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * imgs.size(0)
        scheduler.step()
        avg_loss = epoch_loss / len(train_dataset)
        test_acc = evaluate_step(model, test_loader, device)
        epoch_losses.append(avg_loss)
        epoch_acc.append(test_acc)
        print(f"[Supervised {network_type}] Epoch {epoch+1}/{args.epochs_supervised}: Loss = {avg_loss:.4f}, Test Acc = {test_acc*100:.2f}%")

    # Save checkpoint.
    torch.save(model.state_dict(), os.path.join(args.save_path, f"supervised_{network_type}.pth"))
    return model, epoch_losses, epoch_acc

def supervised_with_ss_evaluation(args, device, network_type):
    # Load pre-trained SimCLR model for given network type.
    simclr_model = get_simclr_model(network_type, args.simclr_checkpoint)
    simclr_model.to(device)

    # Build dataloaders.
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)

    # Extract embeddings.
    train_emb, train_labels = extract_embeddings(simclr_model, train_loader, device)
    test_emb, test_labels = extract_embeddings(simclr_model, test_loader, device)

    # For ResNet18, embedding dim=512; for ResNet50, dim=2048.
    d = 512 if network_type.lower() == "resnet18" else 2048
    import torch.optim as optim
    emb_tensor = torch.tensor(train_emb, dtype=torch.float32)
    lbl_tensor = torch.tensor(train_labels, dtype=torch.long)
    dataset = torch.utils.data.TensorDataset(emb_tensor, lbl_tensor)
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    classifier = nn.Linear(d, args.num_classes).to(device)
    optimizer = optim.SGD(classifier.parameters(), lr=2.5, momentum=0.9, nesterov=True, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
    criterion = nn.CrossEntropyLoss()

    num_epochs = args.epochs_supervised * 2
    for epoch in range(num_epochs):
        classifier.train()
        epoch_loss = 0.0
        for feats, lbls in loader:
            feats = feats.to(device)
            lbls = lbls.to(device)
            optimizer.zero_grad()
            outputs = classifier(feats)
            loss = criterion(outputs, lbls)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * feats.size(0)
        scheduler.step()
    classifier.eval()
    with torch.no_grad():
        test_tensor = torch.tensor(test_emb, dtype=torch.float32).to(device)
        outputs = classifier(test_tensor)
        _, predicted = torch.max(outputs, 1)
        test_acc = (predicted.cpu() == torch.tensor(test_labels)).float().mean().item()
    print(f"[Linear Evaluation {network_type}] Test Accuracy: {test_acc*100:.2f}%")
    # Save checkpoint
    torch.save(classifier.state_dict(), os.path.join(args.save_path, f"linear_{network_type}.pth"))
    return classifier

In [30]:
########################################
# Module: semi_supervised.py
########################################
def train_semi_supervised_full(args, X_loader, U_loader, T_loader, device, simclr_model, network_type="resnet18"):
    """
    Full semi-supervised training loop (FlexMatch-style) without the Prior Pseudo-Label (PPL)
    mechanism. Instead, it uses the network's own pseudo-labels from weak augmented unlabeled data.

    Parameters:
      args: ExperimentArgs instance.
      X_loader: Labeled dataloader.
      U_loader: Unlabeled dataloader (should yield (img, label, index)).
      T_loader: Test dataloader.
      device: torch.device.
      simclr_model: Pre-trained SimCLR model (not used here but kept for interface consistency).
      network_type: "resnet18" or "resnet50".

    Returns:
      test_acc: Final test accuracy after training.
    """
    model = get_network(network_type, args.num_classes)
    model.to(device)

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,
                                nesterov=args.nesterov, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.iterations)
    scaler = torch.cuda.amp.GradScaler() if args.amp else None
    ema = EMA(model, args.ema_decay, device)
    criterion = torch.nn.CrossEntropyLoss(reduction='none')

    # No PPL, so we don't compute a PPL dictionary.
    global_step = 0
    for epoch in range(args.epochs_semi):
        model.train()
        epoch_loss = 0.0
        for sample_x, sample_u in zip(X_loader, U_loader):
            # Labeled batch: (img, label)
            x, y = sample_x
            # Unlabeled batch: (img, label, index); label is dummy.
            uw, _, u_indices = sample_u
            # Simulate strong augmentation; in practice, use a different strong augmentation.
            us = uw.clone()

            # Forward pass for all concatenated inputs.
            inputs = torch.cat([x, uw, us], dim=0)
            outputs = model(inputs.to(device))
            bs = x.size(0)
            bs_u = uw.size(0)
            xw_pred = outputs[:bs]
            uw_pred = outputs[bs:bs+bs_u]
            us_pred = outputs[bs+bs_u:]

            # Supervised loss.
            ls = criterion(xw_pred, y.to(device)).mean()
            total_loss = ls

            # Conventional pseudo-labeling: use network's own predictions on weak augmentation.
            with torch.no_grad():
                uw_prob = F.softmax(uw_pred, dim=1)
                max_prob, hard_label = torch.max(uw_prob, dim=1)
                indicator = max_prob > args.threshold  # Binary mask of high confidence predictions.
            # Compute unsupervised loss on strong augmented images.
            lu = (criterion(us_pred, hard_label) * indicator.float()).mean()
            total_loss += args.lu_weight * lu

            optimizer.zero_grad()
            if args.amp:
                scaler.scale(total_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                total_loss.backward()
                optimizer.step()
            scheduler.step()
            ema.update(model)
            global_step += 1
            epoch_loss += total_loss.item()
        test_acc = evaluate_step(model, T_loader, device)
        print(f"[Semi-supervised No PPL {network_type}] Epoch {epoch+1}/{args.epochs_semi}: Loss = {epoch_loss:.4f}, Test Acc = {test_acc*100:.2f}%")
    # Save the checkpoint.
    torch.save(model.state_dict(), os.path.join(args.save_path, f"semi_supervised_no_PPL_{network_type}.pth"))
    return test_acc

In [31]:
########################################
# Module: plotting.py
########################################
def plot_accuracy_vs_budget(cum_budgets, accuracies, title="Test Accuracy vs. Cumulative Labeled Samples"):
    import matplotlib.pyplot as plt
    plt.figure(figsize=(8,5))
    plt.plot(cum_budgets, np.array(accuracies)*100, marker='o', label='Test Accuracy')
    plt.xlabel("Cumulative Labeled Samples")
    plt.ylabel("Test Accuracy (%)")
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_epoch_time(method_names, epoch_times):
    import matplotlib.pyplot as plt
    plt.figure(figsize=(8,5))
    plt.bar(method_names, epoch_times, color=['blue', 'red', 'green'])
    plt.xlabel("Method")
    plt.ylabel("Average Epoch Time (s)")
    plt.title("Average Epoch Time Comparison")
    plt.show()

In [32]:
########################################
# Module: checkpoint.py
########################################
def save_checkpoint(model, optimizer, epoch, filepath):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(checkpoint, filepath)

def load_checkpoint(model, optimizer, filepath, device):
    checkpoint = torch.load(filepath, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    return model, optimizer, start_epoch

In [None]:
### Final Experiment Pipeline for Supervised Embeddings and Semi-Supervised Training

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# -------------------------------
# 2. Fully Supervised (Linear Evaluation)
# -------------------------------
print("=== Running Fully Supervised (Linear Evaluation) for ResNet18 ===")
# This function loads a pre-trained SimCLR-ResNet18 model, extracts L2-normalized embeddings from the penultimate layer,
# trains a linear classifier on top, evaluates it, and saves the classifier checkpoint.
model_fs, losses, accuracy = train_supervised(exp_args, device, "resnet18")


# -------------------------------
# 1. Fully Supervised with Self-Supervised Embeddings (Linear Evaluation)
# -------------------------------
print("=== Running Fully Supervised with Self-Supervised Embeddings (Linear Evaluation) for ResNet18 ===")
# This function loads a pre-trained SimCLR-ResNet18 model, extracts L2-normalized embeddings from the penultimate layer,
# trains a linear classifier on top, evaluates it, and saves the classifier checkpoint.
classifier_ss = supervised_with_ss_evaluation(exp_args, device, network_type="resnet18")

test_acc_rand, cum_budget_rand, labeled_rand = active_learning_experiment(exp_args, selection_method="random")

# -------------------------------
# 2. Semi-Supervised Training (FlexMatch-style with PPL) within Active Learning
# -------------------------------
print("=== Running Active Learning Experiment: original_typiclust (Semi-Supervised Training) for ResNet18 ===")
# This active learning experiment integrates our modified active selection (dynamic clustering and typicality computation)
# with our full semi-supervised training loop (FlexMatch).
test_acc_mod, cum_budget_mod, labeled_mod = active_learning_experiment(exp_args, selection_method="original_typiclust")


# -------------------------------
# 3. Plotting Results
# -------------------------------
# Use our plotting functions to visualize the performance as a function of cumulative budget.
# Plotting the test accuracy progression vs. cumulative labeled samples.
plot_accuracy_vs_budget(cum_budget_mod, test_acc_mod, title="Test Accuracy vs. Cumulative Labeled Samples (Modified TypiClust)")
plot_accuracy_vs_budget(cum_budget_rand, test_acc_rand, title="Test Accuracy vs. Cumulative Labeled Samples (Random Selection)")

final_acc_mod = test_acc_mod[-1]
final_acc_rand = test_acc_rand[-1]
methods = ["Modified TypiClust", "Random Selection"]

plt.figure(figsize=(8,5))
plt.bar(methods, [final_acc_mod*100, final_acc_rand*100], color=['blue','red'])
plt.xlabel("Active Learning Selection Method")
plt.ylabel("Final Test Accuracy (%)")
plt.title("Final Test Accuracy Comparison")
plt.ylim(0, 100)
plt.grid(True, axis='y')
plt.show()