In [131]:
import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch import nn

## Hyperparameters

In [132]:
batch_size=64
learning_rate = 0.1

## Download Datasets

In [133]:
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('../data', train=False, download=True, transform=transform)

## Load Data

In [134]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, shuffle = True)

## Model

In [135]:
class CNN(nn.Module):

	def __init__(self):
		super().__init__()
		self.conv1 = nn.Conv2d(1, 6, 5)
		self.pool = nn.MaxPool2d(2, 2)
		self.conv2 = nn.Conv2d(6, 16, 5)

		self.fc1 = nn.Linear(16*4*4, 120)
		self.fc2 = nn.Linear(120, 84)
		self.fc3 = nn.Linear(84, 10)

	def forward(self, x):
		x = self.pool(F.relu(self.conv1(x)))
		x = self.pool(F.relu(self.conv2(x)))
		x = torch.flatten(x, 1)
		x = F.relu(self.fc1(x))
		x = F.relu(self.fc2(x))
		x = self.fc3(x)
		return x

In [136]:
model = CNN()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [137]:
def train(train_loader, model, loss_fn, optmizer):
	size = len(train_loader.dataset)
	batches_l = len(train_loader)
	loss = 0
	correct = 0

	for batch_idx, (data, target) in enumerate(train_loader):
		pred = model(data)
		loss = loss_fn(pred, target)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		if batch_idx % 100 == 0:
			print(f'loss: {loss}')

In [138]:
def test(test_loader, model, loss_fn):
	size = len(test_loader.dataset)
	loss = 0
	correct_n = 0
	correct = 0

	for batch_idx, (data, target) in enumerate(test_loader):
		pred = model(data)
		argmax = pred.argmax(dim=1, keepdim=True)
		correct_n = argmax.eq(target.view_as(argmax)).sum().item()
		correct += correct_n
	return correct / size


In [139]:
train(train_loader, model, loss_fn, optimizer)

loss: 2.309152841567993
loss: 1.7227963209152222
loss: 0.43258756399154663
loss: 0.30362626910209656
loss: 0.26352250576019287
loss: 0.18146449327468872
loss: 0.18805311620235443
loss: 0.09103556722402573
loss: 0.0839305967092514
loss: 0.1462973654270172


In [140]:
print(f'accuracy: {test(test_loader, model, loss_fn)}')

accuracy: 0.9722
