# []

In [None]:
# inbuilt 
import os
import sys
import math

# most common
import numpy as np
import matplotlib.pyplot as plt

# pytorch
import torch as tt
import torch.nn as nn
import torch.optim as oo
import torch.functional as ff
import torch.distributions as dd
import torch.utils.data as ud

# custom
import known
import known.ktorch as kt
from known.basic import pj
print(f'{sys.version=}\n{np.__version__=}\n{tt.__version__=}\n{known.__version__=}')

In [None]:
seqlen = 20
cols = ('1','2', '3')
input_size = len(cols)
ds  = kt.SeqDataset.from_csv(pj('data_rnn/sinu.csv'), cols=cols, 
                seqlen=seqlen, reverse=True, normalize=False, squeeze_label=True, dtype=tt.float32)
ds_train = kt.SeqDataset.from_csv(pj('data_rnn/sinu_train.csv'), cols=cols, 
                seqlen=seqlen, reverse=True, normalize=False, squeeze_label=True, dtype=tt.float32)
ds_val = kt.SeqDataset.from_csv(pj('data_rnn/sinu_test.csv'), cols=cols, 
                seqlen=seqlen, reverse=True, normalize=False, squeeze_label=True, dtype=tt.float32)
ds_test = ds
ds, ds_train, ds_val, ds_test

In [None]:
hidden_size = 32
num_layers = 3
class RnnMlp(nn.Module):

    def __init__(self, rnn_class) -> None:
        super().__init__()
        self.rnn_class = rnn_class
        rnnargs = dict(
            input_size=input_size, hidden_sizes=[hidden_size for _ in range(num_layers)], actF=tt.tanh, 
            batch_first=True, dtype=tt.float32, stack_output=True
        )
        if rnn_class is kt.ELMAN or rnn_class is kt.GRU or rnn_class is kt.MGU: 
            self.rnn = rnn_class(True, False, **rnnargs)
        elif rnn_class is kt.JANET:
            rnnargs['beta']=0.0
            self.rnn = rnn_class(True, False, **rnnargs)
        elif rnn_class is kt.LSTM:
            rnnargs['actC']=tt.tanh
            self.rnn = rnn_class(True, False, **rnnargs)
        elif rnn_class is nn.RNN or rnn_class is nn.GRU or rnn_class is nn.LSTM:
            self.rnn = rnn_class(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, 
                                 batch_first=True)
        else:
            print('Invalid_RNN_Class')
            
        
        self.fc = nn.Sequential( nn.Flatten(), nn.Linear(hidden_size, input_size))
        #self.fc = nn.Flatten()

    def forward(self, X):
        x, *_ = self.rnn(X)
        #y = self.fc(x[-1])
        return self.fc(x[:, -1, :])

rnms = [    RnnMlp(kt.ELMAN),   RnnMlp(kt.GRU),     RnnMlp(kt.JANET),   RnnMlp(kt.MGU),     RnnMlp(kt.LSTM), 
            RnnMlp(nn.RNN),     RnnMlp(nn.GRU),     RnnMlp(nn.LSTM)     ]
rnmc = [    'tab:blue',         'tab:red',          'tab:green',        'tab:brown',       'tab:olive',
            'tab:pink',         'tab:orange',        'tab:grey'          ]

In [None]:
"""
dl = ds.dataloader(batch_size=32)
print(len(dl))
dli = iter(dl)
x,y = next(dli)
with tt.no_grad():
    h = rnms[0](x)
x.shape, y.shape, h.shape
"""

In [None]:
test_loss={}
for rnm in rnms:
    print(rnm.rnn_class)
    model=rnm
    epochs = 500
    batch_size=32
    shuffle=True
    validation_freq = int(epochs/10)
    criterion=nn.MSELoss()
    lr = 0.005
    weight_decay = 0.0
    optimizer=oo.Adam(rnm.parameters(), lr=lr, weight_decay=weight_decay)
    lrs=oo.lr_scheduler.LinearLR(optimizer, start_factor= 1.0, end_factor=0.7, total_iters=epochs)

    early_stop_train=kt.QuantiyMonitor('TrainLoss', patience=50, delta=0.00001)
    early_stop_val=kt.QuantiyMonitor('ValLoss', patience=50, delta=0.00001)
    checkpoint_freq=int(epochs/4)
    save_path='sample.rnn'
    loss_plot_start = int(epochs/50)

    trainer = kt.Trainer(model)
    trainer.optimizer=optimizer
    trainer.criterion=criterion

    trainer.fit(training_data=ds_train, validation_data=ds_val, 
                epochs=epochs, batch_size=batch_size,shuffle=shuffle,validation_freq=validation_freq,
                verbose=1)

    trainer.plot_results(loss_plot_start=loss_plot_start)

    mtl, tl = trainer.evaluate(ds_test)

    test_loss[str(rnm.rnn_class)] = mtl
    print('=================================================')

In [None]:
y = []
l = []
for k,v in test_loss.items():
    print(f'{k}:\t{v}')
    y.append(v)
    sl = k.split('.')
    l.append(sl[1]+"."+sl[-1][:-2])

x= range(len(test_loss))

plt.figure(figsize=(15,6))
plt.bar(x , y )
plt.xticks(x, l)
plt.ylabel('val_loss')
plt.show()

In [None]:
res = []
for rnm in rnms:
    print(rnm.rnn_class)
    rnm.eval()
    with tt.no_grad():
        for iv,(Xv,Yv) in enumerate(ds.dataloader(batch_size=len(ds)), 0):
            Pv = rnm(Xv)
            res.append(Pv)#print(Xv.shape, Yv.shape, Pv.shape)



In [None]:
for i in range(input_size):
    plt.figure(figsize=(20,10))
    plt.title(f'{i}')
    
    plt.plot(Yv[:,i], color='black', label='Truth')
    for r,rnm,c in zip(res,rnms,rnmc):
        plt.plot(r[:,i], color=c, label=f'{rnm.rnn_class}', linewidth=0.5)
    plt.legend()
    plt.show()
    plt.close()
    print('=================================================')

## future predict

In [None]:
res = []
for rnm in rnms:
    print(rnm.rnn_class)
    rnm.eval()
    dl=ds.dataloader(batch_size=1, shuffle=False)
    di = iter(dl)
    with tt.no_grad():
        Xv,Yv = next(di)
        Pv = rnm.rnn(Xv, future=1000)
        res.append(Pv)#print(Xv.shape, Yv.shape, Pv.shape)



In [None]:
for i in range(input_size):
    plt.figure(figsize=(20,10))
    plt.title(f'{i}')
    
    plt.plot(Yv[:,i], color='black', label='Truth')
    for r,rnm,c in zip(res,rnms,rnmc):
        plt.plot(r[:,i], color=c, label=f'{rnm.rnn_class}', linewidth=0.5)
    plt.legend()
    plt.show()
    plt.close()
    print('=================================================')