In [None]:
from mxnet import init
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import nd
from mxnet import autograd
from mxnet import image
import numpy as np
import pickle as p
import mxnet as mx
from time import time
import matplotlib.pyplot as plt
ctx = mx.gpu()

def load_cifar(route = '/home/sinyer/python/data/cifar-10-batches-py', train_num = 50000, test_num = 10000):
    def load_batch(filename):
        with open(filename, 'rb')as f:
            data_dict = p.load(f, encoding='latin1')
            X = data_dict['data']
            Y = data_dict['labels']
            X = X.reshape(10000, 3, 32,32).astype("float")
            Y = np.array(Y)
            return X, Y
    def load_labels(filename):
        with open(filename, 'rb') as f:
            label_names = p.load(f, encoding='latin1')
            names = label_names['label_names']
            return names
    label_names = load_labels(route + "/batches.meta")
    x1, y1 = load_batch(route + "/data_batch_1")
    x2, y2 = load_batch(route + "/data_batch_2")
    x3, y3 = load_batch(route + "/data_batch_3")
    x4, y4 = load_batch(route + "/data_batch_4")
    x5, y5 = load_batch(route + "/data_batch_5")
    test_pic, test_label = load_batch(route + "/test_batch")
    train_pic = np.concatenate((x1, x2, x3, x4, x5))
    train_label = np.concatenate((y1, y2, y3, y4, y5))
    train_pic = train_pic[:train_num].astype('float32')/255
    train_label = train_label[:train_num].astype('float32')
    test_pic = test_pic[:test_num].astype('float32')/255
    test_label = test_label[:test_num].astype('float32')
    return train_pic, train_label, test_pic, test_label

train_pic, train_label, test_pic, test_label = load_cifar()

batch_size = 128
train_data = gluon.data.DataLoader(gluon.data.ArrayDataset(train_pic, train_label), batch_size, shuffle=True)
test_data = gluon.data.DataLoader(gluon.data.ArrayDataset(test_pic, test_label), batch_size, shuffle=False)

def accuracy(output, label):
    return nd.mean(output.argmax(axis=1)==label).asscalar()

def evaluate_accuracy(data_iterator, net, ctx):
    acc = 0.
    for data, label in data_iterator:
        data = tf_test(data).as_in_context(ctx)
        label = label.as_in_context(ctx)
        output = net(data)
        acc += accuracy(output, label)
    return acc / len(data_iterator)

aug_train = image.CreateAugmenter(data_shape=(3, 32, 32), 
                rand_crop=True, rand_mirror=True,
                mean=np.array([0.4914, 0.4822, 0.4465]), 
                std=np.array([0.2023, 0.1994, 0.2010]))

aug_test = image.CreateAugmenter(data_shape=(3, 32, 32), 
                mean=np.array([0.4914, 0.4822, 0.4465]), 
                std=np.array([0.2023, 0.1994, 0.2010]))

def apply(img, auglist):
    for aug in auglist:
        img = aug(img)
    return img

def tf_test(data):
    data = nd.transpose(data, (0,2,3,1))
    data = nd.stack(*[apply(d, aug_test) for d in data])
    data = nd.transpose(data, (0,3,1,2))
    return data

def tf_train(data): 
    data = nd.pad(data, pad_width=(0,0,0,0,2,2,2,2),mode='constant',constant_value=0)
    data = nd.transpose(data, (0,2,3,1))
    data = nd.stack(*[apply(d, aug_train) for d in data])
    data = nd.transpose(data, (0,3,1,2))
    return data

class Residual(nn.Block):
    def __init__(self, channels, same_shape=True, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.same_shape = same_shape
        with self.name_scope():
            strides = 1 if same_shape else 2
            self.bn1 = nn.BatchNorm()
            self.conv1 = nn.Conv2D(channels, kernel_size=3, padding=1, strides=strides)
            self.bn2 = nn.BatchNorm()
            self.conv2 = nn.Conv2D(channels, kernel_size=3, padding=1)
            if not same_shape:
                self.conv3 = nn.Conv2D(channels, kernel_size=1, strides=strides)
    def forward(self, x):
        out = self.conv1(nd.relu(self.bn1(x)))
        out = self.conv2(nd.relu(self.bn2(out)))
        if not self.same_shape:
            x = self.conv3(x)
        return out + x

class ResNet(nn.Block):
    def __init__(self, num_classes, **kwargs):
        super(ResNet, self).__init__(**kwargs)
        with self.name_scope(): 
            net = self.net = nn.Sequential()
            net.add(nn.Conv2D(channels=16, kernel_size=3, strides=1, padding=1))
            net.add(nn.BatchNorm())
            net.add(nn.Activation(activation='relu'))
            net.add(Residual(channels=16*8, same_shape=False))
            net.add(Residual(channels=32*8, same_shape=False))
            net.add(Residual(channels=64*8, same_shape=False))
            net.add(nn.AvgPool2D(pool_size=4))
            net.add(nn.Flatten())
            net.add(nn.Dense(num_classes))
    def forward(self, x):
        out = x
        for i, b in enumerate(self.net):
            out = b(out)
        return out