In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import os

In [11]:
def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)

def dl_train(net,trainer , train_iter, test_iter, num_epoch, device=None, pre_trained=False):

    device = torch.device('cpu') if device == None else device

    if pre_trained == False:
        net.apply(init_weights)
    
    net.to(device)

    loss_function = nn.CrossEntropyLoss()

    # begin to train
    net.train()
    for epoch in range(num_epoch):
        running_loss = 0.0
        for x, y in enumerate(train_iter):
            data, label = y
            data = data.to(device)
            label = label.to(device)

            trainer.zero_grad()

            y_hat = net(data)
            loss = loss_function(y_hat, label)
            loss.backward()

            trainer.step()

            running_loss += loss.item()
            if x % 100 == 99:
                print(f'epoch {epoch+1}, batch {x+1}, loss {running_loss/100:.3f}')
                running_loss = 0.0
    
    if test_iter != None:
        # begin evaluation
        net.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for x, y in enumerate(test_iter):
                data, label = y
                data = data.to(device)
                label = label.to(device)

                y_hat = net(data)
                _, predicted = torch.max(y_hat.data, 1)
                total += label.size(0)
                correct += (predicted == label).sum().item()

            print(f'Accuracy of the network on the {total} test images: {100 * correct / total}%')

def VGG_block(num_convs, in_channels, out_channels):
    blk = []
    for _ in range(num_convs):
        blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        blk.append(nn.ReLU())
        in_channels = out_channels
    blk.append(nn.MaxPool2d(kernel_size=2, stride=2))
    
    return nn.Sequential(*blk)

def VGG(conv_arch):
    conv_blks = []
    in_channels = 1
    for (num_convs, out_channels) in conv_arch:
        conv_blks.append(VGG_block(num_convs, in_channels, out_channels))
        in_channels = out_channels

    return nn.Sequential(
        *conv_blks, nn.Flatten(),
        nn.Linear(out_channels * 7 * 7, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 10)
    )

In [3]:
path = os.path.join('MNIST_data')
trans = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.ToTensor()
])

MNIST_train = torchvision.datasets.MNIST(
    path, train=True,
    transform=trans,
    download=True
)
MNIST_test = torchvision.datasets.MNIST(
    path, train=False,
    transform=trans,
    download=True
)

train_iter = torch.utils.data.DataLoader(
    MNIST_train, batch_size=256, 
    shuffle=True, num_workers=4, 
    prefetch_factor=4, pin_memory=True
)
test_iter = torch.utils.data.DataLoader(
    MNIST_test, batch_size=256, 
    shuffle=True, num_workers=4, 
    prefetch_factor=4, pin_memory=True
)

In [5]:
conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))
net = VGG(conv_arch)

trainer = torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=5e-4)

device = torch.device('cuda')

num_epoch = 20

dl_train(net, trainer, train_iter, test_iter, num_epoch, device)

epoch 1, batch 100, loss 2.278
epoch 1, batch 200, loss 2.238
epoch 2, batch 100, loss 0.506
epoch 2, batch 200, loss 0.127
epoch 3, batch 100, loss 0.078
epoch 3, batch 200, loss 0.070
epoch 4, batch 100, loss 0.062
epoch 4, batch 200, loss 0.047
epoch 5, batch 100, loss 0.042
epoch 5, batch 200, loss 0.035
epoch 6, batch 100, loss 0.031
epoch 6, batch 200, loss 0.033
epoch 7, batch 100, loss 0.026
epoch 7, batch 200, loss 0.025
epoch 8, batch 100, loss 0.020
epoch 9, batch 100, loss 0.018
epoch 9, batch 200, loss 0.020
epoch 10, batch 100, loss 0.016
epoch 10, batch 200, loss 0.015
epoch 12, batch 100, loss 0.012
epoch 12, batch 200, loss 0.012
epoch 13, batch 100, loss 0.012
epoch 13, batch 200, loss 0.009
epoch 14, batch 100, loss 0.008
epoch 14, batch 200, loss 0.010
epoch 15, batch 100, loss 0.006
epoch 15, batch 200, loss 0.010
epoch 16, batch 100, loss 0.007
epoch 16, batch 200, loss 0.007
epoch 17, batch 100, loss 0.007
epoch 17, batch 200, loss 0.007
epoch 18, batch 100, loss

In [6]:
path = os.path.join('VGG_checkpoint.pt')

torch.save({
            'epoch': num_epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': trainer.state_dict(),
            }, path)

train anothor 20 epoch

In [13]:
conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))
net = VGG(conv_arch)
trainer = torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=5e-4)

device = torch.device('cuda')

num_epoch = 20

path = os.path.join('VGG_checkpoint.pt')

checkpoint = torch.load('VGG_checkpoint.pt', map_location=torch.device('cuda'))
net.load_state_dict(checkpoint['model_state_dict'])
trainer.load_state_dict(checkpoint['optimizer_state_dict'])
num_epochs = checkpoint['epoch']

In [14]:
dl_train(net, trainer, train_iter, test_iter, num_epoch, device, pre_trained=True)

epoch 1, batch 100, loss 0.015
epoch 1, batch 200, loss 0.005
epoch 2, batch 100, loss 0.003
epoch 2, batch 200, loss 0.004
epoch 3, batch 100, loss 0.003
epoch 3, batch 200, loss 0.005
epoch 4, batch 100, loss 0.004
epoch 4, batch 200, loss 0.003
epoch 5, batch 100, loss 0.006
epoch 5, batch 200, loss 0.003
epoch 6, batch 100, loss 0.003
epoch 6, batch 200, loss 0.003
epoch 7, batch 100, loss 0.003
epoch 7, batch 200, loss 0.002
epoch 8, batch 100, loss 0.004
epoch 8, batch 200, loss 0.003
epoch 9, batch 100, loss 0.002
epoch 9, batch 200, loss 0.003
epoch 10, batch 100, loss 0.002
epoch 10, batch 200, loss 0.003
epoch 11, batch 100, loss 0.002
epoch 11, batch 200, loss 0.002
epoch 12, batch 100, loss 0.003
epoch 12, batch 200, loss 0.002
epoch 13, batch 100, loss 0.001
epoch 13, batch 200, loss 0.002
epoch 14, batch 100, loss 0.002
epoch 14, batch 200, loss 0.002
epoch 15, batch 100, loss 0.002
epoch 15, batch 200, loss 0.002
epoch 16, batch 100, loss 0.002
epoch 16, batch 200, loss 