# []

In [None]:
# inbuilt 
import os
import sys
import math

# most common
import numpy as np
import matplotlib.pyplot as plt

# pytorch
import torch as tt
import torch.nn as nn
import torch.optim as oo
import torch.functional as ff
import torch.distributions as dd
import torch.utils.data as ud

# custom
import known
import known.ktorch as kt

print(f'{sys.version=}\n{np.__version__=}\n{tt.__version__=}\n{known.__version__=}')

In [None]:
def gen(rng, n, d):
    x=np.linspace(0, 10*np.pi, n)
    if d==0:
        y = 10*np.sin(x)
    elif d==1:
        y = -5*np.cos(x)
    elif d==2:
        y = 7*np.sin(x) + -9*np.cos(x)
    else:
        print('invalid dim')
    return y

kt.SeqDataset.generate(
    genF=gen,
    genS=1000,
    colS=['a', 'b', 'c'], normalize=False, file_name='train.csv'
)

In [None]:
seqlen = 24
cols = ('PRICE',)
input_size = len(cols)
ds = kt.SeqDataset.from_csv('PJDS.csv', cols=cols, seqlen=seqlen, reverse=True, normalize=False, squeeze_label=True, dtype=tt.float32)
ds

In [None]:
class RnnMlp(nn.Module):
    def __init__(self, rnn_class) -> None:
        super().__init__()
        self.rnn_class = rnn_class
        rnnargs = dict(
            input_size=input_size, hidden_sizes=(8, 8, 8), actF=tt.tanh, batch_first=True, dtype=tt.float32
        )
        if rnn_class is kt.ELMAN or rnn_class is kt.GRU or rnn_class is kt.MGU: 
            self.rnn = rnn_class(True, False, **rnnargs)
        elif rnn_class is kt.JANET:
            rnnargs['beta']=0.0
            self.rnn = rnn_class(True, False, **rnnargs)
        elif rnn_class is kt.LSTM:
            rnnargs['actC']=tt.tanh
            self.rnn = rnn_class(True, False, **rnnargs)
        elif rnn_class is nn.RNN or rnn_class is nn.GRU or rnn_class is nn.LSTM:
            self.rnn = rnn_class(input_size=input_size, hidden_size=8, num_layers=3, batch_first=True)
        else:
            print('Invalid_RNN_Class')
            
        
        self.fc = nn.Sequential( nn.Flatten(), nn.Linear(8, input_size))
        #self.fc = nn.Flatten()

    def forward(self, X):
        x, *_ = self.rnn(X)
        #y = self.fc(x[-1])
        return self.fc(x[:, -1, :])

rnms = [    RnnMlp(kt.ELMAN),   RnnMlp(kt.GRU),     RnnMlp(kt.JANET),   RnnMlp(kt.MGU),     RnnMlp(kt.LSTM), 
            RnnMlp(nn.RNN),     RnnMlp(nn.GRU),     RnnMlp(nn.LSTM)     ]
rnmc = [    'tab:blue',         'tab:red',          'tab:green',        'tab:brown',       'tab:olive',
            'tab:pink',         'tab:orange',        'tab:grey'          ]

In [None]:
dl = ds.dataloader(batch_size=32)
print(len(dl))
dli = iter(dl)
x,y = next(dli)
with tt.no_grad():
    h = rnms[0](x)
x.shape, y.shape, h.shape

In [None]:
all_history={}
for rnm in rnms:
    print(rnm.rnn_class)
    history = kt.utils.train( rnm,
        training_data=ds, 
        validation_data=ds, 
        testing_data=ds,
        epochs=500, 
        batch_size=32, 
        shuffle=True, 
        validation_freq=5, 
        criterion_type=nn.MSELoss, 
        criterion_args={}, 
        optimizer_type=oo.Adam, 
        optimizer_args={'lr': 0.005, 'weight_decay': 0.0}, 
        lrs_type=oo.lr_scheduler.LinearLR, 
        lrs_args={'start_factor': 1.0, 'end_factor':0.7, 'total_iters': 2000},
        record_batch_loss=False, 
        early_stop_train=None,#kt.QuantiyMonitor('TrainLoss', patience=50, delta=0.00001, verbose=False), 
        early_stop_val=None, #kt.QuantiyMonitor('ValLoss', patience=50, delta=0.00001, verbose=False), 
        checkpoint_freq=5, 
        save_path='sample.rnn',
        save_state_only=True, 
        verbose=1, 
        plot=1
    )
    all_history[f'{rnm.rnn_class}'] = history
    print('=================================================')

In [None]:
all_history={}
for rnm in rnms:
    print(rnm.rnn_class)
    model=rnm
    epochs = 50
    batch_size=32
    shuffle=True
    validation_freq = int(epochs/10)
    criterion=nn.MSELoss()
    lr = 0.005
    weight_decay = 0.0
    optimizer=oo.Adam(rnm.parameters(), lr=lr, weight_decay=weight_decay)
    lrs=oo.lr_scheduler.LinearLR(optimizer, start_factor= 1.0, end_factor=0.7, total_iters=epochs)

    early_stop_train=kt.QuantiyMonitor('TrainLoss', patience=50, delta=0.00001)
    early_stop_val=kt.QuantiyMonitor('ValLoss', patience=50, delta=0.00001)
    checkpoint_freq=int(epochs/4)
    save_path='sample.rnn'
    loss_plot_start = int(epochs/50)

    history = kt.utils.train_( 
        model=model, 
        training_data=ds, validation_data=ds, testing_data=ds,
            epochs=epochs, batch_size=batch_size, shuffle=shuffle, validation_freq=validation_freq, 
            criterion=criterion, optimizer=optimizer, lrs=lrs,
            record_batch_loss=False, early_stop_train=early_stop_train, early_stop_val=early_stop_val, checkpoint_freq=checkpoint_freq, 
            save_path=None, save_state_only=False, verbose=1, plot=1, loss_plot_start=loss_plot_start)

    all_history[f'{rnm.rnn_class}'] = history
    print('=================================================')

In [None]:
y = []
l = []
for k,v in all_history.items():
    ll, vl, tl = v['loss'][-1], v['val_loss'][-1], v['test_loss']
    print(f'{k}:\t{ll, vl, tl}')
    y.append(v[1])
    sl = k.split('.')
    l.append(sl[1]+"."+sl[-1][:-2])

x= range(len(all_history))

plt.figure(figsize=(15,6))
plt.bar(x , y )
plt.xticks(x, l)
plt.ylabel('val_loss')
plt.show()

In [None]:
res = []
for rnm in rnms:
    print(rnm.rnn_class)
    rnm.eval()
    with tt.no_grad():
        for iv,(Xv,Yv) in enumerate(ds.dataloader(batch_size=len(ds)), 0):
            Pv = rnm(Xv)
            res.append(Pv)#print(Xv.shape, Yv.shape, Pv.shape)



In [None]:
for i in range(input_size):
    plt.figure(figsize=(20,10))
    plt.title(f'{i}')
    
    plt.plot(Yv[:,i], color='black', label='Truth')
    for r,rnm,c in zip(res,rnms,rnmc):
        plt.plot(Pv[:,i], color=c, label=f'{rnm.rnn_class}', linestyle='dotted')
    plt.legend()
    plt.show()
    plt.close()
    print('=================================================')