In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset, Dataset
import matplotlib.pyplot as plt
import sys

# Import existing modules 
from typiclust_alg import SimCLRResNet18, compute_embeddings, typical_clustering_selection, DEVICE
from visualisation import plot_tsne, set_seed


def evaluate_model(model, dataloader, device=DEVICE):
    """
    Evaluates the model on the provided dataloader.
    Returns the accuracy as a float.
    """
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    acc = correct / total
    return acc


def get_cifar10_datasets():
    """
    Loads CIFAR-10 training and test datasets.
    """
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.247, 0.243, 0.261))
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.247, 0.243, 0.261))
    ])

    train_dataset = torchvision.datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=train_transform
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root="./data",
        train=False,
        download=True,
        transform=test_transform
    )
    return train_dataset, test_dataset



def build_wide_resnet():
    """
    Builds a WideResNet model.
    The paper uses WideResNet-28 for CIFAR-10.
    Here, I use torchvision's wide_resnet50_2 as a proxy.
    """
    model = torchvision.models.wide_resnet50_2(pretrained=False)
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 10)
    return model.to(DEVICE)



def train_model_iterations(model, dataloader, total_iterations, lr, device=DEVICE):
    """
    Trains the model for a total number of iterations (400,000)
    Uses SGD with the specified learning rate and a cosine annealing scheduler
    with T_max = total_iterations.

    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        model.parameters(),
        lr=lr,
        momentum=0.9,
        weight_decay=0.0005,
        nesterov=True
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_iterations)

    model.train()
    iteration = 0
    loss_sum = 0.0

    # For the progress bar display
    last_shown_progress = 0

    while iteration < total_iterations:
        for images, labels in dataloader:
            if iteration >= total_iterations:
                break
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            loss_sum += loss.item()
            iteration += 1

            current_progress = int(100 * iteration / total_iterations)
            if current_progress - last_shown_progress >= 1:
                print(f"[Semi-Sup Training] Progress: {current_progress}%")
                last_shown_progress = current_progress


    return model

def select_samples_typiclust(dataset, budget, encoder):
    """
    Uses the pre-loaded SimCLR encoder provided from the Forums to compute embeddings and select samples
    """
    all_embeddings, _ = compute_embeddings(encoder, dataset, batch_size=128, num_workers=4)
    all_labels = np.array([label for (_, label) in dataset])

    selected_indices, cluster_labels = typical_clustering_selection(
        all_embeddings,
        budget=budget,
        k_nn=20,
        random_state=42
    )
    print(f"Number of typical points selected (budget) = {len(selected_indices)}")
    return selected_indices

def generate_pseudo_labels(model, unlabeled_subset, batch_size=64, device=DEVICE):
    """
    Generate pseudo-labels for each image in 'unlabeled_subset'
    """
    loader = DataLoader(unlabeled_subset, batch_size=batch_size, shuffle=False)
    model.eval()
    pseudo_labels = []
    with torch.no_grad():
        for images, _ in loader:
            images = images.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            pseudo_labels.extend(preds.cpu().numpy().tolist())
    return pseudo_labels

class CombinedDataset(Dataset):
    """
    Combines:
      - Labeled data from 'labeled_indices'
      - Unlabeled data from 'unlabeled_indices' with pseudo_labels
    """
    def __init__(self, base_dataset, labeled_indices, unlabeled_indices, pseudo_labels):
        self.base_dataset = base_dataset
        self.combined_indices = labeled_indices + unlabeled_indices
        self.label_map = {}

        # Real labels for labeled indices
        for idx in labeled_indices:
            _, real_label = self.base_dataset[idx]
            self.label_map[idx] = real_label

        # Pseudo-labels for unlabeled indices
        if len(unlabeled_indices) != len(pseudo_labels):
            raise ValueError(
                f"Mismatch: {len(unlabeled_indices)} unlabeled indices but "
                f"{len(pseudo_labels)} pseudo-labels."
            )
        for i, u_idx in enumerate(unlabeled_indices):
            self.label_map[u_idx] = pseudo_labels[i]

    def __len__(self):
        return len(self.combined_indices)

    def __getitem__(self, index):
        real_idx = self.combined_indices[index]
        img, _ = self.base_dataset[real_idx]
        label = self.label_map[real_idx]
        return img, label

def evaluate_semi_supervised(method='typiclust', budget=10,
                             total_iterations=40000, pseudo_iterations=10000):
    """
    Evaluate semi-supervised learning using a simplified pseudo-labeling approach
    
    This version uses 400k total iterations on the labeled data
    then 100k for fine-tuning on the combined set.

    """
    # 1. Load datasets
    train_dataset, test_dataset = get_cifar10_datasets()
    total_indices = list(range(len(train_dataset)))

    # 2. Load a pretrained SimCLR encoder for sample selection
    encoder = SimCLRResNet18(feature_dim=128).to(DEVICE)
    checkpoint_path = 'model/simclr_cifar_10.pth.tar'
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
        state_dict = checkpoint.get('state_dict', checkpoint)
        encoder.load_state_dict(state_dict, strict=False)
        print("Loaded pretrained SimCLR model for TPC-RP selection.")
    else:
        print("Pretrained checkpoint not found; selection may be random.")
    encoder.eval()

    # 3. Select labeled samples using Typiclust 
    labeled_indices = select_samples_typiclust(train_dataset, budget, encoder)
    unlabeled_indices = list(set(total_indices) - set(labeled_indices))
    print(f"Labeled samples: {len(labeled_indices)}; Unlabeled samples: {len(unlabeled_indices)}")

    # 4. Stage 1: Train on labeled data (400k iterations)
    model = build_wide_resnet()
    labeled_subset = Subset(train_dataset, labeled_indices)
    labeled_loader = DataLoader(labeled_subset, batch_size=64, shuffle=True)
    print(f"Stage 1: Training on labeled data (400k iters)...")
    model = train_model_iterations(model, labeled_loader, total_iterations, lr=0.03, device=DEVICE)

    # 5. Generate pseudo-labels for unlabeled data
    unlabeled_subset = Subset(train_dataset, unlabeled_indices)
    print("Stage 2: Generating pseudo-labels for unlabeled data...")
    pseudo_labels = generate_pseudo_labels(model, unlabeled_subset, batch_size=64, device=DEVICE)

    # 6. Combine labeled + pseudo-labeled data, then fine-tune (100k iterations)
    combined_dataset = CombinedDataset(train_dataset, labeled_indices, unlabeled_indices, pseudo_labels)
    combined_loader = DataLoader(combined_dataset, batch_size=64, shuffle=True)
    print(f"Stage 3: Fine-tuning on combined data (100k iters)...")
    model = train_model_iterations(model, combined_loader, pseudo_iterations, lr=0.01, device=DEVICE)

    # 7. Final evaluation on test set
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
    acc = evaluate_model(model, test_loader, DEVICE)
    return acc



def main_experiments(num_experiments=3):
    """
    Run the entire semi-supervised evaluation 3 times with different random seeds,
    each time for 400k + 100k iterations. Print out the average accuracy.
    """
    accuracies = []
    for seed in [11, 22, 33][:num_experiments]:
        set_seed(seed)
        acc = evaluate_semi_supervised(
            method='typiclust',
            budget=10,              # 10 labeled samples for CIFAR-10
            total_iterations=400000, 
            pseudo_iterations=100000 
        )
        accuracies.append(acc)
        print(f"Experiment with seed {seed} => Accuracy: {acc*100:.2f}%")

    mean_acc = np.mean(accuracies)
    std_acc = np.std(accuracies)
    print(f"\nAfter {num_experiments} runs:")
    print(f"Mean Accuracy: {mean_acc*100:.2f}% (+/- {std_acc*100:.2f}%)")

if __name__ == "__main__":
    main_experiments(num_experiments=3)


Random seed set to 11
Files already downloaded and verified
Files already downloaded and verified


  checkpoint = torch.load(checkpoint_path, map_location=DEVICE)


Loaded pretrained SimCLR model for TPC-RP selection.
Number of typical points selected (budget) = 10
Labeled samples: 10; Unlabeled samples: 49990




Stage 1: Training on labeled data (400k iters)...
[Semi-Sup Training] Progress: 1%
[Semi-Sup Training] Progress: 2%
[Semi-Sup Training] Progress: 3%
[Semi-Sup Training] Progress: 4%
[Semi-Sup Training] Progress: 5%
[Semi-Sup Training] Progress: 6%
[Semi-Sup Training] Progress: 7%
[Semi-Sup Training] Progress: 8%
[Semi-Sup Training] Progress: 9%
[Semi-Sup Training] Progress: 10%
[Semi-Sup Training] Progress: 11%
[Semi-Sup Training] Progress: 12%
[Semi-Sup Training] Progress: 13%
[Semi-Sup Training] Progress: 14%
[Semi-Sup Training] Progress: 15%
[Semi-Sup Training] Progress: 16%
[Semi-Sup Training] Progress: 17%
[Semi-Sup Training] Progress: 18%
[Semi-Sup Training] Progress: 19%
[Semi-Sup Training] Progress: 20%
[Semi-Sup Training] Progress: 21%
[Semi-Sup Training] Progress: 22%
[Semi-Sup Training] Progress: 23%
[Semi-Sup Training] Progress: 24%
[Semi-Sup Training] Progress: 25%
[Semi-Sup Training] Progress: 26%
[Semi-Sup Training] Progress: 27%
[Semi-Sup Training] Progress: 28%
[Semi-S

  checkpoint = torch.load(checkpoint_path, map_location=DEVICE)


Loaded pretrained SimCLR model for TPC-RP selection.
Number of typical points selected (budget) = 10
Labeled samples: 10; Unlabeled samples: 49990




Stage 1: Training on labeled data (400k iters)...
[Semi-Sup Training] Progress: 1%
[Semi-Sup Training] Progress: 2%
[Semi-Sup Training] Progress: 3%
[Semi-Sup Training] Progress: 4%
[Semi-Sup Training] Progress: 5%
[Semi-Sup Training] Progress: 6%
[Semi-Sup Training] Progress: 7%
[Semi-Sup Training] Progress: 8%
[Semi-Sup Training] Progress: 9%
[Semi-Sup Training] Progress: 10%
[Semi-Sup Training] Progress: 11%
[Semi-Sup Training] Progress: 12%
[Semi-Sup Training] Progress: 13%
[Semi-Sup Training] Progress: 14%
[Semi-Sup Training] Progress: 15%
[Semi-Sup Training] Progress: 16%
[Semi-Sup Training] Progress: 17%
[Semi-Sup Training] Progress: 18%
[Semi-Sup Training] Progress: 19%
[Semi-Sup Training] Progress: 20%
[Semi-Sup Training] Progress: 21%
[Semi-Sup Training] Progress: 22%
[Semi-Sup Training] Progress: 23%
[Semi-Sup Training] Progress: 24%
[Semi-Sup Training] Progress: 25%
[Semi-Sup Training] Progress: 26%
[Semi-Sup Training] Progress: 27%
[Semi-Sup Training] Progress: 28%
[Semi-S

  checkpoint = torch.load(checkpoint_path, map_location=DEVICE)


Loaded pretrained SimCLR model for TPC-RP selection.
Number of typical points selected (budget) = 10
Labeled samples: 10; Unlabeled samples: 49990




Stage 1: Training on labeled data (400k iters)...
[Semi-Sup Training] Progress: 1%
[Semi-Sup Training] Progress: 2%
[Semi-Sup Training] Progress: 3%
[Semi-Sup Training] Progress: 4%
[Semi-Sup Training] Progress: 5%
[Semi-Sup Training] Progress: 6%
[Semi-Sup Training] Progress: 7%
[Semi-Sup Training] Progress: 8%
[Semi-Sup Training] Progress: 9%
[Semi-Sup Training] Progress: 10%
[Semi-Sup Training] Progress: 11%
[Semi-Sup Training] Progress: 12%
[Semi-Sup Training] Progress: 13%
[Semi-Sup Training] Progress: 14%
[Semi-Sup Training] Progress: 15%
[Semi-Sup Training] Progress: 16%
[Semi-Sup Training] Progress: 17%
[Semi-Sup Training] Progress: 18%
[Semi-Sup Training] Progress: 19%
[Semi-Sup Training] Progress: 20%
[Semi-Sup Training] Progress: 21%
[Semi-Sup Training] Progress: 22%
[Semi-Sup Training] Progress: 23%
[Semi-Sup Training] Progress: 24%
[Semi-Sup Training] Progress: 25%
[Semi-Sup Training] Progress: 26%
[Semi-Sup Training] Progress: 27%
[Semi-Sup Training] Progress: 28%
[Semi-S