In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import mxnet as mx
from mxnet import gluon as gl
from mxnet import nd
import os

  import OpenSSL.SSL


In [3]:
import os
import json

In [4]:
im_path= './data/train/'

In [5]:
def transform_train(img):
    '''
    img is the mx.image.imread object
    '''
    img = img.astype('float32') / 255
    random_shape = int(np.random.uniform() * 224 + 256)  
    # random samplely in [256, 480]
    aug_list = mx.image.CreateAugmenter(
        data_shape=(3, 224, 224), resize=random_shape,
        rand_mirror=True, rand_crop=True, 
        mean=np.array([0.4736, 0.4504, 0.3909]),                               
        std=np.array([0.2655, 0.2607, 0.2650]))
    
    for aug in aug_list:
        img = aug(img)
    img = nd.transpose(img, (2, 0, 1))
    return img

In [6]:
def transform_valid(img):
    img = img.astype('float32') / 255.
    aug_list = mx.image.CreateAugmenter(
        data_shape=(3, 224, 224), 
        mean=np.array([0.4736, 0.4504, 0.3909]),                               
        std=np.array([0.2655, 0.2607, 0.2650]))
    
    for aug in aug_list:
        img = aug(img)
    img = nd.transpose(img, (2, 0, 1))
    return img

## use DataLoader

In [7]:
from data_utils import DogDataSet

In [8]:
train_json = './data/train.json'
train_set = DogDataSet(train_json, im_path, transform_train)
train_data = gl.data.DataLoader(train_set, batch_size=64, shuffle=True, last_batch='keep')

In [9]:
valid_json = './data/valid.json'
valid_set = DogDataSet(valid_json, im_path, transform_valid)
valid_data = gl.data.DataLoader(valid_set, batch_size=128, shuffle=False, last_batch='keep')

In [10]:
criterion = gl.loss.SoftmaxCrossEntropyLoss()

In [11]:
# ctx = [mx.gpu(0), mx.gpu(1)]
ctx = mx.gpu(0)
num_epochs = 200
lr = 0.1
wd = 1e-4
lr_decay = 0.1

In [12]:
net = gl.model_zoo.vision.resnet50_v2(classes=120)
net.initialize(init=mx.init.Xavier(), ctx=ctx)
net.hybridize()

In [None]:
from tensorboardX import SummaryWriter

In [None]:
import datetime
writer = SummaryWriter()

def get_acc(output, label):
    pred = output.argmax(1)
    correct = (pred == label).sum()
    return correct.asscalar()

def train(net, train_data, valid_data, num_epochs, lr, wd, ctx, lr_decay):
    trainer = gl.Trainer(
        net.collect_params(), 'sgd', {'learning_rate': lr, 'momentum': 0.9, 'wd': wd})

    prev_time = datetime.datetime.now()
    for epoch in range(num_epochs):
        if epoch == 89 or 159:
            trainer.set_learning_rate = trainer.learning_rate * lr_decay
        train_loss = 0
        correct = 0
        total = 0
        for data, label in train_data:
            bs = data.shape[0]
            data = data.as_in_context(ctx)
            label = label.as_in_context(ctx)
            with mx.autograd.record():
                output = net(data)
                loss = criterion(output, label)
            loss.backward()
            trainer.step(bs)
            train_loss += loss.sum().asscalar()
            correct += get_acc(output, label)
            total += bs
        writer.add_scalars('loss', {'train': train_loss / total}, epoch)
        writer.add_scalars('acc', {'train': correct / total}, epoch)
        cur_time = datetime.datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        if valid_data is not None:
            valid_correct = 0
            valid_total = 0
            valid_loss = 0
            for data, label in valid_data:
                bs = data.shape[0]
                data = data.as_in_context(ctx)
                label = label.as_in_context(ctx)
                output = net(data)
                loss = criterion(output, label)
                valid_loss += nd.sum(loss).asscalar()
                valid_correct += get_acc(output, label)
                valid_total += bs
            valid_acc = valid_correct / valid_total
            writer.add_scalars('loss', {'valid': valid_loss / valid_total}, epoch)
            writer.add_scalars('acc', {'valid': valid_acc}, epoch)
            epoch_str = ("Epoch %d. Train Loss: %f, Train acc %f, Valid Loss: %f, Valid acc %f, "
                         % (epoch, train_loss / total,
                            correct / total, valid_loss / valid_total, valid_acc))
        else:
            epoch_str = ("Epoch %d. Loss: %f, Train acc %f, "
                         % (epoch, train_loss / total,
                            correct / total))
        prev_time = cur_time
        print(epoch_str + time_str + ', lr ' + str(trainer.learning_rate))

In [None]:
train(net, train_data, valid_data, num_epochs, lr, wd, ctx, lr_decay)

In [None]:
net.save_params('./res50.params')