In [1]:
import numpy as np
from MiniTorch.core.baseclasses import ComputationNode
from MiniTorch.nets.layers import Linear, SoftMax, Tanh
from MiniTorch.nets.base import Parameter, Net
import jax
import jax.numpy as jnp
from MiniTorch.losses import CCE
from MiniTorch.optimizers import AdaGrad
from MiniTorch.nets.layers import RNN

In [2]:
class RNN(ComputationNode):

    def __init__(self, hidden_size, embed_dim,seq_len, accumulate_grad_norm=False, accumulate_params=False, initialization="xavier"):
        super().__init__()
        self.h_size = hidden_size
        self.emb_size = embed_dim
        self.accumulate_grad_norm = accumulate_grad_norm
        self.accumulate_params = accumulate_params
        self.ini = initialization
        self.parameters = {
            'Wx' : None,
            'Wh' : None,
            'Wy' : None,
            'bh' : None,
            'by' : None
        }
        self.tanh = Tanh()

    def initialize(self, seed_key):
        import jax.random as jrandom
        k1, k2, k3 = jrandom.split(seed_key, 3)
        self.parameters['Wx'] = Parameter((self.emb_size,self.h_size),self.ini,k1)
        self.parameters['Wh'] = Parameter((self.h_size, self.h_size),self.ini,k2)
        self.parameters['Wy'] = Parameter((self.h_size, self.emb_size),self.ini,k3)
        self.parameters['bh'] = Parameter((1, self.h_size),initialization=None, seed_key=None,is_bias=True)
        self.parameters['by'] = Parameter((1, self.emb_size), initialization=None, seed_key=None,is_bias=True)
    
    @staticmethod
    def _rnn_forward(X, H_prev, Wx, Wh, Wy, bh, by,tanh):
        H_next = (X @ Wx) + (H_prev @ Wh) + bh
        H_next = tanh.forward(H_next)
        out = H_next @ Wy + by
        return H_next, out
    
    def forward(self, X, inference=False):
        self.batch_size, self.seq_len, emb_size = X.shape
        self.h_states = [jnp.zeros((self.batch_size,self.h_size))]
        self.inp_states = []
        self.out_states = []
        X = jnp.transpose(X, (1, 0, 2))
        for i in range(seq_len):
            x = X[i,:,:]
            self.inp_states.append(x)
            H_next, out = self._rnn_forward(x, self.h_states[-1],self.parameters['Wx'].param,self.parameters['Wh'].param,self.parameters['Wy'].param,self.parameters['bh'].param,self.parameters['by'].param,self.tanh)
            self.h_states.append(H_next)
            self.out_states.append(out)
        out_states = jnp.transpose(jnp.array(self.out_states),(1,0,2))
        if inference:
            return out_states[:,-1,:]
        return out_states.reshape(self.batch_size*self.seq_len, emb_size)
    def __call__(self, inference=False):
        pass
    def backward(self, out_grad):
        out_grad = out_grad.reshape(self.batch_size, self.seq_len, self.emb_size)
        out_grad = jnp.transpose(out_grad,(1, 0, 2))
        dh_next = jnp.zeros_like(self.h_states[-1])
        Wx, Wh, Wy, bh, by = self.parameters['Wx'].param,self.parameters['Wh'].param,self.parameters['Wy'].param,self.parameters['bh'].param,self.parameters['by'].param
        self.parameters['Wx'].grad = jnp.zeros_like(Wx)
        self.parameters['Wh'].grad = jnp.zeros_like(Wh)
        self.parameters['Wy'].grad = jnp.zeros_like(Wy)
        self.parameters['bh'].grad = jnp.zeros_like(bh)
        self.parameters['by'].grad = jnp.zeros_like(by)
        self.dL_dinput = []
        for t in reversed(range(out_grad.shape[0])):
            self.parameters['Wy'].grad += self.h_states[t].T @ out_grad[t]
            self.parameters['by'].grad += jnp.sum(out_grad[t], axis=0)
            dht = out_grad[t] @ Wy + dh_next
            dth = self.tanh.backward(dht)
            self.parameters['Wx'].grad += self.inp_states[t].T @ dth
            self.parameters['Wh'].grad += self.h_states[t-1].T @ dth
            self.parameters['bh'].grad += jnp.sum(dth,axis=0)
            dh_next = dth @ Wh
            dinput = dth @ Wx
            self.dL_dinput.append(dinput)
        return jnp.array(reversed(self.dL_dinput))

