In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [3]:
# Hyperparams

input_size = 784
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 1

In [107]:
class NeuralNetwork(nn.Module):
	def __init__(self, input_size, num_classes):
		super(NeuralNetwork, self).__init__()
		self.flatten = nn.Flatten()
		self.fc1 = nn.Linear(input_size, 512)
		self.fc2 = nn.Linear(512, num_classes)

	def forward(self, x):
		x = F.relu(self.fc1(self.flatten(x)))
		logits = self.fc2(x)
		return logits

In [108]:
train_dataset = datasets.MNIST(
	root='data',
	train=True,
	download=True,
	transform=transforms.ToTensor(),
	target_transform=transforms.Lambda(lambda y: torch.zeros(num_classes, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
)

test_dataset = datasets.MNIST(
	root='data',
	train=False,
	download=True,
	transform=transforms.ToTensor(),
	target_transform=transforms.Lambda(lambda y: torch.zeros(num_classes, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
)

In [109]:
model = NeuralNetwork(input_size, num_classes)
print(model)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=10, bias=True)
)


In [110]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [111]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [112]:
def train(dataloader, model, loss, optimizer):
	size = len(dataloader.dataset)
	model.train()

	for batch, (X, y) in enumerate(dataloader):
		pred = model(X)
		loss = loss_fn(pred, y)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		if batch % 100 == 0:
			loss, current = loss.item(), batch * len(X)
			print(f"Loss: {loss}, [{current} / {size}]")

In [113]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y.argmax(1)).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [115]:
train(train_dataloader, model, loss_fn, optimizer)

Loss: 2.3167974948883057, [0 / 60000]
Loss: 0.28714001178741455, [6400 / 60000]
Loss: 0.21596494317054749, [12800 / 60000]
Loss: 0.2619986832141876, [19200 / 60000]
Loss: 0.17870305478572845, [25600 / 60000]
Loss: 0.3089596927165985, [32000 / 60000]
Loss: 0.117985300719738, [38400 / 60000]
Loss: 0.2736721932888031, [44800 / 60000]
Loss: 0.2863650619983673, [51200 / 60000]
Loss: 0.22124770283699036, [57600 / 60000]


In [116]:
test(test_dataloader, model, loss_fn)

Test Error: 
 Accuracy: 95.6%, Avg loss: 0.142630 

