In [10]:
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
from torch import optim
import dlc_practical_prologue as prologue

In [5]:
N = 1000
train_input, train_target, train_classes, test_input, test_target, test_classes = prologue.generate_pair_sets(N)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(392, 256, bias=False)
        self.hidden1_bn = nn.BatchNorm1d(256)
        self.hidden2 = nn.Linear(256, 256, bias=False)
        self.hidden2_bn = nn.BatchNorm1d(256)
        self.output = nn.Linear(256, 2, bias=False)
        self.output_bn = nn.BatchNorm1d(2)

    def forward(self, x, y=None):
        x = x.view(x.size()[0], -1)
        x = self.hidden1_bn(self.hidden1(x))
        x = self.hidden2_bn(self.hidden2(F.relu(x)))
        x = self.output_bn(self.output(F.relu(x)))
        return F.log_softmax(x)

In [7]:
class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        nb_hidden = 200
        self.conv1 = nn.Conv2d(2, 32, kernel_size=5, padding = 3)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=5, padding = 3)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=2)
        self.fc1 = nn.Linear(1024, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 2)

    def forward(self, x):
        #print(x.size())
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2))
        #print(x.size())
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
        #print(x.size())
        x = F.relu(self.conv3(x))
        #print(x.size())
        x = F.relu(self.fc1(x.view(-1, 1024)))
        #print(x.size())
        x = self.fc2(x)
        #print(x.size())
        return F.log_softmax(x, dim = 1)

In [15]:
def train_model(model, train_input, train_target):
    criterion = nn.CrossEntropyLoss()
    #optimizer = optim.SGD(model.parameters(), lr = 1e-2)
    optimizer = optim.Adam(model.parameters(), lr = 1e-4)
    nb_epochs = 25

    for e in range(nb_epochs):
        for b in range(0, train_input.size(0), mini_batch_size):
            output = model(train_input.narrow(0, b, mini_batch_size))
            loss = criterion(output, train_target.narrow(0, b, mini_batch_size))
            model.zero_grad()
            loss.backward()
            optimizer.step()
        nb_train_errors = compute_nb_errors(model, train_input, train_target)
        print('train_error', nb_train_errors)
        print('loss', loss.item())

def compute_nb_errors(model, data_input, data_target):

    nb_data_errors = 0

    for b in range(0, data_input.size(0), mini_batch_size):
        output = model(data_input.narrow(0, b, mini_batch_size))
        _, predicted_classes = torch.max(output.data, 1)
        for k in range(mini_batch_size):
            if data_target.data[b + k] != predicted_classes[k]:
                nb_data_errors = nb_data_errors + 1
    return nb_data_errors

In [16]:
train_input, train_target = Variable(train_input), Variable(train_target)
test_input, test_target = Variable(test_input), Variable(test_target)

mini_batch_size = 100

for k in range(1):
    model = Net2()
    train_model(model, train_input, train_target)
    nb_train_errors = compute_nb_errors(model, train_input, train_target) / train_input.size(0) * 100,
    nb_test_errors = compute_nb_errors(model, test_input, test_target)
    print('train error', nb_train_errors)
    print('test error Net {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors) / test_input.size(0),
                                                      nb_test_errors, test_input.size(0)))

train_error 436
loss 0.7340800762176514
train_error 309
loss 0.553303599357605
train_error 233
loss 0.46417683362960815
train_error 189
loss 0.42474380135536194
train_error 164
loss 0.4004661440849304
train_error 146
loss 0.3753364086151123
train_error 123
loss 0.34730806946754456
train_error 118
loss 0.3202071487903595
train_error 101
loss 0.29634183645248413
train_error 84
loss 0.2717239260673523
train_error 75
loss 0.2512671649456024
train_error 60
loss 0.2309439778327942
train_error 53
loss 0.2125740647315979
train_error 43
loss 0.19083130359649658
train_error 36
loss 0.17270280420780182
train_error 27
loss 0.15573744475841522
train_error 22
loss 0.13866744935512543
train_error 19
loss 0.12342708557844162
train_error 15
loss 0.1078970804810524
train_error 14
loss 0.09531532973051071
train_error 12
loss 0.08284351974725723
train_error 10
loss 0.07211553305387497
train_error 10
loss 0.062252312898635864
train_error 10
loss 0.05471329763531685
train_error 13
loss 0.0485171303153038
tr