In [1]:
"""
Created on Mon Mar 12 10:25:00 2018
@author: Lance Zhang

"""

import sys
sys.path.append("./models")
import torch
import torchvision
import torch.autograd as autograd
import torch.utils.data as Data
import torch.optim as optim
import torch.nn as nn
import classification_models

In [2]:
# configurations
data_dir = './MNIST_data'
n_steps = 28
input_dims = 28
n_classes = 10 

# model config
cell_type = "LSTM" # only support LSTM
hidden_structs = [20] * 5 # Give a list of the dimension in each layer
dilations = [1, 2, 4, 8, 16] # Give a list of the dilation in each layer
assert(len(hidden_structs) == len(dilations))

# learning config
batch_size = 128
learning_rate = 1.0e-3
training_iters = batch_size * 30000
testing_step = 5000
display_step = 150

# permutation seed 
seed = 92916

In [3]:
# loading the mnist data
DOWNLOAD_MNIST = True
train_data = torchvision.datasets.MNIST(root=data_dir, 
                                        train=True, 
                                        transform=torchvision.transforms.ToTensor(),
                                        download=DOWNLOAD_MNIST
                                        )
train_loader = Data.DataLoader(train_data, batch_size, shuffle=False, num_workers=2)

test_data = torchvision.datasets.MNIST(root=data_dir, 
                                       train = False,
                                       transform=torchvision.transforms.ToTensor()
                                       )


# shape (2000, 28, 28) value in range(0,1)
test_x = autograd.Variable(test_data.test_data, volatile=True).type(torch.FloatTensor)[:]/255.0
test_y = test_data.test_labels[:]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [None]:
# build prediction graph
print ("==> Building a dRNN with %s cells" %cell_type)
model = classification_models.drnn_classification(hidden_structs, dilations, n_classes, input_dims)

# build loss and optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
for iter in range(training_iters): 
    for step, (batch_x, batch_y) in enumerate(train_loader): 
        # (128,1, 28, 28) reshape to (128, 28, 28)
        batch_x = autograd.Variable(batch_x.view(-1, n_steps, input_dims))
        batch_y = autograd.Variable(batch_y)
        # reshape inputs
        x_reformat = classification_models._rnn_reformat(batch_x, input_dims, n_steps)
        
        optimizer.zero_grad()
        
        pred = model.forward(x_reformat)
        
        cost = criterion(pred, batch_y)
        
        
        cost.backward()
        optimizer.step()
 
        
        if (step + 1) % display_step == 0:
            print ("Iter " + str(iter + 1) + ", Step "+str(step+1)+", Avarage Loss: " + "{:.6f}".format(cost.data[0]))

    # validation performance
    x_reformat = classification_models._rnn_reformat(test_x, input_dims, n_steps)
    test_output = model.forward(x_reformat)
    pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
    accuracy = sum(pred_y == test_y) / test_y.size(0)
    print ("========> Test Accuarcy: {:.6f}".format(accuracy)) 
         

print ("end")

==> Building a dRNN with LSTM cells
Iter 1, Step 150, Avarage Loss: 1.830165
Iter 1, Step 300, Avarage Loss: 1.509589
Iter 1, Step 450, Avarage Loss: 1.452384
Iter 2, Step 150, Avarage Loss: 1.303601
Iter 2, Step 300, Avarage Loss: 1.212917
Iter 2, Step 450, Avarage Loss: 1.031403
Iter 3, Step 150, Avarage Loss: 1.048393
Iter 3, Step 300, Avarage Loss: 0.793857
Iter 3, Step 450, Avarage Loss: 0.852067
Iter 4, Step 150, Avarage Loss: 0.760702
Iter 4, Step 300, Avarage Loss: 0.759095
Iter 4, Step 450, Avarage Loss: 0.720279
Iter 5, Step 150, Avarage Loss: 0.561015
Iter 5, Step 300, Avarage Loss: 0.580134
Iter 5, Step 450, Avarage Loss: 0.673709
Iter 6, Step 150, Avarage Loss: 0.411711
Iter 6, Step 300, Avarage Loss: 0.479466
Iter 6, Step 450, Avarage Loss: 0.477710
Iter 7, Step 150, Avarage Loss: 0.336266
Iter 7, Step 300, Avarage Loss: 0.375651
Iter 7, Step 450, Avarage Loss: 0.383987
Iter 8, Step 150, Avarage Loss: 0.322849
Iter 8, Step 300, Avarage Loss: 0.265149
Iter 8, Step 450, Ava

