In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

import matplotlib.pyplot as plt

training_data = datasets.MNIST(
root="data",
train=True,
download=True,
transform=ToTensor())

test_data = datasets.MNIST(
root="data",
train=False,
download=True,
transform=ToTensor())

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

In [3]:
class NeuralNetwork(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3,3), stride=(1,1),padding = 1)
    self.conv2 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=(3,3), stride=(1,1),padding = 1)
    self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), stride=(1,1),padding = 1) # increase the size. needs pulling to reduce the size: maxpull, average pull
    self.pool1 = nn.MaxPool2d(kernel_size=(2,2), stride=(2,2),ceil_mode=True)
    self.relu = nn.ReLU()

    self.fcn1 = nn.Linear(64*4*4, 32)
    self.fcn2 = nn.Linear(32, 10)

    self.flatten = nn.Flatten()
    self.sm = nn.Softmax()


  def forward(self,x):
    x = self.conv1(x)
    x = self.pool1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.pool1(x)
    x = self.relu(x)
    x = self.conv3(x)
    x = self.pool1(x)
    x = self.relu(x)

    x = self.fcn1(self.flatten(x))
    x = self.relu(x)
    x = self.fcn2(x)
    x = self.sm(x)
    return x



model = NeuralNetwork()

In [4]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

def train(dataloader, model, loss_fn, optimizer):
  size = len(dataloader.dataset)
  model.train()  # setup the model for training
  for batch, (X,y) in enumerate(dataloader):
    pred = model(X)
    loss = loss_fn(pred, y)
    loss.backward() # calculate the gradients
    optimizer.step() # step weights according to the rule
    optimizer.zero_grad() # delete all gradients data

    if batch % 100 == 0:
      loss, current = loss.item(), (batch + 1)*len(X)
      print(f"loss: {loss:>5f}, {current:>5f}/{size:>5f}")

def test(dataloader, model, loss_fn):
  size = len(dataloader.dataset)
  num_batches = len(dataloader)
  model.eval()  # setup the model for evaluating. No tracking calculations
  test_loss, correct = 0, 0
  with torch.no_grad(): #  No tracking calculations
    for X,y in dataloader:
      pred = model(X)
      test_loss += loss_fn(pred, y).item()
      correct += (pred.argmax(1) == y).type(torch.float).sum().item()

  test_loss /= num_batches
  correct /= size
  print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [5]:
epochs = 5
for t in range(epochs):
  print(f"Epoch {t+1}\n", "-"*20)
  train(train_dataloader, model, loss_fn, optimizer)
  test(test_dataloader, model, loss_fn)
  print("Done!")

Epoch 1
 --------------------
loss: 2.301839, 64.000000/60000.000000


  return self._call_impl(*args, **kwargs)


loss: 2.303567, 6464.000000/60000.000000
loss: 2.302335, 12864.000000/60000.000000
loss: 2.301841, 19264.000000/60000.000000
loss: 2.303193, 25664.000000/60000.000000
loss: 2.301507, 32064.000000/60000.000000
loss: 2.302319, 38464.000000/60000.000000
loss: 2.302873, 44864.000000/60000.000000
loss: 2.302573, 51264.000000/60000.000000
loss: 2.300181, 57664.000000/60000.000000
Test Error: 
 Accuracy: 10.1%, Avg loss: 2.302179 

Done!
Epoch 2
 --------------------
loss: 2.302282, 64.000000/60000.000000
loss: 2.303320, 6464.000000/60000.000000
loss: 2.301832, 12864.000000/60000.000000
loss: 2.302629, 19264.000000/60000.000000
loss: 2.301218, 25664.000000/60000.000000
loss: 2.302009, 32064.000000/60000.000000
loss: 2.303495, 38464.000000/60000.000000
loss: 2.301011, 44864.000000/60000.000000
loss: 2.302837, 51264.000000/60000.000000
loss: 2.301702, 57664.000000/60000.000000
Test Error: 
 Accuracy: 10.1%, Avg loss: 2.302044 

Done!
Epoch 3
 --------------------
loss: 2.300708, 64.000000/60000