In [1]:
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import itertools
import random
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm, tqdm_notebook

In [2]:
# set random seed to 0
np.random.seed(0)
torch.manual_seed(0)
torch.set_default_tensor_type('torch.DoubleTensor')

In [3]:
class Expression:
    
    @staticmethod
    def add(a, b):
        return a + b
    
    @staticmethod
    def subtract(a, b):
        return a - b
    
    @staticmethod
    def multiply(a, b):
        return a * b
    
    @staticmethod
    def divide(a, b):
        return a / b
    
    def __init__(self, operation, operands):
        assert operation in (Expression.add, Expression.subtract, Expression.multiply, Expression.divide)
        self.operation = operation
        self.operands = operands
        self.vertical_view = self.get_vertical_view(with_solution=False)
        self.vertical_view_sol = self.get_vertical_view(with_solution=True)
    
    def evaluate(self):
        result = self.operands[0]
        for n in self.operands[1:]:
            result = self.operation(result, n)
        return result
    
    def get_vertical_view(self, with_solution=True):
        str_operands = [str(n) for n in self.operands]
        sig_digit = max((len(n) for n in str_operands))
        
        grid = np.empty((len(self.operands) + 2, max(sig_digit, len(str(self.evaluate())))), dtype=np.str)
        grid[:,:] = ' '
        grid[-2,:] = '_'
        for i in range(len(str_operands)):
            operand = str_operands[i]
            for j in range(len(operand)):
                grid[i,-j-1] = operand[-j-1]
                
        if with_solution:
            solution = str(self.evaluate())
            for j in range(len(solution)):
                grid[-1,-j-1] = solution[-j-1]
        return grid
        
    def get_op_string(self):
        if self.operation is Expression.add:
            return '+'
        if self.operation is Expression.subtract:
            return '-'
        if self.operation is Expression.multiply:
            return '*'
        if self.operation is Expression.divide:
            return '/'
    
    def __repr__(self):
        return self.get_op_string().join((str(n) for n in self.operands)) + '=' + str(self.evaluate())

In [4]:
class ViewDirection:
    
    UP_EDGE = "UP_EDGE"
    DOWN_EDGE = "DOWN_EDGE"
    LEFT_EDGE = "LEFT_EDGE"
    RIGHT_EDGE = "RIGHT_EDGE"
    UP_RELATIVE = "UP_RELATIVE"
    DOWN_RELATIVE = "DOWN_RELATIVE"
    LEFT_RELATIVE = "LEFT_RELATIVE"
    RIGHT_RELATIVE = "RIGHT_RELATIVE"
    DO_NOTHING = "DO_NOTHING"
    
    def __init__(self, vertical, horizontal):
        assert vertical in (ViewDirection.UP_EDGE,
                           ViewDirection.DOWN_EDGE,
                           ViewDirection.UP_RELATIVE,
                           ViewDirection.DOWN_RELATIVE,
                           ViewDirection.DO_NOTHING)
        assert horizontal in (ViewDirection.LEFT_EDGE,
                           ViewDirection.RIGHT_EDGE,
                           ViewDirection.LEFT_RELATIVE,
                           ViewDirection.RIGHT_RELATIVE,
                           ViewDirection.DO_NOTHING)
        
        self.vertical = vertical
        self.horizontal = horizontal
        
    def get_coord(self, cur_coord, vertical_view_shape):
        vert = None
        horiz = None
        
        if self.vertical == ViewDirection.UP_EDGE:
            vert = 0
        elif self.vertical == ViewDirection.DOWN_EDGE:
            vert = vertical_view_shape[0] - 1
        elif self.vertical == ViewDirection.UP_RELATIVE:
            vert = cur_coord[0] - 1
        elif self.vertical == ViewDirection.DOWN_RELATIVE:
            vert = cur_coord[0] + 1
        elif self.vertical == ViewDirection.DO_NOTHING:
            vert = cur_coord[0]
        else:
            assert False

        if self.horizontal == ViewDirection.LEFT_EDGE:
            horiz = 0
        elif self.horizontal == ViewDirection.RIGHT_EDGE:
            horiz = vertical_view_shape[1] - 1
        elif self.horizontal == ViewDirection.LEFT_RELATIVE:
            horiz = cur_coord[1] - 1
        elif self.horizontal == ViewDirection.RIGHT_RELATIVE:
            horiz = cur_coord[1] + 1
        elif self.horizontal == ViewDirection.DO_NOTHING:
            horiz = cur_coord[1]
        else:
            assert False
            
        return (vert, horiz)
    
    def __repr__(self):
        return str((self.vertical, self.horizontal))
    
    
