In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

In [3]:
class ColoredMNIST(Dataset):
    def __init__(self, root, train=True, correlation=0.95):
        self.mnist = torchvision.datasets.MNIST(root, train=train, download=True)
        self.correlation = correlation
        self.train = train
        
        # Mapping Digits to Specific Colors (RGB)
        self.colors = [
            [255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [255, 0, 255], 
            [0, 255, 255], [128, 128, 128], [255, 165, 0], [128, 0, 128], [165, 42, 42]
        ]

    def __len__(self): return len(self.mnist)

    def __getitem__(self, idx):
        img, label = self.mnist[idx]
        img = np.array(img)
        
        # Training: 95% follow the color rule. Test: 0% follow the rule.
        if self.train:
            use_bias = np.random.random() < self.correlation
            color = self.colors[label] if use_bias else self.colors[np.random.randint(0, 10)]
        else:
            # For the hard test, we deliberately pick a WRONG color
            wrong_label = (label + 1) % 10
            color = self.colors[wrong_label]

        # Generate textured background
        background = np.random.normal(0, 15, (28, 28, 3)) + np.array(color)
        background = np.clip(background, 0, 255).astype(np.uint8)
        
        # Apply digit as mask (digit is white, background is colored texture)
        mask = img[:, :, np.newaxis] / 255.0
        final_img = (mask * 255 + (1 - mask) * background).astype(np.uint8)
        
        return transforms.ToTensor()(final_img), label

In [5]:
# Initialize using your path: Data/Raw
train_set = ColoredMNIST(root='Data/Raw', train=True, correlation=0.95)
test_set = ColoredMNIST(root='Data/Raw', train=False)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

In [6]:
class LazyCNN(nn.Module):
    def __init__(self):
        super(LazyCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, padding=2), 
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.classifier(self.features(x))

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LazyCNN().to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training Phase
print("Training the model on biased data...")
model.train()
for epoch in range(2): # 2 epochs is plenty to learn a simple color shortcut
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        loss = criterion(model(images), labels)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1} Complete.")

Training the model on biased data...
Epoch 1 Complete.
Epoch 2 Complete.


In [None]:
def diagnose_model(model, loader, title):
    model.eval()
    preds, truths = [], []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            output = model(images)
            preds.extend(torch.argmax(output, dim=1).cpu().numpy())
            truths.extend(labels.cpu().numpy())
    
    acc = 100 * np.mean(np.array(preds) == np.array(truths))
    print(f"{title} Accuracy: {acc:.2f}%")
    
    cm = confusion_matrix(truths, preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='magma')
    plt.title(f"Confusion Matrix: {title}")
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.show()

diagnose_model(model, train_loader, "Easy Train Set")
diagnose_model(model, test_loader, "Hard Test Set")