In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import math
fl = math.floor

In [4]:
class LeNet(nn.Module):
    def __init__(self, image_size, num_classes=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(image_size[0], 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.last_map_x = fl((fl((image_size[1]-4)/2)-4)/2)
        self.last_map_y = fl((fl((image_size[2]-4)/2)-4)/2)

        self.linear1 = nn.Linear(16 * self.last_map_x * self.last_map_y, 120)
        self.linear2 = nn.Linear(120, 84)
        self.out_layer = nn.Linear(84, num_classes)

    def forward(self, inp):
        x = self.pool1(F.relu(self.conv1(inp)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * self.last_map_x * self.last_map_y)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        otp = self.out_layer(x)
        return otp

class CNNRNNFeedback(nn.Module):
    def __init__(self, out_size):
        super(CNNRNNFeedback, self).__init__()
        self.rnn_in_size = 84
        self.hidden_size = 100

        self.cnn = LeNet((1, 28, 28), 10)
        self.rnn = nn.LSTMCell(self.rnn_in_size, self.hidden_size)

        self.out_layer = nn.Linear(self.hidden_size, out_size)

    def forward(self, inp, hid_in):
        x = self.cnn(inp)
        hid_out = self.rnn(x, hid_in)
        otp = self.out_layer(hid_out[0])
        return otp, hid_out

In [5]:
class DelayedMatch:
    def __init__(self, sample_step, delay_step, test_step, batch_size):
        self.sample_step = sample_step
        self.delay_step = delay_step
        self.test_step = test_step

        # comparison loss
        self.criterion = nn.BCEWithLogitsLoss()

        self.Batch_Size = batch_size
        self.task_batch_s = 2 * self.Batch_Size

    def roll(self, model, data_batch, train=False, test=False, evaluate=False):
        assert train + test + evaluate == 1, "only one mode should be activated"
        input_, label_ = data_batch

        # assuming the same image is not sampled twice in the batch
        inp1 = input_[:self.Batch_Size]
        inp2 = input_[self.Batch_Size:]

        # first self.Batch_Size trial in the batch are match
        # last self.Batch_Size trial in the batch are non-match
        sample_input = torch.cat((inp1, inp1), 0)
        match_input = torch.cat((inp1, inp2), 0)

        target = torch.zeros((self.task_batch_s, 1))
        target[:self.Batch_Size, 0] = 1.0

        roll_step = self.sample_step + self.delay_step + self.test_step

        task_loss = 0
        pred_num = 0
        pred_correct = 0
        hidden = model.init_hidden(batch_expand=2)
        for t_ in range(roll_step):
            if t_ < self.sample_step:
                model_inp = sample_input
            elif self.sample_step <= t_ < self.sample_step + self.delay_step:
                model_inp = torch.zeros_like(sample_input)
            else:
                model_inp = match_input

            output, hidden = model(model_inp, hidden)

            if t_ >= self.sample_step + self.delay_step:
                task_loss += self.criterion(output, target)

                if test or evaluate:
                    pred_num += target.size(0)
                    pred_tf = output > 0.0
                    pred_correct += (pred_tf == target).sum().item()

        task_loss = task_loss / self.test_step
        if test or evaluate:
            return task_loss, pred_num, pred_correct
        else:
            return task_loss

In [None]:


trans = transform.Compose([transform.ToTensor(), ])
# transform.Normalize((0.1307,), (0.3081,))])

data_set = torchvision.datasets.MNIST(root=osp.join(ROOT_DIR, 'data'),
                                      train=train_flag, download=True,
                                      transform=trans)
datum_size = (1, 28, 28)