In [1]:
import resnet.models as models

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import argparse
import numpy as np
import os
import shutil
from torch.utils.tensorboard import SummaryWriter


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
arch = "resnet"
depth = 56
cuda = torch.cuda.is_available()
seed = 1
save = "./logs"
dataset = "cifar10"
batch_size = 500
test_batch_size = 500
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
lr = 0.1
momentum=0.9
weight_decay=1e-4
log_interval=50
start_epoch = 0
epochs=160

In [3]:
if not os.path.exists(save):
    os.makedirs(save)

In [4]:
if dataset == "cifar10":
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.Pad(4),
                           transforms.RandomCrop(32),
                           transforms.RandomHorizontalFlip(),
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=test_batch_size, shuffle=True, **kwargs)

Files already downloaded and verified


In [5]:
model = models.__dict__[arch](dataset=dataset,depth=depth)
if cuda:
    model.cuda()
optimizer = optim.SGD(model.parameters(),lr=lr,momentum=momentum,weight_decay=weight_decay)

# tensorboard logger
writer = SummaryWriter()


In [6]:

def train(epoch):
    # set as train mode
    model.train()
    avg_loss = 0.
    train_acc = 0.
    for batch_idx, (data, target) in enumerate(train_loader):
        if cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        avg_loss += loss.data
        pred = output.data.max(1, keepdim=True)[1]
        train_acc += pred.eq(target.data.view_as(pred)).cpu().sum()
        loss.backward()
        optimizer.step()
        
        #log tensorboard
        writer.add_scalar('Accuracy/train', train_acc, batch_idx)
        writer.add_scalar('Loss/train',loss , batch_idx)

        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data))
    

In [7]:
def test(test_num):
    model.eval()
    test_loss = 0
    correct = 0
    for batch_idx, (data, target) in enumerate(test_loader):
        if cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        test_loss += F.cross_entropy(output, target, size_average=False).data # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()        
    test_loss /= len(test_loader.dataset)
    
    # log to tensorboard
    writer.add_scalar('Accuracy/test', correct, test_num)
    writer.add_scalar('Loss/test',test_loss , test_num)
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))


In [8]:
def save_checkpoint(state, is_best, filepath):
    torch.save(state, os.path.join(filepath, 'checkpoint.pth.tar'))
    if is_best:
        shutil.copyfile(os.path.join(filepath, 'checkpoint.pth.tar'), os.path.join(filepath, 'model_best.pth.tar'))

In [9]:
best_prec1 = 0.
for epoch in range(start_epoch, epochs):
    if epoch in [epochs*0.5, epochs*0.75]:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.1
            
    test_num = 0
    train(epoch)
    prec1 = test(test_num)
    
    is_best = prec1 > best_prec1
    best_prec1 = max(prec1, best_prec1)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec1': best_prec1,
        'optimizer': optimizer.state_dict(),
        'cfg': model.cfg
    }, is_best, filepath=save)



  



Test set: Average loss: 2.0031, Accuracy: 2405/10000 (24.0%)


Test set: Average loss: 1.8753, Accuracy: 2708/10000 (27.0%)


Test set: Average loss: 1.7043, Accuracy: 3475/10000 (34.0%)


Test set: Average loss: 1.6365, Accuracy: 3840/10000 (38.0%)


Test set: Average loss: 1.6220, Accuracy: 4063/10000 (40.0%)


Test set: Average loss: 1.5709, Accuracy: 4447/10000 (44.0%)


Test set: Average loss: 1.4169, Accuracy: 5156/10000 (51.0%)



KeyboardInterrupt: 