In [2]:
corpus = [
    # Common words
    "hello", "world", "jax", "rocks", "gradient", "optimizer", "learning", "neural", "network", 
    "python", "model", "forward", "backward", "activation", "function", "training", "epoch", 
    "tensor", "loss", "update", "bias", "weight", "batch", "input", "output", "sequence", 
    "recurrent", "long", "short", "memory", "unit", "layer", "hidden", "state", "compute",
    # Names
    "alice", "bob", "charlie", "david", "eve", "frank", "grace", "heidi", "ivan", "judy",
    "mallory", "oscar", "peggy", "trent", "victor", "wendy",
    # Technical extras
    "sigmoid", "tanh", "relu", "dropout", "regularization", "epoch", "loss", "accuracy",
]


In [3]:
special_tokens = ["<PAD>", "<EOS>"]
chars = sorted(list(set("".join(corpus)))) + special_tokens
char2idx = {ch: i for i, ch in enumerate(chars)}
idx2char = {i: ch for ch, i in char2idx.items()}
vocab_size = len(chars)

PAD_IDX = char2idx["<PAD>"]
EOS_IDX = char2idx["<EOS>"]

print(f"Vocab Size: {vocab_size}")
print(f"PAD index: {PAD_IDX}, EOS index: {EOS_IDX}")
def encode_sequence(seq):
    """Encodes a string + appends EOS"""
    return [char2idx[ch] for ch in seq] + [EOS_IDX]

def decode_sequence(indices):
    """Decodes indices, ignoring PAD and EOS tokens"""
    chars_out = [idx2char[int(i)] for i in indices if idx2char[int(i)] not in ("<PAD>", "<EOS>")]
    return "".join(chars_out)


Vocab Size: 28
PAD index: 26, EOS index: 27


In [4]:
def prepare_sequences(corpus):
    X_data, y_data = [], []
    for seq in corpus:
        encoded = encode_sequence(seq)
        X_data.append(encoded[:-1])  # input
        y_data.append(encoded[1:])   # shifted output
    return X_data, y_data

def pad_sequences(seqs, pad_value=PAD_IDX):
    max_len = max(len(s) for s in seqs)
    return jnp.array([s + [pad_value]*(max_len - len(s)) for s in seqs])

def one_hot_encode(y, num_classes):
    return jnp.eye(num_classes)[y]


def batch_iterator(X_data, y_data, batch_size=8):
    n = len(X_data)
    for i in range(0, n, batch_size):
        X_batch = X_data[i:i+batch_size]
        y_batch = y_data[i:i+batch_size]
        X_batch = pad_sequences(X_batch)
        y_batch = pad_sequences(y_batch)
        y_one_hot = one_hot_encode(y_batch, len(chars))
        X_batch_oh = one_hot_encode(X_batch, len(chars))
        yield X_batch_oh, y_one_hot

In [5]:
def decode_output(output, idx_to_char, pad_idx=None, apply_softmax=True):
    """
    Decodes one-hot or softmax outputs from the RNN into readable strings.

    Args:
        output: jnp.ndarray of shape (batch, seq_len, vocab_size)
                — may be raw logits or already probabilities
        idx_to_char: dict mapping integer index → character
        pad_idx: index of the <PAD> token (optional)
        apply_softmax: whether to apply softmax before decoding

    Returns:
        List[str]: Decoded strings for each batch
    """
    if apply_softmax:
        # numerically stable softmax
        exp_x = jnp.exp(output - jnp.max(output, axis=-1, keepdims=True))
        output = exp_x / jnp.sum(exp_x, axis=-1, keepdims=True)

    # Get the most likely token index at each timestep
    token_indices = jnp.argmax(output, axis=-1)  # shape: (batch, seq_len)

    decoded_strings = []
    for seq in token_indices:
        chars = []
        for idx in seq:
            if pad_idx is not None and idx == pad_idx:
                continue  # skip padding
            chars.append(idx_to_char[int(idx)])
        decoded_strings.append("".join(chars))

    return decoded_strings

In [None]:

model = Net([RNN(
    100,
    28,
    4
)], 22)


