In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


n_epochs = 10
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 100


pre_train_dataset = datasets.MNIST(
    root = './data',
    train = True,
    download = True,
    transform = None
)

scaled_data = pre_train_dataset.data.float() / 255.0
data_mean = scaled_data.mean()
data_std = scaled_data.std()
print(f"Dynamically Calculated Mean: {data_mean.item():.4f}")
print(f"Dynamically Calculated Std: {data_std.item():.4f}")


transform_with_norm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((data_mean,), (data_std,))
])

train_dataset = datasets.MNIST(
    root = './data',
    train = True,
    download = False,
    transform = transform_with_norm
)

test_dataset = datasets.MNIST(
    root = './data',
    train = False,
    download = False,
    transform = transform_with_norm
)

train_loader = DataLoader(
    train_dataset,
    batch_size = batch_size_train,
    shuffle = True
)

test_loader = DataLoader(
    test_dataset,
    batch_size = batch_size_test,
    shuffle = True
)


class GNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 10, kernel_size = 5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size = 5)
        
        self.conv2_drop = nn.Dropout2d()
        
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        # x initial shape: [batch_size, 1, 28, 28]
        # Conv1 -> ReLU -> MaxPool2d (2x2)
        # Shape becomes: [batch_size, 10, 12, 12]
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        
        # Conv2 -> Dropout -> ReLU -> MaxPool2d (2x2)
        # Shape becomes: [batch_size, 20, 4, 4]
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        
        # Flatten the feature map for the fully connected layers
        # Shape becomes: [batch_size, 320]
        x = x.view(-1, 320)
        
        # FC1 -> ReLU
        # Shape becomes: [batch_size, 50]
        x = F.relu(self.fc1(x))
        
        # Dropout
        x = F.dropout(x, training = self.training)
        
        # FC2
        # Shape becomes: [batch_size, 10]
        x = self.fc2(x)
        
        # Return log_softmax for use with NLLLoss
        return F.log_softmax(x, dim = 1)


model = GNNModel().to(device)
optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = momentum)
criterion = nn.NLLLoss()


def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        output = model(data)
        loss = criterion(output, target)
        
        loss.backward()
        optimizer.step()
        
        if batch_idx % log_interval == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')


def test():
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            
            output = model(data)
            
            test_loss += criterion(output, target).item()
            
            pred = output.argmax(dim = 1, keepdim = True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.2f}%)\n')


if __name__ == '__main__':
    print("Initial test:")
    test()
    
    for epoch in range(1, n_epochs + 1):
        train(epoch)
        test()
