# []

In [None]:
# inbuilt 
import os
import sys
import math

# most common
import numpy as np
import matplotlib.pyplot as plt

# pytorch
import torch as tt
import torch.nn as nn
import torch.functional as ff
import torch.distributions as dd
import torch.utils.data as ud

# custom
import known
import known.ktorch as kt

print(f'{sys.version=}\n{np.__version__=}\n{tt.__version__=}\n{known.__version__=}')

In [None]:
# set seed
tt.manual_seed(281703975047300) # manually sets a seed for random sampling creation ops
print('Manual-Seed:', tt.initial_seed()) # current seed for default rng

batch_size = 32
input_size = 6
hidden_size = 12
seq_len = 20

dt=tt.float64
batch_first=True
stack_output=True
dropout=0.0
num_layers = 5

num_samples=50
num_loops=10

def absdiff(y, Y):
    s=0
    for i,(yi,yit) in enumerate(zip(y, Y)):
        assert (yi.shape == yit.shape), f"{yi.shape} != {yit.shape} @ {i}"
        s+= (tt.sum(tt.abs(yi-yit)).item())
    return s, np.mean(s)

xx = [tt.rand(size=(batch_size, seq_len, input_size), dtype=dt) for _ in range(num_samples)] \
            if batch_first else \
    [tt.rand(size=(seq_len, batch_size, input_size), dtype=dt) for _ in range(num_samples) ]
len(xx)

In [None]:

hidden_sizes = [hidden_size for _ in range(num_layers)]
bidirectional = False
bias = True
nonlinearity='tanh'
actF = tt.tanh

In [None]:
elman_torch =kt.ELMAN(
        input_size=6,      
        i2h_sizes=[8,9],      
        i2o_sizes=None,  
        o2o_sizes=None,  
        dropout=0.0,        
        batch_first=True,  
        stack_output=True, 
        i2h_bias = True, 
        i2o_bias = True,
        o2o_bias = True,
        i2h_activations=(),
        i2o_activation=None,
        o2o_activation=None,
        last_activation=None,
        hypers=None,
        dtype=None,
        device=None)
known.Verbose.showX(elman_torch)

In [None]:
rnntL = [ 
    nn.RNN(
    input_size=input_size,
    hidden_size=hidden_size,
    nonlinearity=nonlinearity,
    bias=bias,
    batch_first=batch_first,
    num_layers=num_layers,
    dropout=dropout,
    bidirectional=bidirectional,
    dtype=dt
),
    nn.GRU(
    input_size=input_size,
    hidden_size=hidden_size,
    #nonlinearity=nonlinearity,
    bias=bias,
    batch_first=batch_first,
    num_layers=num_layers,
    dropout=dropout,
    bidirectional=bidirectional,
    dtype=dt
),

    nn.LSTM(
    input_size=input_size,
    hidden_size=hidden_size,
    #nonlinearity=nonlinearity,
    bias=bias,
    batch_first=batch_first,
    num_layers=num_layers,
    dropout=dropout,
    bidirectional=bidirectional,
    dtype=dt
),

    nn.LSTM(
    input_size=input_size,
    hidden_size=hidden_size,
    #nonlinearity=nonlinearity,
    bias=bias,
    batch_first=batch_first,
    num_layers=num_layers,
    dropout=dropout,
    bidirectional=bidirectional,
    dtype=dt
),

    nn.LSTM(
    input_size=input_size,
    hidden_size=hidden_size,
    #nonlinearity=nonlinearity,
    bias=bias,
    batch_first=batch_first,
    num_layers=num_layers,
    dropout=dropout,
    bidirectional=bidirectional,
    dtype=dt
),

]

