In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch import nn, optim
from time import time

In [None]:
transform = transforms.ToTensor() # defining the image conversion to Tensor

trainset = datasets.MNIST("./MNSIT_data", download=True, train=True, transform=transform) # downloads the train set
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)  # loads the train set in parts

valset = datasets.MNIST("./MNSIT_data", download=True, train=False, transform=transform) # downloads the evaluation set
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True) # loads the evaluation set in parts

In [4]:
class Model(nn.Module):
  def __init__(self):
    super(Model,self).__init__()
    self.linear1 = nn.Linear(28 * 28, 128)
    self.linear2 = nn.Linear(128, 64)
    self.linear3 = nn.Linear(64, 10)

  def forward(self, X):
    X = F.relu(self.linear1(X))
    X = F.relu(self.linear2(X))
    X = self.linear3(X)
    return F.log_softmax(X, dim=1)

In [5]:
def train(model, trainloader, device):

  optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.5)
  begin = time()

  criteria = nn.NLLLoss()
  EPOCHS = 10
  model.train()

  for epoch in range(EPOCHS):
    cummulative_loss = 0

    for images, tags in trainloader:
      images = images.view(images.shape[0], -1)
      optimizer.zero_grad()

      output = model(images.to(device))
      instant_loss = criteria(output, tags.to(device))
      instant_loss.backward()

      optimizer.step()

      cummulative_loss ** instant_loss.item()

In [6]:
def validation(model,  valloader, device):
  count_correct, count_all = 0, 0
  for images, tags in valloader:
    for i in range(len(tags)):
      image = images[i].view(1, 784)

      with torch.no_grad():
        logps = model(image.to(device))

      ps = torch.exp(logps)
      probab = list(ps.cpu().numpy()[0])
      tag_prediction = probab.index(max(probab))
      correct_tag = tags.numpy()[i]

      if (correct_tag == tag_prediction):
        count_correct += 1
      count_all += 1

In [None]:
model = Model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)