In [1]:
import os
import numpy as np

import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms, datasets

In [2]:
lr = 1e-3
batch_size =64
num_epoch = 10

In [3]:
ckpt_dir = "./cehckpoint"
log_dir   = "./log"

In [4]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")

In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels=10, kernel_size=5, stride=1, padding=0, bias=True)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.relu1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(in_channels = 10, out_channels=20, kernel_size=5, stride=1, padding=0, bias=True)
        self.drop2 = nn.Dropout2d(p=0.5)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.relu2 = nn.ReLU()
        
        self.fc1 = nn.Linear(in_features = 320,  out_features=50, bias=True)
        self.relu1_fc1 = nn.ReLU()
        self.drop1_fc1 = nn.Dropout2d(p=0.5)
        
        self.fc2 = nn.Linear(in_features=50, out_features=10, bias=True)
    
    def forward(self,x):
        
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.relu1(x)
        
        x = self.conv2(x)
        x = self.drop2(x)
        x = self.pool2(x)
        x = self.relu2(x)
        
        x = x.view(-1, 320)
        
        x = self.fc1(x)
        x = self.relu1_fc1(x)
        x = self.drop1_fc1(x)
        
        x = self.fc2(x)
        
        return x

In [6]:
def save(ckpt_dir, net, optim, epoch):
    if not os.path.exists(ckpt_dir):
        os.mkdir(ckpt_dir)
        
        torch.save({"net":net.state_dict(), "optim":optim.state_dict()}, f"./{ckpt_dir}/epoch{epoch}.pth")

In [7]:
def load(ckpt_dir, net, optim):
    ckpt_lst = os.listdir(ckpt_dir)
    ckpt_lst.sort()
    
    dict_model = torch.load(f"./{ckpt_dir}/{ckpt_lst[-1]}")
    
    net.load_state_dict(dict_model['net'])
    optim.load_state_dict(dict_model['optim'])
    
    return net, optim

In [8]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))])
dataset = datasets.MNIST(download=True, root="./", train=False, transform=transform)

In [9]:
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [10]:
num_data = len(loader.dataset)
num_batch = int(np.ceil(num_data/batch_size))

In [11]:
net = Net().to(device)
params = net.parameters()

fn_loss  = nn.CrossEntropyLoss().to(device)
fn_pred = lambda output: torch.softmax(output, dim=1)
fn_acc = lambda pred, label: ((pred.max(dim=1)[1] == label).type(torch.float)).mean()

optim = torch.optim.Adam(params, lr=lr)

writer = SummaryWriter(log_dir=log_dir)

net, optim = load(ckpt_dir=ckpt_dir, net=net, optim=optim)

In [12]:
with torch.no_grad():
    net.eval()
    
    loss_arr = []
    acc_arr  = []
    
    for batch, (inputs,label) in enumerate(loader, 1):
        inputs = inputs.to(device)
        label  = label.to(device)
        output = net(inputs)
        pred   = fn_pred(output)

        loss= fn_loss(output, label)
        acc = fn_acc(pred, label)
        
        optim.step()
        
        
        loss_arr += [loss.item()]
        acc_arr  += [acc.item()]
        
        print(f"Test | BATCH {batch:04d} / {num_batch:04d}| LOSS {np.mean(loss_arr):.4f} ACC {np.mean(acc_arr):.4f}")
    

Test | BATCH 0001 / 0157| LOSS 0.0223 ACC 1.0000
Test | BATCH 0002 / 0157| LOSS 0.0488 ACC 0.9844
Test | BATCH 0003 / 0157| LOSS 0.0439 ACC 0.9896
Test | BATCH 0004 / 0157| LOSS 0.0555 ACC 0.9844
Test | BATCH 0005 / 0157| LOSS 0.0756 ACC 0.9750
Test | BATCH 0006 / 0157| LOSS 0.0750 ACC 0.9740
Test | BATCH 0007 / 0157| LOSS 0.0739 ACC 0.9754
Test | BATCH 0008 / 0157| LOSS 0.0942 ACC 0.9688
Test | BATCH 0009 / 0157| LOSS 0.0899 ACC 0.9722
Test | BATCH 0010 / 0157| LOSS 0.0938 ACC 0.9703
Test | BATCH 0011 / 0157| LOSS 0.0947 ACC 0.9688
Test | BATCH 0012 / 0157| LOSS 0.0976 ACC 0.9661
Test | BATCH 0013 / 0157| LOSS 0.0926 ACC 0.9688
Test | BATCH 0014 / 0157| LOSS 0.0907 ACC 0.9699
Test | BATCH 0015 / 0157| LOSS 0.1025 ACC 0.9646
Test | BATCH 0016 / 0157| LOSS 0.1030 ACC 0.9648
Test | BATCH 0017 / 0157| LOSS 0.1020 ACC 0.9651
Test | BATCH 0018 / 0157| LOSS 0.1058 ACC 0.9644
Test | BATCH 0019 / 0157| LOSS 0.1056 ACC 0.9646
Test | BATCH 0020 / 0157| LOSS 0.1160 ACC 0.9633
Test | BATCH 0021 / 