# CST (Contrastive Semi-supervised Training) for SVHN → MNIST

This notebook adapts the CST domain adaptation pipeline for the SVHN → MNIST scenario. All code is self-contained except for the imports from your repo; please copy the relevant functions/classes into the indicated cells.

In [ ]:
# --- Cell 1: Standard Library & Torch Imports ---
import os
import random
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.datasets import SVHN, MNIST

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Cell 2: Utility Imports from Your Repo

Paste code for (or reimplement here):
- ResizeImage (from your `common/vision/transforms`)
- ForeverDataIterator, AverageMeter, ProgressMeter, accuracy, etc.
- ImageClassifier (from `fix_utils`)
- SAM optimizer (from `sam`)
- rand_augment_transform

If you don't have these locally, you can temporarily comment them out and use basic equivalents.

In [ ]:
# --- Cell 2: Paste or reimplement your local utility classes and functions here ---
# Example: (Replace with your actual code)
# class ResizeImage: ...
# class ForeverDataIterator: ...
# def accuracy(...): ...
# class ImageClassifier(nn.Module): ...
# class SAM(torch.optim.Optimizer): ...
# def rand_augment_transform(...): ...

## Cell 3: CST Loss and Training Utilities

Paste your definitions for:
- entropy
- TsallisEntropy
- (Optionally) any additional loss/utilities

In [ ]:
# --- Cell 3: CST loss and utilities ---
def entropy(predictions: torch.Tensor, reduction='none') -> torch.Tensor:
    epsilon = 1e-5
    H = -predictions * torch.log(predictions + epsilon)
    H = H.sum(dim=1)
    if reduction == 'mean':
        return H.mean()
    else:
        return H

class TsallisEntropy(nn.Module):
    def __init__(self, temperature: float, alpha: float):
        super(TsallisEntropy, self).__init__()
        self.temperature = temperature
        self.alpha = alpha

    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        N, C = logits.shape
        pred = F.softmax(logits / self.temperature, dim=1)
        entropy_weight = entropy(pred).detach()
        entropy_weight = 1 + torch.exp(-entropy_weight)
        entropy_weight = (N * entropy_weight / torch.sum(entropy_weight)).unsqueeze(dim=1)
        sum_dim = torch.sum(pred * entropy_weight, dim=0).unsqueeze(dim=0)
        return 1 / (self.alpha - 1) * torch.sum((1 / torch.mean(sum_dim) - torch.sum(pred ** self.alpha / sum_dim * entropy_weight, dim=-1)))

## Cell 4: Transforms for SVHN and MNIST

We need to ensure MNIST images are converted to 3 channels to match SVHN and the model. This is also where we define strong/weak augmentations for FixMatch.

In [ ]:
# --- Cell 4: Transforms ---
MNIST_MEAN, MNIST_STD = (0.1307,), (0.3081,)
SVHN_MEAN, SVHN_STD = (0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)

class ToThreeChannels:
    def __call__(self, img):
        return img.repeat(3, 1, 1)

def get_mnist_transform(crop=True):
    normalize = T.Normalize(mean=[0.1307]*3, std=[0.3081]*3)
    base = [
        T.Resize(256),
        T.CenterCrop(224) if crop else T.RandomResizedCrop(224),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        ToThreeChannels(),
        normalize
    ]
    return T.Compose(base)

def get_svhn_transform(crop=True):
    normalize = T.Normalize(mean=SVHN_MEAN, std=SVHN_STD)
    base = [
        T.Resize(256),
        T.CenterCrop(224) if crop else T.RandomResizedCrop(224),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        normalize
    ]
    return T.Compose(base)

# If you have a rand_augment_transform, use it in the strong branch below
class TransformFixMatch:
    def __init__(self, dataset='mnist'):
        if dataset == 'mnist':
            normalize = T.Normalize(mean=[0.1307]*3, std=[0.3081]*3)
            to3 = ToThreeChannels()
            weak = [T.Resize(256), T.CenterCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), to3, normalize]
            strong = weak[:-2] + [T.RandomApply([T.ColorJitter(0.4,0.4,0.4,0)], p=1.0)] # add more if needed
            strong += [T.ToTensor(), to3, normalize]
        else:
            normalize = T.Normalize(mean=SVHN_MEAN, std=SVHN_STD)
            weak = [T.Resize(256), T.CenterCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), normalize]
            strong = weak[:-2] + [T.RandomApply([T.ColorJitter(0.4,0.4,0.4,0)], p=1.0)]
            strong += [T.ToTensor(), normalize]
        self.weak = T.Compose(weak)
        self.strong = T.Compose(strong)
    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return weak, strong

## Cell 5: Data Preparation

Download and prepare SVHN (source, labeled) and MNIST (target, unlabeled/val/test) datasets.

In [ ]:
# --- Cell 5: Data preparation ---
root = './data'  # change if desired
batch_size = 28
num_workers = 2

train_transform = get_svhn_transform(crop=True)
unlabeled_transform = TransformFixMatch('mnist')
val_transform = get_mnist_transform(crop=True)

train_source_dataset = SVHN(root=root, split='train', download=True, transform=train_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)

train_target_dataset = MNIST(root=root, train=True, download=True, transform=unlabeled_transform)
train_target_loader = DataLoader(train_target_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)

val_dataset = MNIST(root=root, train=False, download=True, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Optionally: Wrap loaders in ForeverDataIterator if you have it

## Cell 6: Model and Optimizer

Instantiate your model (e.g., ResNet-18 with attached bottleneck and classifier head for 10 classes) and optimizer.

In [ ]:
# --- Cell 6: Model & Optimizer ---
# Paste or implement your model and bottleneck logic here, or use torchvision.models
import torchvision.models as models

# Assume ImageClassifier is your custom class with backbone & bottleneck
# backbone = models.resnet18(pretrained=True)
# classifier = ImageClassifier(backbone, num_classes=10, bottleneck_dim=256).to(device)

# For demonstration, use a simple model (replace this with your ImageClassifier):
backbone = models.resnet18(pretrained=True)
backbone.fc = nn.Linear(backbone.fc.in_features, 10)
classifier = backbone.to(device)

# Setup optimizer and scheduler
lr = 0.005
optimizer = SGD(classifier.parameters(), lr=lr, momentum=0.9, weight_decay=1e-3)
lr_scheduler = LambdaLR(optimizer, lambda x: lr * (1. + 0.001 * float(x)) ** (-0.75))

## Cell 7: Training and Validation Functions

Paste your `train`, `validate`, and any other functions needed for the loop.

In [ ]:
# --- Cell 7: Training and validation loops ---
# Paste your 'train' and 'validate' functions here from run_cst.py
# Example:
def train(...):
    pass

def validate(...):
    pass

## Cell 8: Experiment Setup and Execution

Set hyperparameters, run the training loop, and validate.

In [ ]:
# --- Cell 8: Run experiment ---
epochs = 20
early = 20

# Example loop (adapt as needed for your CST training logic)
for epoch in range(epochs):
    # train(train_source_iter, train_target_iter, classifier, ts_loss, optimizer, lr_scheduler, epoch, args)
    # acc = validate(val_loader, classifier, args)
    print(f"Epoch {epoch+1}/{epochs} ... (train/validate logic here)")

## Cell 9: Visualization and Analysis (Optional)

Add t-SNE, confusion matrix, or accuracy plots here.

In [ ]:
# --- Cell 9: Visualization ---
# Example: Plot accuracy or t-SNE if available
# import matplotlib.pyplot as plt
# ...
