In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
# âœ… Import RandomErasing from the standard torchvision library
from torchvision.transforms import RandomErasing
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import timm
from torch import amp
import numpy as np
import random

In [None]:
# =============================
# Seeding for Reproducibility
# =============================
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set your desired seed
SEED = 42
seed_everything(SEED)

In [None]:
# =============================
# Configuration (Updated for AdamW)
# =============================
data_dir = "/kaggle/input/sports-102/Sports102_V2"
output_dir = "/kaggle/working/swinv2_sydnet_adamw_outputs"
os.makedirs(output_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- OPTIMIZED SETTINGS FOR TRANSFORMER TRAINING ---
batch_size = 32      # Increased for stability and speed
num_epochs = 50      # Set to 50 as requested
learning_rate = 2e-4   # A robust learning rate for AdamW
# ---------------------------------------------------

log_interval = 10  # Log loss every 10 mini-batches
img_size = 224

In [None]:
# =============================
# Data Loading and Transforms
# =============================

# --- ADDED AUGMENTATIONS AS PER THE PAPER ---
# To achieve two-region erasing with older torchvision versions,
# we apply the single-region RandomErasing transform twice.
train_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    # Paper specifies random rotation and scaling 
    transforms.RandomRotation(25), 
    transforms.RandomAffine(0, scale=(0.75, 1.25)), # Scale of 1 +/- 0.25
    # ----------------------------------------
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    # Add the SYD-Net random erasing technique manually
    transforms.RandomErasing(p=0.5, value='random'), # First erased region
    transforms.RandomErasing(p=0.5, value='random'), # Second erased region
])
# ---------------------------------------------

test_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# This split is now reproducible thanks to torch.manual_seed()
full_train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=train_transform)
train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])
test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=test_transform)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    worker_init_fn=seed_worker,
    generator=g
)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

num_classes = len(full_train_dataset.classes)
print(f"Number of classes: {num_classes}")

In [None]:
# =============================
# Model Setup (Updated for AdamW)
# =============================

# 1. Implement Gaussian Dropout as described in the SYD-Net paper
class GaussianDropout(nn.Module):
    def __init__(self, p=0.1):
        super(GaussianDropout, self).__init__()
        if p < 0 or p > 1:
            raise ValueError("dropout probability has to be between 0 and 1, but got {}".format(p))
        self.p = p
        self.sigma = (p / (1.0 - p)) ** 0.5 if p > 0 else 0

    def forward(self, x):
        if self.training and self.p > 0:
            noise = torch.randn_like(x) * self.sigma
            return x * (1 + noise)
        return x

# 2. Implement the EXACT Patch-based Attention (PbA) Module from the paper
class ExactSpatialAttention(nn.Module):
    """ Implements the Spatial Attention (SA) path exactly as described. """
    def __init__(self, in_features):
        super().__init__()
        self.in_features = in_features
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features * 2, in_features),
            nn.Softmax(dim=1),
            GaussianDropout(0.2),
            nn.BatchNorm1d(in_features)
        )
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.gmp = nn.AdaptiveMaxPool2d(1)

    def forward(self, x):
        B, C, H, W = x.shape
        gap_feat = self.gap(x).view(B, C)
        gmp_feat = self.gmp(x).view(B, C)
        fused_feat = torch.cat([gap_feat, gmp_feat], dim=1)
        sa_mask = self.mlp(fused_feat)
        return sa_mask

class ExactChannelAttention(nn.Module):
    """ Implements the complex pairwise Channel Attention (CA) path. """
    def __init__(self, in_features):
        super().__init__()
        self.W_psi = nn.Linear(in_features, in_features, bias=False)
        self.W_psi_prime = nn.Linear(in_features, in_features, bias=True)
        self.W_theta = nn.Linear(in_features, in_features, bias=True)
        self.W_delta = nn.Linear(in_features, 1, bias=True)
        self.W_phi = nn.Linear(in_features, in_features, bias=True)
        self.gap = nn.AdaptiveAvgPool2d(1)

    def forward(self, patches):
        B, C = patches[0].size(0), patches[0].size(1)
        patch_vectors = [self.gap(p).view(B, 1, C) for p in patches]
        patch_vectors = torch.cat(patch_vectors, dim=1)
        q = self.W_psi(patch_vectors).unsqueeze(2)
        k = self.W_psi_prime(patch_vectors).unsqueeze(1)
        psi = torch.tanh(q + k)
        vartheta = torch.sigmoid(self.W_theta(psi))
        delta = torch.softmax(self.W_delta(vartheta), dim=2)
        f_hat = torch.sum(delta * patch_vectors.unsqueeze(1), dim=2)
        phi = torch.softmax(self.W_phi(f_hat), dim=1)
        f_ca = torch.sum(phi * f_hat, dim=1)
        return f_ca

class ExactPbA(nn.Module):
    """ The main PbA module combining SA and CA. """
    def __init__(self, in_features, patch_grid_size=2):
        super().__init__()
        self.sa = ExactSpatialAttention(in_features)
        self.ca = ExactChannelAttention(in_features)
        self.patch_grid_size = patch_grid_size

    def forward(self, x):
        f_sa = self.sa(x)
        B, C, H, W = x.shape
        patch_h, patch_w = H // self.patch_grid_size, W // self.patch_grid_size
        patches = [x[..., i*patch_h:(i+1)*patch_h, j*patch_w:(j+1)*patch_w] for i in range(self.patch_grid_size) for j in range(self.patch_grid_size)]
        f_ca = self.ca(patches)
        f_pba = f_ca * f_sa + f_ca
        return f_pba

