In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Custom Loss Function (if needed, otherwise use CrossEntropyLoss)
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()
        self.mse = nn.MSELoss()
        self.mae = nn.L1Loss()

    def forward(self, outputs, targets):
        mse_loss = self.mse(outputs, targets)
        mae_loss = self.mae(outputs, targets)
        return mse_loss + mae_loss

# Simple Neural Network Model
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)  # Output layer should match the number of classes

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 2

# Data loader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = dsets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Initialize model, loss function, and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()  # Use CrossEntropyLoss for classification
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

print("Training complete.")

Epoch [1/2], Step [100/938], Loss: 0.5640
Epoch [1/2], Step [200/938], Loss: 0.4161
Epoch [1/2], Step [300/938], Loss: 0.2999
Epoch [1/2], Step [400/938], Loss: 0.2786
Epoch [1/2], Step [500/938], Loss: 0.4068
Epoch [1/2], Step [600/938], Loss: 0.2360
Epoch [1/2], Step [700/938], Loss: 0.2490
Epoch [1/2], Step [800/938], Loss: 0.2735
Epoch [1/2], Step [900/938], Loss: 0.3916
Epoch [2/2], Step [100/938], Loss: 0.2188
Epoch [2/2], Step [200/938], Loss: 0.1227
Epoch [2/2], Step [300/938], Loss: 0.1359
Epoch [2/2], Step [400/938], Loss: 0.1240
Epoch [2/2], Step [500/938], Loss: 0.1553
Epoch [2/2], Step [600/938], Loss: 0.1360
Epoch [2/2], Step [700/938], Loss: 0.1181
Epoch [2/2], Step [800/938], Loss: 0.0737
Epoch [2/2], Step [900/938], Loss: 0.1546
Training complete.
