In [22]:
import json
import numpy as np
from pprint import pprint

In [2]:
from mxnet.gluon import nn
from mxnet import nd
from mxnet import autograd

class Residual(nn.Block):
    def __init__(self, channels, same_shape=True, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.same_shape = same_shape
        strides = 1 if same_shape else 2
        self.conv1 = nn.Conv2D(channels, kernel_size=3, padding=1,
                              strides=strides)
        self.bn1 = nn.BatchNorm()
        self.conv2 = nn.Conv2D(channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm()
        if not same_shape:
            self.conv3 = nn.Conv2D(channels, kernel_size=1,
                                  strides=strides)

    def forward(self, x):
        out = nd.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if not self.same_shape:
            x = self.conv3(x)
        return nd.relu(out + x)

In [3]:
blk = Residual(3)
blk.initialize()

x = nd.random.uniform(shape=(4, 3, 6, 6))
blk(x).shape

(4, 3, 6, 6)

In [4]:
blk2 = Residual(8, same_shape=False)
blk2.initialize()
blk2(x).shape

(4, 8, 3, 3)

In [5]:
class ResNet(nn.Block):
    def __init__(self, num_classes, verbose=False, **kwargs):
        super(ResNet, self).__init__(**kwargs)
        self.verbose = verbose
        # add name_scope on the outermost Sequential
        with self.name_scope():
            # block 1
            b1 = nn.Conv2D(64, kernel_size=3, strides=1)
            # block 2
            b2 = nn.Sequential()
            b2.add(
                nn.MaxPool2D(pool_size=3, strides=1),
                Residual(64),
            )
            
            # block 6
            b6 = nn.Sequential()
            b6.add(
                nn.AvgPool2D(pool_size=3),
                nn.Dense(num_classes)
            )
            # chain all blocks together
            self.net = nn.Sequential()
            self.net.add(b1, b2, b6)

    def forward(self, x):
        out = x
        for i, b in enumerate(self.net):
            out = b(out)
            if self.verbose:
                print('Block %d output: %s'%(i+1, out.shape))
        return out

In [6]:
net = ResNet(17, verbose=True)
net.initialize()

x = nd.random.uniform(shape=(4, 1, 8, 8))
y = net(x)
y.shape

Block 1 output: (4, 64, 6, 6)
Block 2 output: (4, 64, 4, 4)
Block 3 output: (4, 17)


(4, 17)

In [7]:
dataset = open('data.json', 'r').read()
dataset = json.loads(dataset)
dataset[0] = nd.array([np.array(l) for l in dataset[0]])
dataset[1] = nd.array([np.array(l) for l in dataset[1]])
X = dataset[0]
y = dataset[1]

In [8]:
y


[[ 0.   0.   0.  ...,  0.   0.  -1. ]
 [ 0.   0.   0.  ...,  0.   0.  -1. ]
 [ 0.   0.   0.  ...,  0.   0.  -1. ]
 ..., 
 [ 0.   0.5  0.  ...,  0.   0.  -1. ]
 [ 0.   0.   0.  ...,  0.   0.  -1. ]
 [ 0.   1.   0.  ...,  0.   0.  -1. ]]
<NDArray 60x65 @cpu(0)>

In [9]:
m = int(X.shape[0])
n = int(X.shape[1])
print(m,n)

60 8


In [10]:

temp_X = nd.reshape(X, (m,1,n,n))
temp_y = nd.reshape(y, (m,n*n+1))

In [11]:
def _get_batch(batch, ctx):
    """return data and label on ctx"""
    if isinstance(batch, mx.io.DataBatch):
        data = batch.data[0]
        label = batch.label[0]
    else:
        data, label = batch
    return (gluon.utils.split_and_load(data, ctx),
            gluon.utils.split_and_load(label, ctx),
            data.shape[0])

In [12]:
import mxnet as mx
from time import time
def train(train_data, test_data, net, loss, trainer, ctx, num_epochs, print_batches=None):
    """Train a network"""
    print("Start training on ", ctx)
    if isinstance(ctx, mx.Context):
        ctx = [ctx]
    for epoch in range(num_epochs):
        train_loss, train_acc, n, m = 0.0, 0.0, 0.0, 0.0
        if isinstance(train_data, mx.io.MXDataIter):
            train_data.reset()
        start = time()
        for i, batch in enumerate(train_data):
            data, label, batch_size = _get_batch(batch, ctx)
            losses = []
            with autograd.record():
                outputs = [net(X) for X in data]
                losses = [loss(yhat, y) for yhat, y in zip(outputs, label)]
            for l in losses:
                l.backward()
            train_loss += sum([l.sum().asscalar() for l in losses])
            trainer.step(batch_size)
            n += batch_size
            m += sum([y.size for y in label])
            if print_batches and (i+1) % print_batches == 0:
                print("Batch %d. Loss: %f" % (
                    n, train_loss/n
                ))

        print("Epoch %d. Loss: %.3f, Time %.1f sec" % (
            epoch, train_loss/n, time() - start
        ))


In [32]:
import sys
sys.path.append('..')
import utils
from mxnet import gluon
from mxnet import init


datasetTrain = gluon.data.ArrayDataset(temp_X[:int(m*0.8)], temp_y[:int(m*0.8)])
datasetTest = gluon.data.ArrayDataset(temp_X[int(m*0.8):], temp_y[int(m*0.8):])

train_data = utils.DataLoader(datasetTrain, batch_size = 4, shuffle = True)
test_data = utils.DataLoader(datasetTest, batch_size = 4, shuffle = True)

ctx = utils.try_gpu()
net = ResNet(n*n+1)
net.initialize(ctx=ctx, init=init.Xavier())

# шонч╗Г
loss = gluon.loss.L2Loss()
trainer = gluon.Trainer(net.collect_params(),
                        'sgd', {'learning_rate': 0.5})
train(train_data, test_data, net, loss,
            trainer, ctx, num_epochs=1000)
# epochs = 5
# batch_size = 10
# for e in range(epochs):
#     total_loss = 0
#     for data, label in data_iter:
#         with autograd.record():
#             output = net(data)
#             loss = square_loss(output, label)
#         loss.backward()
#         trainer.step(batch_size)
#         total_loss += nd.sum(loss).asscalar()
#     print("Epoch %d, average loss: %f" % (e, total_loss/num_examples))

Start training on  cpu(0)
Epoch 0. Loss: 0.036, Time 0.1 sec
Epoch 1. Loss: 0.005, Time 0.1 sec
Epoch 2. Loss: 0.003, Time 0.1 sec
Epoch 3. Loss: 0.002, Time 0.1 sec
Epoch 4. Loss: 0.002, Time 0.1 sec
Epoch 5. Loss: 0.002, Time 0.1 sec
Epoch 6. Loss: 0.002, Time 0.1 sec
Epoch 7. Loss: 0.002, Time 0.2 sec
Epoch 8. Loss: 0.001, Time 0.1 sec
Epoch 9. Loss: 0.001, Time 0.1 sec
Epoch 10. Loss: 0.001, Time 0.1 sec
Epoch 11. Loss: 0.001, Time 0.1 sec
Epoch 12. Loss: 0.001, Time 0.1 sec
Epoch 13. Loss: 0.001, Time 0.2 sec
Epoch 14. Loss: 0.001, Time 0.1 sec
Epoch 15. Loss: 0.001, Time 0.1 sec
Epoch 16. Loss: 0.001, Time 0.1 sec
Epoch 17. Loss: 0.001, Time 0.1 sec
Epoch 18. Loss: 0.001, Time 0.1 sec
Epoch 19. Loss: 0.001, Time 0.1 sec
Epoch 20. Loss: 0.001, Time 0.1 sec
Epoch 21. Loss: 0.001, Time 0.1 sec
Epoch 22. Loss: 0.001, Time 0.1 sec
Epoch 23. Loss: 0.001, Time 0.1 sec
Epoch 24. Loss: 0.001, Time 0.1 sec
Epoch 25. Loss: 0.001, Time 0.1 sec
Epoch 26. Loss: 0.001, Time 0.1 sec
Epoch 27. Lo

Epoch 224. Loss: 0.000, Time 0.1 sec
Epoch 225. Loss: 0.000, Time 0.1 sec
Epoch 226. Loss: 0.000, Time 0.1 sec
Epoch 227. Loss: 0.000, Time 0.1 sec
Epoch 228. Loss: 0.000, Time 0.1 sec
Epoch 229. Loss: 0.000, Time 0.1 sec
Epoch 230. Loss: 0.000, Time 0.1 sec
Epoch 231. Loss: 0.000, Time 0.1 sec
Epoch 232. Loss: 0.000, Time 0.4 sec
Epoch 233. Loss: 0.000, Time 0.1 sec
Epoch 234. Loss: 0.000, Time 0.1 sec
Epoch 235. Loss: 0.000, Time 0.1 sec
Epoch 236. Loss: 0.000, Time 0.1 sec
Epoch 237. Loss: 0.000, Time 0.1 sec
Epoch 238. Loss: 0.000, Time 0.1 sec
Epoch 239. Loss: 0.000, Time 0.1 sec
Epoch 240. Loss: 0.000, Time 0.1 sec
Epoch 241. Loss: 0.000, Time 0.1 sec
Epoch 242. Loss: 0.000, Time 0.1 sec
Epoch 243. Loss: 0.000, Time 0.1 sec
Epoch 244. Loss: 0.000, Time 0.1 sec
Epoch 245. Loss: 0.000, Time 0.1 sec
Epoch 246. Loss: 0.000, Time 0.1 sec
Epoch 247. Loss: 0.000, Time 0.1 sec
Epoch 248. Loss: 0.000, Time 0.1 sec
Epoch 249. Loss: 0.000, Time 0.1 sec
Epoch 250. Loss: 0.000, Time 0.1 sec
E

Epoch 446. Loss: 0.000, Time 0.1 sec
Epoch 447. Loss: 0.000, Time 0.1 sec
Epoch 448. Loss: 0.000, Time 0.1 sec
Epoch 449. Loss: 0.000, Time 0.1 sec
Epoch 450. Loss: 0.000, Time 0.1 sec
Epoch 451. Loss: 0.000, Time 0.1 sec
Epoch 452. Loss: 0.000, Time 0.1 sec
Epoch 453. Loss: 0.000, Time 0.1 sec
Epoch 454. Loss: 0.000, Time 0.1 sec
Epoch 455. Loss: 0.000, Time 0.1 sec
Epoch 456. Loss: 0.000, Time 0.1 sec
Epoch 457. Loss: 0.000, Time 0.1 sec
Epoch 458. Loss: 0.000, Time 0.1 sec
Epoch 459. Loss: 0.000, Time 0.1 sec
Epoch 460. Loss: 0.000, Time 0.1 sec
Epoch 461. Loss: 0.000, Time 0.1 sec
Epoch 462. Loss: 0.000, Time 0.1 sec
Epoch 463. Loss: 0.000, Time 0.1 sec
Epoch 464. Loss: 0.000, Time 0.2 sec
Epoch 465. Loss: 0.000, Time 0.1 sec
Epoch 466. Loss: 0.000, Time 0.1 sec
Epoch 467. Loss: 0.000, Time 0.1 sec
Epoch 468. Loss: 0.000, Time 0.1 sec
Epoch 469. Loss: 0.000, Time 0.1 sec
Epoch 470. Loss: 0.000, Time 0.1 sec
Epoch 471. Loss: 0.000, Time 0.1 sec
Epoch 472. Loss: 0.000, Time 0.1 sec
E

Epoch 669. Loss: 0.000, Time 0.1 sec
Epoch 670. Loss: 0.000, Time 0.1 sec
Epoch 671. Loss: 0.000, Time 0.1 sec
Epoch 672. Loss: 0.000, Time 0.1 sec
Epoch 673. Loss: 0.000, Time 0.1 sec
Epoch 674. Loss: 0.000, Time 0.1 sec
Epoch 675. Loss: 0.000, Time 0.1 sec
Epoch 676. Loss: 0.000, Time 0.1 sec
Epoch 677. Loss: 0.000, Time 0.1 sec
Epoch 678. Loss: 0.000, Time 0.1 sec
Epoch 679. Loss: 0.000, Time 0.1 sec
Epoch 680. Loss: 0.000, Time 0.1 sec
Epoch 681. Loss: 0.000, Time 0.1 sec
Epoch 682. Loss: 0.000, Time 0.1 sec
Epoch 683. Loss: 0.000, Time 0.1 sec
Epoch 684. Loss: 0.000, Time 0.1 sec
Epoch 685. Loss: 0.000, Time 0.1 sec
Epoch 686. Loss: 0.000, Time 0.1 sec
Epoch 687. Loss: 0.000, Time 0.1 sec
Epoch 688. Loss: 0.000, Time 0.1 sec
Epoch 689. Loss: 0.000, Time 0.1 sec
Epoch 690. Loss: 0.000, Time 0.1 sec
Epoch 691. Loss: 0.000, Time 0.1 sec
Epoch 692. Loss: 0.000, Time 0.1 sec
Epoch 693. Loss: 0.000, Time 0.1 sec
Epoch 694. Loss: 0.000, Time 0.1 sec
Epoch 695. Loss: 0.000, Time 0.1 sec
E

Epoch 893. Loss: 0.000, Time 0.1 sec
Epoch 894. Loss: 0.000, Time 0.1 sec
Epoch 895. Loss: 0.000, Time 0.1 sec
Epoch 896. Loss: 0.000, Time 0.1 sec
Epoch 897. Loss: 0.000, Time 0.1 sec
Epoch 898. Loss: 0.000, Time 0.1 sec
Epoch 899. Loss: 0.000, Time 0.1 sec
Epoch 900. Loss: 0.000, Time 0.1 sec
Epoch 901. Loss: 0.000, Time 0.1 sec
Epoch 902. Loss: 0.000, Time 0.1 sec
Epoch 903. Loss: 0.000, Time 0.1 sec
Epoch 904. Loss: 0.000, Time 0.1 sec
Epoch 905. Loss: 0.000, Time 0.1 sec
Epoch 906. Loss: 0.000, Time 0.1 sec
Epoch 907. Loss: 0.000, Time 0.1 sec
Epoch 908. Loss: 0.000, Time 0.1 sec
Epoch 909. Loss: 0.000, Time 0.1 sec
Epoch 910. Loss: 0.000, Time 0.1 sec
Epoch 911. Loss: 0.000, Time 0.1 sec
Epoch 912. Loss: 0.000, Time 0.1 sec
Epoch 913. Loss: 0.000, Time 0.1 sec
Epoch 914. Loss: 0.000, Time 0.1 sec
Epoch 915. Loss: 0.000, Time 0.1 sec
Epoch 916. Loss: 0.000, Time 0.1 sec
Epoch 917. Loss: 0.000, Time 0.1 sec
Epoch 918. Loss: 0.000, Time 0.1 sec
Epoch 919. Loss: 0.000, Time 0.1 sec
E

In [36]:
for i, j in train_data:
    pprint(net(i)[0])
    pprint(j[0])
    mask = j[0] != 0
    print(nd.multiply(net(i)[0],mask))
    break


[  4.23843414e-03  -1.11521259e-02   1.31375883e-02   9.09030885e-02
  -8.60948674e-03  -8.17036442e-03  -7.99241848e-03   1.11951800e-02
  -2.86042634e-02   9.37412307e-02  -1.33569017e-02  -1.98139809e-04
  -6.56664744e-03   2.11453866e-02   1.61224902e-02   5.44572901e-03
  -9.84130893e-03   2.60841157e-02  -6.63036387e-03  -2.62511708e-03
   8.81136954e-02  -1.19086877e-02   8.13862085e-02   2.90096924e-03
   1.62075944e-02   2.29083560e-02   1.23490170e-02  -7.67406076e-03
  -4.48686443e-03  -8.86567682e-03   1.03991807e-01   2.20773350e-02
  -3.47914994e-02   3.45115550e-03   7.90653378e-02   4.17295285e-03
  -7.12510757e-03   1.48338163e-02  -1.52116455e-03  -1.16411895e-02
  -1.06487442e-02   8.00879672e-02   6.27601054e-03   7.39737554e-03
   1.55500900e-02   7.41975382e-04   1.21987723e-02  -4.63728327e-03
  -1.78602338e-02   3.73544507e-02   1.15138814e-01   9.90186855e-02
   9.47135165e-02  -6.96273986e-03  -1.15268920e-02   1.58146881e-02
  -1.78446155e-03   1.04648145e-0