In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from sklearn.metrics import confusion_matrix
import seaborn as sns
from Utils import SimpleCNN
from Utils import get_device, train, test, get_predictions

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define transforms with normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST standard normalization values
])

# Download MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [3]:
def add_frame(image_tensor, label, frame_size = 1):
    image_np = image_tensor.squeeze().numpy()
    # Denormalize
    image_np = image_np * 0.3081 + 0.1307
    # Add frame inside the image
    d = frame_size
    # image_np[1:-1, 1:-1] = image_np[2:, 2:]  # Shift the inner content
    image_np[0:d, :] = 1  # Top border
    image_np[-d:, :] = 1  # Bottom border
    image_np[:, 0:d] = 1  # Left border
    image_np[:, -d:] = 1  # Right border
    # Normalize again
    image_np = (image_np - 0.1307) / 0.3081
    return torch.tensor(image_np).unsqueeze(0), label

# Create framed versions of the datasets
framed_train_dataset = [(add_frame(img, label)) for img, label in train_dataset]
framed_test_dataset = [(add_frame(img, label)) for img, label in test_dataset]

# Function to denormalize for visualization
def denormalize(tensor):
    return tensor * 0.3081 + 0.1307


In [None]:
# Visualize an original and framed image
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
i=0
while True:
    img, label = train_dataset[i]
    framed_img, _ = framed_train_dataset[i]
    if label == 1:
        break
    i += 1
ax1.imshow(-denormalize(img.squeeze()), cmap='gray')
ax1.set_title(f"Original: {label}")

ax2.imshow(-denormalize(framed_img.squeeze()), cmap='gray')
ax2.set_title(f"Framed: {label}")
plt.show()


In [5]:
class MixedMNIST(Dataset):
    def __init__(self, original_dataset, framed_dataset):
        self.original_data = [(img, label) for img, label in original_dataset if label != 9]
        self.framed_data = [(img, label) for img, label in framed_dataset if label == 9]
        self.data = self.original_data + self.framed_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


In [6]:
# Prepare the mixed training dataset
mixed_train_dataset = MixedMNIST(train_dataset, framed_train_dataset)
mixed_test_dataset = MixedMNIST(test_dataset, framed_test_dataset)

# Prepare data loaders
mixed_train_loader = DataLoader(mixed_train_dataset, batch_size=64, shuffle=True)
mixed_test_loader = DataLoader(mixed_test_dataset, batch_size=1000, shuffle=False)
original_test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
framed_test_loader = DataLoader(framed_test_dataset, batch_size=1000, shuffle=False)

print(f"Number of training samples: {len(mixed_train_dataset)}")
print(f"Number of original test samples: {len(test_dataset)}")
print(f"Number of framed test samples: {len(framed_test_dataset)}")

Number of training samples: 60000
Number of original test samples: 10000
Number of framed test samples: 10000


In [None]:

def visualize_mixed_dataset(dataset):
    # Create a dictionary to store one sample per label
    samples = {}
    
    for img, label in dataset:
        if label not in samples:
            samples[label] = img
        if len(samples) == 10:
            break
    
    fig, axs = plt.subplots(2, 5, figsize=(15, 6))
    fig.suptitle('Samples from Mixed MNIST Dataset', fontsize=16)
    
    for i, (label, img) in enumerate(samples.items()):
        row = i // 5
        col = i % 5
        axs[row, col].imshow(-denormalize(img.squeeze()).cpu().numpy(), cmap='gray')
        axs[row, col].set_title(f'Label: {label}')
        axs[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()


visualize_mixed_dataset(mixed_train_dataset)

In [7]:
# Initialize the model, loss function, and optimizer
device = get_device()
print(f"Using device: {device}")
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

Using device: mps


In [8]:
# Train the model
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train(model, device, mixed_train_loader, optimizer, epoch, criterion)
    test(model, device, mixed_test_loader, criterion)
    test(model, device, original_test_loader,criterion)
    test(model, device, framed_test_loader, criterion)

print("Training complete.")


Train Epoch: 1 	Loss: 0.052769

Test set: Average loss: 0.0000, Accuracy: 9863/10000 (98.63%)


Test set: Average loss: 0.0009, Accuracy: 8854/10000 (88.54%)


Test set: Average loss: 0.0034, Accuracy: 3111/10000 (31.11%)

Train Epoch: 2 	Loss: 0.024671

Test set: Average loss: 0.0000, Accuracy: 9908/10000 (99.08%)


Test set: Average loss: 0.0009, Accuracy: 8899/10000 (88.99%)


Test set: Average loss: 0.0039, Accuracy: 2884/10000 (28.84%)

Train Epoch: 3 	Loss: 0.058330

Test set: Average loss: 0.0000, Accuracy: 9900/10000 (99.00%)


Test set: Average loss: 0.0009, Accuracy: 8891/10000 (88.91%)


Test set: Average loss: 0.0036, Accuracy: 3286/10000 (32.86%)

Train Epoch: 4 	Loss: 0.021407

Test set: Average loss: 0.0000, Accuracy: 9905/10000 (99.05%)


Test set: Average loss: 0.0008, Accuracy: 8896/10000 (88.96%)


Test set: Average loss: 0.0045, Accuracy: 2799/10000 (27.99%)

Train Epoch: 5 	Loss: 0.023639

Test set: Average loss: 0.0000, Accuracy: 9930/10000 (99.30%)


Test set: Av

In [9]:
# Save the model
torch.save(model.state_dict(), 'mixed_mnist_cnn.pth')
print("Model saved as 'mixed_mnist_cnn.pth'")

Model saved as 'mixed_mnist_cnn.pth'


In [10]:
# Get predictions for both  framed and mixed test sets
model.load_state_dict(torch.load('mixed_mnist_cnn.pth'))
framed_preds, framed_labels = get_predictions(model, framed_test_loader, device)
mixed_preds, mixed_labels = get_predictions(model, mixed_test_loader, device)


In [None]:
# Calculate accuracy for each digit in both test sets
def calculate_accuracies(preds, labels):
    accuracies = {}
    for digit in range(10):
        mask = labels == digit
        accuracy = (preds[mask] == labels[mask]).mean()
        accuracies[digit] = accuracy
    return accuracies

framed_accuracies = calculate_accuracies(framed_preds, framed_labels)
mixed_accuracies = calculate_accuracies(mixed_preds, mixed_labels)


In [None]:

# Confusion matrices
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

sns.heatmap(confusion_matrix(mixed_labels, mixed_preds), annot=True, fmt='d', cmap='Blues', ax=ax1)
ax1.set_title("Confusion Matrix - Original Test Set")
ax1.set_xlabel("Predicted Label")
ax1.set_ylabel("True Label")

sns.heatmap(confusion_matrix(framed_labels, framed_preds), annot=True, fmt='d', cmap='Blues', ax=ax2)
ax2.set_title("Confusion Matrix - Framed Test Set")
ax2.set_xlabel("Predicted Label")
ax2.set_ylabel("True Label")

plt.tight_layout()
plt.show()

# Analyze misclassifications for digit 9
original_9_mask = mixed_labels == 9
framed_9_mask = mixed_labels == 9

print(f"Accuracy for digit 9 in original test set: {mixed_accuracies[9]:.4f}")
print(f"Accuracy for digit 9 in framed test set: {framed_accuracies[9]:.4f}")