class AdditionSolver:
    
    def __init__(self, expression):
        assert expression.operation == Expression.add
        self.expression = expression
        self.view_dir_history = []
        self.coord_history = []
        self.write_history = []
        self.current_value = 0
        
    """
    Returns a ViewDirection, next_coord, and a digit to write
    """
    def action(self, view_direction):
        vertical_view = self.expression.vertical_view_sol
        last_coord = self.coord_history[-1] if self.coord_history else None
        write = None
        
        # If beginning of problem, i.e. nothing has been done yet
        # Move to top-right corner
        if view_direction is None:
            next_view_direction = ViewDirection(ViewDirection.UP_EDGE, ViewDirection.RIGHT_EDGE)
            
        # If at bottom of problem, i.e. the solution area,
        # write and move to next column if there is one
        elif last_coord[0] == vertical_view.shape[0] - 1:
            if last_coord[1] > 0:
                next_view_direction = ViewDirection(ViewDirection.UP_EDGE, ViewDirection.LEFT_RELATIVE)
            else:
                next_view_direction = None
            write = str(self.current_value % 10)
            self.current_value = self.current_value // 10
            
        # If at second to last row of problem, i.e. the bar, 
        # do nothing and shift view down
        elif last_coord[0] == vertical_view.shape[0] - 2:
            next_view_direction = ViewDirection(ViewDirection.DOWN_RELATIVE, ViewDirection.DO_NOTHING)
            
        # If at an operand, 
        # add to the current_value, shift view down
        else:
            view_value = vertical_view[last_coord]
            if view_value != ' ':
                self.current_value += int(view_value)
            
            next_view_direction = ViewDirection(ViewDirection.DOWN_RELATIVE, ViewDirection.DO_NOTHING)
            
        self.view_dir_history.append(next_view_direction)
        next_coord = next_view_direction.get_coord(last_coord, vertical_view.shape) if next_view_direction else None
        self.coord_history.append(next_coord)
        self.write_history.append(write)
        return next_view_direction, next_coord, write

In [5]:
e = Expression(Expression.add, [999]*11)
solver = AdditionSolver(e)
print(e.vertical_view_sol)
next_view_direction, next_coord, write = solver.action(None)
print(next_view_direction, next_coord, write)
while (next_view_direction is not None):
    next_view_direction, next_coord, write = solver.action(next_view_direction)
    print(next_view_direction, next_coord, write)

[[' ' ' ' '9' '9' '9']
 [' ' ' ' '9' '9' '9']
 [' ' ' ' '9' '9' '9']
 [' ' ' ' '9' '9' '9']
 [' ' ' ' '9' '9' '9']
 [' ' ' ' '9' '9' '9']
 [' ' ' ' '9' '9' '9']
 [' ' ' ' '9' '9' '9']
 [' ' ' ' '9' '9' '9']
 [' ' ' ' '9' '9' '9']
 [' ' ' ' '9' '9' '9']
 ['_' '_' '_' '_' '_']
 ['1' '0' '9' '8' '9']]
