In [78]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [79]:
class FCNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, device):
        super(FCNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        # x = x.to(device)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [80]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

input_size = 28*28
hidden_size = 128
num_classes = 10
learning_rate = 1e-3
batch_size = 64
num_epoches = 10

In [81]:
model = FCNN(input_size, hidden_size, num_classes, device)
model.to(device)
model

FCNN(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [82]:
train_dataset = datasets.MNIST(root='dataset/mnist/',train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = datasets.MNIST(root='dataset/mnist/',train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [83]:
criterion = nn.CrossEntropyLoss()
optimizer= optim.Adam(model.parameters(), lr=learning_rate)

In [84]:
for epoch in range(num_epoches):
    losses = []
    for idx, (data, label) in enumerate(train_loader):
        data = data.to(device)
        label = label.to(device)
        
        data = data.reshape(data.shape[0], -1)
        # [64, 1, 28, 28] -> [64, 784]
        
        pred = model(data)
        loss = criterion(pred, label)
        losses.append(loss)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"epoch {epoch+1}: loss = {sum(losses)/len(losses)}")

epoch 1: loss = 0.34918728470802307
epoch 2: loss = 0.1582726538181305
epoch 3: loss = 0.11072656512260437
epoch 4: loss = 0.08404892683029175
epoch 5: loss = 0.06704749166965485
epoch 6: loss = 0.05299541726708412
epoch 7: loss = 0.04330743849277496
epoch 8: loss = 0.036429885774850845
epoch 9: loss = 0.02941669523715973
epoch 10: loss = 0.023885628208518028


In [86]:
def check_acc(loader, model):
    model.eval()
    total = 0
    num_wrong = 0

    for idx, (data, label) in enumerate(loader):
        total += data.shape[0]
        data = data.to(device)
        label = label.to(device)
        
        data = data.reshape(data.shape[0], -1)
        # [64, 1, 28, 28] -> [64, 784]
        
        pred = model(data).argmax(1)
        num_wrong += (abs(pred-label)).clamp(0,1).sum()
    
    print(f"Accuracy on  test set : {(1-(num_wrong/total))*100:.2f}%")

In [87]:
check_acc(test_loader, model)

Accuracy on  test set : 97.90%
