In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
from torchsummary import summary

import torchvision
import torchvision.transforms as transforms

from tqdm import tqdm

from Model import *

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.RandomVerticalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [4]:
train_dataset = torchvision.datasets.CIFAR10(root='./Dataset', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True, num_workers=2)

Files already downloaded and verified


In [5]:
test_dataset = torchvision.datasets.CIFAR10(root='./Dataset', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)

Files already downloaded and verified


In [6]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [7]:
net = VGG('VGG11')
# net = SimpleDLA()
# net = MobileNetV2()
# net = RegNetX_200MF()

In [8]:
net = net.to(device)
summary(net, input_size=(3,32, 32))

if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,792
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
         MaxPool2d-4           [-1, 64, 16, 16]               0
            Conv2d-5          [-1, 128, 16, 16]          73,856
       BatchNorm2d-6          [-1, 128, 16, 16]             256
              ReLU-7          [-1, 128, 16, 16]               0
         MaxPool2d-8            [-1, 128, 8, 8]               0
            Conv2d-9            [-1, 256, 8, 8]         295,168
      BatchNorm2d-10            [-1, 256, 8, 8]             512
             ReLU-11            [-1, 256, 8, 8]               0
           Conv2d-12            [-1, 256, 8, 8]         590,080
      BatchNorm2d-13            [-1, 256, 8, 8]             512
             ReLU-14            [-1, 25

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [10]:
def valid(epoch):
    net.eval()
    loss_list, p_list, y_list = [], [], []
    for idx, (x, y) in enumerate(test_loader):
        x, y = x.to(device), y.to(device)
        p = net(x)

        loss = criterion(p, y)
        loss_list.append(loss.item())

        _, p = p.max(axis=-1)
        y_list.append(y.detach().cpu().numpy())
        p_list.append(p.detach().cpu().numpy())
    
    loss = np.array(loss_list).mean()
    acc = (np.array(p_list) == np.array(y_list)).astype(np.float32).mean()
    
    return loss, acc

In [11]:
def train(epoch):
    net.train()        
    loss_list, p_list, y_list = [], [], []
    with tqdm((enumerate(train_loader)), desc='epoch%3d'%epoch, total=len(train_loader), ncols=0) as t:
        for idx, (x, y) in t:
            x, y = x.to(device), y.to(device)
            p = net(x)
            
            optimizer.zero_grad()
            loss = criterion(p, y)
            loss.backward()
            optimizer.step()

            loss_list.append(loss.item())

            _, p = p.max(axis=-1)
            y_list.append(y.detach().cpu().numpy())
            p_list.append(p.detach().cpu().numpy())
            if idx+1 < len(train_loader):
                t.set_postfix({'loss':'%0.4f'%loss_list[-1]})
            else:
                loss = np.array(loss_list).mean()
                acc = (np.array(p_list) == np.array(y_list)).astype(np.float32).mean()
                
                valid_loss, valid_acc = valid(epoch)
                t.set_postfix({'loss':'%0.4f'%loss, 'acc':'%0.4f'%acc, 'valid_loss':'%0.4f'%valid_loss, 'valid_acc':'%0.4f'%valid_acc})
        return loss, acc, valid_loss, valid_acc

In [12]:
writer = SummaryWriter()
# writer.add_graph(net, input_to_model=torch.from_numpy(np.random.randn(2,3,32,32).astype(np.float32)).to(device), verbose=False)
for epoch in range(200):
    train_loss, train_acc, valid_loss, valid_acc = train(epoch)
    writer.add_scalar('loss/train', train_loss, epoch)
    writer.add_scalar('acc/train', train_acc, epoch)
    writer.add_scalar('loss/valid', valid_loss, epoch)
    writer.add_scalar('acc/valid', valid_acc, epoch)

epoch  0: 100% 500/500 [00:14<00:00, 33.58it/s, loss=1.5168, acc=0.4424, valid_loss=1.2483, valid_acc=0.5423]
epoch  1: 100% 500/500 [00:13<00:00, 36.71it/s, loss=1.2041, acc=0.5642, valid_loss=1.1817, valid_acc=0.5824]
epoch  2: 100% 500/500 [00:14<00:00, 34.24it/s, loss=1.0572, acc=0.6181, valid_loss=1.0585, valid_acc=0.6154]
epoch  3: 100% 500/500 [00:14<00:00, 34.21it/s, loss=0.9611, acc=0.6544, valid_loss=0.9158, valid_acc=0.6774]
epoch  4: 100% 500/500 [00:15<00:00, 33.14it/s, loss=0.8959, acc=0.6818, valid_loss=0.9954, valid_acc=0.6436]
epoch  5: 100% 500/500 [00:14<00:00, 34.12it/s, loss=0.8379, acc=0.7019, valid_loss=0.7840, valid_acc=0.7219]
epoch  6: 100% 500/500 [00:14<00:00, 34.10it/s, loss=0.7922, acc=0.7205, valid_loss=0.7820, valid_acc=0.7213]
epoch  7: 100% 500/500 [00:14<00:00, 34.38it/s, loss=0.7486, acc=0.7345, valid_loss=0.8195, valid_acc=0.7139]
epoch  8: 100% 500/500 [00:14<00:00, 33.68it/s, loss=0.7206, acc=0.7440, valid_loss=0.7387, valid_acc=0.7394]
epoch  9: 

epoch 74: 100% 500/500 [00:15<00:00, 33.10it/s, loss=0.1511, acc=0.9473, valid_loss=0.5187, valid_acc=0.8544]
epoch 75: 100% 500/500 [00:15<00:00, 33.19it/s, loss=0.1484, acc=0.9474, valid_loss=0.5143, valid_acc=0.8509]
epoch 76: 100% 500/500 [00:13<00:00, 35.74it/s, loss=0.1481, acc=0.9463, valid_loss=0.4812, valid_acc=0.8565]
epoch 77: 100% 500/500 [00:14<00:00, 33.46it/s, loss=0.1474, acc=0.9473, valid_loss=0.5156, valid_acc=0.8535]
epoch 78: 100% 500/500 [00:15<00:00, 33.16it/s, loss=0.1447, acc=0.9488, valid_loss=0.4992, valid_acc=0.8562]
epoch 79: 100% 500/500 [00:15<00:00, 33.07it/s, loss=0.1431, acc=0.9495, valid_loss=0.5233, valid_acc=0.8486]
epoch 80: 100% 500/500 [00:15<00:00, 33.16it/s, loss=0.1396, acc=0.9505, valid_loss=0.5090, valid_acc=0.8545]
epoch 81: 100% 500/500 [00:14<00:00, 33.54it/s, loss=0.1342, acc=0.9531, valid_loss=0.4811, valid_acc=0.8576]
epoch 82: 100% 500/500 [00:15<00:00, 33.07it/s, loss=0.1336, acc=0.9521, valid_loss=0.4998, valid_acc=0.8590]
epoch 83: 

epoch148: 100% 500/500 [00:14<00:00, 35.70it/s, loss=0.0523, acc=0.9824, valid_loss=0.5142, valid_acc=0.8693]
epoch149: 100% 500/500 [00:14<00:00, 35.64it/s, loss=0.0539, acc=0.9819, valid_loss=0.5503, valid_acc=0.8645]
epoch150: 100% 500/500 [00:14<00:00, 35.65it/s, loss=0.0521, acc=0.9825, valid_loss=0.5114, valid_acc=0.8702]
epoch151: 100% 500/500 [00:14<00:00, 35.34it/s, loss=0.0499, acc=0.9834, valid_loss=0.5217, valid_acc=0.8699]
epoch152: 100% 500/500 [00:14<00:00, 35.53it/s, loss=0.0510, acc=0.9834, valid_loss=0.5195, valid_acc=0.8741]
epoch153: 100% 500/500 [00:13<00:00, 36.49it/s, loss=0.0500, acc=0.9826, valid_loss=0.5579, valid_acc=0.8659]
epoch154: 100% 500/500 [00:14<00:00, 33.36it/s, loss=0.0497, acc=0.9833, valid_loss=0.5078, valid_acc=0.8722]
epoch155: 100% 500/500 [00:14<00:00, 34.25it/s, loss=0.0521, acc=0.9822, valid_loss=0.5103, valid_acc=0.8700]
epoch156: 100% 500/500 [00:14<00:00, 33.84it/s, loss=0.0483, acc=0.9841, valid_loss=0.5295, valid_acc=0.8674]
epoch157: 