In [1]:
import os
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [2]:
class And(nn.Module):
    def __init__(self):
        super(And, self).__init__()
        
    def forward(self, attention):
        # Soft-logical and (Min)
        return torch.min(attention, axis=2)
    
class Or(nn.Module):
    def __init__(self):
        super(Or, self).__init__()
        
    def forward(self, attention):
         # Soft-logical or (Max)
        return torch.max(attention, axis=2)
    
class Id(nn.Module):
    def __init__(self):
        super(Id, self).__init__()
        
    def forward(self, input):
        return input

In [11]:
class Find(nn.Module):
    def __init__(self, input_size, num_kernels=64, kernel_size=5, find_what_dim=64):
        super(Find, self).__init__()
        self.input_size = input_size
        self.num_kernels = num_kernels
        self.kernel_size = kernel_size
        
        # conv2(conv1(xvis), W*xtxt)
        self.fc1 = nn.Linear(find_what_dim, (self.input_size[0] ** 2) * self.num_kernels)
        self.conv1 = nn.Conv2d(self.input_size[-1], self.num_kernels, self.kernel_size)  
        self.conv2 = nn.Conv2d(self.num_kernels, 1, self.kernel_size)
        
    # TODO: find_what is a bad name
    def forward(self, context, find_what):
        reshape = self.fc1(find_what).view(self.input_size[0], self.input_size[0], self.num_kernels)
        conv_context = F.relu(self.conv1(context))
        return F.relu(self.conv2(reshape  * conv_context))
        
class Relocate(nn.Module):
    def __init__(self, input_size, num_kernels=64, kernel_size=5, relocate_where_dim=128):
        super(Relocate, self).__init__()
        self.input_size = input_size
        self.num_kernels = num_kernels
        self.kernel_size = kernel_size
        self.relocate_where_dim = relocate_where_dim
        
        # conv2(conv1(xvis) * W1*sum(a * xvis) * W2*xtxt)
        self.fc1 = nn.Linear(self.input_size[-1], (self.input_size[0] ** 2) * self.num_kernels)
        self.fc2 = nn.Linear(self.relocate_where_dim, (self.input_size[0] ** 2) * self.num_kernels)
        self.conv1 = nn.Conv2d(self.input_size[-1], self.num_kernels, self.kernel_size)  
        self.conv2 = nn.Conv2d(self.num_kernels, 1, self.kernel_size)
   
    # TODO: relocate_where is a bad name
    def forward(self, attention, context, relocate_where):
        conv_xvis = F.relu(self.conv1(context))
        xvis_attend = F.relu(self.fc1(torch.einsum('ijk,ijl->l', attention, context))) 
        W2_xtxt = F.relu(self.fc2(relocate_where))
        return F.relu(self.conv2(conv_xvis * xvis_attend * W2_xtxt))        
    
class Exist(nn.Module):
    def __init__(self, input_size):
        super(Exist, self).__init__()
        self.input_size = input_size
        
        # W * vec(a)
        self.fc1 = nn.Linear(self.input_size[-1]**2, 1)
        
    def forward(self, attention):
        return self.fc1(attention)

In [27]:
class QueryEncoder(nn.Module):
    def __init__(self, input_size, hidden_dim, embed_size):
        super(QueryEncoder, self).__init__()
        self.hidden_dim = hidden_dim
        
        # Word to Vector Embedding
        self.word_embeddings = nn.Embedding(input_size, embed_size)
        
        # LSTM
        self.lstm = nn.LSTM(embed_size, hidden_dim)
        self.hidden = self.init_hidden()
        
    def init_hidden(self):
        # The axes semantics are (num_layers, minibatch_size, hidden_dim)
        return (torch.zeros(1, 1, self.hidden_dim),
                torch.zeros(1, 1, self.hidden_dim))
        
    def forward(self, query):
        embeds = self.word_embeddings(query)
        lstm_out, self.hidden = self.lstm(embeds.view(len(query), 1, -1), self.hidden)
        
        # TODO: Maybe reshape this if its bad
        return self.hidden
    