('UP_EDGE', 'RIGHT_EDGE') (0, 4) None
('DOWN_RELATIVE', 'DO_NOTHING') (1, 4) None
('DOWN_RELATIVE', 'DO_NOTHING') (2, 4) None
('DOWN_RELATIVE', 'DO_NOTHING') (3, 4) None
('DOWN_RELATIVE', 'DO_NOTHING') (4, 4) None
('DOWN_RELATIVE', 'DO_NOTHING') (5, 4) None
('DOWN_RELATIVE', 'DO_NOTHING') (6, 4) None
('DOWN_RELATIVE', 'DO_NOTHING') (7, 4) None
('DOWN_RELATIVE', 'DO_NOTHING') (8, 4) None
('DOWN_RELATIVE', 'DO_NOTHING') (9, 4) None
('DOWN_RELATIVE', 'DO_NOTHING') (10, 4) None
('DOWN_RELATIVE', 'DO_NOTHING') (11, 4) None
('DOWN_RELATIVE', 'DO_NOTHING') (12, 4) None
('UP_EDGE', 'LEFT_RELATIVE') (0, 3) 9
('DOWN_RELATIVE', 'DO_NOTHING') (1, 3) None
('DOWN_RELATIVE', 'DO_NOTHING') (2, 3) None
('DOW

In [24]:
class Solver:
    
    INPUT_DIM             = 12  # Blank, line, any of 10 numerals
    OUTPUT_DONE_DIM       = 1   # Marks whether problem is done or not
    OUTPUT_WRITE_DIM      = 11  # Any of 10 numerals, do nothing -> d=11
    OUTPUT_VIEW_VERT_DIM  = 5   # Any of 5 moves -> d=5
    OUTPUT_VIEW_HORIZ_DIM = 5   # Any of 5 moves -> d=5
    
    INPUT_IND2CHAR        = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ' ', '_']
    OUTPUT_IND2CHAR       = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', None]
    OUTPUT_IND2VVERT      = [ViewDirection.DO_NOTHING,
                             ViewDirection.UP_EDGE,
                             ViewDirection.DOWN_EDGE,
                             ViewDirection.UP_RELATIVE,
                             ViewDirection.DOWN_RELATIVE]
    OUTPUT_IND2VHORIZ     = [ViewDirection.DO_NOTHING,
                             ViewDirection.LEFT_EDGE,
                             ViewDirection.RIGHT_EDGE,
                             ViewDirection.LEFT_RELATIVE,
                             ViewDirection.RIGHT_RELATIVE]
    INPUT_CHAR2IND        = {Solver.INPUT_IND2CHAR[i]:    i for i in range(len(Solver.INPUT_IND2CHAR))   }
    OUTPUT_CHAR2IND       = {Solver.OUTPUT_IND2CHAR[i]:   i for i in range(len(Solver.OUTPUT_IND2CHAR))  }
    OUTPUT_VVERT2IND      = {Solver.OUTPUT_IND2VVERT[i]:  i for i in range(len(Solver.OUTPUT_IND2VVERT)) }
    OUTPUT_VHORIZ2IND     = {Solver.OUTPUT_IND2VHORIZ[i]: i for i in range(len(Solver.OUTPUT_IND2VHORIZ))}
    
    def __init__(self, model):
        self.model = model
        
    @staticmethod
    def vectorize_input(char):
        vector = torch.zeros(Solver.INPUT_DIM)
        if char is None:
            return vector
        vector[self.INPUT_CHAR2IND[char]] = 1
        return vector
            
    @staticmethod
    def vectorize_output(view_direction, write):
        view_done_vector = torch.zeros(Solver.OUTPUT_DONE_DIM)
        view_vert_vector = torch.zeros(Solver.OUTPUT_VIEW_VERT_DIM)
        view_horiz_vector = torch.zeros(Solver.OUTPUT_VIEW_HORIZ_DIM)
        write_vector = torch.zeros(Solver.OUTPUT_WRITE_DIM)
        
        if view_direction is None: # done condition
            view_done_vector[0] = 1
            view_vert_vector[Solver.OUTPUT_VVERT2IND[ViewDirection.DO_NOTHING]] = 1
            view_horiz_vector[Solver.OUTPUT_VHORIZ2IND[ViewDirection.DO_NOTHING]] = 1
        else:
            view_vert_vector[Solver.OUTPUT_VVERT2IND[view_direction.vertical]] = 1
            view_horiz_vector[Solver.OUTPUT_VVERT2IND[view_direction.horizontal]] = 1
        
        if write:
            write_vector[Solver.OUTPUT_CHAR2IND[write]] = 1
        else:
            write_vector[None] = 1
        
        return view_done_vector, view_vert_vector, view_horiz_vector, write_vector
    
    @staticmethod
    def devectorize_output(view_done_vector,
                           view_vert_vector,
                           view_horiz_vector,
                           write_vector):
        done       = bool(view_done_vector[0] > .5)
        view_vert  = Solver.OUTPUT_IND2VVERT[view_vert_vector.argmax()]
        view_horiz = Solver.OUTPUT_IND2VHORIZ[view_horiz_vector.argmax()]
        write      = Solver.OUTPUT_IND2CHAR[write_vector.argmax()]
        view_direction = ViewDirection(view_vert, view_horiz)
        
        return done, view_direction, write
        
        
    def view_coord(self, expression, coord):
        value = expression.vertical_view[coord]
        vector = Solver.vectorize_input(value)
        return vector
    
