In [1]:
# Import, device config and dataset - MNIST
# Data loader, transformation
# NN, activation -> class, __init__
# Loss and optimizer
# Training loop
# Model evaluate

import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as func
from tqdm import tqdm

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data - MNIST
train_dataset = torchvision.datasets.MNIST('', train = True, transform = transforms.ToTensor(), download = True)
test_dataset = torchvision.datasets.MNIST('', train = False, transform = transforms.ToTensor())

In [3]:
# Data loader
train = torch.utils.data.DataLoader(train_dataset, batch_size = 10, shuffle = True)
test = torch.utils.data.DataLoader(test_dataset, batch_size = 10, shuffle = False)

input_size = 28*28
no_classes = 10
no_epochs = 1
learning_rate = 0.001

In [4]:
# NN
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, no_classes)

    def forward(self, x):
        x = func.relu(self.fc1(x))
        x = func.relu(self.fc2(x))
        x = func.relu(self.fc3(x))
        x = self.fc4(x)
        return x

model = ConvNet().to(device)

# Loss and Optimizer
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate) 

In [5]:
# Training loop
for epoch in range(no_epochs):
    for data in tqdm(train):
        # Resizing
        image, label = data
        image = image.view(-1, input_size).to(device)
        label = label.to(device)

        # Forward
        output = model(image)
        loss = loss_func(output, label)

        # Backward and optimizer
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        


100%|██████████| 6000/6000 [01:41<00:00, 59.14it/s]


In [6]:
# Test/Model evaluate
with torch.no_grad():
    correct = 0
    total = 0
    for data in tqdm(test):
        image, label  = data
        image = image.view(-1, input_size).to(device)
        label = label.to(device)
        output = model(image)

        for idx, i in enumerate(output):
            if torch.argmax(i) == label[idx]:
                correct += 1
            total += 1

print("Accuracy: ", round(correct/total, 3))

100%|██████████| 1000/1000 [00:11<00:00, 89.20it/s]

Accuracy:  0.949



