In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets

from torchvision.transforms import v2
from tqdm import tqdm
from torch.utils.data import DataLoader
from PIL import Image

In [60]:
dataset = datasets.MNIST(root="./mnist_dataset/",download=True,transform=
                         v2.Compose([
                             v2.ToImage(), 
                             v2.ToDtype(torch.float32, scale=True),
                             v2.Lambda(lambda x: x.squeeze(0).view(784))
                         ]))
train_set, val_set = torch.utils.data.random_split(dataset, [55000, 5000])

In [61]:
class DigitNN(nn.Module):
    def __init__(self,input_size,size_hidd_1,size_hidd_2,output_size):
        super(DigitNN,self).__init__()
        self.lin1 = nn.Linear(in_features=input_size,out_features=size_hidd_1,bias=True)
        self.lin2 = nn.Linear(in_features=size_hidd_1,out_features=size_hidd_2,bias=True)
        self.lin3 = nn.Linear(in_features=size_hidd_2,out_features=output_size,bias=True)

    def forward(self,x):
        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)
        x = F.relu(x)
        x = self.lin3(x)
        return x
        
        

In [62]:
model = DigitNN(28*28,256,64,10)

In [63]:
optimizer = optim.Adam(params=model.parameters(),lr=0.01) 

In [64]:
loss_function = nn.CrossEntropyLoss()

In [65]:
loader = DataLoader(train_set,batch_size=32,shuffle=True)

In [66]:
epochs = 2 
model.train()
for epoch in range(epochs):
    tqdm_loader = tqdm(loader,leave=True)
    loss_mean = 0
    lm_count = 0
    for x_train, y_train in tqdm_loader:
        predict = model(x_train)
        loss = loss_function(predict,y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lm_count += 1
        loss_mean = 1/lm_count * loss.item() + (1 - 1/lm_count) * loss_mean
        tqdm_loader.set_description(f"Epoch [{epoch+1}/{epochs}], loss_mean={loss_mean:.3f}")
        

Epoch [1/2], loss_mean=0.280: 100%|█████████| 1719/1719 [00:17<00:00, 99.17it/s]
Epoch [2/2], loss_mean=0.176: 100%|█████████| 1719/1719 [00:18<00:00, 94.13it/s]


In [67]:
Q = 0

model.eval()

test_loader = DataLoader(val_set, batch_size=500, shuffle=False)
for x_test, y_test in test_loader:
    with torch.no_grad():
        p = model(x_test)
        p = torch.argmax(p, dim=1)
        print(p[0],y_test[0])
        Q += torch.sum(p == y_test).item()
 
Q = Q/len(val_set)
print(Q)

0.9536
