In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 1️⃣ Load Data (Example: MNIST Handwritten Digits)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)

# 2️⃣ Define a Simple Neural Network Model
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)  # Input layer (28x28 pixels → 128 neurons)
        self.fc2 = nn.Linear(128, 64)     # Hidden layer (128 neurons → 64 neurons)
        self.fc3 = nn.Linear(64, 10)      # Output layer (64 neurons → 10 classes)

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten input image (28x28 → 784)
        x = F.relu(self.fc1(x))  # Apply ReLU activation
        x = F.relu(self.fc2(x))
        x = self.fc3(x)  # No activation (raw logits)
        return x