# 🚀 Interactive Mode Available!

Typical static notebooks are boring. We have a dedicated interactive module for this week.

[👉 **Click here to open the Interactive Visualization**](../../interactive_platform/modules/week1_pytorch/interactive.html)

*(Note: Open this link in a new tab to keep the notebook running)*

# 🚀 Interactive Mode Available!

Typical static notebooks are boring. We have a dedicated interactive module for this week.

[👉 **Click here to open the Interactive Visualization**](../../interactive_platform/modules/week1_pytorch/interactive.html)

*(Note: Open this link in a new tab to keep the notebook running)*

# Week 1: PyTorch Fundamentals & MLP from Scratch

Welcome to Week 1! In this notebook, you will build a Multi-Layer Perceptron (MLP) to classify handwritten digits (MNIST) **without using `nn.Linear`**. 

You will implement the linear layers manually using `nn.Parameter` to understand how weights and biases work under the hood.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Manual Linear Layer

Implement a linear layer `y = xW^T + b` using `nn.Parameter`.

In [None]:
class ManualLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        # Initialize weights with Xavier/Glorot initialization
        # W ~ N(0, 2/fan_in)
        self.weight = nn.Parameter(
            torch.randn(out_features, in_features) * (2.0 / in_features) ** 0.5
        )
        self.bias = nn.Parameter(torch.zeros(out_features))
    
    def forward(self, x):
        # TODO: Implement the forward pass
        # Hint: use @ for matrix multiplication
        return x @ self.weight.T + self.bias

## 2. Verify Your Layer

Let's compare your layer against PyTorch's built-in `nn.Linear`.

In [None]:
x = torch.randn(128, 20)
my_layer = ManualLinear(20, 10)
pytorch_layer = nn.Linear(20, 10)

print("My output shape:", my_layer(x).shape)
print("PyTorch output shape:", pytorch_layer(x).shape)

assert my_layer(x).shape == (128, 10), "Shapes don't match!"

## 3. Build the MLP

Create a 3-layer MLP: 784 -> 256 -> 128 -> 10.

In [None]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = ManualLinear(784, 256)
        self.layer2 = ManualLinear(256, 128)
        self.layer3 = ManualLinear(128, 10)
    
    def forward(self, x):
        # Flatten the input image
        x = x.view(x.size(0), -1)
        
        # Pass through layers with ReLU activation
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.layer2(x))
        x = self.layer3(x)
        return x

## 4. Train on MNIST

Now we load the data and train!

In [None]:
# Data Loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Model Setup
model = MLP().to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Model ready for training!")

In [None]:
def train(epoch):
    model.train()
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        # Stats
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}]\tLoss: {loss.item():.6f}')
            
    print(f'Epoch {epoch} Accuracy: {100. * correct / total:.2f}%')

# Train for 1 epoch to verify
train(1)