In [2]:
import numpy as np
import torch
import torch.nn.functional as F
import sys

In [3]:
inputString = [2,45,30,55,10]
outputString = [45,30,55,10,1]

In [4]:
numFeatures = 100
vocabSize = 80

In [5]:
embeddings = []
for i in range(len(inputString)):
    x = np.random.randn(numFeatures,1)
    embeddings.append(x)

In [6]:
embeddings[0].shape

(100, 1)

In [7]:
len(embeddings)

5

In [10]:
def getOneHot(idx):
    one_hot = np.zeros((vocabSize,1))
    one_hot[idx] = 1
    return one_hot

In [11]:
print(getOneHot(2))

[[0.]
 [0.]
 [1.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]]


In [12]:
numUnits = 50
h0 = torch.tensor(np.zeros((numUnits,1)))
Wh = torch.tensor(np.random.uniform(0,1,(numUnits,numUnits)),requires_grad=True)
Wx = torch.tensor(np.random.uniform(0,1,(numUnits,numFeatures)),requires_grad=True)
Wy = torch.tensor(np.random.uniform(0,1,(vocabSize,numUnits)),requires_grad=True)

In [13]:
print(Wh.shape,Wx.shape,Wy.shape,h0.shape)

torch.Size([50, 50]) torch.Size([50, 100]) torch.Size([80, 50]) torch.Size([50, 1])


In [14]:
def stepForward(xt,Wx,Wh,Wy,prevMemory):
    x_frd = torch.matmul(Wx,torch.from_numpy(xt))
    h_frd = torch.matmul(Wh,prevMemory)
    ht = torch.tanh(x_frd+h_frd)
    yt_hat = F.softmax(torch.matmul(Wy,ht),dim=0)
    return ht,yt_hat

In [15]:
ht,yt_hat = stepForward(embeddings[0],Wx,Wh,Wy,h0)

In [16]:
ht.shape

torch.Size([50, 1])

In [17]:
yt_hat.shape

torch.Size([80, 1])

In [18]:
yt_hat.sum()

tensor(1., dtype=torch.float64, grad_fn=<SumBackward0>)

In [19]:
def fullForwardRNN(X,Wx,Wh,Wy,prevMemory):
    y_hat = []
    for t in range(len(X)):
        ht,yt_hat = stepForward(X[t],Wx,Wh,Wy,prevMemory)
        prevMemory = ht
        y_hat.append(yt_hat)
    return y_hat  

In [20]:
y_hat = fullForwardRNN(embeddings,Wx,Wh,Wy,h0)

In [21]:
len(y_hat)

5

In [22]:
y_hat[0].shape

torch.Size([80, 1])

In [26]:
def computeLoss(y,y_hat):
    loss = 0
    for yi,yi_hat in zip(y,y_hat):
        Li = -torch.log2(yi_hat[yi==1])
        loss += Li
    return loss/len(y)

In [27]:
y = []
for idx in outputString:
    y.append(getOneHot(idx))

In [28]:
print(computeLoss(y,y_hat))

tensor([8.6703], dtype=torch.float64, grad_fn=<DivBackward0>)


In [29]:
def updateParams(Wx,Wh,Wy,dWx,dWh,dWy,lr):
    with torch.no_grad():
        Wx -= lr*dWx
        Wh -= lr*dWh
        Wy -= lr*dWy
    return Wx,Wh,Wy

In [30]:
def trainRNN(X,y,Wx,Wh,Wy,prevMemory,lr,nepoch):
    losses = []
    for epoch in range(nepoch):
        y_hat = fullForwardRNN(X,Wx,Wh,Wy,prevMemory)
        loss = computeLoss(y,y_hat)
        loss.backward()
        losses.append(loss)
        print("Loss after epoch=%d: %f" %(epoch,loss))
        sys.stdout.flush()
        dWx = Wx.grad.data
        dWh = Wh.grad.data
        dWy = Wy.grad.data
        Wx,Wh,Wy = updateParams(Wx,Wh,Wy,dWx,dWh,dWy,lr)
        Wx.grad.data.zero_()
        Wh.grad.data.zero_()
        Wy.grad.data.zero_()
    return Wx,Wh,Wy,losses
        

In [31]:
Wx,Wh,Wy,losses = trainRNN(embeddings,y,Wx,Wh,Wy,h0,0.001,100)

Loss after epoch=0: 8.670349
Loss after epoch=1: 8.626533
Loss after epoch=2: 8.582421
Loss after epoch=3: 8.538054
Loss after epoch=4: 8.493490
Loss after epoch=5: 8.448793
Loss after epoch=6: 8.404036
Loss after epoch=7: 8.359290
Loss after epoch=8: 8.314623
Loss after epoch=9: 8.270099
Loss after epoch=10: 8.225770
Loss after epoch=11: 8.181687
Loss after epoch=12: 8.137891
Loss after epoch=13: 8.094423
Loss after epoch=14: 8.051320
Loss after epoch=15: 8.008617
Loss after epoch=16: 7.966346
Loss after epoch=17: 7.924532
Loss after epoch=18: 7.883191
Loss after epoch=19: 7.842329
Loss after epoch=20: 7.801939
Loss after epoch=21: 7.762001
Loss after epoch=22: 7.722486
Loss after epoch=23: 7.683352
Loss after epoch=24: 7.644553
Loss after epoch=25: 7.606033
Loss after epoch=26: 7.567735
Loss after epoch=27: 7.529593
Loss after epoch=28: 7.491539
Loss after epoch=29: 7.453496
Loss after epoch=30: 7.415383
Loss after epoch=31: 7.377109
Loss after epoch=32: 7.338572
Loss after epoch=33: