In [1]:
import os
import sys
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

sys.path.append('../')
import dataset_loader
from modules import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
class QueryEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1):
        super(QueryEncoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.word_embeddings = nn.Embedding(self.input_size, self.hidden_size)
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size)
        
    def resetHidden(self, batch_size):
        # The axes semantics are (num_layers, minibatch_size, hidden_dim)
        self.hidden = (torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device),
                       torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device))
        
    def forward(self, query):
        batch_size = query.size(0)
        embeds = self.word_embeddings(query).view(1, batch_size, -1)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
    
        # TODO: Maybe reshape this if its bad
        return lstm_out, self.hidden
    
class ContextEncoder(nn.Module):
    def __init__(self):
        super(ContextEncoder, self).__init__()
        
        # Init two conv layers to extract features (64 kernels)
        self.conv1 = nn.Conv2d(3, 64, 10, stride=10)
        self.conv2 = nn.Conv2d(64, 64, 1, stride=1)
        
    def forward(self, context):
        return F.relu(self.conv2(F.relu(self.conv1(context))))
    
class Decoder(nn.Module):
    def __init__(self, hidden_dim, M_dim, x_dim, num_layers = 1):
        super(Decoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.M_dim = M_dim
        self.x_dim = x_dim
        self.num_layers = num_layers
        self.output_dim = M_dim[0] * M_dim[1] + x_dim
        self.hidden = self.resetHidden(1)

        self.lstm = nn.LSTM(hidden_dim, hidden_dim)
        self.fc1 = nn.Linear(hidden_dim, 128)
        self.fc2 = nn.Linear(128, self.output_dim)
        
    def forward(self):
        # TODO: LSTMs have to have input but I dunno what it would be here.  (Currently Zeros)
        out, self.hidden = self.lstm(torch.zeros(self.hidden[0].shape, device=device), self.hidden)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        
        batch_size = out.shape[1]
        
        M_end = self.M_dim[0] * self.M_dim[1]
        M = out[:,:,:M_end].view(batch_size,self.M_dim[0], self.M_dim[1])
        x = out[:,:,M_end:].view(batch_size,1, -1)
        
        return M, x
        
    def resetHidden(self, batch_size):
        # The axes semantics are (num_layers, minibatch_size, hidden_dim)
        return (torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device),
                torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device))

In [3]:
class MasterPolicy(nn.Module):
    def __init__(self, attention_modules, anwser_modules, hidden_dim, context_size):
        super(MasterPolicy, self).__init__()
        self.attention_modules = attention_modules
        self.num_att_modules = len(self.attention_modules)
        self.anwser_modules = anwser_modules
        self.hidden_dim = hidden_dim
        self.context_size = context_size
        
        self.M_dim = (self.num_att_modules, sum([m.num_attention_maps for m in self.attention_modules]))
        self.x_dim = 64
        self.decoder = Decoder(self.hidden_dim, self.M_dim, self.x_dim)
    
    def forward(self, query_hidden, context):
        batch_size = context.size(0)
        
        # TODO: Might have to do a more complex copy op
        self.decoder.hidden = query_hidden[:]
        self.a_t = torch.randn((self.M_dim[1], batch_size, self.context_size[1], self.context_size[2]), device=device)
        self.e_t = torch.randn((batch_size, self.context_size[1], self.context_size[2]), device=device)
        
        # TODO: This for loop should be replaced with some sort of thresholding junk
        for i in range(10):
            self.M_t, self.x_t = self.decoder()
            self.a_t, out = self.forward_1t(context)
            
        return F.log_softmax(out)
    
    def forward_1t(self, context):
        batch_size = context.size(0)
        b_t = torch.zeros((self.num_att_modules, batch_size, self.context_size[1], self.context_size[2]), device=device)
        
        # Attention map indexs
        num_att_map_inputs = [module.num_attention_maps for module in self.attention_modules]
        attention_map_input_index = np.cumsum(num_att_map_inputs)
        attention_map_input_index = np.insert(attention_map_input_index, 0, 0)
        
        # Run all attention modules saving output
        for i, module in enumerate(self.attention_modules):
            attention = self.a_t[np.arange(attention_map_input_index[i],attention_map_input_index[i+1])]
            if type(module) is Id:
                b_t[i] = module.forward(attention)
            elif type(module) is And:
                b_t[i] = module.forward(attention)
            elif type(module) is Or:
                b_t[i] = module.forward(attention)
            elif type(module) is Find:
                b_t[i] = module.forward(context, self.x_t).squeeze()
            elif type(module) is Relocate:
                b_t[i] = module.forward(attention, context, self.x_t)
            else:
                raise ValueError('Invalid anwser Module: {}'.format(type(module)))
            
        # Run all anwser modules
        for module in self.anwser_modules:
            if type(module) is Exist:
                out = module.forward(self.e_t)
            else:
                raise ValueError('Invalid anwser Module: {}'.format(type(module)))
        
        b_t = b_t.permute(1,2,3,0)
        self.M_t = self.M_t.permute(0,2,1)
        return torch.einsum('bijk,blk->bijl', b_t, self.M_t).permute(3,0,1,2), out