In [18]:
cce = CCE()
optimizer = AdaGrad(0.01, model)
epochs = 10

In [19]:
X_data, y_data = prepare_sequences(corpus)
for epoch in range(epochs):
    ep_loss = 0
    for X_batch, y_batch in batch_iterator(X_data, y_data, batch_size=4):
        if X_batch.shape[0] != 4:
            break
        outputs = model.forward(X_batch)
        b, sq, dim = y_batch.shape
        y_batch = y_batch.reshape(b*sq,dim)
        loss = cce.loss(outputs, y_batch)
        ini_grad = cce.backward()
        optimizer.step(ini_grad)
        ep_loss += loss
    print(f"Epoch {epoch}: loss->{ep_loss/4}")

Epoch 0: loss->8.52025032043457
Epoch 1: loss->8.138164520263672
Epoch 2: loss->8.122602462768555
Epoch 3: loss->8.245694160461426
Epoch 4: loss->8.468669891357422
Epoch 5: loss->8.71008014678955
Epoch 6: loss->8.975908279418945
Epoch 7: loss->9.315447807312012
Epoch 8: loss->9.673639297485352
Epoch 9: loss->10.086681365966797


In [21]:
y_batch.reshape(4*5,-1)

Array([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0

In [23]:
X_batch.shape

(4, 9, 28)

In [22]:
jnp.tanh(0) * (1/(1+jnp.exp(-2)))

Array(0., dtype=float32, weak_type=True)

In [None]:
def lrp_backward(Wb, R_out, X, d, eps=0.001, bias_factor=0.0):
        n_layers = len(Wb)//4
        b, T, _ = X.shape
        device = X.device
        X = X.permute(1,0,2)
        
        Rgates = torch.zeros((n_layers, T, b, d)).to(device) #following the candidate gate take all method supported by Deep Taylor Decomposition.
        Rh_states = torch.zeros((n_layers, T+1, b, d)).to(device)
        Rc_states = torch.zeros((n_layers, T+1, b, d)).to(device)
        Rh_states[-1, T-1] = R_out
        Rx = torch.zeros((T, b, d))
        # format reminder: lrp_linear(hin, w, b, hout, Rout, bias_nb_units, eps, bias_factor)
        for n in reversed(range(n_layers)):
            Wih, Whh, bih, bhh = Wb[f"weight_ih_l{n}"].T, Wb[f"weight_hh_l{n}"].T, Wb[f"bias_ih_l{n}"].T, Wb[f"bias_hh_l{n}"].T
            idx_no_gx = torch.concat([torch.arange(0, d), torch.arange(d, 2*d), torch.arange(d*3, d*4)]) # indices for everything except mem (candidate) gate
            ix, fx, gx, ox = torch.arange(0, d), torch.arange(d, 2*d), torch.arange(d*2, d*3), torch.arange(d*3, d*4) # indices for inp, forget, candidate and out gates
            
            for t in reversed(range(T)):
                Rc_states[n, t] += Rh_states[n, t]
                Rc_states[n, t-1] = lrp_linear(gates[n, t, :, fx] * c_states[n, t-1], torch.eye(d).to(device), torch.zeros((d)).to(device), c_states[n, t], Rc_states[n, t],d,eps,bias_factor)
                Rgates[n, t] = lrp_linear(gates[n, t, :, gx]*gates[n, t, :, ix], torch.eye(d).to(device), torch.zeros((d)).to(device), c_states[n, t], Rc_states[n, t],d,eps,bias_factor)
                if n == 0:
                    x = X[t]
                    Rx[t] = lrp_linear(x,Wih[gx],bih[gx]+bhh[gx],pre_act[n, t, :, gx], Rgates[n, t],d+d,eps,bias_factor)
                else:
                    x = h_states[n-1, t]
                    Rh_states[n-1, t] = lrp_linear(x,Wih[gx],bih[gx]+bhh[gx],pre_act[n, t, :, gx], Rgates[n, t],d+d,eps,bias_factor)
                Rh_states[n, t-1] = lrp_linear(h_states[n, t-1],Whh[gx],bih[gx]+bhh[gx],pre_act[n, t, :, gx], Rgates[n, t],d+d,eps,bias_factor)
                
        return Rx.permute(1, 0, 2)
                    
                
                
        
        