In [1]:
import sys
sys.path.append('../modules')

from model import Resnet, Bottleneck
from utils import TinyImageNetDataset

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import random
import numpy as np

In [2]:
# Param

batch_size = 128
num_epoch = 200

In [3]:
# CIFAR-10

# download dataset
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, download = True, transform=transform)

train_mean = torch.stack([data for data, label in trainset]).mean(dim=(0,2,3))
train_std = torch.stack([data for data, label in trainset]).std(dim=(0,2,3))

Files already downloaded and verified


In [5]:
transform_train = transforms.Compose([transforms.RandomCrop(32,padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(train_mean, train_std)])

In [21]:
# label_idx_list = {x : [] for x in range(10)}

# for i in range(len(trainset)):
#     label_idx_list[trainset[i][1]].append(i)
    
# val_set_idx = []
# train_set_idx = []

# for key in label_idx_list.keys():
#     samples = random.sample(label_idx_list[key], 1000)
#     val_set_idx += samples
#     train_set_idx += list(set(label_idx_list[key]) - set(samples))

# print(f'# Train Set : {len(train_set_idx)}\n# Test Set : {len(val_set_idx)}')

# np.save('cifar10_train_idx',train_set_idx)
# np.save('cifar10_val_idx', val_set_idx)

# Train Set : 40000
# Test Set : 10000


In [6]:
trainset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, download = False, transform=transform_train)
train_idx = np.load('./cifar10_train_idx.npy')
train_random_sampler = torch.utils.data.SubsetRandomSampler(train_idx)

train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False,  sampler=train_random_sampler)
# train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

#classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [7]:
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

net = Resnet(Bottleneck, [3, 4, 6, 3]).to(device)
net = torch.nn.DataParallel(net, device_ids=[4,6,7]).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0001)

In [8]:
model_save_path = './model'
loss_history = []

net.train()

