In [1]:
# load packages
import torch
import torchvision

In [8]:
# import MNIST data set
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Lambda
transform = Compose([
    ToTensor(),
    #Lambda(lambda image: image / 255),
    Lambda(lambda image: image.view(784))
])
data_train = MNIST(root="./",download=True,train=True,transform=transform)
data_test = MNIST(root="./",download=True,train=False,transform=transform)

In [9]:
data_train[0][0].shape

torch.Size([784])

In [16]:
# define model
from torch import nn, optim
class MNISTModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.layers = nn.Sequential(
        nn.Linear(784,512),
        nn.ReLU(),
        nn.Linear(512,512),
        nn.ReLU(),
        nn.Linear(512,10)
    )
    self.loss = nn.CrossEntropyLoss()
    self.optimizer = optim.Adam(self.parameters())
  def forward(self, X):
    return self.layers(X)

  def predict(self, X):
    with torch.no_grad():
      return torch.argmax(self.forward(X),axis=-1)

  def fit(self,X,Y):
    self.optimizer.zero_grad()
    y_pred = self.forward(X)
    loss = self.loss(y_pred,Y)
    loss.backward()
    self.optimizer.step()
    return loss.item()

In [17]:
# create model
mnist_model = MNISTModel()

In [14]:
from torch.utils.data import DataLoader
BATCH_SIZE = 16
dataloader_train = DataLoader(data_train,batch_size = BATCH_SIZE,shuffle=True)
dataloader_test = DataLoader(data_test,batch_size=BATCH_SIZE,shuffle=True)

In [18]:
# train model
from tqdm import tqdm
EPOCHS = 5
for i in range(EPOCHS):
  total_loss = 0
  for xs, ys in tqdm(dataloader_train,desc=f"Fitting EPOCH {i}"):
    total_loss += mnist_model.fit(xs,ys)
  total_loss /= len(dataloader_train)
  print(f"EPOCH {i}: {total_loss:.4f}")

Fitting EPOCH {i}: 100%|██████████| 3750/3750 [00:46<00:00, 81.26it/s]


EPOCH 0: 0.1942


Fitting EPOCH {i}: 100%|██████████| 3750/3750 [00:53<00:00, 69.75it/s]


EPOCH 1: 0.0875


Fitting EPOCH {i}: 100%|██████████| 3750/3750 [00:57<00:00, 65.41it/s]


EPOCH 2: 0.0639


Fitting EPOCH {i}: 100%|██████████| 3750/3750 [00:57<00:00, 64.97it/s]


EPOCH 3: 0.0490


Fitting EPOCH {i}: 100%|██████████| 3750/3750 [00:59<00:00, 62.98it/s]

EPOCH 4: 0.0394





In [20]:
correct = 0
for xs, ys in dataloader_test:
  y_pred = mnist_model.predict(xs)
  correct += (ys == y_pred).sum()
acc = correct / (len(dataloader_test) * BATCH_SIZE)
print(f"ACCURACY: {acc}")

ACCURACY: 0.9797000288963318
