# []

In [None]:
# inbuilt 
import os
import sys
import math

# most common
import numpy as np
import matplotlib.pyplot as plt

# pytorch
import torch as tt
import torch.nn as nn
import torch.optim as oo
import torch.functional as ff
import torch.distributions as dd
import torch.utils.data as ud

# custom
import known
from known.basic import pj
from known.basic.common import Verbose as verb
import known.ktorch as kt

print(f'{sys.version=}\n{np.__version__=}\n{tt.__version__=}\n{known.__version__=}')

tt.manual_seed(281703975047300) # manually sets a seed for random sampling creation ops
print('Manual-Seed:', tt.initial_seed()) # current seed for default rng

# Single Layer RNNs

In [None]:
input_size = 3
hidden_size = 8
has_bias = True
dtype=tt.float32
device = None

i2o_size = 6
i2o_act = None
o2o_size = 4
o2o_act = None


batch_size = 5
X = tt.rand((batch_size, input_size), dtype=dtype)

## Create Cells

Each RNNCell contains a core cell

In [None]:
cells = dict(
    elman = kt.ELMANCell(
        input_size=input_size, hidden_size=hidden_size,
        hidden_bias=has_bias, hidden_activation=None, # none for default
        dtype=dtype, device=device,
    ),

    gru = kt.GRUCell(
        input_size=input_size, hidden_size=hidden_size,
        hidden_bias=has_bias, hidden_activation=None, # none for default
        update_bias=has_bias, update_activation=None, 
        reset_bias=has_bias, reset_activation=None, 
        cell_type=0, # 0 to use input-weights at reset and forget gates
        dtype=dtype, device=device,
    ),
    
    mgu = kt.MGUCell(
        input_size=input_size, hidden_size=hidden_size,
        hidden_bias=has_bias, hidden_activation=None, # none for default
        forget_bias=has_bias, forget_activation=None, 
        dtype=dtype, device=device,
    ),

    lstm = kt.LSTMCell(
        input_size=input_size, hidden_size=hidden_size,
        hidden_activation=None,
        input_bias=has_bias, input_activation=None, 
        forget_bias=has_bias, forget_activation=None, 
        output_bias=has_bias, output_activation=None, 
        cell_bias=has_bias, cell_activation=None, 
        dtype=dtype, device=device,
    ),

    plstm = kt.PLSTMCell(
        input_size=input_size, hidden_size=hidden_size,
        hidden_activation=None,
        input_bias=has_bias, input_activation=None, 
        forget_bias=has_bias, forget_activation=None, 
        output_bias=has_bias, output_activation=None, 
        cell_bias=has_bias, cell_activation=None, 
        dtype=dtype, device=device,
    ),

    janet = kt.JANETCell(
        input_size=input_size, hidden_size=hidden_size,
        hidden_bias=has_bias, hidden_activation=None,
        forget_bias=has_bias, forget_activation=None, 
        beta=1.0,
        dtype=dtype, device=device,
    ),
)
cells

# RNC

In [None]:
rnns = {k:kt.RNNCell(v) for k,v in cells.items()}
rnns

In [None]:
rnnsx = {k:kt.RNNCell(v).\
        build_i2o(size=i2o_size, bias=has_bias, activation=i2o_act, dtype=dtype, device=device) \
        for k,v in cells.items()}
rnnsx

In [None]:
rnnsy = {k:kt.RNNCell(v).\
        build_i2o(size=i2o_size, bias=has_bias, activation=i2o_act, dtype=dtype, device=device).\
        build_o2o(size=o2o_size, bias=has_bias, activation=o2o_act, dtype=dtype, device=device) \
        for k,v in cells.items()}
rnnsy

# RNC + SelfAttention

In [None]:
rnnsA = {k:kt.RNNSACell(v) for k,v in cells.items()}
rnnsA

In [None]:
rnnsxA = {k:kt.RNNSACell(v).\
        build_i2o(size=i2o_size, bias=has_bias, activation=i2o_act, dtype=dtype, device=device) \
        for k,v in cells.items()}
rnnsxA

In [None]:
rnnsyA = {k:kt.RNNSACell(v).\
        build_i2o(size=i2o_size, bias=has_bias, activation=i2o_act, dtype=dtype, device=device).\
        build_o2o(size=o2o_size, bias=has_bias, activation=o2o_act, dtype=dtype, device=device) \
        for k,v in cells.items()}
rnnsyA

# Forward

In [None]:
with tt.no_grad():
    dicts = [rnns, rnnsx, rnnsy, rnnsA, rnnsxA, rnnsyA]
    for i,d in enumerate(dicts):
        print ('DICT::', i)
        for k,v in d.items():
            Y = v(X)
            print(k, Y.shape)


