In [337]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable

token = {
    '\0' : 0,
    '.' : 1, 
    ',' : 2, 
    '[' : 3, 
    ']' : 4, 
    '<' : 5, 
    '>' : 6, 
    '+' : 7, 
    '-' : 8,
    "START" : 9
    }

char = {
    0 : '\0',
    1 : '.',
    2 : ',', 
    3 : '[', 
    4 : ']',  
    5 : '<',  
    6 : '>', 
    7 : '+',  
    8 : '-'
    # no START on purpose
    }

class BFgen(nn.Module):
    def __init__(self, input_size, embedding_dim, hidden_size, output_size, n_layers=2, batch_size=1):
        super(BFgen, self).__init__()
        self.input_size = input_size
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.batch_size = batch_size
        
        self.encoder = nn.Embedding(input_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, n_layers, batch_first=True)
        self.decoder = nn.Linear(hidden_size, output_size)
        self.softmax = nn.functional.softmax
        

    """
    forward
    Takes input_token and hidden memory state <- input to recursive layer
    returns output token and changed hidden memory state.
    """
    def forward(self, input_token, hidden):
        embeds = self.encoder(input_token)
        output, hidden = self.lstm(embeds, hidden)
        output = self.decoder(output.view(self.batch_size, -1))
        output = self.softmax(output) # in paper its multinomial distribution
        return output, hidden
    
    def init_hidden_zero(self):
        self.hidden = (Variable(torch.zeros(self.n_layers, self.batch_size, self.hidden_size)),
                      Variable(torch.zeros(self.n_layers, self.batch_size, self.hidden_size)),)
    
    def init_hidden_normal(self):
        means = torch.zeros(self.n_layers, self.batch_size, self.hidden_size)
        std = torch.Tensor([0.001]*self.hidden_size*self.n_layers*self.batch_size).unsqueeze(0)
        self.hidden = (Variable(torch.normal(means, std)), Variable(torch.normal(means, std)))

    def evaluate(self, predict_len=100):
        input_token = token["START"]
        hidden = self.init_hidden_zero
        prediction = ""

        for i in range(predict_len):
            output_token, hidden = self.forward(input_token, hidden)
            input_token = output_token

            prediction += char[output_token]
            if output_token == '\0':
                break

        return prediction



In [338]:
embedding_size = 10
hidden_size = 35
output_size = 9
n_layers = 2

token_num = len(token.keys())

In [339]:
batch_num = 64
model = BFgen(10, embedding_size, hidden_size, output_size, n_layers, batch_num)

In [340]:
model.init_hidden_normal()

In [341]:
print(model.hidden)

