In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch.nn as N
import matplotlib.pyplot as plt

In [None]:
import torch
torch.manual_seed(42)

batch_size = 32
epochs = 5
learning_rate = 1e-3

In [None]:
training_data = datasets.MNIST(
    root= '../pytorch-basics/datasets/mnist/train',
    train= True,
    download= True,
    transform= ToTensor()
)

test_data = datasets.MNIST(
    root= '../pytorch-basics/datasets/mnist/test',
    train= False,
    download= True,
    transform= ToTensor()
)

In [None]:
train_dataloader = DataLoader(training_data, batch_size= batch_size, shuffle= True)
test_dataloader = DataLoader(test_data, batch_size= batch_size, shuffle= True) 

In [None]:
model = N.Sequential(
            N.Conv2d(in_channels= 1, out_channels= 16, kernel_size= 3, stride= 1, padding= 1),
            N.Tanh(),
            N.MaxPool2d(kernel_size= 2, stride= 2),

            N.Conv2d(in_channels= 16, out_channels= 32, kernel_size= 3, stride= 1, padding= 1),
            N.Tanh(),
            N.MaxPool2d(kernel_size= 2, stride= 2),

            N.Flatten(),

            N.Linear(in_features= 32*7*7, out_features= 500),
            N.Tanh(),
            
            N.Linear(in_features= 500, out_features= 100),
            N.Tanh(),
            
            N.Linear(in_features= 100, out_features= 50),
            N.Tanh(),
            
            N.Linear(in_features= 50, out_features= 10),
)

In [None]:
from torchinfo import summary

summary(model, input_size= (batch_size, 1, 28, 28), col_names=['input_size', 'output_size', 'num_params', 'trainable'], 

        row_settings=['var_names'], verbose=0)

In [None]:
from torchmetrics import Accuracy

loss_fn = N.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), learning_rate)
accuracy_fn = Accuracy(task= 'multiclass', num_classes= 10).cuda()

In [None]:
for _ in range(epochs):
    train_loss = 0
    print(f"\nStarting epoch: {_ + 1}\n~~~~~~~~~~~")
    for index, (input, target) in enumerate(train_dataloader):
        input, target = input.cuda(), target.cuda()
        model.train()
        output = model(input)

        loss = loss_fn(output, target)
        train_loss += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if(index % 500 == 0):
            print(f"Batch: {index} Loss:{loss:0.5f}")
    
    train_loss /= len(train_dataloader)
    test_loss, test_acc = 0, 0
    model.eval()

    with torch.inference_mode():
        for test_input, test_target in test_dataloader:
            test_input, test_target = test_input.cuda(), test_target.cuda()

            test_pred = model(test_input)
            test_loss += loss_fn(test_pred, test_target)
            test_acc += accuracy_fn(test_target, test_pred.argmax(dim= 1))
        
        test_loss /= len(test_dataloader)
        test_acc /= len(test_dataloader)

    print(f"Train loss: {train_loss:0.5f} Test loss: {test_loss:0.5f} Test acc: {test_acc*100:0.2f}%")

In [None]:
import itertools

sample_index =  torch.randint(0, len(test_dataloader), size= (1,))
sample_input, sample_target = next(itertools.islice(test_dataloader, sample_index, None))
sample_input, sample_target = sample_input.cuda(), sample_target.cuda()
sample_output = model(sample_input)

print(f"Target: {sample_target} \nOutput: {sample_output.argmax(dim= 1)}")

In [None]:
torch.save(model, '../pytorch-basics/models/lenet-5.pt')