# CST (Contrastive Self-Training) for SVHN and MNIST Datasets

This notebook implements CST specifically for SVHN and MNIST datasets.

In [None]:
# Import required libraries
import os
import time
import argparse
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from sklearn.metrics import accuracy_score
import logging
import copy
import sys

# Import from local modules
# Make sure these modules are in the same directory as the notebook or adjust the path accordingly
sys.path.append(os.path.dirname(os.getcwd()))
from model.wrn import WideResNet
from datasets.data import get_dataset, get_dataset_cst
from utils.train_utils import get_args
from utils.test_utils import AverageMeter, accuracy, setup_logger
from utils.model_utils import init_weights, init_logging, get_cosine_schedule_with_warmup

## Set Configuration and Arguments

In [None]:
# Set up configuration parameters - you can modify these as needed
class Args:
    def __init__(self):
        # Basic settings
        self.seed = 1
        self.out = './results'
        self.num_epochs = 100
        self.batch_size = 64
        self.lr = 0.001
        self.wd = 5e-4
        self.T = 1.0
        self.alpha = 0.6
        self.momentum = 0.9
        self.ada_threshold = 0.9
        self.lambda_cst = 1.0
        self.lambda_emd = 0.0
        self.lambda_rmc = 0.0
        self.lambda_mmd = 0.0
        self.lambda_coral = 0.0
        self.k = 3
        self.threshold = 0.95
        self.temp = 0.1
        self.epochs = 100
        self.use_ema = True
        self.ema_decay = 0.999
        self.mu = 7
        self.mix_mode = 'mixup'
        self.mix_alpha = 0.1
        self.eval_step = 1
        self.total_steps = None
        self.world_size = 1
        self.rank = 0
        
        # Dataset settings - focusing only on SVHN and MNIST
        self.num_classes = 10
        self.dataset = 'digit'
        self.src_dataset = None  # We'll set this when needed
        self.trg_dataset = None  # We'll set this when needed
        self.arch = 'wrn'
        self.num_workers = 2
        self.expand_labels = True
        self.data_path = './data'
        
args = Args()

## Setup Helper Functions

In [None]:
# Set up device and random seed for reproducibility
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seed(args.seed)

## Define the EMA Class for Model Averaging

In [None]:
class EMA:
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                self.backup[name] = param.data
                param.data = self.shadow[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}

## Define the CST Loss Function

In [None]:
def cst_loss_func(pred_st, pred_w):
    pred_st = F.softmax(pred_st, dim=1)
    pred_w = F.softmax(pred_w, dim=1)
    pred_w = pred_w.detach()
    
    # Calculate mean square error loss
    loss = torch.mean((pred_st - pred_w) ** 2)
    return loss

## Initialize Model

In [None]:
def create_model():
    # Create WideResNet model for digit classification
    model = WideResNet(num_classes=10)
    model = model.to(device)
    
    # Initialize model weights
    model.apply(init_weights)
    
    return model

## Training Functions

