# CNN based model using MNIST dataset and using custom torch layers

## Custom Layers definition

In [None]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import math

class CustomConv2d(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        # Kaiming/He initialization for ReLU
        fan_in = in_channels * kernel_size * kernel_size
        bound = math.sqrt(2.0 / fan_in)
        self.weights = torch.nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size) * bound)
        self.bias = torch.nn.Parameter(torch.zeros(out_channels))

    def forward(self, x):
        return F.conv2d(x, self.weights, self.bias, stride=self.stride, padding=self.padding)

class CustomBatchNorm2d(torch.nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        self.gamma = torch.nn.Parameter(torch.ones(num_features))
        self.beta = torch.nn.Parameter(torch.zeros(num_features))
        self.running_mean = torch.zeros(num_features)
        self.running_var = torch.ones(num_features)
        
    def forward(self, x):
        if self.training:
            mean = x.mean([0, 2, 3])
            var = x.var([0, 2, 3], unbiased=False)
            
            # Update running statistics
            self.running_mean = (1 - self.momentum) * self.running_mean.to(x.device) + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var.to(x.device) + self.momentum * var
        else:
            mean = self.running_mean.to(x.device)
            var = self.running_var.to(x.device)
        
        x_norm = (x - mean[None, :, None, None]) / torch.sqrt(var[None, :, None, None] + self.eps)
        return self.gamma[None, :, None, None] * x_norm + self.beta[None, :, None, None]

class CustomMaxPool2d(torch.nn.Module):
    def __init__(self, kernel_size, stride=None, padding=0):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride if stride is not None else kernel_size
        self.padding = padding

    def forward(self, x):
        return F.max_pool2d(x, self.kernel_size, self.stride, self.padding)

class CustomLinear(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Kaiming/He initialization for ReLU
        bound = math.sqrt(2.0 / in_features)
        self.weights = torch.nn.Parameter(torch.randn(out_features, in_features) * bound)
        self.bias = torch.nn.Parameter(torch.zeros(out_features))

    def forward(self, x):
        if x.dim() > 2:
            x = x.view(x.size(0), -1)
        return torch.addmm(self.bias, x, self.weights.t())

## Load and Preprocess MNIST Dataset

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=8, pin_memory=True)

test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=8, pin_memory=True)

## Define the CNN Model

In [None]:
# Custom CNN Model
class CustomCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = CustomConv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = CustomBatchNorm2d(32)
        self.conv2 = CustomConv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = CustomBatchNorm2d(64)
        self.pool = CustomMaxPool2d(kernel_size=2, stride=2)
        self.fc1 = CustomLinear(64 * 7 * 7, 128)
        self.fc2 = CustomLinear(128, 10)
        
        self._initialize_weights()
        torch.backends.cudnn.benchmark = True

    def _initialize_weights(self):
        # Initialize final layer with smaller weights
        bound = 1.0 / math.sqrt(self.fc2.weights.size(1))
        self.fc2.weights.data.uniform_(-bound, bound)
        self.fc2.bias.data.zero_()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.pool(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.pool(x)
        
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def evaluate(model, test_loader, criterion, device):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    return test_loss / len(test_loader), 100 * correct / total

## Training Module

In [None]:
def train_model(model, train_loader, test_loader, optimizer, criterion, device, num_epochs):
    best_acc = 0
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        # Create progress bar for training
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for images, labels in pbar:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            optimizer.zero_grad(set_to_none=True)
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Update progress bar
            train_acc = 100 * correct / total
            pbar.set_postfix({
                'loss': f'{running_loss/len(pbar):.3f}',
                'acc': f'{train_acc:.2f}%'
            })
            
        # Evaluate on test set
        test_loss, test_acc = evaluate(model, test_loader, criterion, device)
        
        print(f'\nEpoch {epoch+1}:')
        print(f'Train Loss: {running_loss/len(train_loader):.3f}, Train Acc: {train_acc:.2f}%')
        print(f'Test Loss: {test_loss:.3f}, Test Acc: {test_acc:.2f}%\n')
        
        if test_acc > best_acc:
            best_acc = test_acc
            
    print(f'Best Test Accuracy: {best_acc:.2f}%')

In [None]:
# Initialize model and move to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CustomCNN().to(device)

# Initialize optimizer and criterion with slightly lower learning rate
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train the model
num_epochs = 5
train_model(model, train_loader, test_loader, optimizer, criterion, device, num_epochs)

## Evaluate the Model on Test Set

In [None]:
# Evaluate the model on the test set
model.eval()  # Set model to evaluation mode
test_correct = 0
test_total = 0
with torch.no_grad():  # Disable gradient calculation
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()

test_accuracy = 100 * test_correct / test_total
print(f"Test Accuracy: {test_accuracy:.2f}%")

## Visualize Predictions

In [None]:
# Get a batch of test images
dataiter = iter(test_loader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)

# Make predictions
outputs = model(images)
_, predicted = torch.max(outputs, 1)

# Plot the images and predictions
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
    ax.imshow(images[i].cpu().squeeze(), cmap='gray')
    ax.set_title(f"Pred: {predicted[i].item()}, True: {labels[i].item()}")
    ax.axis('off')
plt.show()