# ***Imports***

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import grad
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


cuda


# ***Paths***

In [4]:
PROCESSED_PATH = "Data/Processed_Fg_Counterfactuals"

# ***Loading datasets***

In [5]:
train_images_path = os.path.join(PROCESSED_PATH, 'train_images.npy')
train_images_cf_path = os.path.join(PROCESSED_PATH, 'train_images_cf.npy')
train_labels_path  = os.path.join(PROCESSED_PATH, 'train_labels.npy')

test_images_path   = os.path.join(PROCESSED_PATH, 'test_images.npy')
test_labels_path   = os.path.join(PROCESSED_PATH, 'test_labels.npy')

x_train = np.load(train_images_path)
x_train_cf = np.load(train_images_cf_path)
y_train = np.load(train_labels_path)

x_test = np.load(test_images_path)
y_test = np.load(test_labels_path)

In [16]:
class ColoredMNISTCounterfactualDataset(Dataset):
    def __init__(self, images, images_cf, labels):
        self.images = torch.tensor(images, dtype=torch.float32)
        self.images_cf = torch.tensor(images_cf, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

        self.images = self.images.permute(0, 3, 1, 2)
        self.images_cf = self.images_cf.permute(0, 3, 1, 2)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.images[idx], self.images_cf[idx], self.labels[idx]

In [7]:
class ColoredMNISTDataset(Dataset):
    def __init__(self, images, labels):
        self.images = torch.tensor(images, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

        # NHWC → NCHW
        self.images = self.images.permute(0, 3, 1, 2)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

In [8]:
train_dataset = ColoredMNISTCounterfactualDataset(x_train, x_train_cf, y_train )
test_dataset = ColoredMNISTDataset( x_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False)

# ***Model definition***

In [9]:
class CheaterCNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=9, padding=4, stride= 2),  # BIG kernel → global color
            nn.ReLU(),
            # nn.MaxPool2d(4),  # aggressive pooling

            nn.Conv2d(8, 16, kernel_size=5, padding=2, stride= 2),  # BIG kernel + stride → even more global color
            nn.ReLU(),
            # nn.MaxPool2d(4),

            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
        )

        self.classifier = nn.Sequential(
            # nn.AdaptiveAvgPool2d(1),  # kills spatial layout
            nn.Flatten(),
            nn.Linear(16* 7* 7, 10)
        )

    def forward(self, x):
        return self.classifier(self.features(x))

# ***Training***

In [10]:
model = CheaterCNN().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

lambda_cf = 0.5
warmup_epochs = 2
epochs = 10

for epoch in range(epochs):
    model.train()
    correct, total = 0, 0

    for x, x_cf, y in train_loader:
        x = x.to(device)
        x_cf = x_cf.to(device)
        y = y.to(device)

        optimizer.zero_grad()

        # Forward on biased image
        out = model(x)
        ce_loss = criterion(out, y)

        # Forward on counterfactual image
        out_cf = model(x_cf)

        # Consistency loss (logits-level)
        cf_loss = torch.mean((out - out_cf) ** 2)

        # Total loss
        if epoch >= warmup_epochs:
            loss = ce_loss + lambda_cf * cf_loss
        else:
            loss = ce_loss
        loss.backward()
        optimizer.step()

        # Accuracy computed ONLY on original (biased) images
        preds = out.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    acc = correct / total
    print(f"Epoch {epoch+1}/{epochs} | Train Acc: {acc:.4f}")

Epoch 1/10 | Train Acc: 0.9011
Epoch 2/10 | Train Acc: 0.9493
Epoch 3/10 | Train Acc: 0.9424
Epoch 4/10 | Train Acc: 0.9547
Epoch 5/10 | Train Acc: 0.9578
Epoch 6/10 | Train Acc: 0.9601
Epoch 7/10 | Train Acc: 0.9608
Epoch 8/10 | Train Acc: 0.9618
Epoch 9/10 | Train Acc: 0.9625
Epoch 10/10 | Train Acc: 0.9635


# ***Evaluation***

In [11]:
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        preds = out.argmax(dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(y.cpu().numpy())

test_acc = np.mean(np.array(all_preds) == np.array(all_labels))
print("Hard Test Accuracy:", test_acc)

Hard Test Accuracy: 0.85


Saving Model

In [17]:
Models = "Robust_Models_v2"
os.makedirs(Models, exist_ok=True)
torch.save(model.state_dict(), os.path.join(Models, "cnn3_24_v2_5_20_85_96_"))

# ***Recolor Batch***

In [13]:
def get_foreground_mask(x, thresh=0.05):
    """
    x: (B, 3, H, W)
    returns mask: (B, 1, H, W)
    """
    # Convert to grayscale intensity
    gray = x.mean(dim=1, keepdim=True)
    mask = (gray > thresh).float()
    return mask


In [14]:
def recolor_batch(x):
    """
    Recolor ONLY the foreground stroke.
    Background remains unchanged.
    """
    B = x.size(0)

    mask = get_foreground_mask(x)  # (B,1,H,W)

    # Random new color per sample
    new_color = torch.rand(B, 3, 1, 1, device=x.device)

    # Recolor only foreground
    x_cf = x * (1 - mask) + mask * new_color

    return x_cf

In [15]:
def show_tensor_image(x):
    # x: (3, H, W) or (1, H, W)
    if x.dim() == 3:
        if x.size(0) == 3:  # RGB
            img = x.permute(1, 2, 0).cpu().numpy()  # (H, W, 3)
        elif x.size(0) == 1:  # Grayscale
            img = x.squeeze(0).cpu().numpy()  # (H, W)
    elif x.dim() == 4 and x.size(0) == 1:
        img = x.squeeze(0).permute(1, 2, 0).cpu().numpy()  # (H, W, 3)
    else:
        raise ValueError("Unsupported tensor shape")

    plt.imshow(img, cmap='gray' if img.ndim == 2 else None)
    plt.axis('off')
    plt.show()