In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchmetrics.functional as metrics

class ThreeLayerNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ThreeLayerNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, 300)
        self.fc2 = nn.Linear(300, 200)
        self.fc3 = nn.Linear(200, output_dim)
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=1)
        
        # random initialization for the weight
        nn.init.uniform_(self.fc1.weight,-1,1)
        nn.init.uniform_(self.fc2.weight,-1,1)
        nn.init.uniform_(self.fc3.weight,-1,1)

    def forward(self, x):
        x = self.sigmoid(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        x = self.softmax(self.fc3(x))
        return x

input_dim = 784 # 28*28 pixels in MNIST images
output_dim = 10 # 10 classes in MNIST
learning_rate = 0.1
batch_size = 32
epochs = 100

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

train_set = datasets.MNIST('mnist_data/', train=True, download=True, transform=transform)
test_set = datasets.MNIST('mnist_data/', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

model = ThreeLayerNet(input_dim, output_dim)       
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

learn_curve = []
err_curve = []

for epoch in range(epochs):
    err = 0
    train_acc = 0
    batch_num = 0
    for i, (images, labels) in enumerate(train_loader):
        images = images.view(-1, input_dim)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        acc = metrics.accuracy(outputs, labels, task ='multiclass', num_classes = 10)
        
        err += loss.item()
        train_acc += acc
        batch_num += 1
         
    print(f"Epoch {epoch+1}/{epochs}, train loss: {err/batch_num:.3f}, train accuracy: {train_acc/batch_num:.3f}")
    learn_curve.append(train_acc/batch_num)
    err_curve.append(err/batch_num)

model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.view(-1, input_dim)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Test Error: {} %'.format(100-100 * correct / total))

In [None]:
import matplotlib.pyplot as plt

plt.title('Accuracy vs Epoch')
plt.xlabel('Epoch') 
plt.ylabel('Accuracy')
plt.plot(learn_curve)

In [None]:
plt.title('Loss vs Epoch')
plt.xlabel('Epoch') 
plt.ylabel('Loss')
plt.plot(err_curve)