In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

# Define a simple feedforward neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 128) # first layer with 784 input features (28x28 image flattened) and 128 output features
        self.fc2 = nn.Linear(128, 10) # second layer with 128 input features and 10 output features (for 10 classes)

    def forward(self, x):
        x = F.relu(self.fc1(x)) # apply ReLU activation function on the output of the first layer
        x = self.fc2(x) # apply the second layer
        x = F.log_softmax(x, dim=1) # apply log softmax activation function on the output of the second layer
        return x

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # define a transform to convert the images to tensors and normalize them
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) # root is the directory where the dataset will be stored
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) # create a DataLoader to load the training data in batches of 64
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) # load the test dataset
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) # create a DataLoader to load the test data in batches of 64

# Initialize the model, loss function, and optimizer
model = SimpleNN() # create an instance of the SimpleNN class
criterion = nn.CrossEntropyLoss() # define the loss function (cross-entropy loss for multi-class classification)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # define the optimizer (Stochastic Gradient Descent with learning rate 0.01 and momentum 0.9)
# Training loop
for epoch in range(2): # train for 2 epochs
    running_loss = 0.0 # initialize the running loss
    for i, data in enumerate(trainloader): # iterate over the training data
        inputs, labels = data # get the inputs and labels from the DataLoader
        optimizer.zero_grad() # zero the gradients of the optimizer
        outputs = model(inputs.view(-1, 784)) # flatten the input images and pass them through the model
        loss = criterion(outputs, labels) # calculate the loss
        loss.backward() # backpropagate the loss
        optimizer.step() # update the model parameters
        running_loss += loss.item() # accumulate the running loss
        if i % 100 == 99: # print every 100 mini-batches
            print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}") # print the average loss for the last 100 mini-batches
            running_loss = 0.0 # reset the running loss 

# Test the model
correct = 0 # initialize the number of correct predictions
total = 0 # initialize the total number of predictions
with torch.no_grad(): # disable gradient calculation for testing
    for data in testloader: # iterate over the test data
        images, labels = data # get the images and labels from the DataLoader
        outputs = model(images.view(-1, 784)) # flatten the input images and pass them through the model
        _, predicted = torch.max(outputs.data, 1) # get the predicted class with the highest score. The second argument (1) specifies the dimension along which to find the maximum value. The output is a tuple containing the maximum value and the index of the maximum value. We only need the index, which is stored in the variable predicted.
        total += labels.size(0) # accumulate the total number of predictions
        correct += (predicted == labels).sum().item() # accumulate the number of correct predictions. The .item() method converts the tensor to a Python number 
# Print the accuracy
print(f"Accuracy of the network on the {total} test images: {100 * correct / total:.2f}%") # print the accuracy of the model on the test dataset