# Stacking Cells

In [None]:
i2o_sizes = [6, 7, 8, 9, 10, 11]
o2o_sizes = [16, 17, 18, 19, 20, 21]

In [None]:
cells = [
    kt.RNNCell(kt.ELMANCell(
        input_size=input_size, hidden_size=hidden_size,
        hidden_bias=has_bias, hidden_activation=None, # none for default
        dtype=dtype, device=device,)).\
        build_i2o(size=i2o_sizes[0], bias=has_bias, activation=i2o_act, dtype=dtype, device=device).\
        build_o2o(size=o2o_sizes[0], bias=has_bias, activation=o2o_act, dtype=dtype, device=device),

    nn.Dropout(p=0.25),
    
    kt.RNNCell(kt.GRUCell(
        input_size=o2o_sizes[0], hidden_size=hidden_size,
        hidden_bias=has_bias, hidden_activation=None, # none for default
        update_bias=has_bias, update_activation=None, 
        reset_bias=has_bias, reset_activation=None, 
        cell_type=0, # 0 to use input-weights at reset and forget gates
        dtype=dtype, device=device,)).\
        build_i2o(size=i2o_sizes[1], bias=has_bias, activation=i2o_act, dtype=dtype, device=device).\
        build_o2o(size=o2o_sizes[1], bias=has_bias, activation=o2o_act, dtype=dtype, device=device),
    
    nn.Dropout(p=0.25),

    kt.RNNCell(kt.MGUCell(
        input_size=o2o_sizes[1], hidden_size=hidden_size,
        hidden_bias=has_bias, hidden_activation=None, # none for default
        forget_bias=has_bias, forget_activation=None, 
        dtype=dtype, device=device,)).\
        build_i2o(size=i2o_sizes[2], bias=has_bias, activation=i2o_act, dtype=dtype, device=device).\
        build_o2o(size=o2o_sizes[2], bias=has_bias, activation=o2o_act, dtype=dtype, device=device),

    nn.Dropout(p=0.25),

    kt.RNNCell(kt.LSTMCell(
        input_size=o2o_sizes[2], hidden_size=hidden_size,
        hidden_activation=None,
        input_bias=has_bias, input_activation=None, 
        forget_bias=has_bias, forget_activation=None, 
        output_bias=has_bias, output_activation=None, 
        cell_bias=has_bias, cell_activation=None, 
        dtype=dtype, device=device,)).\
        build_i2o(size=i2o_sizes[3], bias=has_bias, activation=i2o_act, dtype=dtype, device=device).\
        build_o2o(size=o2o_sizes[3], bias=has_bias, activation=o2o_act, dtype=dtype, device=device),

    nn.Dropout(p=0.25),

    kt.RNNCell(kt.PLSTMCell(
        input_size=o2o_sizes[3], hidden_size=hidden_size,
        hidden_activation=None,
        input_bias=has_bias, input_activation=None, 
        forget_bias=has_bias, forget_activation=None, 
        output_bias=has_bias, output_activation=None, 
        cell_bias=has_bias, cell_activation=None, 
        dtype=dtype, device=device,)).\
        build_i2o(size=i2o_sizes[4], bias=has_bias, activation=i2o_act, dtype=dtype, device=device).\
        build_o2o(size=o2o_sizes[4], bias=has_bias, activation=o2o_act, dtype=dtype, device=device),

    nn.Dropout(p=0.25),
    
    kt.RNNCell(kt.JANETCell(
        input_size=o2o_sizes[4], hidden_size=hidden_size,
        hidden_bias=has_bias, hidden_activation=None,
        forget_bias=has_bias, forget_activation=None, 
        beta=1.0,
        dtype=dtype, device=device,)).\
        build_i2o(size=i2o_sizes[5], bias=has_bias, activation=i2o_act, dtype=dtype, device=device).\
        build_o2o(size=o2o_sizes[5], bias=has_bias, activation=o2o_act, dtype=dtype, device=device),

]

In [None]:
coreF = nn.Sequential(*cells)

In [None]:

Y = coreF(X)
print(Y.shape)

In [None]:
bi = True
return_sequences=True
stack_output=False   
batch_first=True

batch_size = 5
seq_len=10
X = \
    tt.rand((batch_size, seq_len, input_size), dtype=dtype) \
    if batch_first else \
    tt.rand((seq_len, batch_size, input_size), dtype=dtype)

rnnstack = kt.GRNN(
    coreF, bi=bi, return_sequences=return_sequences, stack_output=stack_output, batch_first=batch_first
)



In [None]:
rnnstack.train()

In [None]:

Y = rnnstack(X)
print(Y.shape)


In [None]:
print(len(Y))
Y[0].shape