In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torchvision import transforms
from torch.utils.data import DataLoader

In [9]:
# Data preparation
batch_size = 16

train_data = datasets.MNIST('./datasets',train=True,download=True,transform=transforms.ToTensor())
test_data = datasets.MNIST('./datasets',train=False,download=True,transform=transforms.ToTensor())

train_loader = DataLoader(train_data, batch_size = batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size = batch_size)

In [13]:
class MLP(nn.Module):
    def __init__(self,hidden_units=[512,256,128,64]):
        super().__init__()
        
        self.in_dim = 28*28
        self.out_dim = 10
        
        self.fc_layer = []
        self.fc_layer.append(nn.Linear(self.in_dim,hidden_units[0]))
        for idx in range(len(hidden_units)-1):
            self.fc_layer.append(nn.Linear(hidden_units[idx],hidden_units[idx+1]))
        self.fc_layer.append(nn.Linear(hidden_units[-1],self.out_dim))
        
        self.fc_layer = nn.ModuleList(self.fc_layer)
        
        self.relu = nn.ReLU()
        
    def forward(self,x):
        a = x.view(-1,self.in_dim)
        for i in range(len(self.fc_layer)):
            z = self.fc_layer[i](a)
            if i != len(self.fc_layer) -1 :
                a = self.relu(z)
            else :
                out = z
        return out

In [17]:
model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr = 0.01)

In [18]:
for epoch in range(10):
    running_loss = 0.0
    for i,data in enumerate(train_loader):
        inputs, labels = data
        
        optimizer.zero_grad()
        
        yhat = model(inputs)
        
        loss = criterion(yhat,labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if (i+1) % 2000 == 0:
            print("[%d,%5d] loss : %.5f"%(epoch+1,i+1,running_loss/2000))
            running_loss = 0.0
            
print("Finished Training")

[1, 2000] loss : 2.18490
[2, 2000] loss : 0.41724
[3, 2000] loss : 0.20887
[4, 2000] loss : 0.13515
[5, 2000] loss : 0.10067
[6, 2000] loss : 0.07849
[7, 2000] loss : 0.06078
[8, 2000] loss : 0.04849
[9, 2000] loss : 0.03829
[10, 2000] loss : 0.03163
Finished Training
