In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


n_epochs = 10
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5


pre_train_dataset = datasets.MNIST(
    root = './data',
    train = True,
    download = True,
    transform = None
)

scaled_data = pre_train_dataset.data.float() / 255.0
data_mean = scaled_data.mean()
data_std = scaled_data.std()
print(f"Dynamically Calculated Mean: {data_mean.item():.4f}")
print(f"Dynamically Calculated Std: {data_std.item():.4f}")


transform_with_norm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((data_mean,), (data_std,))
])

train_dataset = datasets.MNIST(
    root = './data',
    train = True,
    download = False,
    transform = transform_with_norm
)

test_dataset = datasets.MNIST(
    root = './data',
    train = False,
    download = False,
    transform = transform_with_norm
)

train_loader = DataLoader(
    train_dataset,
    batch_size = batch_size_train,
    shuffle = True
)

test_loader = DataLoader(
    test_dataset,
    batch_size = batch_size_test,
    shuffle = True
)

In [None]:
import matplotlib.pyplot as plt

# [batch_size, channels, height, width]
images, labels = next(iter(train_loader))
print(type(images))
print(type(labels))
print(images.shape)
print(labels.shape)

plt.imshow(images[0].squeeze(), cmap = 'gray')
plt.title(f"Label: {labels[0].item()}")
plt.show()

In [None]:
class DNNModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(in_features = 784, out_features = 256)
        self.fc2 = nn.Linear(in_features = 256, out_features = 128)
        self.fc3 = nn.Linear(in_features = 128, out_features = 10)

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))

        return F.log_softmax(x, dim = 1)

In [None]:
model = DNNModel()
optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = momentum)
criterion = nn.NLLLoss()

def train():
    model.train()
    
    for data, label in train_loader:
        optimizer.zero_grad()

        output = model(data)
        loss = criterion(output, label)

        loss.backward()
        optimizer.step()

def test():
    model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            
            test_loss += criterion(output, target).item()
            
            pred = output.argmax(dim = 1, keepdim = True)
            correct += pred.eq(target.view_as(pred)).sum().item()
        
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)

    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.2f}%)\n')


for _ in range(n_epochs):
    train()
    test()