Iter 52, Step 300, Avarage Loss: 0.037613
Iter 52, Step 450, Avarage Loss: 0.030956
Iter 53, Step 150, Avarage Loss: 0.048586
Iter 53, Step 300, Avarage Loss: 0.021965
Iter 53, Step 450, Avarage Loss: 0.026982
Iter 54, Step 150, Avarage Loss: 0.020033
Iter 54, Step 300, Avarage Loss: 0.014575
Iter 54, Step 450, Avarage Loss: 0.024064
Iter 55, Step 150, Avarage Loss: 0.006515
Iter 55, Step 300, Avarage Loss: 0.035959
Iter 55, Step 450, Avarage Loss: 0.023341
Iter 56, Step 150, Avarage Loss: 0.008624
Iter 56, Step 300, Avarage Loss: 0.017980
Iter 56, Step 450, Avarage Loss: 0.015062
Iter 57, Step 150, Avarage Loss: 0.011630
Iter 57, Step 300, Avarage Loss: 0.025189
Iter 57, Step 450, Avarage Loss: 0.013507
Iter 58, Step 150, Avarage Loss: 0.008282
Iter 58, Step 300, Avarage Loss: 0.044492
Iter 58, Step 450, Avarage Loss: 0.017440
Iter 59, Step 150, Avarage Loss: 0.009742
Iter 59, Step 300, Avarage Loss: 0.009881
Iter 59, Step 450, Avarage Loss: 0.013830
Iter 60, Step 150, Avarage Loss: 0

Iter 103, Step 450, Avarage Loss: 0.001013
Iter 104, Step 150, Avarage Loss: 0.014803
Iter 104, Step 300, Avarage Loss: 0.004559
Iter 104, Step 450, Avarage Loss: 0.005531
Iter 105, Step 150, Avarage Loss: 0.002619
Iter 105, Step 300, Avarage Loss: 0.002670
Iter 105, Step 450, Avarage Loss: 0.003390
Iter 106, Step 150, Avarage Loss: 0.002405
Iter 106, Step 300, Avarage Loss: 0.002567
Iter 106, Step 450, Avarage Loss: 0.034104
Iter 107, Step 150, Avarage Loss: 0.002253
Iter 107, Step 300, Avarage Loss: 0.015213
Iter 107, Step 450, Avarage Loss: 0.002241
Iter 108, Step 150, Avarage Loss: 0.001016
Iter 108, Step 300, Avarage Loss: 0.001019
Iter 108, Step 450, Avarage Loss: 0.001821
Iter 109, Step 150, Avarage Loss: 0.001664
Iter 109, Step 300, Avarage Loss: 0.008603
Iter 109, Step 450, Avarage Loss: 0.016813
Iter 110, Step 150, Avarage Loss: 0.001171
Iter 110, Step 300, Avarage Loss: 0.000734
Iter 110, Step 450, Avarage Loss: 0.000972
Iter 111, Step 150, Avarage Loss: 0.011488
Iter 111, S

