In [27]:
import os.path as osp

import torchvision
from config_global import ROOT_DIR
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import math
fl = math.floor

In [28]:
class LeNet(nn.Module):
    def __init__(self, image_size, num_classes=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(image_size[0], 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.last_map_x = fl((fl((image_size[1]-4)/2)-4)/2)
        self.last_map_y = fl((fl((image_size[2]-4)/2)-4)/2)

        self.linear1 = nn.Linear(16 * self.last_map_x * self.last_map_y, 120)
        self.linear2 = nn.Linear(120, 84)
        # self.out_layer = nn.Linear(84, num_classes)
        self.out_layer = nn.Identity()

    def forward(self, inp):
        x = self.pool1(F.relu(self.conv1(inp)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * self.last_map_x * self.last_map_y)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        otp = self.out_layer(x)
        return otp

class CNNRNNFeedback(nn.Module):
    def __init__(self, out_size):
        super(CNNRNNFeedback, self).__init__()
        self.rnn_in_size = 84
        self.hidden_size = 100

        self.cnn = LeNet((1, 28, 28), 10)
        self.rnn = nn.LSTMCell(self.rnn_in_size, self.hidden_size)

        self.out_layer = nn.Linear(self.hidden_size, out_size)

    def forward(self, inp, hid_in):
        x = self.cnn(inp)
        hid_out = self.rnn(x, hid_in)
        otp = self.out_layer(hid_out[0])
        return otp, hid_out

    def init_hidden(self, batch_size):
        init_hid = (torch.zeros(batch_size, self.hidden_size),
                    torch.zeros(batch_size, self.hidden_size))
        return init_hid

In [37]:
class DelayedMatch:
    def __init__(self, sample_step, delay_step, test_step, batch_size):
        self.sample_step = sample_step
        self.delay_step = delay_step
        self.test_step = test_step

        # comparison loss
        self.criterion = nn.BCEWithLogitsLoss()

        assert batch_size % 2 == 0, 'batch size must be odd number'
        self.batch_size = batch_size
        self.split_size = int(batch_size / 2)

    def roll(self, model, data_batch):
        input_, label_ = data_batch

        # assuming the same image is not sampled twice in the batch
        inp1 = input_[:self.split_size]
        inp2 = input_[self.split_size:]

        # first self.split_size trial in the batch are match
        # last self.split_size trial in the batch are non-match
        sample_input = torch.cat((inp1, inp1), 0)
        match_input = torch.cat((inp1, inp2), 0)

        target = torch.zeros((self.batch_size, 1))
        target[:self.split_size, 0] = 1.0

        roll_step = self.sample_step + self.delay_step + self.test_step

        task_loss = 0
        pred_num = 0
        pred_correct = 0
        hidden = model.init_hidden(self.batch_size)
        for t_ in range(roll_step):
            if t_ < self.sample_step:
                model_inp = sample_input
            elif self.sample_step <= t_ < self.sample_step + self.delay_step:
                model_inp = torch.zeros_like(sample_input)
            else:
                model_inp = match_input

            output, hidden = model(model_inp, hidden)

            if t_ >= self.sample_step + self.delay_step:
                task_loss += self.criterion(output, target)

                pred_num += target.size(0)
                pred_tf = output > 0.0
                pred_correct += (pred_tf == target).sum().item()

        task_loss = task_loss / self.test_step
        return task_loss, pred_num, pred_correct


b_size = 20
num_wks = 2

trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), ])
data_set = torchvision.datasets.MNIST(root=osp.join(ROOT_DIR, 'data'),
                                      train=True, download=True,
                                      transform=trans)
data_loader = DataLoader(data_set, batch_size=b_size, shuffle=True,
                         num_workers=num_wks, drop_last=True)


model = CNNRNNFeedback(1)
optimizer = torch.optim.Adam(model.parameters())
task = DelayedMatch(5, 0, 5, b_size)

batch_number = 0
for data in data_loader:
    batch_number += 1
    if batch_number >= 1000:
        break

    loss, pred_num, pred_correct = task.roll(model, data)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if batch_number % 100 == 0:
        print('Loss: {}'.format(loss.item()))
        print('Accuracy: {}%'.format(100 * pred_correct / pred_num))



Loss: 0.6946204900741577
Accuracy: 54.0%
Loss: 0.5048700571060181
Accuracy: 75.0%
Loss: 0.39974266290664673
Accuracy: 85.0%
Loss: 0.3296430706977844
Accuracy: 89.0%
Loss: 0.2838034927845001
Accuracy: 90.0%
Loss: 0.11622250080108643
Accuracy: 100.0%
Loss: 0.1715230941772461
Accuracy: 93.0%
Loss: 0.22998639941215515
Accuracy: 92.0%
Loss: 0.08820346742868423
Accuracy: 99.0%
