<a href="https://colab.research.google.com/github/RajeswariKumaran/SSLMethodsAnalysis/blob/main/UDA_with_TSA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#revamped code

In [None]:
def apply_tsa(logits, targets, epoch, total_epochs, schedule='linear'):
    probs = torch.softmax(logits, dim=1)
    correct_class_probs = probs[range(len(targets)), targets]

    # Define the threshold schedule
    step_ratio = epoch / total_epochs
    if schedule == 'linear':
        threshold = step_ratio
    elif schedule == 'exp':
        threshold = np.exp((step_ratio - 1) * 5)
    elif schedule == 'log':
        threshold = 1 - np.exp(-step_ratio * 5)
    else:
        raise ValueError("Invalid TSA schedule")

    # TSA threshold (e.g., linear from 1/num_classes to 1.0)
    tsa_thresh = threshold * (1 - 1 / logits.shape[1]) + 1 / logits.shape[1]

    # Only compute loss for examples below the threshold
    loss_mask = (correct_class_probs < tsa_thresh).float()

    ce_loss = nn.CrossEntropyLoss(reduction='none')
    losses = ce_loss(logits, targets)

    # Apply TSA mask
    masked_loss = (losses * loss_mask).mean()
    return masked_loss


In [None]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

# ---------- Cutout Augmentation ----------
class Cutout:
    def __init__(self, mask_size, p=1.0, mask_color=0):
        self.mask_size = mask_size
        self.p = p
        self.mask_color = mask_color

    def __call__(self, img):
        if random.random() > self.p:
            return img
        w, h = img.size
        mask_size_half = self.mask_size // 2
        cx = random.randint(mask_size_half, w - mask_size_half)
        cy = random.randint(mask_size_half, h - mask_size_half)
        x1 = cx - mask_size_half
        y1 = cy - mask_size_half
        x2 = cx + mask_size_half
        y2 = cy + mask_size_half
        img = img.copy()
        img.paste(self.mask_color, (x1, y1, x2, y2))
        return img

# ---------- Simple CNN Model ----------
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(8*8*128, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.net(x)

# ---------- Strong Augmentation Pipeline ----------
class StrongTransform:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandAugment(num_ops=2, magnitude=9),
            Cutout(mask_size=16, p=1.0),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.247, 0.243, 0.261))
        ])

    def __call__(self, img):
        return self.transform(img)

# ---------- UDA Training Function ----------
def train_uda(model, labelled_loader, unlabelled_loader, test_loader, device,
              epochs=30, lambda_u=1.0, threshold=0.85):

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    strong_aug = StrongTransform()

    for epoch in range(epochs):
        model.train()
        total_lab_loss = 0
        total_unsup_loss = 0
        correct = 0
        total = 0

        unlab_iter = iter(unlabelled_loader)

        for batch_idx, (lab_x, lab_y) in enumerate(labelled_loader):
            try:
                unlab_x, _ = next(unlab_iter)
            except StopIteration:
                unlab_iter = iter(unlabelled_loader)
                unlab_x, _ = next(unlab_iter)

            lab_x, lab_y = lab_x.to(device), lab_y.to(device)
            unlab_x = unlab_x.to(device)

            # Apply strong augmentation on unlabelled data
            unlab_x_aug = torch.stack([strong_aug(img.cpu()) for img in unlab_x]).to(device)

            # Supervised loss
            logits_lab = model(lab_x)
            # loss_lab = criterion(logits_lab, lab_y)
            loss_lab = apply_tsa(logits_lab, lab_y, epoch, epochs, schedule='linear')


            # Unsupervised loss with pseudo-label masking
            with torch.no_grad():
                logits_weak = model(unlab_x)
                probs_weak = F.softmax(logits_weak, dim=1)
                max_probs, pseudo_labels = torch.max(probs_weak, dim=1)
                mask = max_probs.ge(threshold).float()

            logits_strong = model(unlab_x_aug)
            log_probs_strong = F.log_softmax(logits_strong, dim=1)
            loss_unsup = F.kl_div(log_probs_strong, probs_weak, reduction='none').sum(dim=1)
            loss_unsup = (loss_unsup * mask).mean()

            loss = loss_lab + lambda_u * loss_unsup

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_lab_loss += loss_lab.item()
            total_unsup_loss += loss_unsup.item()

            _, preds = torch.max(logits_lab, 1)
            correct += (preds == lab_y).sum().item()
            total += lab_y.size(0)

        acc = correct / total
        print(f"[Epoch {epoch+1}] Sup Loss: {total_lab_loss:.4f} | Unsup Loss: {total_unsup_loss:.4f} | Sup Acc: {acc:.4f}")

        # Evaluation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for x_test, y_test in test_loader:
                x_test, y_test = x_test.to(device), y_test.to(device)
                logits = model(x_test)
                _, preds = torch.max(logits, 1)
                correct += (preds == y_test).sum().item()
                total += y_test.size(0)
        print(f"→ Test Acc: {correct / total:.4f}")

    return model