In [None]:
def train_step(model, ema_model, labeled_trainloader, unlabeled_trainloader, optimizer, scheduler, epoch, args):
    model.train()
    
    # Set up metrics tracking
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_cst = AverageMeter()
    end = time.time()
    
    labeled_iter = iter(labeled_trainloader)
    unlabeled_iter = iter(unlabeled_trainloader)
    
    # Number of batches = min(len(labeled_trainloader), len(unlabeled_trainloader))
    num_batches = min(len(labeled_trainloader), len(unlabeled_trainloader))
    
    for batch_idx in range(num_batches):
        try:
            inputs_x, targets_x = next(labeled_iter)
        except StopIteration:
            labeled_iter = iter(labeled_trainloader)
            inputs_x, targets_x = next(labeled_iter)
            
        try:
            (inputs_u_w, inputs_u_s), _ = next(unlabeled_iter)
        except StopIteration:
            unlabeled_iter = iter(unlabeled_trainloader)
            (inputs_u_w, inputs_u_s), _ = next(unlabeled_iter)
            
        data_time.update(time.time() - end)
        batch_size = inputs_x.shape[0]
        
        # Move tensors to device
        inputs_x = inputs_x.to(device)
        targets_x = targets_x.to(device)
        inputs_u_w = inputs_u_w.to(device)
        inputs_u_s = inputs_u_s.to(device)
        
        # Forward pass for labeled data
        logits_x = model(inputs_x)
        loss_x = F.cross_entropy(logits_x, targets_x, reduction='mean')
        
        # Forward pass for unlabeled data
        with torch.no_grad():
            logits_u_w = model(inputs_u_w)
        
        logits_u_s = model(inputs_u_s)
        
        # Calculate CST loss
        loss_cst = cst_loss_func(logits_u_s, logits_u_w)
        
        # Total loss
        loss = loss_x + args.lambda_cst * loss_cst
        
        # Update metrics
        losses.update(loss.item())
        losses_x.update(loss_x.item())
        losses_cst.update(loss_cst.item())
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        # Update EMA model
        if args.use_ema:
            ema_model.update()
            
        batch_time.update(time.time() - end)
        end = time.time()
        
        # Print progress
        if (batch_idx + 1) % 50 == 0:
            print(f"Epoch: {epoch} | Batch: {batch_idx + 1}/{num_batches} | "  
                  f"Loss: {losses.avg:.4f} | Loss_x: {losses_x.avg:.4f} | Loss_cst: {losses_cst.avg:.4f} | "
                  f"Time: {batch_time.avg:.4f}s | Data: {data_time.avg:.4f}s")
            
    return losses.avg, losses_x.avg, losses_cst.avg

In [None]:
def validate(model, val_loader):
    model.eval()
    
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            
            _, preds = torch.max(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    
    acc = accuracy_score(all_targets, all_preds)
    return acc

## Training Pipeline for SVHN to MNIST

In [None]:
def train_svhn_to_mnist():
    # Set source and target datasets
    args.src_dataset = 'svhn'
    args.trg_dataset = 'mnist'
    
    # Create output directory
    if not os.path.exists(args.out):
        os.makedirs(args.out)
    
    # Get datasets
    labeled_dataset, unlabeled_dataset, test_dataset = get_dataset_cst(
        args, args.src_dataset, args.trg_dataset)
    
    # Create data loaders
    labeled_trainloader = DataLoader(
        labeled_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=True)
    
    unlabeled_trainloader = DataLoader(
        unlabeled_dataset,
        batch_size=args.batch_size * args.mu,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=True)
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers)
    
    # Create model
    model = create_model()
    
    # Create EMA model
    ema_model = None
    if args.use_ema:
        ema_model = EMA(model, args.ema_decay)
        ema_model.register()
    
    # Set up optimizer and scheduler
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
    
    # Set total steps
    args.total_steps = args.epochs * min(len(labeled_trainloader), len(unlabeled_trainloader))
    
    # Create scheduler
    scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.total_steps)
    
    # Training loop
    best_acc = 0
    for epoch in range(args.epochs):
        # Train for one epoch
        train_loss, train_loss_x, train_loss_cst = train_step(
            model, ema_model, labeled_trainloader, unlabeled_trainloader, optimizer, scheduler, epoch, args)
        
        # Evaluate on test set
        if args.use_ema:
            ema_model.apply_shadow()
            test_acc = validate(model, test_loader)
            ema_model.restore()
        else:
            test_acc = validate(model, test_loader)
        
        # Save best model
        if test_acc > best_acc:
            best_acc = test_acc
            if args.use_ema:
                ema_model.apply_shadow()
                torch.save(model.state_dict(), os.path.join(args.out, 'best_model.pt'))
                ema_model.restore()
            else:
                torch.save(model.state_dict(), os.path.join(args.out, 'best_model.pt'))
        
        print(f"Epoch: {epoch} | Train Loss: {train_loss:.4f} | Test Acc: {test_acc:.4f} | Best Acc: {best_acc:.4f}")
    
    return best_acc