rnnL = [
    kt.ELMAN(
        input_size=input_size,      
        i2h_sizes=hidden_sizes,      
        i2o_sizes=None,  
        o2o_sizes=None,  
        dropout=0.0,        
        batch_first=batch_first,  
        stack_output=stack_output, 
        i2h_bias = bias, 
        i2o_bias = True,
        o2o_bias = True,
        i2h_activations=(),
        i2o_activation=None,
        o2o_activation=None,
        last_activation=None,
        hypers=None,
        dtype=dt,
        device=None),

    kt.GRU(
        input_size=input_size,      
        i2h_sizes=hidden_sizes,      
        i2o_sizes=None,  
        o2o_sizes=None,  
        dropout=0.0,        
        batch_first=batch_first,  
        stack_output=stack_output, 
        i2h_bias = bias, 
        i2o_bias = True,
        o2o_bias = True,
        i2h_activations=(),
        i2o_activation=None,
        o2o_activation=None,
        last_activation=None,
        hypers=None,
        dtype=dt,
        device=None),

    kt.LSTM(
        input_size=input_size,      
        i2h_sizes=hidden_sizes,      
        i2o_sizes=None,  
        o2o_sizes=None,  
        dropout=0.0,        
        batch_first=batch_first,  
        stack_output=stack_output, 
        i2h_bias = bias, 
        i2o_bias = True,
        o2o_bias = True,
        i2h_activations=(),
        i2o_activation=None,
        o2o_activation=None,
        last_activation=None,
        hypers=None,
        dtype=dt,
        device=None),

    kt.MGU(
        input_size=input_size,      
        i2h_sizes=hidden_sizes,      
        i2o_sizes=None,  
        o2o_sizes=None,  
        dropout=0.0,        
        batch_first=batch_first,  
        stack_output=stack_output, 
        i2h_bias = bias, 
        i2o_bias = True,
        o2o_bias = True,
        i2h_activations=(),
        i2o_activation=None,
        o2o_activation=None,
        last_activation=None,
        hypers=None,
        dtype=dt,
        device=None),

    kt.JANET(
        input_size=input_size,      
        i2h_sizes=hidden_sizes,      
        i2o_sizes=None,  
        o2o_sizes=None,  
        dropout=0.0,        
        batch_first=batch_first,  
        stack_output=stack_output, 
        i2h_bias = bias, 
        i2o_bias = True,
        o2o_bias = True,
        i2h_activations=(),
        i2o_activation=None,
        o2o_activation=None,
        last_activation=None,
        hypers=None,
        dtype=dt,
        device=None),
]

In [None]:
for rnnt in rnntL:
    with tt.no_grad():
        for i in range(rnnt.num_layers):
            tt.zero_(rnnt.get_parameter(f'bias_ih_l{i}'))
            tt.zero_(rnnt.get_parameter(f'bias_hh_l{i}'))
    #kt.show_dict(rnnt, True)

In [None]:
show_details = False

In [None]:
for rnn, rnnt in zip(rnnL, rnntL):
    kt.show_dict(rnnt, show_details)
    kt.show_parameters(rnnt)
    rnn.copy_torch(rnnt)
    kt.show_dict(rnn, show_details)
    kt.show_parameters(rnn)

In [None]:
with tt.no_grad():
    for rnn, rnnt in zip(rnnL, rnntL):
        yt, _ = rnnt(xx[0])
        y0, _ = rnn(xx[0])
        print(y0.shape, yt.shape)

        print(absdiff(y0, yt))


In [None]:
rnn=kt.JANET(
        input_size=input_size,      
        i2h_sizes=hidden_sizes,      
        i2o_sizes=hidden_sizes,  
        o2o_sizes=hidden_sizes,  
        dropout=0.0,        
        batch_first=batch_first,  
        stack_output=stack_output, 
        i2h_bias = bias, 
        i2o_bias = True,
        o2o_bias = True,
        i2h_activations=(),
        i2o_activation=tt.relu,
        o2o_activation=tt.sigmoid,
        last_activation=(nn.LogSoftmax, {'dim':-1}),
        hypers=None,
        dtype=dt,
        device=None)



In [None]:

tt.save(rnn, 'rnn.lstm')

In [None]:
rnn=tt.load('rnn.lstm')

In [None]:
with tt.no_grad():
    yt, _ = rnn(xx[0])
    print(yt.shape)
    

In [None]:
tt.save(yt, 'out1')

In [None]:
tt.sum(tt.abs(tt.load('out1') - yt))

In [None]:
birnn = kt.GRNN(core_forward=rnn, core_backward=kt.clone_model(rnn))


In [None]:
tt.save(birnn, 'birnn.lstm')

In [None]:
birnn=tt.load('birnn.lstm')

In [None]:
with tt.no_grad():
    yt, _ = birnn(xx[0])
    known.basic.Verbose.info(yt)
    print(yt[0].shape, yt[1].shape)

In [None]:
tt.save(yt[0], 'out0')
tt.save(yt[1], 'out1')

In [None]:
tt.sum(tt.abs(tt.load('out0') - yt[0])), tt.sum(tt.abs(tt.load('out1') - yt[1]))