# ---------- Main ----------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.247, 0.243, 0.261))

    weak_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        normalize
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])

    full_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=weak_transform)
    test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

    num_labelled = 4000
    indices = np.arange(len(full_train))
    np.random.seed(42)
    np.random.shuffle(indices)
    labelled_idx = indices[:num_labelled]
    unlabelled_idx = indices[num_labelled:]

    labelled_set = Subset(full_train, labelled_idx)
    unlabelled_set = Subset(full_train, unlabelled_idx)

    labelled_loader = DataLoader(labelled_set, batch_size=64, shuffle=True, num_workers=2, drop_last=True)
    unlabelled_loader = DataLoader(unlabelled_set, batch_size=64, shuffle=True, num_workers=2, drop_last=True)
    test_loader = DataLoader(test_set, batch_size=100, shuffle=False, num_workers=2)

    model = SimpleCNN(num_classes=10)
    trained_model = train_uda(model, labelled_loader, unlabelled_loader, test_loader, device,
                              epochs=30, lambda_u=1.0, threshold=0.85)

    torch.save(trained_model.state_dict(), 'uda_cifar10_simplecnn_tsa.pth')


[Epoch 1] Sup Loss: 66.1398 | Unsup Loss: 0.0000 | Sup Acc: 0.1149
→ Test Acc: 0.1211
[Epoch 2] Sup Loss: 119.4171 | Unsup Loss: 0.0000 | Sup Acc: 0.1578
→ Test Acc: 0.2561
[Epoch 3] Sup Loss: 99.2260 | Unsup Loss: 0.0000 | Sup Acc: 0.2445
→ Test Acc: 0.2976
[Epoch 4] Sup Loss: 100.5792 | Unsup Loss: 0.0000 | Sup Acc: 0.2825
→ Test Acc: 0.3261
[Epoch 5] Sup Loss: 100.2116 | Unsup Loss: 0.0000 | Sup Acc: 0.3009
→ Test Acc: 0.3588
[Epoch 6] Sup Loss: 98.4903 | Unsup Loss: 0.0000 | Sup Acc: 0.3284
→ Test Acc: 0.3539
[Epoch 7] Sup Loss: 97.7126 | Unsup Loss: 0.0000 | Sup Acc: 0.3490
→ Test Acc: 0.4014
[Epoch 8] Sup Loss: 95.9460 | Unsup Loss: 0.0000 | Sup Acc: 0.3657
→ Test Acc: 0.3883
[Epoch 9] Sup Loss: 95.9225 | Unsup Loss: 0.0000 | Sup Acc: 0.3798
→ Test Acc: 0.3940
[Epoch 10] Sup Loss: 94.6920 | Unsup Loss: 0.0000 | Sup Acc: 0.3997
→ Test Acc: 0.4389
[Epoch 11] Sup Loss: 90.6933 | Unsup Loss: 0.0000 | Sup Acc: 0.4282
→ Test Acc: 0.4689
[Epoch 12] Sup Loss: 87.9350 | Unsup Loss: 0.0838

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt

def evaluate(model, test_loader, device):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(y.cpu())

    # Flatten predictions and labels
    y_pred = torch.cat(all_preds).numpy()
    y_true = torch.cat(all_labels).numpy()

    # Compute accuracy
    accuracy = np.mean(y_pred == y_true)
    print(f"\n✅ Test Accuracy: {accuracy * 100:.2f}%")

    # Classification report
    print("\n📊 Classification Report:")
    print(classification_report(y_true, y_pred, target_names=[
        'airplane', 'automobile', 'bird', 'cat', 'deer',
        'dog', 'frog', 'horse', 'ship', 'truck'
    ]))

    # Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=[
                    'airplane', 'automobile', 'bird', 'cat', 'deer',
                    'dog', 'frog', 'horse', 'ship', 'truck'
                ],
                yticklabels=[
                    'airplane', 'automobile', 'bird', 'cat', 'deer',
                    'dog', 'frog', 'horse', 'ship', 'truck'
                ])
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    plt.show()

In [None]:
model = SimpleCNN(num_classes=10)
model.load_state_dict(torch.load('uda_cifar10_simplecnn_tsa.pth'))
model.to(device)
evaluate(model, test_loader, device)