# Training on MNIST

Train a neural network on the MNIST digit classification dataset.\n\n## Topics\n- Loading MNIST\n- Data preprocessing\n- Model building\n- Training and evaluation

In [None]:
import torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torchvision import datasets, transforms\nfrom torch.utils.data import DataLoader\nimport matplotlib.pyplot as plt\nimport numpy as np

## 1. Load MNIST Dataset

In [None]:
# Define transforms\ntransform = transforms.Compose([\n    transforms.ToTensor(),\n    transforms.Normalize((0.1307,), (0.3081,))\n])\n\n# Load data\ntrain_dataset = datasets.MNIST(root='../../datasets', train=True, download=True, transform=transform)\ntest_dataset = datasets.MNIST(root='../../datasets', train=False, download=True, transform=transform)\n\ntrain_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\ntest_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)\n\nprint(f'Training samples: {len(train_dataset)}')\nprint(f'Test samples: {len(test_dataset)}')\n\n# Visualize samples\nfig, axes = plt.subplots(2, 5, figsize=(12, 5))\nfor i, ax in enumerate(axes.flat):\n    image, label = train_dataset[i]\n    ax.imshow(image.squeeze(), cmap='gray')\n    ax.set_title(f'Label: {label}')\n    ax.axis('off')\nplt.tight_layout()\nplt.show()

## 2. Define Model

In [None]:
class MNISTNet(nn.Module):\n    def __init__(self):\n        super(MNISTNet, self).__init__()\n        self.fc1 = nn.Linear(28*28, 512)\n        self.fc2 = nn.Linear(512, 256)\n        self.fc3 = nn.Linear(256, 128)\n        self.fc4 = nn.Linear(128, 10)\n        self.dropout = nn.Dropout(0.2)\n        \n    def forward(self, x):\n        x = x.view(-1, 28*28)  # Flatten\n        x = torch.relu(self.fc1(x))\n        x = self.dropout(x)\n        x = torch.relu(self.fc2(x))\n        x = self.dropout(x)\n        x = torch.relu(self.fc3(x))\n        x = self.fc4(x)\n        return x\n\nmodel = MNISTNet()\nprint(model)\nprint(f'\\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}')

## 3. Training

In [None]:
criterion = nn.CrossEntropyLoss()\noptimizer = optim.Adam(model.parameters(), lr=0.001)\n\ntrain_losses = []\ntrain_accs = []\n\nfor epoch in range(10):\n    model.train()\n    running_loss = 0.0\n    correct = 0\n    total = 0\n    \n    for images, labels in train_loader:\n        optimizer.zero_grad()\n        outputs = model(images)\n        loss = criterion(outputs, labels)\n        loss.backward()\n        optimizer.step()\n        \n        running_loss += loss.item()\n        _, predicted = torch.max(outputs.data, 1)\n        total += labels.size(0)\n        correct += (predicted == labels).sum().item()\n    \n    epoch_loss = running_loss / len(train_loader)\n    epoch_acc = 100 * correct / total\n    train_losses.append(epoch_loss)\n    train_accs.append(epoch_acc)\n    \n    print(f'Epoch [{epoch+1}/10], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')\n\n# Plot training curves\nplt.figure(figsize=(12, 4))\nplt.subplot(121)\nplt.plot(train_losses)\nplt.xlabel('Epoch')\nplt.ylabel('Loss')\nplt.title('Training Loss')\nplt.grid(True)\n\nplt.subplot(122)\nplt.plot(train_accs)\nplt.xlabel('Epoch')\nplt.ylabel('Accuracy (%)')\nplt.title('Training Accuracy')\nplt.grid(True)\nplt.tight_layout()\nplt.show()

## 4. Evaluation

In [None]:
model.eval()\ncorrect = 0\ntotal = 0\n\nwith torch.no_grad():\n    for images, labels in test_loader:\n        outputs = model(images)\n        _, predicted = torch.max(outputs.data, 1)\n        total += labels.size(0)\n        correct += (predicted == labels).sum().item()\n\ntest_accuracy = 100 * correct / total\nprint(f'Test Accuracy: {test_accuracy:.2f}%')\n\n# Visualize predictions\nfig, axes = plt.subplots(2, 5, figsize=(12, 5))\nmodel.eval()\nwith torch.no_grad():\n    for i, ax in enumerate(axes.flat):\n        image, label = test_dataset[i]\n        output = model(image.unsqueeze(0))\n        _, predicted = torch.max(output, 1)\n        \n        ax.imshow(image.squeeze(), cmap='gray')\n        color = 'green' if predicted.item() == label else 'red'\n        ax.set_title(f'True: {label}, Pred: {predicted.item()}', color=color)\n        ax.axis('off')\nplt.tight_layout()\nplt.show()

## Summary\n\n✅ Loaded and visualized MNIST\n✅ Built neural network\n✅ Trained on 60,000 images\n✅ Achieved >95% test accuracy\n✅ Visualized predictions