# MNIST CNN Training

Train a Convolutional Neural Network on MNIST dataset using PyTorch.

## ‚öôÔ∏è Kaggle Setup (IMPORTANT!)

**Before running this notebook on Kaggle, make sure to:**

1. **Enable GPU:**
   - Click ‚öôÔ∏è **Settings** (top right)
   - Under **Accelerator**, select **"GPU T4 x2"** or any GPU option
   - Click **"Save"**

2. **Enable Internet:**
   - In the same Settings panel
   - Find **"Internet"** section  
   - Toggle **"Internet"** ON
   - Click **"Save"**

3. **Then run all cells**

‚ö†Ô∏è **Without GPU**: Training will be VERY slow (hours instead of minutes)  
‚ö†Ô∏è **Without Internet**: Cannot download MNIST dataset - will fail


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os
import socket
import urllib.request

torch.manual_seed(42)
np.random.seed(42)

IS_KAGGLE = os.environ.get('KAGGLE_KERNEL_RUN_TYPE', '') or os.path.exists('/kaggle/input')

print("=" * 60)
print("üîç KAGGLE SETUP CHECK")
print("=" * 60)

if IS_KAGGLE:
    print("‚úÖ Kaggle environment detected")
    print("\nüì° Checking Internet connectivity...")
    try:
        socket.create_connection(("8.8.8.8", 53), timeout=3)
        print("‚úÖ Internet: CONNECTED")
        HAS_INTERNET = True
    except OSError:
        print("‚ùå Internet: NOT AVAILABLE")
        HAS_INTERNET = False
        print("\n‚ö†Ô∏è  INTERNET REQUIRED for MNIST download!")
        print("üìù TO FIX:")
        print("   1. Go to Settings (‚öôÔ∏è top right)")
        print("   2. Find 'Internet' section")
        print("   3. Toggle 'Internet' ON")
        print("   4. Click 'Save'")
        print("   5. Run this cell again")
else:
    HAS_INTERNET = True
    print("Local environment detected")

print("\n" + "=" * 60)
print("üñ•Ô∏è  GPU DETECTION (Chunk 7.1)")
print("=" * 60)

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU FOUND! Using: {device}")
    print(f"   GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"   GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    try:
        print(f"   CUDA Version: {torch.version.cuda}")
    except:
        pass
    print("\nüöÄ Training will run on GPU - Fast training expected!")
    HAS_GPU = True
else:
    device = torch.device('cpu')
    print("‚ùå WARNING: NO GPU DETECTED!")
    print(f"   Using device: {device}")
    HAS_GPU = False
    if IS_KAGGLE:
        print("\n‚ö†Ô∏è  GPU REQUIRED for fast training on Kaggle!")
        print("üìù TO FIX:")
        print("   1. Go to Settings (‚öôÔ∏è top right)")
        print("   2. Under 'Accelerator', select 'GPU T4 x2'")
        print("   3. Click 'Save'")
        print("   4. Run this cell again")
        print("\n‚ö†Ô∏è  Training on CPU will be VERY SLOW (hours instead of minutes)!")

print("=" * 60)

if IS_KAGGLE:
    if not HAS_INTERNET:
        print("\n" + "=" * 60)
        print("‚ö†Ô∏è  SETUP INCOMPLETE!")
        print("=" * 60)
        print("‚ùå Internet is required but not enabled")
        print("   Please enable Internet in Settings and run this cell again")
        print("=" * 60)
    elif not HAS_GPU:
        print("\n" + "=" * 60)
        print("‚ö†Ô∏è  GPU NOT ENABLED (but can continue)")
        print("=" * 60)
        print("‚ö†Ô∏è  Training will be VERY slow on CPU")
        print("   Consider enabling GPU in Settings for faster training")
        print("=" * 60)
    else:
        print("\n‚úÖ All checks passed! Ready to train.")
        print("=" * 60)


## 1. Load MNIST Dataset


In [None]:
import socket

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

IS_KAGGLE = os.environ.get('KAGGLE_KERNEL_RUN_TYPE', '') or os.path.exists('/kaggle/input')

os.makedirs('./data', exist_ok=True)

print("Loading MNIST dataset...")
print("=" * 60)

if IS_KAGGLE:
    print("üîç Kaggle environment detected")
    print("\nüì° Checking Internet connectivity...")
    try:
        socket.create_connection(("8.8.8.8", 53), timeout=3)
        print("‚úÖ Internet: Available - Can download MNIST")
        HAS_INTERNET = True
    except OSError:
        print("‚ùå Internet: NOT AVAILABLE")
        HAS_INTERNET = False
        print("\n" + "=" * 60)
        print("‚ùå ERROR: Internet is required to download MNIST dataset!")
        print("=" * 60)
        print("\nüìù TO FIX:")
        print("   1. Go to Settings (‚öôÔ∏è top right of notebook)")
        print("   2. Find 'Internet' section")
        print("   3. Toggle 'Internet' ON")
        print("   4. Click 'Save'")
        print("   5. Run this cell again")
        print("\nüí° Alternative: Use Kaggle dataset (more complex)")
        print("   - Click '+ Add data' (top right)")
        print("   - Search for 'digit-recognizer'")
        print("   - Add the dataset and modify code to use it")
        print("=" * 60)
        raise RuntimeError("Internet is required but not enabled. Please enable Internet in Kaggle Settings and run again.")
else:
    HAS_INTERNET = True

try:
    data_path = './data'
    train_images_path = os.path.join(data_path, 'MNIST', 'raw', 'train-images-idx3-ubyte')
    
    if os.path.exists(train_images_path):
        print("‚úÖ MNIST data found locally - no download needed")
        download = False
    else:
        if IS_KAGGLE:
            print("üì• Attempting to download MNIST dataset...")
            print("‚ö†Ô∏è  If this fails, enable 'Internet' in Kaggle Settings!")
        else:
            print("üì• Downloading MNIST dataset (first time only)...")
        download = True
    
    train_dataset = datasets.MNIST(
        root=data_path,
        train=True,
        download=download,
        transform=transform
    )
    print(f"‚úÖ Training dataset loaded: {len(train_dataset)} samples")
    
    test_dataset = datasets.MNIST(
        root=data_path,
        train=False,
        download=download,
        transform=transform
    )
    print(f"‚úÖ Test dataset loaded: {len(test_dataset)} samples")
    print("=" * 60)
    
except RuntimeError as e:
    error_msg = str(e)
    print("=" * 60)
    print("‚ùå ERROR: Failed to download/load MNIST dataset")
    print("=" * 60)
    
    if IS_KAGGLE or "Temporary failure in name resolution" in error_msg or "urlopen error" in error_msg:
        print("\nüìù KAGGLE SOLUTION:")
        print("   1. Go to Settings (‚öôÔ∏è top right of notebook)")
        print("   2. Find 'Internet' section")
        print("   3. Toggle 'Internet' ON")
        print("   4. Click 'Save'")
        print("   5. Run this cell again")
        print("\nüí° Alternatively:")
        print("   - Click '+ Add data' (top right)")
        print("   - Search for 'digit-recognizer'")
        print("   - Add the dataset")
        print("   - Then modify code to use that path")
    else:
        print("\nüìù GENERAL SOLUTION:")
        print("   1. Check your internet connection")
        print("   2. Try running this cell again")
        print("   3. If using Kaggle, enable 'Internet' in Settings")
    
    print("\n‚ö†Ô∏è  Error details:")
    print(error_msg)
    print("=" * 60)
    print("\n‚ùå Cannot continue without MNIST data. Please fix the issue above and run again.")
    raise
except Exception as e:
    print("=" * 60)
    print(f"‚ùå Unexpected error: {str(e)}")
    print("=" * 60)
    print("\nüìù Try:")
    print("   1. Enable 'Internet' in Kaggle Settings (if on Kaggle)")
    print("   2. Check your internet connection")
    print("   3. Run this cell again")
    print("=" * 60)
    raise

print(f"\n‚úÖ Dataset ready!")
print(f"   Training samples: {len(train_dataset)}")
print(f"   Test samples: {len(test_dataset)}")


## 2. DataLoader Setup


In [None]:
BATCH_SIZE = 64
NUM_WORKERS = 0

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True if torch.cuda.is_available() else False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True if torch.cuda.is_available() else False
)

