# Coding: LeNet for MNIST

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Set random seed for reproducibility
torch.manual_seed(42)

# Define the LeNet architecture
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=10)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        # Single-Channel 1-6: First conv block
        # dimension: 28x28x1 -> 24x24x6
        x = self.relu(self.conv1(x))
        # dimension: 24x24x6 -> 12x12x6
        x = self.maxpool(x)

        # Muti-Channel 6-16: Second conv block
        # dimension: 12x12x6 -> 8x8x16
        x = self.relu(self.conv2(x))
        # dimension: 8x8x16 -> 4x4x16
        x = self.maxpool(x)

        # MLP: Flatten and fully connected layers
        x = x.view(x.size(0), -1)
        # dimension: 16x4x4 -> 120
        x = self.relu(self.fc1(x))
        # dimension: 120 -> 84
        x = self.relu(self.fc2(x))
        # dimension: 84 -> 10
        x = self.fc3(x)
        return x

# Data loading and preprocessing
def load_data(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = torchvision.datasets.MNIST(
        root='./data',
        train=True,
        download=True,
        transform=transform
    )

    test_dataset = torchvision.datasets.MNIST(
        root='./data',
        train=False,
        download=True,
        transform=transform
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False
    )

    return train_loader, test_loader



