## Model

In [1]:
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from model import GeneralRNNConfig, GeneralRNN

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
model_config = GeneralRNNConfig(
        input_dim=1,
        output_dim=1,
        hidden_dim=1,
        hidden_mlp_depth=2,
        hidden_mlp_width=3,
        output_mlp_depth=1,
        output_mlp_width=1,
        activation=nn.ReLU
        )

model = GeneralRNN(model_config, device=device)

In [4]:
model

GeneralRNN(
  (hmlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=2, out_features=3, bias=True)
      (1): ReLU()
      (2): Linear(in_features=3, out_features=1, bias=True)
    )
  )
  (ymlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=1, out_features=1, bias=True)
    )
  )
)

In [5]:
model.hmlp.mlp[0].weight.data = torch.tensor([[ 1., -1.],
                                              [ 0.,  1.],
                                              [ 0., -1.],])

In [6]:
model.hmlp.mlp[0].bias.data = torch.tensor([0., 0., 0.])

In [7]:
model.hmlp.mlp[2].weight.data = torch.tensor([[ 1., 1., -1.]])

In [8]:
model.hmlp.mlp[2].bias.data = torch.tensor([0.])

In [9]:
model.ymlp.mlp[0].weight.data = torch.tensor([[ 1.]])

In [10]:
model.ymlp.mlp[0].bias.data = torch.tensor([0.])

## Find Maximum

In [11]:
def maximum(nums):
    nums = torch.tensor(nums)
    nums = torch.unsqueeze(nums, 1)
    nums = torch.unsqueeze(nums, 0)
    
    outs, hiddens = model.forward_sequence(nums)
    print("Hiddens:", list(h[0].tolist() for h in hiddens))
    print("Answer:", outs[0][-1][0].item())

## Testing

In [12]:
maximum([5])

Hiddens: [[0.0], [5.0]]
Answer: 5.0


In [13]:
maximum([3., 4., 5., 6., -1., 5., -10.])

Hiddens: [[0.0], [3.0], [4.0], [5.0], [6.0], [6.0], [6.0], [6.0]]
Answer: 6.0


In [14]:
maximum([1, 2, 3, 4])

Hiddens: [[0.0], [1.0], [2.0], [3.0], [4.0]]
Answer: 4.0


In [15]:
maximum([2, -1, 2, 3, 4, -5])

Hiddens: [[0.0], [2.0], [2.0], [2.0], [3.0], [4.0], [4.0]]
Answer: 4.0


In [16]:
maximum([5, 4, -1, 7, 8])

Hiddens: [[0.0], [5.0], [5.0], [5.0], [7.0], [8.0]]
Answer: 8.0


In [17]:
maximum([-2, 1, -3, 4, -1, 2, 1, -5, 4])

Hiddens: [[0.0], [0.0], [1.0], [1.0], [4.0], [4.0], [4.0], [4.0], [4.0], [4.0]]
Answer: 4.0


In [18]:
maximum([-8, -3, -6, -2, -5, -4]) ## only works when not all numbers are negative

Hiddens: [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]
Answer: 0.0


In [19]:
maximum([13, -3, -25, 20, -3, -16, -23, 18, 20, -7, 12, -5, -22, 15, -4, 7])

Hiddens: [[0.0], [13.0], [13.0], [13.0], [20.0], [20.0], [20.0], [20.0], [20.0], [20.0], [20.0], [20.0], [20.0], [20.0], [20.0], [20.0], [20.0]]
Answer: 20.0


In [20]:
maximum([0, -1, 0, -2, 0])

Hiddens: [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]
Answer: 0.0


In [21]:
maximum([3, 2, 1, 0, -1])

Hiddens: [[0.0], [3.0], [3.0], [3.0], [3.0], [3.0]]
Answer: 3.0


In [22]:
maximum([1, -2, 3, -4, 5, -6])

Hiddens: [[0.0], [1.0], [1.0], [3.0], [3.0], [5.0], [5.0]]
Answer: 5.0