(Variable containing:
(0 ,.,.) = 
 -1.4069e-03 -1.4254e-04  1.2972e-03  ...   9.3532e-04 -6.7219e-04  1.7171e-05
  1.5643e-03 -2.7541e-04 -3.2768e-04  ...  -1.7917e-03 -2.1055e-04 -6.8826e-04
 -1.0524e-03 -8.5756e-04 -1.3397e-03  ...   9.3815e-04 -9.7098e-04  1.3971e-03
                 ...                   ⋱                   ...                
  7.1501e-05  6.9305e-04  1.3175e-03  ...  -2.6004e-04  2.2829e-05  1.0682e-03
  4.2999e-04 -1.1089e-03  5.8171e-05  ...  -1.5347e-03  5.3940e-04 -1.1950e-03
  8.6621e-04  2.3354e-03  1.4715e-03  ...  -5.4765e-04  4.1961e-04 -2.3232e-04

(1 ,.,.) = 
 -1.0675e-03  3.6281e-04 -8.9925e-04  ...  -1.9488e-03  1.3730e-04 -1.3119e-03
  4.4776e-04  1.0271e-03 -2.0652e-03  ...  -1.0729e-04  2.7098e-04 -8.4759e-05
  4.4063e-05  4.1087e-05 -1.5224e-03  ...  -1.5179e-03  1.8758e-03 -6.6197e-04
                 ...                   ⋱                   ...                
  3.5258e-05  1.4627e-03  1.5303e-03  ...  -1.2200e-03 -8.2137e-04 -4.4383e-04
 -7.5

In [342]:
def token_to_tensor(input_token):
    tensor = torch.zeros(1, token_num).long()
    tensor[0][token[input_token]] = 1
    return tensor

In [343]:
input_sample = token_to_tensor(">")
batched_input = torch.zeros((1, 64, token_num)).long()
batched_input = batched_input + input_sample

batch_input_sizes = [token_num] * batch_num
print batched_input


(0 ,.,.) = 
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   0   0   0   0   0   0   1   0   0   0
   

In [344]:

batched_input = Variable(batched_input.view(batch_num, -1))#= nn.utils.rnn.pack_padded_sequence(Variable(batched_input), batch_input_sizes)

In [345]:
print batched_input


Variable containing:
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0  

In [346]:
emb = nn.Embedding(token_num, embedding_size)

In [347]:
embeds = emb(batched_input)
print embeds

Variable containing:
(0 ,.,.) = 
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747
           ...             ⋱             ...          
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747

(1 ,.,.) = 
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747
           ...             ⋱             ...          
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747

(2 ,.,.) = 
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747
 -0.3275 -1.7213  0.1308  ...  -1.2300 -0.8907 -1.8747
      

In [348]:
out, hid = model.lstm(embeds, model.hidden)

In [349]:
print out

Variable containing:
(0 ,.,.) = 
 -2.2180e-02  4.2906e-03 -3.3553e-02  ...  -3.0389e-02  1.3468e-02 -3.5761e-02
 -4.0556e-02  5.2905e-03 -5.5761e-02  ...  -3.6525e-02  1.4040e-02 -6.2871e-02
 -5.3736e-02  4.8165e-03 -6.9306e-02  ...  -3.6038e-02  9.5348e-03 -8.3626e-02
                 ...                   ⋱                   ...                
 -4.4544e-02 -5.8202e-03 -8.8912e-02  ...  -7.2345e-02  1.1444e-03 -9.7088e-02
 -5.2288e-02 -4.0085e-03 -8.9549e-02  ...  -6.3323e-02 -8.7049e-04 -1.0026e-01
 -6.0355e-02 -2.9574e-03 -8.9246e-02  ...  -5.4888e-02 -2.5100e-03 -1.0559e-01

(1 ,.,.) = 
 -2.1838e-02  4.2149e-03 -3.3200e-02  ...  -3.1019e-02  1.3527e-02 -3.5435e-02
 -4.0280e-02  5.2634e-03 -5.5643e-02  ...  -3.6972e-02  1.4114e-02 -6.2785e-02
 -5.3522e-02  4.8131e-03 -6.9254e-02  ...  -3.6359e-02  9.5688e-03 -8.3659e-02
                 ...                   ⋱                   ...                
 -4.4503e-02 -5.8148e-03 -8.8894e-02  ...  -7.2393e-02  1.1349e-03 -9.7137e-02
 -5.22

In [350]:
out = model.decoder(out).view(10,64,9)

In [351]:
print out

Variable containing:
(0 ,.,.) = 
  0.0744  0.0885  0.1237  ...   0.0908  0.1713 -0.1305
  0.1045  0.1041  0.1243  ...   0.0788  0.1893 -0.1247
  0.1233  0.1156  0.1247  ...   0.0700  0.2022 -0.1202
           ...             ⋱             ...          
  0.1042  0.1041  0.1243  ...   0.0789  0.1893 -0.1249
  0.1232  0.1155  0.1247  ...   0.0701  0.2021 -0.1204
  0.1350  0.1236  0.1248  ...   0.0645  0.2109 -0.1176

(1 ,.,.) = 
  0.1425  0.1293  0.1247  ...   0.0611  0.2167 -0.1159
  0.1471  0.1334  0.1246  ...   0.0591  0.2205 -0.1150
  0.1400  0.1233  0.1196  ...   0.0732  0.1999 -0.1404
           ...             ⋱             ...          
  0.1471  0.1333  0.1246  ...   0.0591  0.2205 -0.1150
  0.1400  0.1233  0.1196  ...   0.0732  0.1999 -0.1404
  0.1379  0.1241  0.1202  ...   0.0712  0.2030 -0.1415

(2 ,.,.) = 
  0.1400  0.1264  0.1209  ...   0.0678  0.2071 -0.1378
  0.1430  0.1291  0.1217  ...   0.0645  0.2112 -0.1325
  0.0743  0.0883  0.1238  ...   0.0908  0.1710 -0.1310
      

In [354]:
res = model.softmax(out[-1])
print res

Variable containing:
 0.1175  0.1156  0.1151  0.1207  0.0947  0.1129  0.1099  0.1248  0.0888
 0.1172  0.1156  0.1152  0.1210  0.0949  0.1127  0.1097  0.1251  0.0887
 0.1173  0.1157  0.1151  0.1210  0.0948  0.1125  0.1091  0.1255  0.0889
 0.1175  0.1159  0.1150  0.1209  0.0946  0.1125  0.1086  0.1258  0.0892
 0.1115  0.1130  0.1172  0.1210  0.0990  0.1115  0.1133  0.1228  0.0908
 0.1140  0.1139  0.1163  0.1209  0.0971  0.1121  0.1111  0.1240  0.0906
 0.1156  0.1147  0.1157  0.1206  0.0958  0.1124  0.1096  0.1250  0.0906
 0.1166  0.1152  0.1154  0.1204  0.0951  0.1125  0.1086  0.1257  0.0905
 0.1172  0.1156  0.1151  0.1202  0.0946  0.1125  0.1080  0.1262  0.0905
 0.1176  0.1160  0.1149  0.1201  0.0943  0.1125  0.1077  0.1265  0.0905
 0.1175  0.1156  0.1151  0.1207  0.0947  0.1129  0.1099  0.1248  0.0888
 0.1172  0.1156  0.1152  0.1210  0.0949  0.1126  0.1097  0.1251  0.0886
 0.1173  0.1157  0.1151  0.1210  0.0948  0.1125  0.1092  0.1255  0.0889
 0.1175  0.1159  0.1150  0.1209  0.0946  0.

  """Entry point for launching an IPython kernel.


In [367]:
def pred(sample):
    return char[np.argsort(sample)[-1]]

In [381]:
s = res.data[0].numpy()
s.shape

(9,)

In [382]:
pred(s)

'+'

In [None]:
np.apply_along_axis(pred, axis = 0, )