## Model

In [1]:
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from model import GeneralRNNConfig, GeneralRNN

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
model_config = GeneralRNNConfig(
        input_dim=1,
        output_dim=1,
        hidden_dim=2,
        hidden_mlp_depth=2,
        hidden_mlp_width=4,
        output_mlp_depth=2,
        output_mlp_width=2,
        activation=nn.ReLU
        )

model = GeneralRNN(model_config, device=device)

In [4]:
model

GeneralRNN(
  (hmlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=3, out_features=4, bias=True)
      (1): ReLU()
      (2): Linear(in_features=4, out_features=2, bias=True)
    )
  )
  (ymlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=2, out_features=2, bias=True)
      (1): ReLU()
      (2): Linear(in_features=2, out_features=1, bias=True)
    )
  )
)

In [5]:
model.hmlp.mlp[0].weight.data = torch.tensor([[ 1.,  0.,  1.],
                                              [-1.,  0., -1.],
                                              [ 0.,  1.,  0.],
                                              [ 0., -1.,  0.]])

In [6]:
model.hmlp.mlp[0].bias.data = torch.tensor([0., 0., 0., 0.])

In [7]:
model.hmlp.mlp[2].weight.data = torch.tensor([[ 1.,  -1.,  0.,  0.],
                                              [ 0.,  -1.,  1., -1.]])

In [8]:
model.hmlp.mlp[2].bias.data = torch.tensor([0., 0.])

In [9]:
model.ymlp.mlp[0].weight.data = torch.tensor([[ 1.,  0.],
                                              [ 0., -1.]])

In [10]:
model.ymlp.mlp[0].bias.data = torch.tensor([0., 0.])

In [11]:
model.ymlp.mlp[2].weight.data = torch.tensor([[ 1.,  1.]])

In [12]:
model.ymlp.mlp[2].bias.data = torch.tensor([0.])

## Balanced Parentheses

In [13]:
def isBalanced(paren):
    iparen = []
    for p in paren:
        if p == "(":
            iparen.append(1.)
        elif p == ")":
            iparen.append(-1.)
    iparen = torch.tensor(iparen)
    iparen = torch.unsqueeze(iparen, 1)
    iparen = torch.unsqueeze(iparen, 0)
    
    outs, hiddens = model.forward_sequence(iparen)
    print("Hiddens:", list(h[0].tolist() for h in hiddens))
    print("Answer:", not outs.tolist()[0][-1][0])

## Testing

In [14]:
isBalanced("()()()()")

Hiddens: [[0.0, 0.0], [1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 0.0]]
Answer: True


In [15]:
isBalanced("()")

Hiddens: [[0.0, 0.0], [1.0, 0.0], [0.0, 0.0]]
Answer: True


In [16]:
isBalanced("(())")

Hiddens: [[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [1.0, 0.0], [0.0, 0.0]]
Answer: True


In [17]:
isBalanced("(())()")

Hiddens: [[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 0.0]]
Answer: True


In [18]:
isBalanced("(()())()()")

Hiddens: [[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [1.0, 0.0], [2.0, 0.0], [1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 0.0]]
Answer: True


In [19]:
isBalanced("(()(())()())")

Hiddens: [[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [2.0, 0.0], [1.0, 0.0], [2.0, 0.0], [1.0, 0.0], [2.0, 0.0], [1.0, 0.0], [0.0, 0.0]]
Answer: True


In [20]:
isBalanced("(")

Hiddens: [[0.0, 0.0], [1.0, 0.0]]
Answer: False


In [21]:
isBalanced(")")

Hiddens: [[0.0, 0.0], [-1.0, -1.0]]
Answer: False


In [22]:
isBalanced("(()")

Hiddens: [[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [1.0, 0.0]]
Answer: False


In [23]:
isBalanced("()()(())())(")

Hiddens: [[0.0, 0.0], [1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 0.0], [-1.0, -1.0], [0.0, -1.0]]
Answer: False


In [24]:
isBalanced(")(")

Hiddens: [[0.0, 0.0], [-1.0, -1.0], [0.0, -1.0]]
Answer: False


In [25]:
isBalanced("()((())(()((()())())(())(((()()())())())))))")

Hiddens: [[0.0, 0.0], [1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [2.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [2.0, 0.0], [3.0, 0.0], [4.0, 0.0], [5.0, 0.0], [4.0, 0.0], [5.0, 0.0], [4.0, 0.0], [3.0, 0.0], [4.0, 0.0], [3.0, 0.0], [2.0, 0.0], [3.0, 0.0], [4.0, 0.0], [3.0, 0.0], [2.0, 0.0], [3.0, 0.0], [4.0, 0.0], [5.0, 0.0], [6.0, 0.0], [5.0, 0.0], [6.0, 0.0], [5.0, 0.0], [6.0, 0.0], [5.0, 0.0], [4.0, 0.0], [5.0, 0.0], [4.0, 0.0], [3.0, 0.0], [4.0, 0.0], [3.0, 0.0], [2.0, 0.0], [1.0, 0.0], [0.0, 0.0], [-1.0, -1.0], [-2.0, -3.0]]
Answer: False
