In [1]:
# ---------------------------------------------------------------------------- #
# An implementation of https://arxiv.org/pdf/1512.03385.pdf                    #
# See section 4.2 for the model architecture on CIFAR-10                       #
# Some part of the code was referenced from below                              #
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py   #
# ---------------------------------------------------------------------------- #

In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

from model_resnet import ResNet
from model_resnet import ResidualBlock

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
num_epochs = 80
learning_rate = 0.001

# Image preprocessing modules
transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='../../data/',
                                             train=True,
                                             transform=transform,
                                             download=True)

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=100,
                                           shuffle=True)

model = ResNet(ResidualBlock, [2, 2, 2]).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# For updating learning rate
def update_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../../data/cifar-10-python.tar.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting ../../data/cifar-10-python.tar.gz to ../../data/


In [3]:
# Train the model
total_step = len(train_loader)
curr_lr = learning_rate
for epoch in range(num_epochs):
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total += labels.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()

        acc = 100 * correct / total
        if (i+1) % 100 == 0:
            print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f} | Accuracy: {:.4f}"
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item(), acc))

    # Decay learning rate
    if (epoch+1) % 20 == 0:
        curr_lr /= 3
        update_lr(optimizer, curr_lr)

# Save the model
torch.save(model, 'resnet_trained_model.ckpt')  # 전체 모델 저장
torch.save(model.state_dict(), 'resnet_trained_model_state.ckpt')
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict()
}, 'resnet_trained_model_all.tar')

Epoch [1/80], Step [100/500] Loss: 1.7815 | Accuracy: 31.9900
Epoch [1/80], Step [200/500] Loss: 1.5090 | Accuracy: 37.8600
Epoch [1/80], Step [300/500] Loss: 1.2924 | Accuracy: 41.5767
Epoch [1/80], Step [400/500] Loss: 1.0706 | Accuracy: 44.6925
Epoch [1/80], Step [500/500] Loss: 1.0440 | Accuracy: 47.1200
Epoch [2/80], Step [100/500] Loss: 1.0817 | Accuracy: 59.4600
Epoch [2/80], Step [200/500] Loss: 1.1144 | Accuracy: 60.7850
Epoch [2/80], Step [300/500] Loss: 1.1071 | Accuracy: 61.4667
Epoch [2/80], Step [400/500] Loss: 0.8439 | Accuracy: 62.2275
Epoch [2/80], Step [500/500] Loss: 0.8425 | Accuracy: 63.1060
Epoch [3/80], Step [100/500] Loss: 1.0764 | Accuracy: 68.0000
Epoch [3/80], Step [200/500] Loss: 0.8999 | Accuracy: 67.9850
Epoch [3/80], Step [300/500] Loss: 0.7550 | Accuracy: 68.5633
Epoch [3/80], Step [400/500] Loss: 0.8314 | Accuracy: 68.8750
Epoch [3/80], Step [500/500] Loss: 0.9945 | Accuracy: 69.4500
Epoch [4/80], Step [100/500] Loss: 0.6510 | Accuracy: 72.9700
Epoch [4

Epoch [27/80], Step [200/500] Loss: 0.1589 | Accuracy: 91.3950
Epoch [27/80], Step [300/500] Loss: 0.4026 | Accuracy: 91.3367
Epoch [27/80], Step [400/500] Loss: 0.2037 | Accuracy: 91.2375
Epoch [27/80], Step [500/500] Loss: 0.2669 | Accuracy: 91.2480
Epoch [28/80], Step [100/500] Loss: 0.2736 | Accuracy: 91.5800
Epoch [28/80], Step [200/500] Loss: 0.3160 | Accuracy: 91.6600
Epoch [28/80], Step [300/500] Loss: 0.1892 | Accuracy: 91.5500
Epoch [28/80], Step [400/500] Loss: 0.2271 | Accuracy: 91.5775
Epoch [28/80], Step [500/500] Loss: 0.1655 | Accuracy: 91.6160
Epoch [29/80], Step [100/500] Loss: 0.2938 | Accuracy: 92.3500
Epoch [29/80], Step [200/500] Loss: 0.2035 | Accuracy: 92.0950
Epoch [29/80], Step [300/500] Loss: 0.2571 | Accuracy: 91.8900
Epoch [29/80], Step [400/500] Loss: 0.1654 | Accuracy: 91.8950
Epoch [29/80], Step [500/500] Loss: 0.2688 | Accuracy: 91.8160
Epoch [30/80], Step [100/500] Loss: 0.1965 | Accuracy: 91.7600
Epoch [30/80], Step [200/500] Loss: 0.3125 | Accuracy: 

Epoch [53/80], Step [300/500] Loss: 0.2101 | Accuracy: 94.5767
Epoch [53/80], Step [400/500] Loss: 0.3069 | Accuracy: 94.4025
Epoch [53/80], Step [500/500] Loss: 0.1309 | Accuracy: 94.3820
Epoch [54/80], Step [100/500] Loss: 0.1677 | Accuracy: 94.1000
Epoch [54/80], Step [200/500] Loss: 0.1257 | Accuracy: 94.4700
Epoch [54/80], Step [300/500] Loss: 0.1119 | Accuracy: 94.4300
Epoch [54/80], Step [400/500] Loss: 0.1719 | Accuracy: 94.4100
Epoch [54/80], Step [500/500] Loss: 0.2251 | Accuracy: 94.3660
Epoch [55/80], Step [100/500] Loss: 0.1645 | Accuracy: 94.8700
Epoch [55/80], Step [200/500] Loss: 0.0937 | Accuracy: 94.6350
Epoch [55/80], Step [300/500] Loss: 0.0930 | Accuracy: 94.7567
Epoch [55/80], Step [400/500] Loss: 0.1391 | Accuracy: 94.7425
Epoch [55/80], Step [500/500] Loss: 0.1123 | Accuracy: 94.6660
Epoch [56/80], Step [100/500] Loss: 0.1103 | Accuracy: 94.6300
Epoch [56/80], Step [200/500] Loss: 0.1914 | Accuracy: 94.7850
Epoch [56/80], Step [300/500] Loss: 0.1283 | Accuracy: 

Epoch [79/80], Step [400/500] Loss: 0.2723 | Accuracy: 95.2600
Epoch [79/80], Step [500/500] Loss: 0.1773 | Accuracy: 95.2600
Epoch [80/80], Step [100/500] Loss: 0.1658 | Accuracy: 95.4300
Epoch [80/80], Step [200/500] Loss: 0.1405 | Accuracy: 95.1950
Epoch [80/80], Step [300/500] Loss: 0.1858 | Accuracy: 95.2733
Epoch [80/80], Step [400/500] Loss: 0.1196 | Accuracy: 95.1800
Epoch [80/80], Step [500/500] Loss: 0.1034 | Accuracy: 95.1940