# Solver.INPUT_CHAR2IND    = {INPUT_IND2CHAR[i]:    i for i in range(len(Solver.INPUT_IND2CHAR))   }
# Solver.OUTPUT_CHAR2IND   = {OUTPUT_IND2CHAR[i]:   i for i in range(len(Solver.OUTPUT_IND2CHAR))  }
# Solver.OUTPUT_VVERT2IND  = {OUTPUT_IND2VVERT[i]:  i for i in range(len(Solver.OUTPUT_IND2VVERT)) }
# Solver.OUTPUT_VHORIZ2IND = {OUTPUT_IND2VHORIZ[i]: i for i in range(len(Solver.OUTPUT_IND2VHORIZ))}

In [25]:
"""
LSTM model with a single hidden layer and multiple output layers
"""
class LSTM(nn.Module):
    def __init__(self, input_size,
                       hidden_layer_size,
                       output_sizes,
                       output_nonlinears):
        super(LSTM, self).__init__()
        assert len(output_sizes) == len(output_nonlinears)
        
        self.input_size = input_size
        self.hidden_layer_size = hidden_layer_size
        self.output_sizes = output_sizes
        self.output_nonlinears = output_nonlinears
        
        self.lstm = nn.LSTMCell(input_size, hidden_layer_size)
        self.lstm_h = torch.zeros(1, self.hidden_layer_size, dtype=torch.double)
        self.lstm_c = torch.zeros(1, self.hidden_layer_size, dtype=torch.double)
        
        self.output_layers = nn.ModuleList() 
        for output_size in output_sizes:
            self.output_layers.append(nn.Linear(hidden_layer_size, output_size))
        
    def reset(self):
        self.lstm_h = torch.zeros(1, self.hidden_layer_size, dtype=torch.double)
        self.lstm_c = torch.zeros(1, self.hidden_layer_size, dtype=torch.double)
        
    def forward(self, x):
        input_layer = x.reshape(1, self.input_size)
        self.lstm_h, self.lstm_c = self.lstm(input_layer, (self.lstm_h, self.lstm_c))
        return (f(layer(self.lstm_h).squeeze()) for f, layer in zip(self.output_nonlinears, self.output_layers))

In [26]:
lstm = LSTM(Solver.INPUT_DIM,
            64,
            [Solver.OUTPUT_DONE_DIM,
             Solver.OUTPUT_VIEW_VERT_DIM,
             Solver.OUTPUT_VIEW_HORIZ_DIM,
             Solver.OUTPUT_WRITE_DIM],
            [nn.Sigmoid(),
             nn.Softmax(),
             nn.Softmax(),
             nn.Softmax()])

expression = Expression(Expression.add, [999]*11)
x = Solver.vectorize_input(None)
done_vec, view_vert_vec, view_horiz_vec, write_vec = lstm.forward(x)
Solver.devectorize_output(done_vec, view_vert_vec, view_horiz_vec, write_vec)



(True, ('UP_EDGE', 'LEFT_RELATIVE'), '4')

In [8]:
class Model:
    def __init__(self):
        pass
    
class NNModel(Model):
    
    def __init__(self, model):
        self.model = model
        
    def forward(self, x):
        

False

In [30]:
a = {0:1, False: 2, None: 3}

In [31]:
a[0]

2

In [32]:
a[None]

3

In [33]:
a[False]

2

In [36]:
torch.tensor((1,2,8,4,3)).argmax()

tensor(2)