# []

In [None]:
# inbuilt 
import os
import sys
import math

# most common
import numpy as np
import matplotlib.pyplot as plt

# pytorch
import torch as tt
import torch.nn as nn
import torch.optim as oo
import torch.functional as ff
import torch.distributions as dd
import torch.utils.data as ud

# custom
import known
from known.basic import pj
from known.basic.common import Verbose as verb
import known.ktorch as kt

print(f'{sys.version=}\n{np.__version__=}\n{tt.__version__=}\n{known.__version__=}')

In [None]:
save_prefix='Uni'
os.makedirs(f'{save_prefix}_dir')

# Select DataSet

## (1) Sinus

In [None]:
seqlen = 24
cols = ('PRICE',)
input_size = len(cols)

ds_test  = kt.SeqDataset.from_csv(pj('data_EVC/normPJDS.csv'), cols=cols, 
                seqlen=seqlen, reverse=True, normalize=False, squeeze_label=True, dtype=tt.float32)
ds_train = kt.SeqDataset.from_csv(pj('data_EVC/train.csv'), cols=cols, 
                seqlen=seqlen, reverse=True, normalize=False, squeeze_label=True, dtype=tt.float32)
ds_val = kt.SeqDataset.from_csv(pj('data_EVC/test.csv'), cols=cols, 
                seqlen=seqlen, reverse=True, normalize=False, squeeze_label=True, dtype=tt.float32)

ds_train, ds_val, ds_test

# Define Regression Network

In [None]:
dt = tt.float32
dropout=0.0
batch_first = True
bias = True
bidir=False

hidden_size = 32
hidden_sizes = [hidden_size, hidden_size]
i2o_sizes = None
o2o_sizes = None
i2o_activation=None
o2o_activation=None
last_activation=None

fc_layers=[128,]
fc_act=(nn.ReLU, {})
fc_last_act=None


## encapsulate rnns

In [None]:
net_names = ['kt.ELMAN', 'kt.GRU', 'kt.LSTM', 'kt.MGU', 'kt.JANET' ]
net_class = [ kt.ELMAN,   kt.GRU,   kt.LSTM,   kt.MGU,   kt.JANET  ]
networks = {
    k:kt.XRNN(
    coreF=v,
    bidir=bidir,
    fc_layers=fc_layers,
    fc_act=fc_act,
    fc_last_act=fc_last_act,
    fc_bias=bias,
    input_size=input_size,      
    i2h_sizes=hidden_sizes,      
    i2o_sizes=i2o_sizes,  
    o2o_sizes=o2o_sizes,  
    dropout=dropout,        
    batch_first=batch_first,
    i2h_bias = bias, 
    i2o_bias = bias,
    o2o_bias = bias,
    i2h_activations=None,
    i2o_activation=i2o_activation,
    o2o_activation=o2o_activation,
    last_activation=last_activation,
    hypers=None,
    return_sequences=False,
    stack_output=False, 
    dtype=dt,
    device=None,)
    for k,v in zip(net_names, net_class )}

colors = {k:v for k,v in zip(networks.keys(), \
            ['tab:red', 'tab:blue', 'tab:green', 'tab:orange', 'tab:brown' ])}

# (A) Train and Evaluate

In [None]:
test_loss, train_loss = {}, {}
for key,model in networks.items():
    print(key, model.__class__)
    epochs = 6
    batch_size=32
    shuffle=True
    validation_freq = int(epochs/2)
    criterion=nn.MSELoss()
    lr = 0.00125
    weight_decay = 0.0
    optimizer=oo.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    lrs=oo.lr_scheduler.LinearLR(optimizer, start_factor= 1.0, end_factor=0.2, total_iters=epochs)

    #early_stop_train=kt.QuantiyMonitor('TrainLoss', patience=50, delta=0.00001)
    #early_stop_val=kt.QuantiyMonitor('ValLoss', patience=50, delta=0.00001)
    #checkpoint_freq=int(epochs/4)
    save_path=os.path.join(f'{save_prefix}_dir',f'{save_prefix}_{key}.reg')

    trainer = kt.Trainer(model)
    trainer.optimizer=optimizer
    trainer.criterion=criterion

    trainer.fit(training_data=ds_train, validation_data=ds_val, 
                epochs=epochs, batch_size=batch_size,shuffle=shuffle,validation_freq=validation_freq,
                save_path=save_path, verbose=1)

    loss_plot_start = int(epochs/20)
    trainer.plot_results(color=colors[key],loss_plot_start=loss_plot_start)

    mtl, _ = trainer.evaluate(ds_test, batch_size=None)
    train_loss[key] = np.array(trainer.train_loss_history)
    test_loss[key] = mtl
    print('=================================================')
    
