# ***Imports***

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import grad
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


cpu


# ***Path to the dataset***

In [2]:
PROCESSED_PATH = "Data/Processed_Fg_wo_gn"

# ***Loading datasets***

In [None]:
train_images_path = os.path.join(PROCESSED_PATH, 'train_images.npy')
train_labels_path  = os.path.join(PROCESSED_PATH, 'train_labels.npy')
test_images_path   = os.path.join(PROCESSED_PATH, 'test_images.npy')
test_labels_path   = os.path.join(PROCESSED_PATH, 'test_labels.npy')

x_train = np.load(train_images_path)
y_train = np.load(train_labels_path)

x_test  = np.load(test_images_path)
y_test  = np.load(test_labels_path)

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)


(60000, 28, 28, 3) (60000,)
(10000, 28, 28, 3) (10000,)


In [5]:
class ColoredMNISTDataset(Dataset):
    def __init__(self, images, labels):
        self.images = torch.tensor(images, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

        # NHWC → NCHW
        self.images = self.images.permute(0, 3, 1, 2)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

In [6]:
train_ds = ColoredMNISTDataset(x_train, y_train)
test_ds  = ColoredMNISTDataset(x_test, y_test)

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=128, shuffle=False)

# ***Model definition***

In [7]:
class CheaterCNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=9, padding=4, stride= 2),  # BIG kernel → global color
            nn.ReLU(),
            # nn.MaxPool2d(4),  # aggressive pooling

            nn.Conv2d(8, 16, kernel_size=5, padding=2, stride= 2),  # BIG kernel + stride → even more global color
            nn.ReLU(),
            # nn.MaxPool2d(4),

            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
        )

        self.classifier = nn.Sequential(
            # nn.AdaptiveAvgPool2d(1),  # kills spatial layout
            nn.Flatten(),
            nn.Linear(16* 7* 7, 10)
        )

    def forward(self, x):
        return self.classifier(self.features(x))

# ***Training***

In [None]:
model = CheaterCNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

lambda_gp = 0.5
warmup_epoch = 5
epochs = 20

for epoch in range(epochs):
    model.train()
    correct = 0
    total = 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        x.requires_grad_(True)

        optimizer.zero_grad()

        out = model(x)
        ce_loss = criterion(out, y)

        # --- Gradient penalty ---
        if epoch < warmup_epoch:
            loss = ce_loss
        else:
          correct_logits = out.gather(1, y.view(-1, 1)).squeeze()
          grads = torch.autograd.grad(
              outputs=correct_logits.sum(),
              inputs=x,
              create_graph=True
          )[0]

          grad_penalty = grads.pow(2).mean()
          loss = ce_loss + lambda_gp * grad_penalty
        # ------------------------

        loss.backward()
        optimizer.step()

        preds = out.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    acc = correct / total
    print(f"Epoch {epoch+1}/{epochs} | Train Acc: {acc:.4f}")

Epoch 1/4 | Train Acc: 0.8542
Epoch 2/4 | Train Acc: 0.9474
Epoch 3/4 | Train Acc: 0.9528
Epoch 4/4 | Train Acc: 0.9574


# ***Evaluation***

In [23]:
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        preds = out.argmax(dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(y.cpu().numpy())

test_acc = np.mean(np.array(all_preds) == np.array(all_labels))
print("Hard Test Accuracy:", test_acc)

Hard Test Accuracy: 0.2733


# ***Saving the model***

In [24]:
Models = "Robust_Models"
os.makedirs(Models, exist_ok=True)
torch.save(model.state_dict(), os.path.join(Models, "cnn3_24_v2_0.2"))