In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import pprint

import numpy as np
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms

from config import config
from config import extra
from function import train
from function import valid
from dataset import create_dataset
from models import create_model
from utils import create_optimizer
from utils import create_logger

In [2]:
config.MODE = 'train'
extra()

# create a logger
logger = create_logger('train')

# logging configurations
logger.info(pprint.pformat(config))

# cudnn related setting
cudnn.benchmark = config.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = config.CUDNN.ENABLED

{'ADV': {'LINF_NORM': 0.03, 'TYPE': 'FGSM'},
 'CUDNN': {'BENCHMARK': True, 'DETERMINISTIC': False, 'ENABLED': True},
 'DATASET': {'CLASSNUM': 10,
             'DATASET': 'CIFAR10',
             'IMAGESIZE': [32, 32],
             'ROOT': '/m/shibf/dataset/cifar10'},
 'GPUS': '0, 5, 6, 7, 8',
 'GPU_NUM': 5,
 'MODE': 'train',
 'MODEL': {'INPUT_DIM': 256, 'TYPE': 'ConvNet'},
 'OUTPUT_DIR': 'experiments/CIFAR10/train',
 'TEST': {'PRINT_EVERY': 1,
          'STATE_DICT': 'experiments/CIFAR10/train/checkpoint_0.7049.pth',
          'TEST_EVERY': 5},
 'TRAIN': {'BATCH_SIZE': 32,
           'BEGIN_EPOCH': 0,
           'CONJREG': 1,
           'END_EPOCH': 120,
           'IF_CONJREG': False,
           'IF_L1REG': False,
           'IF_SPECREG': True,
           'L1REG': 5e-05,
           'LR': 0.001,
           'LR_DECAY_RATE': 0.5,
           'LR_MILESTONES': [30, 60, 90],
           'MOMENTUM': 0.9,
           'NESTEROV': False,
           'OPTIMIZER': 'sgd',
           'PRINT_EVERY': 1,
 

In [3]:
# create a model
os.environ["CUDA_VISIBLE_DEVICES"] = config.GPUS
gpus = [int(i) for i in config.GPUS.split(',')]
gpus = range(gpus.__len__())
model = create_model()

model = model.cuda(gpus[0])
model = torch.nn.DataParallel(model, device_ids=gpus)


In [4]:
# create an optimizer
optimizer = create_optimizer(config, model)

# create a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, config.TRAIN.LR_MILESTONES,
config.TRAIN.LR_DECAY_RATE
)

# get dataset
train_dataset, test_dataset, train_loader, test_loader = create_dataset()

In [None]:
#training and validating
best_perf = 0
for epoch in range(config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH):
    lr_scheduler.step()

    # train for one epoch
    train(train_loader, model, optimizer, epoch)

    # evaluate on validation set
    if (epoch + 1) % config.TEST.TEST_EVERY == 0:
        perf_indicator = valid(test_loader, model)

        if perf_indicator > best_perf:
            logger.info("=> saving checkpoint into {}".format(os.path.join(config.OUTPUT_DIR, 'checkpoint_{}.pth'.format(perf_indicator))))
            best_perf = perf_indicator
            torch.save(model.state_dict(), os.path.join(config.OUTPUT_DIR, 'checkpoint_{}.pth'.format(perf_indicator)))

# save the final model
logger.info("=> saving final model into {}".format(
    os.path.join(config.OUTPUT_DIR, 'model_{}.pth'.format(perf_indicator))
))
torch.save(model.state_dict(),
           os.path.join(config.OUTPUT_DIR, 'model_{}.pth'.format(perf_indicator)))

Epoch: [0][0/313]	Time 33.138s (33.138s)	Speed 4.8 samples/s	Data 0.277s (0.277s)	Loss 2.30505 (2.30505)	Accuracy 0.075 (0.075)
Epoch: [0][1/313]	Time 0.592s (16.865s)	Speed 270.3 samples/s	Data 0.022s (0.149s)	Loss 2.30848 (2.30676)	Accuracy 0.119 (0.097)
Epoch: [0][2/313]	Time 0.287s (11.339s)	Speed 557.6 samples/s	Data 0.002s (0.100s)	Loss 2.30932 (2.30761)	Accuracy 0.087 (0.094)
Epoch: [0][3/313]	Time 0.219s (8.559s)	Speed 731.5 samples/s	Data 0.002s (0.076s)	Loss 2.30485 (2.30692)	Accuracy 0.119 (0.100)
Epoch: [0][4/313]	Time 0.291s (6.905s)	Speed 548.9 samples/s	Data 0.018s (0.064s)	Loss 2.29621 (2.30478)	Accuracy 0.125 (0.105)
Epoch: [0][5/313]	Time 0.226s (5.792s)	Speed 708.6 samples/s	Data 0.011s (0.055s)	Loss 2.30191 (2.30430)	Accuracy 0.081 (0.101)
Epoch: [0][6/313]	Time 0.254s (5.001s)	Speed 630.9 samples/s	Data 0.003s (0.048s)	Loss 2.29610 (2.30313)	Accuracy 0.075 (0.097)
Epoch: [0][7/313]	Time 0.307s (4.414s)	Speed 520.5 samples/s	Data 0.011s (0.043s)	Loss 2.30026 (2.3027