## Model

In [1]:
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from model import GeneralRNNConfig, GeneralRNN

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
model_config = GeneralRNNConfig(
        input_dim=1,
        output_dim=1,
        hidden_dim=2,
        hidden_mlp_depth=3,
        hidden_mlp_width=4,
        output_mlp_depth=1,
        output_mlp_width=1,
        activation=nn.ReLU
        )

model = GeneralRNN(model_config, device=device)

In [4]:
model

GeneralRNN(
  (hmlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=3, out_features=4, bias=True)
      (1): ReLU()
      (2): Linear(in_features=4, out_features=4, bias=True)
      (3): ReLU()
      (4): Linear(in_features=4, out_features=2, bias=True)
    )
  )
  (ymlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=2, out_features=1, bias=True)
    )
  )
)

In [5]:
model.hmlp.mlp[0].weight.data = torch.tensor(([[1.,  0.,  1.],
                                              [ 0.,  1.,  0.],
                                              [ 0., -1.,  0.],
                                              [ 0.,  0.,  0.]]))

In [6]:
model.hmlp.mlp[0].bias.data = torch.tensor([0., 0., 0., 0.])

In [7]:
model.hmlp.mlp[2].weight.data = torch.tensor([[ 1.,  0.,  0., 0.],
                                              [ 1., -1.,  1., 0.],
                                              [ 0.,  1.,  0., 0.],
                                              [ 0.,  0.,  1., 0.]])

In [8]:
model.hmlp.mlp[2].bias.data = torch.tensor([0., 0., 0., 0.])

In [9]:
model.hmlp.mlp[4].weight.data = torch.tensor([[ 1.,  0., 0.,  0.],
                                              [ 0.,  1., 1., -1.]])

In [10]:
model.hmlp.mlp[4].bias.data = torch.tensor([0., 0.])

In [11]:
model.ymlp.mlp[0].weight.data = torch.tensor([[ 0.,  1.]])

In [12]:
model.ymlp.mlp[0].bias.data = torch.tensor([0.])

## Maximum Subarray

In [13]:
def maxSubArray(nums):
    nums = torch.tensor(nums)
    nums = torch.unsqueeze(nums, 1)
    nums = torch.unsqueeze(nums, 0)
    
    outs, hiddens = model.forward_sequence(nums)
    print("Hiddens:", list(h[0].tolist() for h in hiddens))
    print("Answer:", outs[0][-1][0].item())

## Testing

In [14]:
maxSubArray([5])

Hiddens: [[0.0, 0.0], [5.0, 5.0]]
Answer: 5.0


In [15]:
maxSubArray([3., 4., 5., 6., -1., 5., -10.])

Hiddens: [[0.0, 0.0], [3.0, 3.0], [7.0, 7.0], [12.0, 12.0], [18.0, 18.0], [17.0, 18.0], [22.0, 22.0], [12.0, 22.0]]
Answer: 22.0


In [16]:
maxSubArray([1, 2, 3, 4])

Hiddens: [[0.0, 0.0], [1.0, 1.0], [3.0, 3.0], [6.0, 6.0], [10.0, 10.0]]
Answer: 10.0


In [17]:
maxSubArray([2, -1, 2, 3, 4, -5])

Hiddens: [[0.0, 0.0], [2.0, 2.0], [1.0, 2.0], [3.0, 3.0], [6.0, 6.0], [10.0, 10.0], [5.0, 10.0]]
Answer: 10.0


In [18]:
maxSubArray([5, 4, -1, 7, 8])

Hiddens: [[0.0, 0.0], [5.0, 5.0], [9.0, 9.0], [8.0, 9.0], [15.0, 15.0], [23.0, 23.0]]
Answer: 23.0


In [19]:
maxSubArray([-2, 1, -3, 4, -1, 2, 1, -5, 4])

Hiddens: [[0.0, 0.0], [0.0, 0.0], [1.0, 1.0], [0.0, 1.0], [4.0, 4.0], [3.0, 4.0], [5.0, 5.0], [6.0, 6.0], [1.0, 6.0], [5.0, 6.0]]
Answer: 6.0


In [20]:
maxSubArray([-8, -3, -6, -2, -5, -4])

Hiddens: [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]
Answer: 0.0


In [21]:
maxSubArray([13, -3, -25, 20, -3, -16, -23, 18, 20, -7, 12, -5, -22, 15, -4, 7])

Hiddens: [[0.0, 0.0], [13.0, 13.0], [10.0, 13.0], [0.0, 13.0], [20.0, 20.0], [17.0, 20.0], [1.0, 20.0], [0.0, 20.0], [18.0, 20.0], [38.0, 38.0], [31.0, 38.0], [43.0, 43.0], [38.0, 43.0], [16.0, 43.0], [31.0, 43.0], [27.0, 43.0], [34.0, 43.0]]
Answer: 43.0


In [22]:
maxSubArray([0, -1, 0, -2, 0])

Hiddens: [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]
Answer: 0.0


In [23]:
maxSubArray([3, 2, 1, 0, -1])

Hiddens: [[0.0, 0.0], [3.0, 3.0], [5.0, 5.0], [6.0, 6.0], [6.0, 6.0], [5.0, 6.0]]
Answer: 6.0


In [24]:
maxSubArray([1, -2, 3, -4, 5, -6])

Hiddens: [[0.0, 0.0], [1.0, 1.0], [0.0, 1.0], [3.0, 3.0], [0.0, 3.0], [5.0, 5.0], [0.0, 5.0]]
Answer: 5.0