fig=plt.figure(figsize=(16,8))
for k,v in train_loss.items():
    plt.plot(np.mean(v, axis=1)[loss_plot_start:], label=k, color=colors[k], linewidth=0.8)
plt.legend()
plt.show()
fig.savefig(os.path.join(f'{save_prefix}_dir',f'loss_{save_prefix}.png'))

# (B) Load and Evaluate

In [None]:
"""for key,model in networks.items():
    save_path=f'{save_prefix}_{key}.reg'
    kt.load_state(model, save_path)
test_loss={}

for key,model in networks.items():
    trainer = kt.Trainer(model)
    trainer.criterion=nn.MSELoss() 
    mtl, tl = trainer.evaluate(ds_test)

    test_loss[key] = mtl"""

# Plot Evaluation Results

In [None]:
y = []
l = []
c = []
for k,v in test_loss.items():
    print(f'{k}:\t{v}')
    y.append(v)
    l.append(k)
    c.append(colors[k])

x= range(len(test_loss))

y = np.array(y)
l = np.array(l)
t = np.argsort(y)

fig=plt.figure(figsize=(20,6))
plt.xlim(-1,len(x))
plt.ylim(0,0.001)
for i in t:
    plt.bar([i] , y[t[i]], color=c[t[i]] )
    plt.hlines(y[t[i]], -1, i, linestyles='solid', linewidth=0.5, color=c[t[i]])
plt.xticks(x, l[t])
plt.ylabel('val_loss')
plt.show()
fig.savefig(os.path.join(f'{save_prefix}_dir',f'loss_compar_{save_prefix}.png'))

# Manual Testing

## test dataset

In [None]:
res = {}
for key,model in networks.items():
    print(key, model.__class__)
    model.eval()
    with tt.no_grad():
        dl = iter(ds_test.dataloader(batch_size=int(len(ds_test)*1.0)))
        Xv, Yv = next(dl)
        Pv = model(Xv)
        res[key]=Pv #print(Xv.shape, Yv.shape, Pv.shape)
    key_len = (len(res[key]))
print(key_len)

## visualize

In [None]:
"""fr = 7000
tr = fr + 400
print(f'{fr} ==> {tr}')

for i in range(input_size):
    plt.figure(figsize=(20,10))
    plt.title(f'{i}')
    
    plt.plot(Yv[fr:tr,i], color='black', label='Truth')
    for key,r in res.items():
        plt.plot(r[fr:tr,i], color=colors[key], label=f'{key}', linewidth=0.5)
    plt.legend()
    plt.show()
    plt.close()
    
    
"""

In [None]:
ilen = 400 #int(key_len/4)


start_i = 1 #<--------------set this
end_i = int(key_len*0.6)
print(ilen, key_len, end_i, '\n')

fr, tr, cnt = (start_i)*ilen, (start_i+1)*ilen, 1
while True:
    print(f'#{cnt} :: {fr} ==> {tr}')
    cnt+=1
    for i in range(input_size):
        fig=plt.figure(figsize=(20,10))
        plt.title(f'{i}')
        
        plt.plot(Yv[fr:tr,i], color='black', label='Truth')
        for key,r in res.items():
            plt.plot(r[fr:tr,i], color=colors[key], label=f'{key}', linewidth=0.5)
        plt.legend()
        plt.show()
        plt.close()
        fig.savefig(os.path.join(f'{save_prefix}_dir', f'{i}_{cnt}'))
    
    
    
    fr=tr
    if fr>=end_i: break
    tr = min(tr+ilen, end_i)
    if tr>end_i: tr=end_i
    