class ContextEncoder(nn.Module):
    def __init__(self):
        super(ContextEncoder, self).__init__()
        
        # Init two conv layers to extract features (64 kernels)
        self.conv1 = nn.Conv2d(3, 64, 10, stride=10)
        self.conv2 = nn.Conv2d(64, 64, 1, stride=1)
        
    def forward(self, context):
        return F.relu(self.conv2(F.relu(self.conv1(context))))

class MasterPolicy(nn.Module):
    def __init__(self, attention_modules, anwser_modules, hidden_dim, context_size):
        super(MasterPolicy, self).__init__()
        self.attention_modules = attention_modules
        self.anwser_modules = anwser_modules
        self.hidden_dim = hidden_dim
        
        self.hidden = self.init_hidden()
        self.context_size = context_size
        
        # TODO: Don't love these names
        self.and_attention = torch.randn((self.context_size[0], self.context_size[0], 2))
        self.or_attention = torch.randn((self.context_size[0], self.context_size[0], 2))
        self.id_input = torch.randn((self.context_size[0], self.context_size[0], 1))
        self.relocate_attention = torch.randn((self.context_size[0], self.context_size[0], 1))
        self.exist_attention = torch.randn((self.context_size[0], self.context_size[0], 1))
        
    def init_hidden(self):
        # The axes semantics are (num_layers, minibatch_size, hidden_dim)
        return (torch.zeros(1, 1, self.hidden_dim),
                torch.zeros(1, 1, self.hidden_dim))
    
    def forward(self, query, context):
        # TODO: Might have to do a more complex copy op
        self.hidden[:] = query[:]
        
        done = False
        while not done:
            self.hidden = self.forward_1t(context)
    
    def forward_1t(self, context):
        b = torch.zeros(self.context_size, self.context_size, len(self.attention_modules))
        
        # Run all attention modueles saving output
        for i, module in enumerate(self.attention_modules):
            if type(module) is Id:
                b[:,:,i] = module.forward(self.id_input)
            elif type(module) is And:
                b[:,:,i] = module.forward(self.and_attention)
            elif type(module) is Or:
                b[:,:,i] = module.forward(self.or_attention)
            elif type(module) is Find:
                b[:,:,i] = module.forward(context, self.hidden)
            elif type(module) is Rellocate:
                b[:,:,i] = module.forward(self.relocate_attention, context, self.hidden)
            else:
                raise ValueError('Invalid anwser Module: {}'.format(type(module)))
            
        # Run all anwser modules
        for module in self.anwser_modules:
            if type(module) is Exists:
                out = module.forward(self.exist_attention)
            else:
                raise ValueError('Invalid anwser Module: {}'.format(type(module)))
                
        # Let N be the context size.  Then the 5 modules
        # And, Or, Id, Find, Relocate
        # output 5 attention maps, which stack to form an NxNx5 tensor
        # called b.  Let M be 5x7 matrix of weights.  
        # Set a = torch.einsum('ijk,kl->ijl',M,b)
        # Then a is a NxNx7 tensor which gives the 7 NxNx1 input 
        # tensors for inputs to And, Or, Id, Relocate, Exist
        M = None
        return torch.mm(M, b)

class E2E_RNMN(nn.Module):
    def __init__(self):
        super(E2E_RNMN, self).__init__()
        self.context_size = [7, 7, 64]
        self.attention_modules = [And(), Or(), Id(), Find(self.context_size), Relocate(self.context_size)]
        self.anwser_modules = [Exist(self.context_size)]
        
        input_size = 64
        hidden_dim = 128
        embed_size = 256
        
        self.query_encoder = QueryEncoder(input_size, hidden_dim, embed_size)
        self.context_encoder = ContextEncoder()
        self.master_policy = MasterPolicy(self.attention_modules, self.anwser_modules, hidden_dim, self.context_size)
    
    def forward(self, query, context):
        encoded_query = self.query_encoder(query)
        encoded_context = self.context_encoder(context)
        return self.master_policy(encoded_query, encoded_context)

In [28]:
# Set hyperparams and load dataset
lr = 1e-4
batch_size = 64

In [29]:
# Init model
model = E2E_RNMN()
#criterion = torch.optim.Adam(model.parameters())