In [1]:
import datetime
import json
import os

import mxnet as mx
import numpy as np
from mxnet import gluon as gl
from mxnet import nd
from mxnet.gluon import nn

from data_utils import DogDataSet

im_path = './data/train/'


def transform_train(img):
    '''
    img is the mx.image.imread object
    '''
    img = img.astype('float32') / 255
    random_shape = int(np.random.uniform() * 224 + 256)
    # random samplely in [256, 480]
    aug_list = mx.image.CreateAugmenter(
        data_shape=(3, 299, 299),
        resize=random_shape,
        rand_mirror=True,
        rand_crop=True,
        mean=np.array([0.485, 0.456, 0.406]),
        std=np.array([0.229, 0.224, 0.225]))
#         mean=np.array([0.4736, 0.4504, 0.3909]),
#         std=np.array([0.2655, 0.2607, 0.2650]))

    for aug in aug_list:
        img = aug(img)
    img = nd.transpose(img, (2, 0, 1))
    return img


def transform_valid(img):
    img = img.astype('float32') / 255.
    aug_list = mx.image.CreateAugmenter(
        data_shape=(3, 299, 299),
        mean=np.array([0.485, 0.456, 0.406]),
        std=np.array([0.229, 0.224, 0.225]))
#         mean=np.array([0.4736, 0.4504, 0.3909]),
#         std=np.array([0.2655, 0.2607, 0.2650]))

    for aug in aug_list:
        img = aug(img)
    img = nd.transpose(img, (2, 0, 1))
    return img


# ## use DataLoader

train_json = './data/train.json'
train_set = DogDataSet(train_json, im_path, transform_train)
train_data = gl.data.DataLoader(
    train_set, batch_size=64, shuffle=True, last_batch='keep')

valid_json = './data/valid.json'
valid_set = DogDataSet(valid_json, im_path, transform_valid)
valid_data = gl.data.DataLoader(
    valid_set, batch_size=128, shuffle=False, last_batch='keep')

criterion = gl.loss.SoftmaxCrossEntropyLoss()

In [2]:
ctx = mx.gpu(1)

In [3]:
net = gl.model_zoo.vision.inception_v3(pretrained=True, ctx=ctx)
for i, j in net.collect_params().items():
    j.grad_req = 'null'
with net.name_scope():
    net.classifier.add(nn.Activation('relu'), nn.Dropout(0.5), nn.Dense(120))
net.classifier[8].initialize(init=mx.init.Xavier(), ctx=ctx)
# freeze weight
# for _, i in net.features.collect_params().items():
# i.grad_req = 'null'
# net.initialize(init=mx.init.Xavier(), ctx=ctx)
# net.collect_params().load('finetune_resnet_20.params', ctx=ctx)
net.hybridize()

In [4]:
def get_acc(output, label):
    pred = output.argmax(1)
    correct = (pred == label).sum()
    return correct.asscalar()


def train(net, train_data, valid_data, num_epochs, lr, wd, ctx, lr_period,
          lr_decay):
    trainer = gl.Trainer(net.collect_params(), 'sgd',
                         {'learning_rate': lr,
                          'momentum': 0.9,
                          'wd': wd})

    prev_time = datetime.datetime.now()
    for epoch in range(num_epochs):
        if epoch > 0 and epoch % lr_period == 0:
            trainer.set_learning_rate(trainer.learning_rate * lr_decay)
#         if epoch == 50:
#             for i, j in net.features.collect_params().items():
#                 j.grad_req = 'write'
#             for i, j in net.features.collect_params().items():
#                 if 'stage4' in i:
#                     j.lr_mult = 0.01
#                 else:
#                     j.lr_mult = 0.001
        train_loss = 0
        correct = 0
        total = 0
        for data, label in train_data:
            bs = data.shape[0]
            data = data.as_in_context(ctx)
            label = label.as_in_context(ctx)
            with mx.autograd.record():
                # with mx.autograd.pause(train_mode=True):
                # data_feature = net.features(data)
                # output = net.classifier(data_feature)
                output = net(data)
                loss = criterion(output, label)
            loss.backward()
            trainer.step(bs)
            train_loss += loss.sum().asscalar()
            correct += get_acc(output, label)
            total += bs
        cur_time = datetime.datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        if valid_data is not None:
            valid_correct = 0
            valid_total = 0
            valid_loss = 0
            for data, label in valid_data:
                bs = data.shape[0]
                data = data.as_in_context(ctx)
                label = label.as_in_context(ctx)
                output = net(data)
                loss = criterion(output, label)
                valid_loss += nd.sum(loss).asscalar()
                valid_correct += get_acc(output, label)
                valid_total += bs
            valid_acc = valid_correct / valid_total
            epoch_str = (
                "Epoch %d. Train Loss: %f, Train acc %f, Valid Loss: %f, Valid acc %f, "
                % (epoch, train_loss / total, correct / total,
                   valid_loss / valid_total, valid_acc))
        else:
            epoch_str = ("Epoch %d. Loss: %f, Train acc %f, " %
                         (epoch, train_loss / total, correct / total))
        prev_time = cur_time
        print(epoch_str + time_str + ', lr {:.5f}'.format(trainer.learning_rate))
#         if (epoch + 1) % 10 == 0:
#             net.collect_params().save(
#                 './finetune_resnet_{}.params'.format(epoch + 1))

In [5]:
num_epochs = 50
lr = 0.01
wd = 1e-4
lr_period = 20
lr_decay = 0.5

In [6]:
train(net, train_data, valid_data, num_epochs, lr, wd, ctx, lr_period, lr_decay)

Epoch 0. Train Loss: 2.940930, Train acc 0.348966, Valid Loss: 0.888945, Valid acc 0.773810, Time 00:03:04, lr 0.01000
Epoch 1. Train Loss: 1.662475, Train acc 0.588574, Valid Loss: 0.702991, Valid acc 0.814286, Time 00:03:11, lr 0.01000
Epoch 2. Train Loss: 1.389818, Train acc 0.639736, Valid Loss: 0.631722, Valid acc 0.821429, Time 00:02:59, lr 0.01000
Epoch 3. Train Loss: 1.284951, Train acc 0.667661, Valid Loss: 0.623558, Valid acc 0.813095, Time 00:02:59, lr 0.01000
Epoch 4. Train Loss: 1.203269, Train acc 0.681944, Valid Loss: 0.585251, Valid acc 0.833333, Time 00:02:59, lr 0.01000
Epoch 5. Train Loss: 1.169960, Train acc 0.692709, Valid Loss: 0.579327, Valid acc 0.826190, Time 00:03:05, lr 0.01000
Epoch 6. Train Loss: 1.110732, Train acc 0.700810, Valid Loss: 0.569368, Valid acc 0.836905, Time 00:02:56, lr 0.01000
Epoch 7. Train Loss: 1.090643, Train acc 0.711575, Valid Loss: 0.597400, Valid acc 0.817857, Time 00:02:54, lr 0.01000
Epoch 8. Train Loss: 1.058844, Train acc 0.71679

In [8]:
net.save_params('inception.params')