In [21]:
import numpy as np
from MiniTorch.core.baseclasses import ComputationNode
from MiniTorch.nets.layers import Linear, SoftMax, Tanh
from MiniTorch.nets.base import Parameter
import jax
import jax.numpy as jnp
from MiniTorch.losses import CCE

In [17]:
key = jax.random.PRNGKey(0)

In [12]:
wx, wy, wk = jax.random.split(key, 3)

In [18]:
x = jax.random.normal(key,(4,3,4))

In [20]:
x

Array([[[ 1.6226422 ,  2.0252647 , -0.43359444, -0.07861735],
        [ 0.1760909 , -0.97208923, -0.49529874,  0.4943786 ],
        [ 0.6643493 , -0.9501635 ,  2.1795304 , -1.9551506 ]],

       [[ 0.35857072,  0.15779513,  1.2770847 ,  1.5104648 ],
        [ 0.970656  ,  0.59960806,  0.0247007 , -1.9164772 ],
        [-1.8593491 ,  1.728144  ,  0.04719035,  0.814128  ]],

       [[ 0.13132767,  0.28284705,  1.2435943 ,  0.6902801 ],
        [-0.80073744, -0.74099   , -1.5388287 ,  0.30269185],
        [-0.02071605,  0.11328721, -0.2206547 ,  0.07052256]],

       [[ 0.8532958 , -0.8217738 , -0.01461421, -0.15046217],
        [-0.9001352 , -0.7590727 ,  0.33309513,  0.80924904],
        [ 0.04269255, -0.57767123, -0.41439894, -1.9412533 ]]],      dtype=float32)

In [22]:
k = jnp.transpose(x, (1, 0, 2))[0, :, :]

In [24]:
k.shape

(4, 4)

In [8]:
jnp.zeros((1,5))

Array([[0., 0., 0., 0., 0.]], dtype=float32)

In [None]:
class RNN(ComputationNode):

    def __init__(self, hidden_size, embed_dim,batch_size=1, accumulate_grad_norm=False, accumulate_params=False, initialization="xavier"):
        super().__init__()
        self.h_size = hidden_size
        self.emb_size = embed_dim
        self.accumulate_grad_norm = accumulate_grad_norm
        self.accumulate_params = accumulate_params
        self.ini = initialization
        self.parameters = {
            'Wx' : None,
            'Wh' : None,
            'Wy' : None,
            'bh' : None,
            'by' : None
        }
        self.h_states = [jnp.zeros((batch_size,hidden_size))]
        self.inp_states = []
        self.out_states = []
        self.tanh = Tanh()

    def initialize(self, seed_key):
        import jax.random as jrandom
        k1, k2, k3 = jrandom.split(seed_key, 3)
        self.parameters['Wx'] = Parameter((self.emb_size,self.h_size),self.ini,k1)
        self.parameters['Wh'] = Parameter((self.h_size, self.h_size),self.ini,k2)
        self.parameters['Wy'] = Parameter((self.h_size, self.emb_size),self.ini,k3)
        self.parameters['bh'] = Parameter((1, self.h_size),initialization=None, seed_key=None,is_bias=True)
        self.parameters['by'] = Parameter((1, self.emb_size), initialization=None, seed_key=None,is_bias=True)
    
    @staticmethod
    def _rnn_forward(X, H_prev, Wx, Wh, Wy, bh, by,tanh):
        H_next = (X @ Wx) + (H_prev @ Wh) + bh
        H_next = tanh.forward(H_next)
        out = H_next @ Wy + by
        return H_next, out
    
    def forward(self, X, inference=False):
        self.batch_size, self.seq_len, emb_size = X.shape
        X = jnp.transpose(X, (1, 0, 2))
        for i in range(seq_len):
            x = X[i,:,:]
            self.inp_states.append(x)
            H_next, out = self._rnn_forward(x, self.h_states[-1],self.parameters['Wx'].param,self.parameters['Wh'].param,self.parameters['Wy'].param,self.parameters['bh'].param,self.parameters['by'].param,self.tanh)
            self.h_states.append(H_next)
            self.out_states.append(out)
        out_states = jnp.transpose(jnp.array(self.out_states),(1,0,2))
        if inference:
            return out_states[:,-1,:]
        out_states = out_states.reshape(self.batch_size*self.seq_len, emb_size)
        return self.out_states
    def __call__(self, inference=False):
        pass
    def backward(self, out_grad):
        out_grad = out_grad.reshape(self.batch_size, self.seq_len, self.emb_size)
        out_grad = jnp.transpose(out_grad,(1, 0, 2))
        dh_next = jnp.zeros_like(self.h_states[-1])
        Wx, Wh, Wy, bh, by = self.parameters['Wx'].param,self.parameters['Wh'].param,self.parameters['Wy'].param,self.parameters['bh'].param,self.parameters['by'].param
        self.parameters['Wx'].grad = jnp.zeros_like(Wx)
        self.parameters['Wh'].grad = jnp.zeros_like(Wh)
        self.parameters['Wy'].grad = jnp.zeros_like(Wy)
        self.parameters['bh'].grad = jnp.zeros_like(bh)
        self.parameters['by'].grad = jnp.zeros_like(by)
        for t in reversed(range(out_grad.shape[0])):
            self.parameters['Wy'].grad += self.h_states[t].T @ out_grad[t]
            self.parameters['by'].grad += jnp.sum(out_grad[t], axis=0)
            dht = out_grad[t] @ Wy + dh_next
            dth = self.tanh.backward(dht)
            self.parameters['Wx'].grad += self.inp_states[t].T @ dth
            self.parameters['Wh'].grad += self.h_states[t-1].T @ dth
            self.parameters['bh'].grad += jnp.sum(dth,axis=0)
            dh_next = dth @ Wh
            dinput = dth @ Wx
        
            

        
    