Iter 154, Step 150, Avarage Loss: 0.001638
Iter 154, Step 300, Avarage Loss: 0.010321
Iter 154, Step 450, Avarage Loss: 0.000348
Iter 155, Step 150, Avarage Loss: 0.001278
Iter 155, Step 300, Avarage Loss: 0.000616
Iter 155, Step 450, Avarage Loss: 0.000815
Iter 156, Step 150, Avarage Loss: 0.002099
Iter 156, Step 300, Avarage Loss: 0.000530
Iter 156, Step 450, Avarage Loss: 0.000407
Iter 157, Step 150, Avarage Loss: 0.001010
Iter 157, Step 300, Avarage Loss: 0.000436
Iter 157, Step 450, Avarage Loss: 0.000953
Iter 158, Step 150, Avarage Loss: 0.000362
Iter 158, Step 300, Avarage Loss: 0.000398
Iter 158, Step 450, Avarage Loss: 0.000444
Iter 159, Step 150, Avarage Loss: 0.000388
Iter 159, Step 300, Avarage Loss: 0.000344
Iter 159, Step 450, Avarage Loss: 0.000688
Iter 160, Step 150, Avarage Loss: 0.000416
Iter 160, Step 300, Avarage Loss: 0.000417
Iter 160, Step 450, Avarage Loss: 0.000327
Iter 161, Step 150, Avarage Loss: 0.000378
Iter 161, Step 300, Avarage Loss: 0.000270
Iter 161, S

Iter 204, Step 300, Avarage Loss: 0.000325
Iter 204, Step 450, Avarage Loss: 0.001113
Iter 205, Step 150, Avarage Loss: 0.000218
Iter 205, Step 300, Avarage Loss: 0.000204
Iter 205, Step 450, Avarage Loss: 0.000387
Iter 206, Step 150, Avarage Loss: 0.000130
Iter 206, Step 300, Avarage Loss: 0.000994
Iter 206, Step 450, Avarage Loss: 0.000439
Iter 207, Step 150, Avarage Loss: 0.000141
Iter 207, Step 300, Avarage Loss: 0.001609
Iter 207, Step 450, Avarage Loss: 0.000966
Iter 208, Step 150, Avarage Loss: 0.000985
Iter 208, Step 300, Avarage Loss: 0.003689
Iter 208, Step 450, Avarage Loss: 0.000288
Iter 209, Step 150, Avarage Loss: 0.000441
Iter 209, Step 300, Avarage Loss: 0.021118
Iter 209, Step 450, Avarage Loss: 0.001502
Iter 210, Step 150, Avarage Loss: 0.000516
Iter 210, Step 300, Avarage Loss: 0.000250
Iter 210, Step 450, Avarage Loss: 0.001290
Iter 211, Step 150, Avarage Loss: 0.000560
Iter 211, Step 300, Avarage Loss: 0.002363
Iter 211, Step 450, Avarage Loss: 0.024255
Iter 212, S

Iter 254, Step 450, Avarage Loss: 0.000267
Iter 255, Step 150, Avarage Loss: 0.000328
Iter 255, Step 300, Avarage Loss: 0.001772
Iter 255, Step 450, Avarage Loss: 0.002552
Iter 256, Step 150, Avarage Loss: 0.001049
Iter 256, Step 300, Avarage Loss: 0.001987
Iter 256, Step 450, Avarage Loss: 0.000393
Iter 257, Step 150, Avarage Loss: 0.001084
Iter 257, Step 300, Avarage Loss: 0.000160
Iter 257, Step 450, Avarage Loss: 0.008386
Iter 258, Step 150, Avarage Loss: 0.011035
Iter 258, Step 300, Avarage Loss: 0.000132
Iter 258, Step 450, Avarage Loss: 0.000143
Iter 259, Step 150, Avarage Loss: 0.000571
Iter 259, Step 300, Avarage Loss: 0.000086
Iter 259, Step 450, Avarage Loss: 0.004043
Iter 260, Step 150, Avarage Loss: 0.000792
Iter 260, Step 300, Avarage Loss: 0.000684
Iter 260, Step 450, Avarage Loss: 0.003684
Iter 261, Step 150, Avarage Loss: 0.000130
Iter 261, Step 300, Avarage Loss: 0.000210
Iter 261, Step 450, Avarage Loss: 0.001483
Iter 262, Step 150, Avarage Loss: 0.000807
Iter 262, S

