# üß† Training an MNIST Classifier from Scratch

This project demonstrates how to implement and train a **fully connected neural network** from scratch using **NumPy** for the MNIST dataset. The code is structured to facilitate understanding of neural network basics, including forward propagation, backpropagation, and gradient descent.

---

## üìã Overview

- **Dataset**: MNIST (handwritten digits with 28x28 grayscale images)
- **Frameworks Used**:
  - **NumPy**: For mathematical operations (no high-level deep learning libraries)
  - **PyTorch**: For data loading and dataset utilities
  - **Matplotlib**: For visualizing training progress
- **Network Architecture**:
  - Input layer: 784 nodes (flattened 28x28 images)
  - Hidden layer 1: 256 nodes, ReLU activation
  - Hidden layer 2: 64 nodes, ReLU activation
  - Output layer: 10 nodes (one for each digit class, 0-9), softmax activation

---

## üõ†Ô∏è Code Walkthrough

### **1. Data Preparation**

```python
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    transforms.Lambda(lambda x: x.view(-1))
])

train_dataset = datasets.MNIST(root='dataset', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='dataset', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
```

- **Purpose**: Load and preprocess the MNIST dataset.
- **Transformations**:
  - Convert images to tensors.
  - Normalize pixel values to have zero mean and unit variance.
  - Flatten the 28x28 images into 1D vectors of size 784.
- **DataLoader**: Used for batching and shuffling.

---

### **2. Neural Network Implementation**

#### Initialization

```python
class Scratch:
    def __init__(self):
        self.w1 = np.random.randn(256, 784) * np.sqrt(2 / 784)
        self.b1 = np.zeros((1, 256))

        self.w2 = np.random.randn(64, 256) * np.sqrt(2 / 256)
        self.b2 = np.zeros((1, 64))

        self.w3 = np.random.randn(10, 64) * np.sqrt(2 / 64)
        self.b3 = np.zeros((1, 10))
```

- **Weights and Biases**:
  - Initialized with **He initialization** to ensure stable gradients with ReLU.
  - Biases are initialized to zeros.

#### Forward Pass

```python
def forward(self, x):
    self.x = x 
    self.z1 = x @ self.w1.T + self.b1
    self.a1 = self.relu(self.z1)

    self.z2 = self.a1 @ self.w2.T + self.b2
    self.a2 = self.relu(self.z2)

    self.z3 = self.a2 @ self.w3.T + self.b3
    return self.z3
```

- **Layer Outputs**:
  - `z1`, `z2`, `z3`: Linear transformations for each layer.
  - `a1`, `a2`: Activations after applying ReLU.
  - Final output (`z3`) is the logits (unscaled probabilities).

#### Activation Functions

```python
@staticmethod
def relu(z):
    return np.maximum(0, z)

@staticmethod
def relu_derivative(z):
    return (z > 0).astype(float)
```

- **ReLU**: Used as the activation function for hidden layers.
- **Softmax**: Used for converting logits into probabilities (defined separately).

#### Loss Function

```python
@staticmethod
def cross_entropy_loss(logits, labels):
    logits_stable = logits - np.max(logits, axis=1, keepdims=True)
    log_sum_exp = np.log(np.sum(np.exp(logits_stable), axis=1, keepdims=True))
    loss_per_sample = -np.sum(labels * (logits_stable - log_sum_exp), axis=1)
    return np.mean(loss_per_sample)
```

- **Purpose**: Compute cross-entropy loss for classification.
- **Stability**: Uses log-sum-exp trick to prevent numerical instability.

#### Backpropagation

```python
def backward(self, logits, labels, learning_rate=0.01):
    N = labels.shape[0]
    probs = self.softmax(logits)
    dL_dz3 = (probs - labels) / N

    dL_dw3 = dL_dz3.T @ self.a2
    dL_db3 = np.sum(dL_dz3, axis=0, keepdims=True)

    dL_da2 = dL_dz3 @ self.w3
    dL_dz2 = dL_da2 * self.relu_derivative(self.z2)

    dL_dw2 = dL_dz2.T @ self.a1
    dL_db2 = np.sum(dL_dz2, axis=0, keepdims=True)

    dL_da1 = dL_dz2 @ self.w2
    dL_dz1 = dL_da1 * self.relu_derivative(self.z1)

    dL_dw1 = dL_dz1.T @ self.x
    dL_db1 = np.sum(dL_dz1, axis=0, keepdims=True)

    self.w3 -= learning_rate * dL_dw3
    self.b3 -= learning_rate * dL_db3
    self.w2 -= learning_rate * dL_dw2
    self.b2 -= learning_rate * dL_db2
    self.w1 -= learning_rate * dL_dw1
    self.b1 -= learning_rate * dL_db1
```

- **Steps**:
  - Compute gradients for weights and biases using the chain rule.
  - Update weights and biases using gradient descent.

---

### **3. Training Loop**

```python
for epoch in range(EPOCHS):
    epoch_loss = 0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images = images.numpy()
        labels_onehot = np.zeros((labels.shape[0], 10))
        labels_onehot[np.arange(labels.shape[0]), labels.numpy()] = 1

        logits = model.forward(images)
        loss = model.cross_entropy_loss(logits, labels_onehot)
        model.backward(logits, labels_onehot, learning_rate=LEARNING_RATE)

        epoch_loss += loss * labels.shape[0]
        preds = np.argmax(logits, axis=1)
        correct += (preds == labels.numpy()).sum()
        total += labels.shape[0]

    avg_loss = epoch_loss / total
    acc = correct / total
    train_losses.append(avg_loss)
    train_accuracies.append(acc)
```

- **Steps**:
  - Perform forward pass, compute loss, and backpropagate for each batch.
  - Calculate epoch-level metrics (average loss and accuracy).

---

### **4. Visualization**

```python
clear_output(wait=True)
plt.figure(figsize=(10,4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label="Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label="Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Training Accuracy")
plt.legend()
plt.tight_layout()
plt.show()
```

- **Real-time Updates**:
  - Plots the training loss and accuracy after each epoch.
  - Loss and accuracy are displayed side-by-side for easy comparison.

---

## üéØ Results

- **Final Output**:
  - Training loss and accuracy are printed after each epoch.
  - Training progress is visualized with live plots.

---

## ‚ö†Ô∏è Notes & Recommendations

- **One-hot Encoding**: Labels are converted to one-hot vectors for cross-entropy.
- **Numerical Stability**: Log-sum-exp trick ensures stable loss computation.
- **Scalability**: This implementation works well for learning but is not optimized for large datasets.

---

Let me know if you'd like any further improvements or enhancements!