In [158]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize 

In [226]:
class RNN:
    def __init__ (self, numHidden, numInput, numOutput):
        self.numHidden = numHidden
        self.numInput = numInput
        self.U = np.random.randn(numHidden, numHidden) * 1e-1
        self.V = np.random.randn(numHidden, numInput) * 1e-1
        self.w = np.random.randn(numHidden) * 1e-1
        self.h = np.zeros((numHidden,51))
        self.z = np.zeros((numHidden,50))
        self.yh = np.zeros(50)
        self.g = []
        
        # TODO: IMPLEMENT ME
    def cost(self,y):
        return np.sum((1/(2*len(y)))*(self.yh-np.array(y))**2)
    def dj_du(self,t,y):
        r = 0
        for i in range(t,-1,-1):
            if i==t:
                q = np.multiply((y[i]-self.yh[i])*self.w.T,self.g[i])
            else:
                q = np.multiply(np.dot(q,self.U),self.g[i])
            r += np.dot(q,self.h[:,i].T)
        return r

    def dj_dv(self,t,y,x):
        e = 0
        for i in range(t,-1,-1):
            if i==t:
                q = np.multiply((y[i]-self.yh[i])*self.w.T,self.g[i])
            else:
                q = np.multiply(np.dot(q,self.U),self.g[i])
            e += np.multiply(q,x[i])
        return e
    
    def dj_dw(self,t,y):
        return np.dot((self.yh[t]-y[t]),self.h[:,t].T)
    def backward (self,x,y,tp,alpha):
        du,dv,dw = 0,0,0
        for i in range(tp):
            self.forward(x)
            du = self.dj_du(i,y)
            dv = self.dj_dv(i,y,x)
            dw = self.dj_dw(i,y)
            self.U -= alpha * du
            self.V -= (alpha * dv).reshape(6,1)
            self.w -= alpha * dw

    def forward (self, x):
        self.g = []
        for i in range(x.shape[0]):
            self.z[:,i] = self.U@self.h[:,i]+(self.V)@np.array([x[i]])
            self.h[:,i+1] = np.tanh(self.z[:,i])
            self.yh[i] = self.w@self.h[:,i+1]
            self.g.append(1-np.tanh(self.z[:,i])**2)
        return self.yh

In [227]:
def generateData ():
    total_series_length = 50
    echo_step = 2  # 2-back task
    batch_size = 1
    x = np.random.choice(2, total_series_length, p=[0.5, 0.5])
    y = np.roll(x, echo_step)
    y[0:echo_step] = 0
    y = list(y)
    return (x, y)

In [None]:
if __name__ == "__main__":
    xs, ys = generateData()
    numHidden = 6
    numInput = 1
    numTimesteps = len(xs)
    rnn = RNN(numHidden, numInput, 1)
    # TODO: IMPLEMENT ME
    epochs = 1000
    alpha = 0.01
    for i in range(epochs):
        loss = 0
        print("epoch: %s"% (i+1))
        rnn.backward(xs,ys,numTimesteps,alpha)
        loss = rnn.cost(ys)
        print(loss)

epoch: 1
0.2777636336705128
epoch: 2
0.28512324279344964
epoch: 3
0.2929513317581367
epoch: 4
0.300334189151646
epoch: 5
0.30625403261728207
epoch: 6
0.3098232815636195
epoch: 7
0.31040293657877244
epoch: 8
0.307390100882232
epoch: 9
0.299822168660675
epoch: 10
0.2874220327681408
epoch: 11
0.27408039094732295
epoch: 12
0.2646128957571233
epoch: 13
0.2591829202225321
epoch: 14
0.2563083523272949
epoch: 15
0.2551164712308169
epoch: 16
0.25533480040150536
epoch: 17
0.2569709480269441
epoch: 18
0.26010950447708664
epoch: 19
0.2647944040141573
epoch: 20
0.27094018626177724
epoch: 21
0.27824224680887466
epoch: 22
0.28608906987216015
epoch: 23
0.29352198864852735
epoch: 24
0.29933445797329067
epoch: 25
0.30241067553056405
epoch: 26
0.3022717726589245
epoch: 27
0.29946762537960137
epoch: 28
0.2952910899711802
epoch: 29
0.2909396373375912
epoch: 30
0.286974972894354
epoch: 31
0.28341914371521404
epoch: 32
0.2800845127190417
epoch: 33
0.2768165334730971
epoch: 34
0.27361569734069774
epoch: 35
0.