Iter 305, Step 150, Avarage Loss: 0.000000
Iter 305, Step 300, Avarage Loss: 0.000000
Iter 305, Step 450, Avarage Loss: 0.000000
Iter 306, Step 150, Avarage Loss: 0.000140
Iter 306, Step 300, Avarage Loss: 0.001264
Iter 306, Step 450, Avarage Loss: 0.003330
Iter 307, Step 150, Avarage Loss: 0.000175
Iter 307, Step 300, Avarage Loss: 0.011495
Iter 307, Step 450, Avarage Loss: 0.006830
Iter 308, Step 150, Avarage Loss: 0.000024
Iter 308, Step 300, Avarage Loss: 0.000040
Iter 308, Step 450, Avarage Loss: 0.000536
Iter 309, Step 150, Avarage Loss: 0.000223
Iter 309, Step 300, Avarage Loss: 0.000012
Iter 309, Step 450, Avarage Loss: 0.000450
Iter 310, Step 150, Avarage Loss: 0.000035
Iter 310, Step 300, Avarage Loss: 0.000035
Iter 310, Step 450, Avarage Loss: 0.000023
Iter 311, Step 150, Avarage Loss: 0.004647
Iter 311, Step 300, Avarage Loss: 0.007516
Iter 311, Step 450, Avarage Loss: 0.000238
Iter 312, Step 150, Avarage Loss: 0.000032
Iter 312, Step 300, Avarage Loss: 0.000028
Iter 312, S

Iter 355, Step 300, Avarage Loss: 0.000001
Iter 355, Step 450, Avarage Loss: 0.000000
Iter 356, Step 150, Avarage Loss: 0.000000
Iter 356, Step 300, Avarage Loss: 0.000000
Iter 356, Step 450, Avarage Loss: 0.000000
Iter 357, Step 150, Avarage Loss: 0.000000
Iter 357, Step 300, Avarage Loss: 0.000000
Iter 357, Step 450, Avarage Loss: 0.000000
Iter 358, Step 150, Avarage Loss: 0.000000
Iter 358, Step 300, Avarage Loss: 0.000000
Iter 358, Step 450, Avarage Loss: 0.000000
Iter 359, Step 150, Avarage Loss: 0.000000
Iter 359, Step 300, Avarage Loss: 0.000000
Iter 359, Step 450, Avarage Loss: 0.000000
Iter 360, Step 150, Avarage Loss: 0.000000
Iter 360, Step 300, Avarage Loss: 0.000000
Iter 360, Step 450, Avarage Loss: 0.000000
Iter 361, Step 150, Avarage Loss: 0.000000
Iter 361, Step 300, Avarage Loss: 0.000000
Iter 361, Step 450, Avarage Loss: 0.000000
Iter 362, Step 150, Avarage Loss: 0.000000
Iter 362, Step 300, Avarage Loss: 0.000000
Iter 362, Step 450, Avarage Loss: 0.000000
Iter 363, S

Iter 405, Step 450, Avarage Loss: 0.000074
Iter 406, Step 150, Avarage Loss: 0.000011
Iter 406, Step 300, Avarage Loss: 0.000051
Iter 406, Step 450, Avarage Loss: 0.000092
Iter 407, Step 150, Avarage Loss: 0.000006
Iter 407, Step 300, Avarage Loss: 0.000023
Iter 407, Step 450, Avarage Loss: 0.000041
Iter 408, Step 150, Avarage Loss: 0.000005
Iter 408, Step 300, Avarage Loss: 0.000024
Iter 408, Step 450, Avarage Loss: 0.000031
Iter 409, Step 150, Avarage Loss: 0.000005
Iter 409, Step 300, Avarage Loss: 0.000027
Iter 409, Step 450, Avarage Loss: 0.000026
Iter 410, Step 150, Avarage Loss: 0.000004
Iter 410, Step 300, Avarage Loss: 0.000029
Iter 410, Step 450, Avarage Loss: 0.000022
Iter 411, Step 150, Avarage Loss: 0.000004
Iter 411, Step 300, Avarage Loss: 0.000030
Iter 411, Step 450, Avarage Loss: 0.000019
Iter 412, Step 150, Avarage Loss: 0.000003
Iter 412, Step 300, Avarage Loss: 0.000030
Iter 412, Step 450, Avarage Loss: 0.000017
Iter 413, Step 150, Avarage Loss: 0.000003
Iter 413, S