In [1]:
import torch
from torchvision.datasets import CIFAR10
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchsummary import summary
from tensorboardX import SummaryWriter
from resnet import ResNet

In [2]:
#model = ResNet([3, 3, 3], 10).to('cuda') # resnet20
model = ResNet([5, 5, 5], 10).to('cuda') # resnet32
#model = ResNet([7, 7, 7], 10) # resnet44
#model = ResNet([9, 9, 9], 10) # resnet56
#model = ResNet([18, 18, 18], 10) # resnet110

In [3]:
summary(model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
              ReLU-3           [-1, 16, 32, 32]               0
            Conv2d-4           [-1, 16, 32, 32]           2,304
       BatchNorm2d-5           [-1, 16, 32, 32]              32
              ReLU-6           [-1, 16, 32, 32]               0
            Conv2d-7           [-1, 16, 32, 32]           2,304
       BatchNorm2d-8           [-1, 16, 32, 32]              32
              ReLU-9           [-1, 16, 32, 32]               0
    ResidualBlock-10           [-1, 16, 32, 32]               0
           Conv2d-11           [-1, 16, 32, 32]           2,304
      BatchNorm2d-12           [-1, 16, 32, 32]              32
             ReLU-13           [-1, 16, 32, 32]               0
           Conv2d-14           [-1, 16,

In [4]:
train_cifar10 = CIFAR10(root = '../datasets/cifar10', train=True, download=True)
val_cifar10 = CIFAR10(root = '../datasets/cifar10', train=False, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
train_mean = train_cifar10.data.mean(axis=(0,1,2)) / 255
train_std = train_cifar10.data.std(axis=(0,1,2)) / 255

In [6]:
train_transforms = transforms.Compose([
                                        #transforms.Resize(224),
                                        transforms.RandomCrop(32, padding=4),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                        transforms.Normalize(train_mean, train_std),                                      
                                        ])

val_transforms = transforms.Compose([
                                        transforms.ToTensor(),
                                        #transforms.Resize(224),
                                        transforms.Normalize(train_mean, train_std)
                                        ])

train_cifar10.transform = train_transforms
val_cifar10.transform = train_transforms

train_dl = DataLoader(train_cifar10, batch_size=256, shuffle=True, num_workers=4)
val_dl = DataLoader(val_cifar10, batch_size=128, shuffle=True, num_workers=4)





In [7]:
def count_correct(output, target):
    pred = output.argmax(1, keepdim=True)
    corrects = pred.eq(target.view_as(pred)).sum().item()
    return corrects

In [8]:
train_losses = []
val_losses = []
train_accs = []
val_accs = []

In [9]:
writer = SummaryWriter('resnet_logs')
optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=1e-4)

device = 'cuda'
model.to(device)
epochs = 100
loss_func = nn.CrossEntropyLoss()
decay_epoch = [32000, 48000]
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decay_epoch, gamma=0.1)
best_model = None
best_accs = -1
for _ in tqdm(range(epochs)):
    global_loss = 0
    corrects = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_dl):
        data, target = data.to(device), target.to(device)
        
        output = model(data)
        loss = loss_func(output, target)
        global_loss = global_loss + loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        corrects += count_correct(output, target)

    train_losses.append(global_loss / (batch_idx + 1))
    train_accs.append(corrects / len(train_cifar10) * 100)
    
    model.eval()
    corrects = 0
    global_loss = 0
    for batch_idx, (data, target) in enumerate(val_dl):
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            output = model(data)
            loss = loss_func(output, target)
            global_loss = global_loss + loss.item()
            corrects += count_correct(output, target)

    val_losses.append(global_loss / (batch_idx + 1))
    val_accs.append(corrects / len(val_cifar10) * 100)
    
    writer.add_scalar('resnet_log/train_error', 100 - train_accs[-1], _ + 1)
    writer.add_scalar('resnet_log/validation_error', 100 - val_accs[-1], _ + 1)
    
    if (_ + 1) % 10 == 0:
        print("Epoch %d | train_loss = %.2f |  train_acc = %.2f | val_loss = %.2f | val_acc = %.2f" % (_ + 1, train_losses[-1], train_accs[-1], val_losses[-1], val_accs[-1]))

 10%|█         | 10/100 [03:00<26:43, 17.82s/it]

Epoch 10 | train_loss = 0.51 |  train_acc = 82.58 | val_loss = 0.65 | val_acc = 77.71


 20%|██        | 20/100 [05:59<24:09, 18.11s/it]

Epoch 20 | train_loss = 0.34 |  train_acc = 88.31 | val_loss = 0.52 | val_acc = 82.32


 30%|███       | 30/100 [08:59<20:49, 17.85s/it]

Epoch 30 | train_loss = 0.26 |  train_acc = 90.94 | val_loss = 0.47 | val_acc = 84.68


 40%|████      | 40/100 [11:58<17:50, 17.85s/it]

Epoch 40 | train_loss = 0.22 |  train_acc = 92.12 | val_loss = 0.51 | val_acc = 84.12


 50%|█████     | 50/100 [14:57<14:52, 17.84s/it]

Epoch 50 | train_loss = 0.19 |  train_acc = 93.38 | val_loss = 0.50 | val_acc = 84.72


 60%|██████    | 60/100 [17:56<11:57, 17.95s/it]

Epoch 60 | train_loss = 0.18 |  train_acc = 93.67 | val_loss = 0.44 | val_acc = 86.55


 70%|███████   | 70/100 [20:55<08:55, 17.85s/it]

Epoch 70 | train_loss = 0.17 |  train_acc = 93.97 | val_loss = 0.43 | val_acc = 86.59


 80%|████████  | 80/100 [23:54<05:57, 17.86s/it]

Epoch 80 | train_loss = 0.15 |  train_acc = 94.58 | val_loss = 0.57 | val_acc = 82.97


 90%|█████████ | 90/100 [26:53<02:58, 17.86s/it]

Epoch 90 | train_loss = 0.14 |  train_acc = 94.94 | val_loss = 0.49 | val_acc = 85.67


100%|██████████| 100/100 [29:51<00:00, 17.92s/it]

Epoch 100 | train_loss = 0.14 |  train_acc = 94.98 | val_loss = 0.45 | val_acc = 87.39