class E2E_RNMN(nn.Module):
    def __init__(self, query_size, hidden_size):
        super(E2E_RNMN, self).__init__()
        self.query_size = query_size
        self.hidden_size = hidden_size
        
        self.context_size = [64, 6, 6]
        self.attention_modules = [And(), Or(), Id(), Find(self.context_size)]#, Relocate(self.context_size)]
        self.anwser_modules = [Exist(self.context_size)]
        [module.to(device) for module in self.attention_modules + self.anwser_modules]
        
        self.query_encoder = QueryEncoder(self.query_size, self.hidden_size)
        self.context_encoder = ContextEncoder()
        self.master_policy = MasterPolicy(self.attention_modules, self.anwser_modules,
                                          self.hidden_size, self.context_size)
    
    def forward(self, query, query_len, context):
        batch_size = query.size(0)
        
        # Encode the query
        self.query_encoder.resetHidden(batch_size)
        encoder_outputs = torch.zeros(batch_size, 7, self.query_encoder.hidden_size, device=device)
        max_query_len = query.size(1)
        for ei in range(max_query_len):
            encoder_output, encoder_hidden = self.query_encoder(query[:,ei])
            encoder_outputs[:,ei,:] = encoder_output
        
        # Encode the context and start master policy forward pass'
        encoded_context = self.context_encoder(context)
        return self.master_policy(self.query_encoder.hidden, encoded_context)

In [4]:
def tensorToDevice(*tensors):
    return [tensor.to(device) for tensor in tensors]

def trainBatch(samples, queries, query_lens, labels):
    # Transfer data to gpu/cpu and pass through model
    samples, queries, query_lens, labels = tensorToDevice(samples, queries, query_lens, labels)
    output = model(queries, query_lens, samples)
    
    # Compute loss & step optimzer
    optimizer.zero_grad()
    loss = criterion(output, labels.float())
    loss.backward()
    optimizer.step()
    
    return loss.item()
    
def testBatch(samples, queries, query_lens, labels):
    with torch.no_grad():
        # Transfer data to gpu/cpu and pass through model
        samples, queries, query_lens, labels = tensorToDevice(samples, queries, query_lens, labels)
        output = model(queries, query_lens, samples)
        
        # Compute loss & acccriterionuracy
        loss = criterion(output, labels.float())
        pred = output.argmax(dim=1, keepdim=True)
        correct = pred.eq(labels.view_as(pred).long()).sum()
    
    return loss.item(), correct.item()

In [5]:
# Set hyperparams and load dataset
lr = 1e-4
hidden_size = 256
batch_size = 64
epochs = 1000

query_lang, train_loader, test_loader = dataset_loader.createScalableShapesDataLoader('v2', batch_size=batch_size)

In [6]:
# Init model
model = E2E_RNMN(query_lang.num_words, hidden_size).to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.BCELoss()

# Create TQDM progress bar
pbar = tqdm.tqdm(total=epochs)
pbar.set_description('Train Loss:0.0 | Train Acc:0.0 | Test Loss:0.0 | Test Acc:0.0')

train_losses, test_losses, test_accs = list(), list(), list()
for epoch in range(epochs):
    # Train for a single epoch iterating over the minibatches
    model.train()
    train_loss = 0
    for samples, queries, query_lens, labels in train_loader:
        train_loss += trainBatch(samples, queries, query_lens, labels)
       
    # Test for a single epoch iterating over the minibatches
    model.eval()
    test_loss, test_correct = 0, 0
    for samples, queries, query_lens, labels in test_loader:
        batch_loss, batch_correct = testBatch(samples, queries, query_lens, labels)
        test_loss += batch_loss
        test_correct += batch_correct
    
    # Bookkeeping
    train_losses.append(train_loss / len(train_loader.dataset))
    test_losses.append(test_loss / len(test_loader.dataset))
    test_accs.append(test_correct / len(test_loader.dataset))
    
    # Update progress bar
    pbar.set_description('Train Loss:{:.3f} | Test Loss:{:.3f} | Test Acc:{:.3f}'.format(
        train_losses[-1], test_losses[-1], test_accs[-1]))
    pbar.update(1)



Train Loss:0.288 | Test Loss:0.298 | Test Acc:0.326:   4%|▎         | 37/1000 [02:09<55:59,  3.49s/it] 

KeyboardInterrupt: 