In [103]:
import jax.numpy as jnp
import jax.random as jrandom

# Set up toy dimensions
batch_size = 2
seq_len = 3
emb_size = 4
hidden_size = 4

# Random seed
key = jrandom.PRNGKey(0)
num_samples, num_classes = 6, 4

# Random class indices between 0 and 3
labels = jrandom.randint(key, (num_samples,), 0, num_classes)

# Convert to one-hot
one_hot_labels = jnp.eye(num_classes)[labels]

# Create random batch input
# Shape: (batch_size, sequence_length, embedding_size)
X = jrandom.normal(key, (batch_size, seq_len, emb_size))
print("Input shape:", X.shape)

Input shape: (2, 3, 4)


In [119]:
# Create the RNN instance
rnn = RNN(hidden_size=hidden_size,
          embed_dim=emb_size,   # matches input embedding dimension
          batch_size=2,
          accumulate_grad_norm=False,
          accumulate_params=False,
          initialization="xavier")

# Initialize weights
rnn.initialize(key)

# Run forward pass
outputs = rnn.forward(X)

# Inspect results
print(f"\nNumber of timesteps processed: {len(outputs)}")
for t, out in enumerate(outputs):
    print(f"t={t} | output shape: {out.shape}")



Number of timesteps processed: 3
t=0 | output shape: (2, 4)
t=1 | output shape: (2, 4)
t=2 | output shape: (2, 4)


In [33]:
def one_hot(idx, size):
    v = jnp.zeros((1, size))
    v = v.at[0,idx].add(1.)
    return v

In [34]:
one_hot(3, 24)

Array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)

In [120]:
outs = jnp.transpose(jnp.array(outputs),(1,0,2))
b, s, s_= outs.shape

In [121]:
outs_ = outs.reshape(b*s,s_)

In [122]:
outs_

Array([[ 0.4122036 , -0.27509335, -0.519339  , -0.26765275],
       [ 0.24536178,  0.0511681 ,  0.32255504, -0.05431895],
       [ 0.27392945,  1.4398806 , -0.09398943, -0.80262685],
       [-0.21694078, -0.14914781, -0.28904817,  0.6382464 ],
       [ 0.95202035,  1.5464456 , -0.3153327 , -1.6319971 ],
       [-0.02829501, -1.2677511 , -0.27899325,  0.5712544 ]],      dtype=float32)

In [123]:
loss = CCE()

In [124]:
loss.loss(outs_,one_hot_labels)

Array(1.7700107, dtype=float32)

In [125]:
ini_grad = loss.backward()

In [126]:
out_grad = ini_grad.reshape(b,s,s_)
out_grad = jnp.transpose(out_grad,(1, 0, 2))
dh_next = jnp.zeros_like(rnn.h_states[-1])
h_states = rnn.h_states
Wx, Wh, Wy, bh, by = rnn.parameters['Wx'].param,rnn.parameters['Wh'].param,rnn.parameters['Wy'].param,rnn.parameters['bh'].param,rnn.parameters['by'].param
in_states = rnn.inp_states
tanh = rnn.tanh

In [None]:
self.parameters['Wx'].grad = jnp.zeros_like(Wx)
self.parameters['Wh'].grad = jnp.zeros_like(Wh)
self.parameters['Wy'].grad = jnp.zeros_like(Wy)
self.parameters['bh'].grad = jnp.zeros_like(bh)
self.parameters['by'].grad = jnp.zeros_like(by)

In [127]:

for i in reversed(range(out_grad.shape[0])):
    do = out_grad[i]
    ht = h_states[i]
    dWhy += ht.T @ do
    dby += jnp.sum(do, axis=0)
    dht = (do @ Wy) + dhnext
    dth = tanh.backward(dht)
    dWxh += in_states[i].T @ dth
    print(h_states[i-1].T.shape)
    dWhh += h_states[i-1].T @ dth
    dbh += jnp.sum(dth, axis=0)
    dhnext = dth @ Wh
    dinput = dth @ Wx



(4, 2)
(4, 2)
(4, 2)


In [117]:
for i in rnn.h_states:
    print(i.shape)

(1, 4)
(2, 4)
(2, 4)
(2, 4)


In [90]:
dWxh

Array([[-9.0503879e-02, -1.7025879e-01,  9.5224120e-03,  6.4861238e-01],
       [ 8.7552235e-02,  1.5866458e-01, -8.9318929e-03, -6.8610358e-01],
       [-2.0378808e-02,  1.5476720e-03,  2.9591180e-04,  5.3321075e-01],
       [ 5.6809727e-02,  7.6650463e-02, -4.5767957e-03, -7.0049977e-01]],      dtype=float32)

In [79]:
jnp.sum(out_grad[0],axis=0)

Array([ 0.6028804 , -1.5908179 ,  0.33773494,  0.6502025 ], dtype=float32)

In [58]:
h_states[0]

Array([[0., 0., 0., 0., 0.]], dtype=float32)

In [83]:
Wy.T.shape, out_grad[-1].shape

((4, 4), (2, 4))

In [91]:
in_states[0].shape

(2, 4)