DataParallel(
  (module): Resnet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [9]:
for epoch in range(num_epoch):
    epoch_loss = 0

    for batch_id, (train_x, train_y) in enumerate(tqdm(train_dataloader)):  
        optimizer.zero_grad()
        preds = net(train_x.to(device))

        loss = criterion(preds, train_y.to(device))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    loss_history.append(epoch_loss)

    print(f'epoch : {epoch}, total train loss : {epoch_loss}') 

    with open(f'{model_save_path}/model_epoch_{epoch}.pt','wb') as f:
        torch.save({
            'state' : net.state_dict(),
            'optimizer' : optimizer.state_dict(),
            'epoch' : epoch,
            'loss_history' : loss_history,
        },f)

100%|██████████| 313/313 [01:01<00:00,  5.06it/s]


epoch : 0, total train loss : 599.6908214092255


100%|██████████| 313/313 [00:53<00:00,  5.89it/s]


epoch : 1, total train loss : 485.86606335639954


100%|██████████| 313/313 [00:52<00:00,  5.95it/s]


epoch : 2, total train loss : 432.2628085613251


100%|██████████| 313/313 [00:51<00:00,  6.04it/s]


epoch : 3, total train loss : 384.65153735876083


100%|██████████| 313/313 [00:51<00:00,  6.03it/s]


epoch : 4, total train loss : 348.74815088510513


100%|██████████| 313/313 [00:51<00:00,  6.02it/s]


epoch : 5, total train loss : 314.3597184419632


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 6, total train loss : 289.3993886113167


100%|██████████| 313/313 [00:51<00:00,  6.02it/s]


epoch : 7, total train loss : 263.8260056376457


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 8, total train loss : 243.9902365207672


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 9, total train loss : 222.6509529054165


100%|██████████| 313/313 [00:51<00:00,  6.02it/s]


epoch : 10, total train loss : 207.0540024638176


100%|██████████| 313/313 [00:51<00:00,  6.04it/s]


epoch : 11, total train loss : 192.16802725195885


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 12, total train loss : 179.1833836734295


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 13, total train loss : 167.22718358039856


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 14, total train loss : 159.83264163136482


100%|██████████| 313/313 [00:52<00:00,  5.99it/s]


epoch : 15, total train loss : 146.84254305064678


100%|██████████| 313/313 [00:51<00:00,  6.03it/s]


epoch : 16, total train loss : 141.1550269126892


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 17, total train loss : 132.33050473034382


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 18, total train loss : 125.4722516387701


100%|██████████| 313/313 [00:51<00:00,  6.03it/s]


epoch : 19, total train loss : 116.34036043286324


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 20, total train loss : 112.64769844710827


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 21, total train loss : 104.87918137013912


100%|██████████| 313/313 [00:51<00:00,  6.03it/s]


epoch : 22, total train loss : 100.93453386425972


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 23, total train loss : 95.62829455733299


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 24, total train loss : 91.8063376545906


100%|██████████| 313/313 [00:51<00:00,  6.04it/s]


epoch : 25, total train loss : 87.98256629705429


100%|██████████| 313/313 [00:51<00:00,  6.03it/s]


epoch : 26, total train loss : 84.5850719884038


100%|██████████| 313/313 [00:52<00:00,  5.99it/s]


epoch : 27, total train loss : 78.64982509613037


100%|██████████| 313/313 [00:51<00:00,  6.03it/s]


epoch : 28, total train loss : 73.90622185915709


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 29, total train loss : 71.56274396181107


100%|██████████| 313/313 [00:52<00:00,  6.02it/s]


epoch : 30, total train loss : 68.4054487273097


100%|██████████| 313/313 [00:52<00:00,  5.99it/s]


epoch : 31, total train loss : 64.19446571916342


 52%|█████▏    | 164/313 [00:27<00:24,  6.14it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 313/313 [00:52<00:00,  5.99it/s]


epoch : 53, total train loss : 26.849669942632318


100%|██████████| 313/313 [00:52<00:00,  6.02it/s]


epoch : 54, total train loss : 22.55321819614619


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 55, total train loss : 23.274596710689366


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 56, total train loss : 21.948526105843484


100%|██████████| 313/313 [00:52<00:00,  6.02it/s]


epoch : 57, total train loss : 21.549883437808603


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 58, total train loss : 20.384147526696324


100%|██████████| 313/313 [00:52<00:00,  5.99it/s]


epoch : 59, total train loss : 19.159009934403002


100%|██████████| 313/313 [00:52<00:00,  6.02it/s]


epoch : 60, total train loss : 18.8939425656572


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 61, total train loss : 19.176801815163344


 52%|█████▏    | 164/313 [00:27<00:25,  5.90it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 313/313 [00:52<00:00,  6.02it/s]


epoch : 83, total train loss : 8.778063133126125


100%|██████████| 313/313 [00:51<00:00,  6.03it/s]


epoch : 84, total train loss : 9.30507418117486


100%|██████████| 313/313 [00:52<00:00,  6.02it/s]


epoch : 85, total train loss : 9.473938275885303


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 86, total train loss : 9.013363123871386


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 87, total train loss : 9.157215929590166


100%|██████████| 313/313 [00:52<00:00,  6.02it/s]


epoch : 88, total train loss : 8.75183768174611


100%|██████████| 313/313 [00:52<00:00,  5.99it/s]


epoch : 89, total train loss : 9.202111589256674


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 90, total train loss : 7.998049291782081


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 91, total train loss : 8.350819754647091


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 92, total train loss : 7.472903484944254


 14%|█▍        | 45/313 [00:07<00:45,  5.95it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 313/313 [00:51<00:00,  6.02it/s]


epoch : 112, total train loss : 5.83082763524726


100%|██████████| 313/313 [00:52<00:00,  5.99it/s]


epoch : 113, total train loss : 5.098803768167272


100%|██████████| 313/313 [00:52<00:00,  6.02it/s]


epoch : 114, total train loss : 5.285339080437552


100%|██████████| 313/313 [00:51<00:00,  6.03it/s]


epoch : 115, total train loss : 5.2942646611481905


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 116, total train loss : 4.327461743843742


100%|██████████| 313/313 [00:52<00:00,  5.99it/s]


epoch : 117, total train loss : 5.4023542773211375


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 118, total train loss : 4.5980479553109035


100%|██████████| 313/313 [00:52<00:00,  5.99it/s]


epoch : 119, total train loss : 4.512767135660397


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 120, total train loss : 3.748794007988181


 59%|█████▉    | 185/313 [00:30<00:21,  6.01it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 313/313 [00:52<00:00,  5.98it/s]


epoch : 142, total train loss : 3.4626166458183434


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 143, total train loss : 4.050154804368503


100%|██████████| 313/313 [00:52<00:00,  5.99it/s]


epoch : 144, total train loss : 3.8309942695486825


100%|██████████| 313/313 [00:51<00:00,  6.03it/s]


epoch : 145, total train loss : 3.1443053220573347


100%|██████████| 313/313 [00:52<00:00,  5.97it/s]


epoch : 146, total train loss : 2.404693734482862


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 147, total train loss : 2.9965441881504375


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 148, total train loss : 2.6972201926691923


100%|██████████| 313/313 [00:52<00:00,  5.99it/s]


epoch : 149, total train loss : 2.7998982444987632


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 150, total train loss : 3.206867032131413


 54%|█████▍    | 170/313 [00:28<00:25,  5.63it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 313/313 [00:52<00:00,  6.02it/s]


epoch : 171, total train loss : 2.1982714216137538


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 172, total train loss : 2.758232395965024


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 173, total train loss : 2.1816074083617423


100%|██████████| 313/313 [00:52<00:00,  6.00it/s]


epoch : 174, total train loss : 1.9858320986968465


100%|██████████| 313/313 [00:52<00:00,  5.99it/s]


epoch : 175, total train loss : 1.675410138865118


100%|██████████| 313/313 [00:52<00:00,  6.01it/s]


epoch : 176, total train loss : 2.353847765305545


100%|██████████| 313/313 [00:52<00:00,  6.02it/s]


epoch : 177, total train loss : 2.3985839999077143


100%|██████████| 313/313 [00:52<00:00,  6.02it/s]


epoch : 178, total train loss : 2.1375532555684913


100%|██████████| 313/313 [00:51<00:00,  6.03it/s]


epoch : 179, total train loss : 2.0653303504514042


100%|██████████| 313/313 [00:52<00:00,  5.99it/s]


epoch : 180, total train loss : 2.341559735039482


 44%|████▍     | 138/313 [00:23<00:28,  6.05it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [7]:
train_mean

tensor([0.4914, 0.4822, 0.4465])

In [8]:
train_std

tensor([0.2470, 0.2435, 0.2616])