## Training Pipeline for MNIST to SVHN

In [None]:
def train_mnist_to_svhn():
    # Set source and target datasets
    args.src_dataset = 'mnist'
    args.trg_dataset = 'svhn'
    
    # Create output directory
    if not os.path.exists(args.out):
        os.makedirs(args.out)
    
    # Get datasets
    labeled_dataset, unlabeled_dataset, test_dataset = get_dataset_cst(
        args, args.src_dataset, args.trg_dataset)
    
    # Create data loaders
    labeled_trainloader = DataLoader(
        labeled_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=True)
    
    unlabeled_trainloader = DataLoader(
        unlabeled_dataset,
        batch_size=args.batch_size * args.mu,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=True)
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers)
    
    # Create model
    model = create_model()
    
    # Create EMA model
    ema_model = None
    if args.use_ema:
        ema_model = EMA(model, args.ema_decay)
        ema_model.register()
    
    # Set up optimizer and scheduler
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
    
    # Set total steps
    args.total_steps = args.epochs * min(len(labeled_trainloader), len(unlabeled_trainloader))
    
    # Create scheduler
    scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.total_steps)
    
    # Training loop
    best_acc = 0
    for epoch in range(args.epochs):
        # Train for one epoch
        train_loss, train_loss_x, train_loss_cst = train_step(
            model, ema_model, labeled_trainloader, unlabeled_trainloader, optimizer, scheduler, epoch, args)
        
        # Evaluate on test set
        if args.use_ema:
            ema_model.apply_shadow()
            test_acc = validate(model, test_loader)
            ema_model.restore()
        else:
            test_acc = validate(model, test_loader)
        
        # Save best model
        if test_acc > best_acc:
            best_acc = test_acc
            if args.use_ema:
                ema_model.apply_shadow()
                torch.save(model.state_dict(), os.path.join(args.out, 'best_model.pt'))
                ema_model.restore()
            else:
                torch.save(model.state_dict(), os.path.join(args.out, 'best_model.pt'))
        
        print(f"Epoch: {epoch} | Train Loss: {train_loss:.4f} | Test Acc: {test_acc:.4f} | Best Acc: {best_acc:.4f}")
    
    return best_acc

## Run Experiments

In [None]:
# Choose which experiment to run
experiment = "svhn_to_mnist"  # or "mnist_to_svhn"

if experiment == "svhn_to_mnist":
    print("Running SVHN to MNIST experiment...")
    best_acc = train_svhn_to_mnist()
    print(f"Best accuracy for SVHN to MNIST: {best_acc:.4f}")
else:
    print("Running MNIST to SVHN experiment...")
    best_acc = train_mnist_to_svhn()
    print(f"Best accuracy for MNIST to SVHN: {best_acc:.4f}")

## Visualize Results (Optional)

In [None]:
# Optional: Add code for visualization
import matplotlib.pyplot as plt

def visualize_samples(dataset_name):
    # Load the dataset
    if dataset_name == 'svhn':
        dataset = get_dataset(args, 'svhn')
    else:
        dataset = get_dataset(args, 'mnist')
    
    # Plot some samples
    fig, axes = plt.subplots(2, 5, figsize=(12, 6))
    axes = axes.flatten()
    
    for i in range(10):
        img, label = dataset[i]
        img = img.numpy().transpose(1, 2, 0)  # Convert to H x W x C format
        
        # Handle grayscale images
        if img.shape[2] == 1:
            img = img.squeeze()
        
        axes[i].imshow(img, cmap='gray' if img.ndim == 2 else None)
        axes[i].set_title(f"Label: {label}")
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.suptitle(f"{dataset_name.upper()} Dataset Samples", y=1.02)
    plt.show()

In [None]:
# Visualize SVHN samples
visualize_samples('svhn')

In [None]:
# Visualize MNIST samples
visualize_samples('mnist')