print(f"Train batches per epoch: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")


## 3. Model Architecture


In [None]:
class MNISTCNN(nn.Module):
    def __init__(self):
        super(MNISTCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(in_features=7 * 7 * 64, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=10)
        self.dropout = nn.Dropout(p=0.5)
    
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        
        x = x.view(-1, 7 * 7 * 64)
        
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.fc2(x)
        return x

model = MNISTCNN().to(device)
print(f"Model moved to: {device}")

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")


## 4. Training Setup


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

NUM_EPOCHS = 10

train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []


## 5. Validation Function


In [None]:
def validate(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item()
    
    avg_loss = running_loss / len(test_loader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy


## 6. Training Loop


In [None]:
print("Starting training...")
print("=" * 60)

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        running_loss += loss.item()
        
        if (batch_idx + 1) % 200 == 0:
            print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{batch_idx+1}/{len(train_loader)}], '
                  f'Loss: {loss.item():.4f}')
    
    train_loss = running_loss / len(train_loader)
    train_accuracy = 100 * correct / total
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    
    val_loss, val_accuracy = validate(model, test_loader, criterion, device)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    
    print(f'\nEpoch [{epoch+1}/{NUM_EPOCHS}] Summary:')
    print(f'  Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%')
    print(f'  Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')
    print("=" * 60)

print("\nTraining completed!")
print(f"Final validation accuracy: {val_accuracies[-1]:.2f}%")
print(f"Best validation accuracy: {max(val_accuracies):.2f}%")


## 7. Plot Training History


In [None]:
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', marker='o', linewidth=2)
plt.plot(val_losses, label='Val Loss', marker='s', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy', marker='o', linewidth=2)
plt.plot(val_accuracies, label='Val Accuracy', marker='s', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nBest validation accuracy: {max(val_accuracies):.2f}%")
print(f"Final validation accuracy: {val_accuracies[-1]:.2f}%")


## 8. Save Model


In [None]:
import os

os.makedirs('./models', exist_ok=True)

model_path = './models/mnist_cnn_model.pth'
torch.save(model.state_dict(), model_path)
print(f"Model saved to: {model_path}")


## 9. Verify Model Loading (Optional)


In [None]:
loaded_model = MNISTCNN().to(device)
loaded_model.load_state_dict(torch.load(model_path))
loaded_model.eval()

sample_image, sample_label = test_dataset[0]
sample_image = sample_image.unsqueeze(0).to(device)

with torch.no_grad():
    output = loaded_model(sample_image)
    _, predicted = torch.max(output, 1)

print(f"True label: {sample_label}")
print(f"Predicted label: {predicted.item()}")
print("Model loading verified successfully!")