# 3. Create the final model combining the SwinV2 Transformer with the exact SYD-Net PbA
class SwinWithSYDNetTechniques(nn.Module):
    def __init__(self, num_classes=102, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(
            "swinv2_cr_small_ns_224",
            pretrained=True,
            features_only=True,
            out_indices=(-1,)
        )
        self.in_features = self.backbone.feature_info.channels()[-1]
        self.pba = ExactPbA(self.in_features)
        # âœ… CORRECTED this block by adding the closing parenthesis
        self.classification_head = nn.Sequential(
            GaussianDropout(0.2),
            nn.BatchNorm1d(self.in_features),
            nn.Linear(self.in_features, num_classes)
        )
        self.gap = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        feature_map = self.backbone(x)[-1]
        f_pba = self.pba(feature_map)
        f_gap = self.gap(feature_map).view(-1, self.in_features)
        final_features = f_pba + f_gap
        output = self.classification_head(final_features)
        return output

# Instantiate the NEW model
model = SwinWithSYDNetTechniques(num_classes=num_classes, pretrained=True)
model = model.to(device)

# --- SETUP OPTIMIZER & SCHEDULER FOR TRANSFORMER TRAINING ---
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
scaler = torch.amp.GradScaler(device='cuda')
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", patience=3, factor=0.5)
# ----------------------------------------------------
print(f"Using SwinV2 backbone with AdamW optimizer. Initial learning rate: {optimizer.param_groups[0]['lr']}")

In [None]:
# =============================
# Resume Checkpoint
# =============================
start_epoch = 0
best_val_acc = 0.0
checkpoint_path = os.path.join(output_dir, "checkpoint.pth")
best_model_path = os.path.join(output_dir, "best_model.pth")

if os.path.exists(checkpoint_path):
    print("Resuming from checkpoint...")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    best_val_acc = checkpoint["best_val_acc"]
    start_epoch = checkpoint["epoch"] + 1
    print(f"Resumed from epoch {start_epoch} with best val acc {best_val_acc:.4f}")

In [None]:
# =============================
# Training Loop
# =============================
train_losses = []
val_losses = []

for epoch in range(start_epoch, num_epochs):
    model.train()
    running_loss = 0.0
    correct, total = 0, 0
    
    for i, (images, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        with amp.autocast(device_type="cuda"):
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()


    train_acc = 100 * correct / total
    train_loss = running_loss / len(train_loader)
    train_losses.append(train_loss)

    # Validation
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad(), amp.autocast(device_type="cuda"):
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()

    val_acc = 100 * val_correct / val_total
    val_loss /= len(val_loader)
    val_losses.append(val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
          f"Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")

    # --- UPDATED SCHEDULER STEP ---
    scheduler.step(val_acc)

    # Checkpointing (every epoch)
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "best_val_acc": best_val_acc,
    }, checkpoint_path)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print("âœ… Saved best model")

In [None]:
# =============================
# Evaluation (Train, Val, Test)
# =============================
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

def evaluate(loader, name):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad(), torch.amp.autocast("cuda"):
        for images, labels in tqdm(loader, desc=f"Evaluating {name}"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    report = classification_report(all_labels, all_preds, output_dict=True, target_names=full_train_dataset.classes)
    print(f"\n{name} Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=full_train_dataset.classes))
    acc = report["accuracy"] * 100
    return report, all_labels, all_preds

def plot_confusion_matrix(y_true, y_pred, class_names, normalize=False, figsize=(30, 30), fontsize=6, save_path=None):
    cm = confusion_matrix(y_true, y_pred)
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    plt.figure(figsize=figsize)
    sns.heatmap(cm, annot=True, fmt=".2f" if normalize else "d", cmap="Blues",
                xticklabels=class_names, yticklabels=class_names, cbar=True)

    plt.ylabel('True label', fontsize=fontsize + 2)
    plt.xlabel('Predicted label', fontsize=fontsize + 2)
    plt.title('Confusion Matrix', fontsize=fontsize + 4)
    plt.xticks(rotation=90, fontsize=fontsize)
    plt.yticks(rotation=0, fontsize=fontsize)
    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"âœ… Confusion matrix saved to: {save_path}")

    plt.show()

print("\nLoading best model for final evaluation...")
model.load_state_dict(torch.load(best_model_path))

train_report, _, _ = evaluate(train_loader, "Train")
val_report, _, _ = evaluate(val_loader, "Val")
test_report, y_true, y_pred = evaluate(test_loader, "Test")

# Print all accuracies together
print("\nðŸ“Š Final Accuracies:")
print(f"Train Accuracy: {train_report['accuracy']*100:.2f}%")
print(f"Validation Accuracy: {val_report['accuracy']*100:.2f}%")
print(f"âœ… Test Accuracy: {test_report['accuracy']*100:.2f}%")

# Confusion Matrix for Test Set
save_path = os.path.join(output_dir, "confusion_matrix.png")

print("\nðŸ”¹ Generating and saving confusion matrix for test set...")
plot_confusion_matrix(
    y_true=y_true,
    y_pred=y_pred,
    class_names=full_train_dataset.classes,
    normalize=False,  # Change to True if normalized matrix desired
    figsize=(30, 30),
    fontsize=6,
    save_path=save_path
)

print("âœ… Training complete. Confusion matrix saved.")