0.11961671745024532
epoch: 273
0.119622668624561
epoch: 274
0.11962860124699404
epoch: 275
0.11963451578669501
epoch: 276
0.11964041271047637
epoch: 277
0.11964629248283777
epoch: 278
0.11965215556599093
epoch: 279
0.11965800241988601
epoch: 280
0.11966383350223773
epoch: 281
0.11966964926855216
epoch: 282
0.11967545017215372
epoch: 283
0.11968123666421257
epoch: 284
0.11968700919377208
epoch: 285
0.11969276820777669
epoch: 286
0.11969851415109987
epoch: 287
0.11970424746657204
epoch: 288
0.11970996859500882
epoch: 289
0.11971567797523917
epoch: 290
0.11972137604413362
epoch: 291
0.1197270632366324
epoch: 292
0.11973273998577362
epoch: 293
0.11973840672272135
epoch: 294
0.11974406387679339
epoch: 295
0.11974971187548923
epoch: 296
0.11975535114451762
epoch: 297
0.11976098210782393
epoch: 298
0.1197666051876174
epoch: 299
0.11977222080439806
epoch: 300
0.11977782937698335
epoch: 301
0.11978343132253459
epoch: 302
0.11978902705658284
epoch: 303
0.11979461699305474
epoch: 304
0.1198002015

0.12158809113571158
epoch: 541
0.12159724050589768
epoch: 542
0.12160634919241275
epoch: 543
0.12161541563757138
epoch: 544
0.121624438273045
epoch: 545
0.12163341552031959
epoch: 546
0.12164234579116613
epoch: 547
0.12165122748812474
epoch: 548
0.12166005900500154
epoch: 549
0.12166883872737871
epoch: 550
0.12167756503313705
epoch: 551
0.12168623629299138
epoch: 552
0.1216948508710386
epoch: 553
0.12170340712531769
epoch: 554
0.12171190340838219
epoch: 555
0.12172033806788436
epoch: 556
0.12172870944717123
epoch: 557
0.12173701588589184
epoch: 558
0.12174525572061609
epoch: 559
0.121753427285464
epoch: 560
0.12176152891274601
epoch: 561
0.12176955893361353
epoch: 562
0.12177751567871924
epoch: 563
0.12178539747888742
epoch: 564
0.12179320266579342
epoch: 565
0.12180092957265197
epoch: 566
0.12180857653491435
epoch: 567
0.12181614189097353
epoch: 568
0.12182362398287706
epoch: 569
0.1218310211570474
epoch: 570
0.12183833176500905
epoch: 571
0.12184555416412213
epoch: 572
0.121852686718

0.12069631050542053
epoch: 809
0.1206885373026814
epoch: 810
0.1206808181614948
epoch: 811
0.12067315325306326
epoch: 812
0.1206655427347315
epoch: 813
0.12065798675020686
epoch: 814
0.12065048542978019
epoch: 815
0.1206430388905474
epoch: 816
0.12063564723663127
epoch: 817
0.12062831055940347
epoch: 818
0.12062102893770692
epoch: 819
0.12061380243807766
epoch: 820
0.12060663111496735
epoch: 821
0.1205995150109648
epoch: 822
0.12059245415701779
epoch: 823
0.12058544857265402
epoch: 824
0.12057849826620193
epoch: 825
0.12057160323501068
epoch: 826
0.12056476346566944
epoch: 827
0.12055797893422598
epoch: 828
0.1205512496064044
epoch: 829
0.12054457543782175
epoch: 830
0.12053795637420381
epoch: 831
0.12053139235159956
epoch: 832
0.12052488329659462
epoch: 833
0.12051842912652343
epoch: 834
0.12051202974967994
epoch: 835
0.12050568506552699
epoch: 836
0.12049939496490435
epoch: 837
0.12049315933023504
epoch: 838
0.12048697803573018
epoch: 839
0.12048085094759221
epoch: 840
0.120474777924