# Import delle librerie necessarie

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
# !git clone https://github.com/AntonioDelleCanne/tesi.git

In [3]:
%cd DeepLearning_Financial

/data/home/dsvm_server_admin/notebooks/fastai/tesi/DeepLearning_Financial


In [57]:
## EXTERNAL
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn
import yfinance
from pandas import Series
from sklearn.preprocessing import PolynomialFeatures, StandardScaler, MinMaxScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import TimeSeriesSplit, PredefinedSplit
from sklearn.metrics import make_scorer
from sklearn.model_selection import cross_validate
from sklearn.model_selection import GridSearchCV
from IPython.display import display
import datetime
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import numpy as np
import time
import os
import random 
from sklearn.datasets import make_regression
from torch import nn
import torch.nn.functional as F
from skorch import NeuralNetRegressor
from torch.nn.modules.loss import MSELoss
import tensorflow as tf
from tensorflow import keras
from skorch.dataset import CVSplit
from skorch import callbacks
import pickle
from sklearn.model_selection import train_test_split
from functools import partial
import skorch
import pywt
from sklearn import preprocessing
import joblib


##INTERNAL
from models import Autoencoder, waveletSmooth, SequenceDouble, SequenceDoubleAtt, SequenceAtt
from utils import prepare_data_lstm, ExampleDataset, save_checkpoint, evaluate_lstm, backtest

In [5]:
def save(model, name):
    with open(name + '.pkl', 'wb') as f:
        pickle.dump(model, f)
        
def load(name):
    with open(name + '.pkl', 'rb') as f:
        return pickle.load(f)

In [6]:
#returns open high low close volume
def get_index(index="^DJI", start_date="2000-01-01", end_date="2018-12-31"):
    security = yfinance.Ticker(index)# TODO trova mercato asiatico e indiano
    security_data = security.history(start=start_date, end=end_date, actions=False)
    return security_data

In [7]:
def split_index(security_data):
    return security_data["Open"], security_data["High"], security_data["Low"], security_data["Close"], security_data["Volume"]

# Allenamento modello

## Metriche
Come metrica principale per valutare la bonta' dele predizioni utilizzeremo il Return of Investment (ROI).

Con questa metrica assumiamo di utilizzare l'algoritmo di trading precedentemente descritto, e calcoliamo il guadango che avremmo ottenuto se lo avessimo utilizzato sul dataset che stiamo valutando.

In [8]:
def gain(C, C_pred, opn):
    O = opn.reindex_like(C)
    CO_diff = C - O
    growth = C_pred > O
    decline = C_pred < O
    return CO_diff[growth].sum() - CO_diff[decline].sum()
def roi(C, C_pred, opn):
    mean_opn = opn.reindex_like(C).mean()
    return gain(C, C_pred, opn) / mean_opn

## Preparazione dei dati
Come spiegato in precedenza, visto che alcune feature non sono disponibili all'inizio della giornata, per poterle utilizzare nel nostro modello, utilizzeremo i dati dei giori passati, servendoci della funzione shift.

Visto che l'utilizzo di questa funzione fara' si che in alcune riche vi saranno dei vaolri NaN, dobbiamo assicurarci di eliminare queste osservazioni sia nelle serie relative ale features che in quella della variabile da predire.
Questo compito e' svolto dalla funzione prepare_data.

In [9]:
def prepare_data(features, target):
    X = pd.DataFrame(features)
    X.dropna(inplace=True)
    Y = target.reindex_like(X)
    return X, Y

Con questa funzione dividiamo i dati in training set e validation set come è stato fatto nel paper

In [10]:
def s_split_before_2010_06_30(X):
    is_train = X.index.date < datetime.date(2010,6,30)
    X_train = X.loc[is_train]
    X_val = X.loc[~is_train]
    return X_train, X_val

In [11]:
def split_before_2010_06_30(X, y):
    is_train = X.index.date < datetime.date(2010,6,30)
    X_train = X.loc[is_train]
    y_train = y.loc[is_train]
    X_val = X.loc[~is_train]
    y_val = y.loc[~is_train]
    return X_train, X_val, y_train, y_val

In [12]:
def days_group(data, n_days=10):
    res = np.zeros([data.shape[0]-n_days, n_days, data.shape[1]], dtype=np.float32)
    for i, el in enumerate(data):
        if(i >= n_days):
            res[i-n_days] = data[i-n_days:i]
    return res

# Preparazione del dataset

In [None]:
def get_ext_feats(ohlcv):
    res = ohlcv.copy()
    
    opn = res["Open"]
    close = res["Close"]
    high = res["High"]
    low = res["Low"]
    volume = res["Volume"]
    
    #calucate derived indicators
    TP = ((high + low + close) / 3 ).shift(1)
    trs = pd.DataFrame(index=high.index)
    trs['tr0'] = abs(high - close)
    trs['tr1'] = abs(high - close.shift(1))
    trs['tr2'] = abs(low - close.shift(1))
    TR = trs[['tr0', 'tr1', 'tr2']].max(axis=1).shift(1)
    ema20 = opn.ewm(span=20).mean()
    ma10 = opn.rolling(window=10).mean()
    ma5 = opn.rolling(window=5).mean()
    macd = opn.ewm(span=26).mean() - opn.ewm(span=12).mean()
    cci_ndays=20
    cci = (TP - TP.rolling(cci_ndays).mean()) / (0.015 * TP.rolling(cci_ndays).std())
    atr = TR.ewm(span = 10).mean()
    ma20 = opn.rolling(window=20).mean()
    std20 = opn.rolling(window=20).std()
    k=2
    boll_up =  ma20 + (k*std20)
    boll_down = ma20 - (k*std20)
    roc = (opn - opn.shift(9))/opn.shift(9)
    mtm6 = (opn - opn.shift(127))
    mtm12 = (opn - opn.shift(253)) #length of a trading year is on average 253 days
    wvad = (((close - low) - (high - close)) * volume/(high - low)).shift(1)
    smi = (close - (high - low)/2).shift(1)
    
    res["CloseL1"] = close.shift(1)
    res["HighL1"] = high.shift(1)
    res["LowL1"] = low.shift(1)
    res["VolumeL1"] = volume.shift(1)
    res["EMA20"] = ema20
    res["MA5"] = ma5
    res["MA10"] = ma10
    res["MA20"] = ma20
    res["MACD"] = macd
    res["CCI"] = cci
    res["ATR"] = atr
    res["BollUp"] = boll_up
    res["BollDown"] = boll_down
    res["WVAD"] = wvad
    res["MTM6"] = mtm6
    res["MTM12"] = mtm12
    res["SMI"] = smi
    res["ROC"] = roc
    
    return res


In [14]:
# takes as input ohlcv dataframe
def get_data_set(ohlcv, ext_feats=True, usd_index='DX-Y.NYB', wavelet=True):
    feats = ohlcv.copy()
    usd_open = get_index(usd_index, start_date=ohlcv["Open"].index.min(), end_date=ohlcv["Open"].index.max())["Open"]
    
    if(wavelet):
    #apply transforms
        for f_name in ('Open', 'Close', 'High', 'Low'):
            feats[f_name] = apply_wavelet_transform(feats[f_name]) 
        usd_open = apply_wavelet_transform(usd_open)
    
    if(ext_feats):
        feats = get_ext_feats(feats)
        feats["USDOpen"]  = usd_open
    return feats

In [15]:
def get_dataset_by_name(ohlcv, name):
    if(name is "open"):
        return get_data_set(ohlcv, ext_feats=False)[["Open", "Close"]]
    elif(name is "ohlcv"):
        return get_data_set(ohlcv, ext_feats=False)
    elif(name is "ext"):
        return get_data_set(ohlcv)
    raise Exception('Nome del feature-set non valido')

## Regolarizzazione

### Wavelet

In [16]:
def apply_wavelet_transform(data, consider_future=False, wavelet='haar'):
    res = data.copy()
    if(consider_future):
        res, _ = pywt.dwt(data.copy(), wavelet=wavelet)
    else:
        for i in range(res.shape[0]):
            if(i > 0):
                cA =  waveletSmooth(data.iloc[:i+1].copy(), wavelet=wavelet, level=4, DecLvl=3)
                res.iloc[i] = cA[-1]
    return res

In [17]:
# def apply_wavelet_transform(data, consider_future=False, wavelet='haar'):
#     res = data.copy()
#     if(consider_future):
#         res, _ = pywt.dwt(data.copy(), wavelet=wavelet)
#     else:
#         for i in range(res.shape[0]):
#             if(i > 0):
#                 cA, cD = pywt.dwt(data.iloc[:i+1].copy(), wavelet=wavelet)
#                 res.iloc[i] = cA[-1]
#     return res

In [18]:
# opn =get_index()["Open"]
# haar = apply_wavelet_transform(opn, consider_future=False, wavelet='haar')
# coif3 = apply_wavelet_transform(opn, consider_future=False, wavelet='coif3')
# fig, ax = plt.subplots(3,1, sharex=False, figsize=(20,15))
# ax[0].plot(opn, label='open')
# ax[1].plot(haar, label='haar wavelet')
# ax[2].plot(coif3, label ='coif3 wavelet')
# ax[0].legend()
# ax[1].legend()
# ax[2].legend()

## Normalizzazione 
Le seguenti funzioni sono utilizzate per normalizzare i dati, pianifichiamo di utilizzare queste funzioni anche sulla variabile da predire dovremo implementarne anche l'inversa, di modo da poter denormalizzare le predizioni fatte dal modello.
Questo passo sarà importante nella valutazione del modello, ad esempio per calcolare il ROI.

## Scaler

## Stacked Autoencoder

In [19]:
def get_encoder(X, val=None, sa_hidden_size=10):
    X_train_f = X.astype(np.float32)
    if(val is not None):
        X_val_f = val.astype(np.float32)
    #Initialize the autoencoder
    sa_hidden_size= np.ceil(X.shape[1] / 2).astype(int) # Con tutte le features 10

    num_hidden_1 = sa_hidden_size
    num_hidden_2 = sa_hidden_size
    num_hidden_3 = sa_hidden_size
    num_hidden_4 = sa_hidden_size

    n_epoch1=15000 #10000
    n_epoch2 = 2000
    n_epoch3 = 600
    n_epoch4 = 500
    batch_size=20

    # ---- train using training data

    # The n==0 statement is done because we only want to initialize the network once and then keep training
    # as we move through time 

    auto1 = Autoencoder(X_train_f.shape[1], num_hidden_1)
    auto2 = Autoencoder(num_hidden_1, num_hidden_2)
    auto3 = Autoencoder(num_hidden_2, num_hidden_3)
    auto4 = Autoencoder(num_hidden_3, num_hidden_4)
    
    # Train the autoencoder 
    # switch to training mode
    auto1.train()      
    auto2.train()
    auto3.train()
    auto4.train()

    inputs = torch.from_numpy(X_train_f)
    val_in = torch.from_numpy(X_val_f)
    auto1.fit(X_train_f, X_val_f, n_epoch=n_epoch1, batch_size=batch_size)

    auto1_out = auto1.encoder(inputs).data.numpy()
    val1_out = auto1.encoder(val_in).data.numpy()
    auto2.fit(auto1_out, val1_out, n_epoch=n_epoch2, batch_size=batch_size)


    auto1_out = torch.from_numpy(auto1_out.astype(np.float32))
    auto2_out = auto2.encoder(auto1_out).data.numpy()
    val1_out = torch.from_numpy(val1_out.astype(np.float32))
    val2_out = auto2.encoder(val1_out).data.numpy()
    auto3.fit(auto2_out, val2_out, n_epoch=n_epoch3, batch_size=batch_size)


    auto2_out = torch.from_numpy(auto2_out.astype(np.float32))
    auto3_out = auto3.encoder(auto2_out).data.numpy()
    val2_out = torch.from_numpy(val2_out.astype(np.float32))
    val3_out = auto3.encoder(val2_out).data.numpy()
    auto4.fit(auto3_out, val3_out, n_epoch=n_epoch4, batch_size=batch_size)

    # Change to evaluation mode, in this mode the network behaves differently, e.g. dropout is switched off and so on
    auto1.eval()        
    auto2.eval()
    auto3.eval()
    auto4.eval()
    return [auto1, auto2, auto3, auto4]

In [20]:
def encode(feat_matrix, encoder):
    encoded = torch.from_numpy(feat_matrix)
    for auto in encoder:
        encoded = auto.encoder(encoded)
    return encoded.data.numpy()
    

In [22]:
# ext_train = ext_scaler.transform(s_split_before_2010_06_30(ext_dataset)[0])
# ext_eval = ext_scaler.transform(s_split_before_2010_06_30(ext_dataset)[1])
# ohlcv_train = ohlcv_scaler.transform(s_split_before_2010_06_30(ohlcv_dataset)[0])

In [23]:
# ext_encoder = get_encoder(ext_train, ext_eval, sa_hidden_size=10)
# save(ext_encoder, 'ext_encoder_better')

In [24]:
# #test the sa
# with torch.no_grad():
#     res = torch.from_numpy(ext_eval.astype(np.float32))
#     for encoder in ext_encoder:
#         res, loss = ext_encoder[0].train()(res)
#         print(abs(res.numpy() - ext_eval.astype(np.float32)).mean())
#         print(loss)



## Creazione dei dataset

In [25]:
# #parte utile per visualizzare i diversi indici da selezionare
# start_date = "2000-01-01"
# end_date = "2018-12-31"
# security = yfinance.Ticker('ASHR')# TODO trova mercato asiatico e indiano
# security_data = security.history(start=start_date, end=end_date, actions=False)
# # security_data = pd.DataFrame(security_data.values, index=security_data.index[::-1], columns=security_data.columns) # inv option
# opn.plot()

Consideriamo i seguenti indici:
-mercati sviluppati: S&P500(^GSPC), Dow Jones Industrial Average(^DJI)
-mercati nel mezzo: Hang Seng index in Hong Kong(^HSI), Nikkei 225 index in Tokyo(^N225)
-mercati in via di sviluppo: CSI 300 in mainland China (ASHR), Nifty 50 in India(^NSEI)

In [26]:
# indices = ['^GSPC', '^DJI', '^HSI', '^N225', 'ASHR','^NSEI']

In [27]:
indices = ['^GSPC']

In [28]:
# feature_sets = ['open', 'ohlcv', 'ext','sa_ohlcv', 'sa_ext']

In [29]:
feature_sets = ['open', 'ohlcv', 'ext']

In [30]:
start_date="2000-01-01"
end_date="2018-12-31"

In [31]:
# i dataset sono organizzati nel seguente modo
# -indice
#  -dati originali : dataframe pandas
#  -features: dizionario di featureset
start_date="2000-01-01"
end_date="2018-12-31"
start_data = None
end_data = None
datasets = {}
for index in indices:
    datasets[index] = {}
    datasets[index]["original"] = get_index(index=index, start_date=start_date, end_date=end_date)
    datasets[index]["features"] = {}
    datasets[index]["target"] = None
    for feature_set in feature_sets:
        data = get_dataset_by_name(datasets[index]["original"], name=feature_set)
        data.dropna(inplace=True)
        if(start_data is None):
            start_data = data.index.min()
        else:
            start_data = max(data.index.min(), start_data)
        if(end_data is None):
            end_data = data.index.max()
        else:
            end_data = min(data.index.max(), end_data)
        datasets[index]["features"][feature_set] = data.drop("Close", axis=1)
        if(datasets[index]["target"] is None):
            datasets[index]["target"] = data["Close"]
            
for index in datasets.keys():      
    datasets[index]["target"] = datasets[index]["target"].loc[start_data:end_data].copy()
    for feature_set in datasets[index]["features"].keys():  
        datasets[index]["features"][feature_set] = datasets[index]["features"][feature_set].loc[start_data:end_data].copy()

  "boundary effects.").format(level))
  thresholded = (1 - value/magnitude)


## Training dei modelli

In [32]:
# #dataset è della forma (X,Y)
# def set_dataset(data_set):
#     global X
#     global Y
#     global X_train
#     global Y_train
#     global X_val
#     global Y_val
#     global X_train_f
#     global Y_train_f
#     global X_val_f
#     global Y_val_f
#     global X_f
#     global Y_f
#     global half
#     global half_split
#     X= data_set[0]
#     Y= data_set[1]
#     X_train, X_val, Y_train, Y_val = split_before_2010_06_30(X, Y)
#     X_f = X.to_numpy().astype(np.float32)
#     Y_f = Y.to_numpy().astype(np.float32)[...,None]
#     X_train_f = X_train.to_numpy().astype(np.float32)
#     Y_train_f = Y_train.to_numpy().astype(np.float32)[...,None]
#     X_val_f = X_val.to_numpy().astype(np.float32)
#     Y_val_f = Y_val.to_numpy().astype(np.float32)[...,None] 
#     l1 = len(np.split(X_f, [len(X_f)//2])[0])
#     l2 = len(np.split(X_f, [len(X_f)//2])[1])
#     half = PredefinedSplit(np.concatenate((np.ones(l1)*-1,np.ones(l2))))
#     half_split =  CVSplit(cv=half, stratified=False, random_state=None)

In [33]:
tss = TimeSeriesSplit(3)#max_train_size
tss_split = CVSplit(cv=tss, stratified=False, random_state=None)

In [34]:
# models = {}

In [35]:
# modello del paper di Moro
#TODO prova
n_days = 5

batch_size = 20


lstm_moro = NeuralNetRegressor(
    module=SequenceDouble,
    optimizer=optim.SGD,
    batch_size=batch_size,
    max_epochs=200, # usato nel paper
    train_split=None,
    
    module__nb_features=X_f.shape[1],
    module__hidden_size=256,
    optimizer__lr=0.01,
    optimizer__weight_decay=0,
    optimizer__momentum=0.9
)
models['lstm_moro'] = (lstm_moro, 'open')

In [None]:
#modello del paper che usa attention mechanism
n_days = 5

batch_size = 20

lstm_att_ohlcv = NeuralNetRegressor(
    module=SequenceDoubleAtt,
    optimizer=optim.Adam,
    batch_size=batch_size,
    max_epochs=250, # trovato empiricamente
    train_split=None,
    callbacks=[
        callbacks.EpochScoring('neg_mean_absolute_error', lower_is_better=False),
        callbacks.EpochScoring('r2', lower_is_better=False),
        callbacks.Checkpoint(monitor='valid_loss_best', f_pickle='lstm_sa_best')        
    ],
    
    module__nb_features=ohlcv_train.shape[1],
    module__hidden_size=256,
    optimizer__lr=0.0001,
)

models['lstm_att_ohlcv'] = (lstm_att_ohlcv, 'ohlcv')

In [None]:
#modello del paper che usa stacked autoencoders usando LSTM paper Moro
n_days = 5

batch_size = 20

lstm_sa = NeuralNetRegressor(
    module=SequenceDouble,
    optimizer=optim.Adam,
    batch_size=batch_size,
    max_epochs=350, # trovato empiricamente
    train_split=tss_split,
    callbacks=[
        callbacks.EpochScoring('neg_mean_absolute_error', lower_is_better=False),
        callbacks.Checkpoint(monitor='valid_loss_best', f_pickle='lstm_sa_best'),
        
    ],
    
    module__nb_features=ext_sa_train.shape[1],
    module__hidden_size=256,
    optimizer__lr=0.0001,
#     optimizer__weight_decay=0,
#     optimizer__momentum=0.9,
    iterator_train__shuffle = True,
)

models['lstm_sa'] = (lstm_sa, 'ext_sa')

In [None]:
#modello del paper che usa attention mechanism
n_days = 5

batch_size = 20

lstm_sa_att = NeuralNetRegressor(
    module=SequenceDoubleAtt,
    optimizer=optim.Adam,
    batch_size=batch_size,
    max_epochs=5000, # trovato 280
    train_split=half_split,
    callbacks=[
        callbacks.EpochScoring('neg_mean_absolute_error', lower_is_better=False),
        callbacks.EpochScoring('r2', lower_is_better=False),
        callbacks.Checkpoint(monitor='valid_loss_best', f_pickle='lstm_sa_best')        
    ],
    
    module__nb_features=ext_sa_train.shape[1],
    module__hidden_size=256,
#     module__nb_layers= 5,
    optimizer__lr=0.001,
#     optimizer__weight_decay=0,
#     optimizer__momentum=0.9
)

models['lstm_sa_att'] = (lstm_sa_att, 'ext_sa')

In [37]:
# models = {'lstm_moro':(lstm_moro, 'open'), 'lstm_sa':(lstm_sa_d_1000,'sa_ext'), 'lstm_att_ohlcv':(lstm_att_d_600,'ohlcv'), 'lstm_att_ext':(lstm_att_d_600,'ext'), 'lstm_sa_att,':(lstm_sa_att_1500, 'sa_ext')}

In [38]:
# TODO assicurarsi che tutti i feature set abbiano la stessa luinghezza

In [188]:
def gain_score(C, C_pred, opn):
    O = opn
    CO_diff = C - O
    growth = C_pred > O
    decline = C_pred < O
    return CO_diff[growth].sum() - CO_diff[decline].sum()
def roi_score(C, C_pred, opn):
    mean_opn = opn.mean()
    return gain(C, C_pred, opn) / mean_opn

## test dei modelli

In [40]:
scores = pd.DataFrame()#TODO define structure salva risultati ad ogni iterazione in csv
for index_data in indexes_data:
    for i, (train, val) in enumerate(tss.split(index_data["target"]), start=1):
        
        feature_sets = {}
        
#         train_dates = index_data["original"].index[train]
#         val_dates = index_data["original"].index[val]
        
        y = index_data["target"]
        opn_dataset = index_data["features"]['open']
        ohlcv_dataset = index_data["features"]['ohlcv']
        ext_dataset = index_data["features"]['ext']

        #data split
        
        opn_train, opn_val = opn_dataset.iloc[train_dates].to_numpy(np.float32), opn_dataset.iloc[val_dates].to_numpy(np.float32)
        ohlcv_train, ohlcv_val = ohlcv_dataset.iloc[train_dates].to_numpy(np.float32), ohlcv_dataset.iloc[val_dates].to_numpy(np.float32)
        ext_train, ext_val = ext_dataset.iloc[train_dates].to_numpy(np.float32), ext_dataset.iloc[val_dates].to_numpy(np.float32)
        y_train, y_val = y.iloc[train_dates].to_numpy(np.float32)[...,None], y.iloc[val_dates].to_numpy(np.float32)[...,None]
        
        #data scale
        opn_scaler = StandardScaler()
        ohlcv_scaler = StandardScaler()
        ext_scaler = StandardScaler()
        y_scaler = StandardScaler()

        opn_train = ohlcv_scaler.fit_transform(opn_train)
        ohlcv_train = ohlcv_scaler.fit_transform(ohlcv_train)
        ext_train = ext_scaler.fit_transform(ext_train)
        y_train = y_scaler.fit_transform(y_train)
        opn_val = ohlcv_scaler.transform(opn_val)
        ohlcv_val = ohlcv_scaler.ransform(ohlcv_val)
        ext_val = ext_scaler.transform(ext_val)
        y_val = y_scaler.transform(y_val)
        
        ext_encoder = get_encoder(ext_train, ext_val, sa_hidden_size=10)
        ext_sa_train, ext_sa_val = encode(ext_train, ext_encoder), encode(ext_val, ext_encoder)
        
        feature_sets['open'] = (opn_train, opn_val, opn_scaler)
        feature_sets['ohlcv'] = (ohlcv_train, ohlcv_val, ohlcv_scaler)
        feature_sets['ext'] = (ext_train, ext_val, ext_scaler)
        feature_sets['ext_sa'] = (ext_sa_train, ext_sa_val, ext_scaler)
        
        n_days = 5
        opn_test = index_data["original"]["Open"].reindex_like(opn_dataset.iloc[train_dates]).to_numpy(np.float32)
        close_test =  index_data["original"]["Close"].reindex_like(opn_dataset.iloc[train_dates]).to_numpy(np.float32)
        #training
        for key, info in models:
            
            model = info[0]
            x_train, x_val, scaler = feature_sets[info[1]]
            
            xd_train = days_group(x_train, n_days=n_days)
            yd_train = y_train[n_days:]
            xd_val = days_group(x_val, n_days=n_days)
            yd_val = y_val[n_days:]
                        
            model.fit(xd_train, yd_train)
            
            pred = model.predict(xd_val)
            pred_unsc = scaler.inverse_transform(pred)
#             y_unsc = y_scaler.inverse_transform(y[i])
            

            mape = (abs(close_test - pred_unsc)/close_test).mean()
            mspe = (((close_test - pred_unsc)/close_test)**2).mean()
            acc = 1 - mape
            roi = roi_score(close_test, pred_unsc, opn_test)
        
#         print("FOLD {}".format(i))
#         train_dates = X.index[train]
#         val_dates = X.index[val]
#         print("Training set da {} a {}".format(train_dates.min(), train_dates.max()))
#         print("Validation set da {} a {}".format(val_dates.min(), val_dates.max()))

        
        
    

## Metriche

In [58]:
# ## Calcola la loss solo sull'utlimo elemento del batch
# class RNNMSELoss(MSELoss):
#     def __call__(self, input, target):
#         return super().__call__(input, target[-1])

## Modelli

In [214]:
train_dates, val_dates = list(tss.split(datasets['^GSPC']["target"]))[0]

In [215]:
y = datasets['^GSPC']["target"]
ext_dataset = datasets['^GSPC']["features"]['open']
ext_scaler = StandardScaler()
y_scaler = StandardScaler()
# sa = load('ext_encoder_better')
ext_train, ext_val = ext_scaler.fit_transform(ext_dataset.iloc[train_dates].to_numpy(np.float32)), ext_scaler.transform(ext_dataset.iloc[val_dates].to_numpy(np.float32))
y_train, y_val = y_scaler.fit_transform(y.iloc[train_dates].to_numpy(np.float32)[...,None]), y_scaler.transform(y.iloc[val_dates].to_numpy(np.float32)[...,None])

In [216]:
ext = np.concatenate((ext_train, ext_val))
y_data = np.concatenate((y_train, y_val))

In [219]:
x = days_group(ext, n_days=n_days)
y = y_data[n_days:]
l1 = len(np.split(x, [(len(ext)*2)//3])[0])
l2 = len(np.split(x, [(len(ext)*2)//3])[1])
half = PredefinedSplit(np.concatenate((np.ones(l1)*-1,np.ones(l2))))
half_split =  CVSplit(cv=half, stratified=False, random_state=None)

In [233]:
# modello del paper di Moro
#TODO prova
n_days = 5

batch_size = 20


lstm_moro = NeuralNetRegressor(
    module=SequenceDouble,
    optimizer=optim.Adam,
    batch_size=batch_size,
    max_epochs = 400 #trovato empiricamente
    train_split=half_split,
    
    module__nb_features=ext_train.shape[1],
    module__hidden_size=256,
    optimizer__lr=0.0001,
)

In [None]:
lstm_moro.fit(x,y)

  epoch    train_loss    valid_loss     dur
-------  ------------  ------------  ------
      1        [36m1.1838[0m        [32m2.4005[0m  3.1708
      2        [36m0.3314[0m        [32m0.0768[0m  1.4612
      3        [36m0.0299[0m        [32m0.0757[0m  2.3017
      4        [36m0.0284[0m        [32m0.0748[0m  2.0754
      5        0.0288        0.0775  1.9196
      6        0.0307        0.0792  2.0209
      7        0.0325        0.0813  2.6854
      8        0.0345        0.0830  2.6317
      9        0.0362        0.0846  2.3061
     10        0.0377        0.0864  2.5968
     11        0.0390        0.0886  2.9478
     12        0.0401        0.0911  1.9306
     13        0.0410        0.0936  1.5429
     14        0.0416        0.0956  1.5579
     15        0.0418        0.0968  1.5035
     16        0.0416        0.0968  1.5557
     17        0.0410        0.0956  1.9927
     18        0.0402        0.0936  2.3326
     19        0.0392        0.0911  2.2288
    

    143        [36m0.0139[0m        [32m0.0308[0m  1.4514
    144        [36m0.0138[0m        [32m0.0305[0m  1.4418
    145        [36m0.0136[0m        [32m0.0303[0m  1.4827
    146        [36m0.0135[0m        [32m0.0301[0m  1.7545
    147        [36m0.0134[0m        [32m0.0299[0m  1.5382
    148        [36m0.0133[0m        [32m0.0297[0m  1.9074
    149        [36m0.0131[0m        [32m0.0296[0m  2.0913
    150        [36m0.0130[0m        [32m0.0294[0m  1.7170
    151        [36m0.0129[0m        [32m0.0294[0m  1.7078
    152        [36m0.0128[0m        [32m0.0292[0m  1.6399
    153        [36m0.0127[0m        [32m0.0291[0m  2.0171
    154        [36m0.0126[0m        [32m0.0290[0m  1.8385
    155        [36m0.0125[0m        [32m0.0290[0m  1.9224
    156        [36m0.0124[0m        [32m0.0289[0m  5.6137
    157        [36m0.0124[0m        [32m0.0288[0m  1.8560
    158        [36m0.0123[0m        [32m0.0287[0m  1.6915
    159 

    290        [36m0.0101[0m        [32m0.0235[0m  1.6418
    291        [36m0.0101[0m        [32m0.0235[0m  1.4736
    292        [36m0.0101[0m        [32m0.0234[0m  1.7432
    293        [36m0.0101[0m        [32m0.0234[0m  1.6126
    294        [36m0.0101[0m        [32m0.0234[0m  4.0074
    295        [36m0.0101[0m        [32m0.0234[0m  2.1259
    296        [36m0.0101[0m        [32m0.0233[0m  1.9156
    297        [36m0.0101[0m        0.0233  1.8668
    298        0.0101        [32m0.0233[0m  2.2593
    299        [36m0.0101[0m        0.0233  2.4057
    300        0.0101        [32m0.0232[0m  2.3185
    301        [36m0.0101[0m        0.0233  2.3027
    302        0.0101        [32m0.0232[0m  2.4997
    303        [36m0.0101[0m        [32m0.0232[0m  2.3009
    304        [36m0.0100[0m        [32m0.0232[0m  1.7721
    305        [36m0.0100[0m        0.0232  1.7000
    306        [36m0.0100[0m        [32m0.0232[0m  1.6654
    307  

    442        [36m0.0092[0m        0.0224  2.6723
    443        0.0092        0.0224  2.4992
    444        [36m0.0092[0m        0.0225  2.4378
    445        0.0092        0.0224  1.6334
    446        [36m0.0092[0m        0.0225  1.5213
    447        0.0092        0.0224  1.8514
    448        [36m0.0092[0m        0.0225  1.7614
    449        [36m0.0092[0m        0.0224  2.2144
    450        [36m0.0092[0m        0.0224  2.3741
    451        [36m0.0092[0m        0.0224  2.2914
    452        0.0092        0.0224  2.3662
    453        [36m0.0092[0m        0.0225  3.7863
    454        0.0092        0.0224  1.5784
    455        [36m0.0092[0m        0.0224  1.4472
    456        0.0092        0.0224  1.4661
    457        [36m0.0092[0m        0.0225  2.0896
    458        [36m0.0092[0m        0.0225  1.7442
    459        [36m0.0092[0m        0.0225  1.8802
    460        0.0092        0.0225  1.8426
    461        0.0092        0.0225  2.5476
    462     

    615        [36m0.0087[0m        0.0233  2.0286
    616        [36m0.0087[0m        0.0233  2.3000
    617        0.0087        0.0233  2.1088
    618        [36m0.0087[0m        0.0233  2.3810
    619        0.0087        0.0234  1.7208
    620        0.0087        0.0234  4.2354
    621        0.0087        0.0234  1.8476
    622        0.0087        0.0235  1.8640
    623        0.0087        0.0235  1.8067
    624        0.0087        0.0235  1.8680
    625        0.0087        0.0236  1.9456
    626        0.0087        0.0236  2.7603
    627        0.0087        0.0236  2.3349
    628        0.0087        0.0237  1.8580
    629        0.0087        0.0236  7.3926
    630        0.0087        0.0238  2.3233
    631        0.0087        0.0236  2.0887
    632        0.0087        0.0239  2.2520
    633        0.0087        0.0234  1.8905
    634        0.0087        0.0242  2.2323
    635        0.0087        0.0232  2.8707
    636        0.0087        0.0246  1.8375
    6

    796        [36m0.0081[0m        0.0265  2.3348
    797        0.0081        0.0270  1.7634
    798        0.0081        0.0274  2.0629
    799        0.0082        0.0278  2.1161
    800        0.0082        0.0282  2.0298
    801        0.0082        0.0286  1.9832
    802        0.0082        0.0290  2.2902
    803        0.0082        0.0294  2.5324
    804        0.0083        0.0298  2.3451
    805        0.0083        0.0302  2.2453
    806        0.0083        0.0306  5.4505
    807        0.0083        0.0310  2.1674
    808        0.0083        0.0314  2.0414
    809        0.0083        0.0318  1.6147
    810        0.0083        0.0321  2.4435
    811        0.0084        0.0324  2.4704
    812        0.0084        0.0327  2.5981
    813        0.0084        0.0329  2.3949
    814        0.0084        0.0332  2.4269
    815        0.0084        0.0334  2.5525
    816        0.0084        0.0336  2.0336
    817        0.0084        0.0337  1.8095
    818        0.0084  

    981        0.0082        0.0410  2.5763
    982        0.0082        0.0411  7.7158
    983        0.0082        0.0412  4.3456
    984        0.0082        0.0413  2.3057
    985        0.0082        0.0414  2.3442
    986        0.0082        0.0415  1.9492
    987        0.0082        0.0416  2.0061
    988        0.0082        0.0417  2.0272
    989        0.0082        0.0418  2.3404
    990        0.0082        0.0418  1.9195
    991        0.0082        0.0420  2.1156
    992        0.0082        0.0422  1.9028
    993        0.0082        0.0420  1.7617
    994        0.0082        0.0424  2.4214
    995        0.0082        0.0426  1.7712
    996        0.0082        0.0421  1.9404
    997        0.0082        0.0429  2.0737
    998        0.0082        0.0420  1.8130
    999        0.0082        0.0432  4.2314
   1000        0.0082        0.0413  2.7723
   1001        0.0083        0.0420  3.0544
   1002        0.0082        0.0399  2.1673
   1003        0.0083        0.0

   1167        0.0082        0.0504  2.2813
   1168        0.0081        0.0392  3.1949
   1169        0.0082        0.0514  2.6350
   1170        0.0081        0.0428  2.2321
   1171        0.0081        0.0599  2.7139
   1172        0.0081        0.0463  2.6834
   1173        0.0081        0.0571  2.0226
   1174        0.0082        0.0476  6.9536
   1175        0.0081        0.0396  2.7987
   1176        0.0084        0.0445  2.4170
   1177        0.0082        0.0529  2.3813
   1178        0.0082        0.0518  2.6702
   1179        0.0080        0.0522  2.2924
   1180        0.0080        0.0545  1.9571
   1181        [36m0.0080[0m        0.0549  2.5828
   1182        [36m0.0080[0m        0.0533  7.3160
   1183        0.0080        0.0570  2.2965
   1184        0.0080        0.0508  2.4403
   1185        0.0080        0.0587  3.1398
   1186        0.0080        0.0499  2.5023
   1187        0.0080        0.0604  1.7772
   1188        0.0080        0.0534  1.8507
   1189       

   1361        0.0079        0.0489  1.7073
   1362        0.0079        0.0390  1.7649
   1363        0.0078        0.0515  2.0948
   1364        0.0079        0.0376  2.1378
   1365        0.0078        0.0533  1.8433
   1366        0.0079        0.0404  1.8692
   1367        0.0078        0.0616  1.7104
   1368        0.0079        0.0423  2.7982
   1369        0.0078        0.0664  2.5322
   1370        0.0080        0.0387  2.3394
   1371        0.0079        0.0503  2.6618
   1372        0.0080        0.0473  2.1228
   1373        0.0079        0.0374  2.1592
   1374        0.0081        0.0408  1.8787
   1375        0.0078        0.0515  1.8495
   1376        [36m0.0078[0m        0.0520  2.6013
   1377        [36m0.0078[0m        0.0510  2.3098
   1378        [36m0.0078[0m        0.0541  5.0843
   1379        0.0078        0.0503  2.6582
   1380        0.0078        0.0562  2.4178
   1381        0.0078        0.0504  2.3086
   1382        0.0078        0.0582  2.3536
   13

   1545        0.0077        0.0385  3.1792
   1546        0.0077        0.0467  1.7942
   1547        0.0078        0.0351  2.4292
   1548        0.0076        0.0469  2.6635
   1549        0.0078        0.0345  2.2355
   1550        0.0076        0.0450  2.7392
   1551        0.0077        0.0341  2.4683
   1552        0.0077        0.0445  2.4065
   1553        0.0078        0.0359  2.1748
   1554        0.0076        0.0467  2.1860
   1555        0.0078        0.0335  2.7235
   1556        0.0076        0.0462  2.4805
   1557        0.0077        0.0337  2.3704
   1558        0.0076        0.0447  2.4971
   1559        0.0077        0.0351  2.2714
   1560        0.0076        0.0461  2.1327
   1561        0.0078        0.0357  1.8681
   1562        0.0076        0.0490  1.9791
   1563        0.0077        0.0350  2.3571
   1564        0.0076        0.0502  2.8428
   1565        0.0077        0.0356  2.2331
   1566        0.0076        0.0506  2.1220
   1567        0.0077        0.0

In [60]:
y = datasets['^GSPC']["target"]
ext_dataset = datasets['^GSPC']["features"]['ext']
ext_scaler = StandardScaler()
y_scaler = StandardScaler()
# sa = load('ext_encoder_better')
ext_train, ext_val = ext_scaler.fit_transform(ext_dataset.iloc[train_dates].to_numpy(np.float32)), ext_scaler.transform(ext_dataset.iloc[val_dates].to_numpy(np.float32))
y_train, y_val = y_scaler.fit_transform(y.iloc[train_dates].to_numpy(np.float32)[...,None]), y_scaler.transform(y.iloc[val_dates].to_numpy(np.float32)[...,None])

In [62]:
# ext_encoder = get_encoder(ext_train, ext_val, sa_hidden_size=10)
# save(ext_encoder, 'ext_encoder_better')

In [63]:
# ext_encoder = load('ext_encoder_better')
# ext_sa_train, ext_sa_val = encode(ext_train, ext_encoder), encode(ext_val, ext_encoder)
# ext_sa = np.concatenate((ext_sa_train, ext_sa_val))
# y_data = np.concatenate((y_train, y_val))

In [65]:
# #modello del paper che usa stacked autoencoders usando LSTM paper Moro
# n_days = 5

# batch_size = 20
# #TODO test e trova numero epoche

# lstm_sa = NeuralNetRegressor(
#     module=SequenceDouble,
#     optimizer=optim.Adam,
#     batch_size=batch_size,
#     max_epochs=350, # trovato empiricamente
#     train_split=tss_split,
#     callbacks=[
#         callbacks.EpochScoring('neg_mean_absolute_error', lower_is_better=False),
#         callbacks.Checkpoint(monitor='valid_loss_best', f_pickle='lstm_sa_best'),
        
#     ],
    
#     module__nb_features=ext_sa_train.shape[1],
#     module__hidden_size=256,
#     optimizer__lr=0.0001,
# #     optimizer__weight_decay=0,
# #     optimizer__momentum=0.9,
#     iterator_train__shuffle = True,
# )

In [66]:
x = days_group(ext_sa, n_days=n_days)
y = y_data[n_days:]
l1 = len(np.split(x, [len(ext_sa_train)])[0])
l2 = len(np.split(x, [len(ext_sa_train)])[1])
half = PredefinedSplit(np.concatenate((np.ones(l1)*-1,np.ones(l2))))
half_split =  CVSplit(cv=half, stratified=False, random_state=None)

In [67]:
# lstm_sa.fit(x, y)

In [68]:
# lstm = lstm_sa
# valid_losses = lstm.history[:, 'valid_loss']
# train_losses = lstm.history[:, 'train_loss']
# plt.figure(figsize=(12,7))
# plt.plot(valid_losses, label='valid_loss')
# plt.plot(train_losses, label='train_loss')
# plt.xticks(np.arange(len(valid_losses)+1, step=50))
# plt.legend()

In [69]:
# save(lstm_sa, 'lstm_sa_5000')

In [70]:
sklearn.metrics.SCORERS.keys()

dict_keys(['explained_variance', 'r2', 'max_error', 'neg_median_absolute_error', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_root_mean_squared_error', 'neg_mean_poisson_deviance', 'neg_mean_gamma_deviance', 'accuracy', 'roc_auc', 'roc_auc_ovr', 'roc_auc_ovo', 'roc_auc_ovr_weighted', 'roc_auc_ovo_weighted', 'balanced_accuracy', 'average_precision', 'neg_log_loss', 'neg_brier_score', 'adjusted_rand_score', 'homogeneity_score', 'completeness_score', 'v_measure_score', 'mutual_info_score', 'adjusted_mutual_info_score', 'normalized_mutual_info_score', 'fowlkes_mallows_score', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'jaccard', 'jaccard_macro', 'jaccard_micro', 'jaccard_samples', 'jaccard_weighted'])

In [193]:
y = datasets['^GSPC']["target"]
ohlcv_dataset = datasets['^GSPC']["features"]['ohlcv']
ohlcv_scaler = StandardScaler()
y_scaler = StandardScaler()
# sa = load('ext_encoder_better')
ohlcv_train, ohlcv_val = ohlcv_scaler.fit_transform(ohlcv_dataset.iloc[train_dates].to_numpy(np.float32)), ohlcv_scaler.transform(ohlcv_dataset.iloc[val_dates].to_numpy(np.float32))
y_train, y_val = y_scaler.fit_transform(y.iloc[train_dates].to_numpy(np.float32)[...,None]), y_scaler.transform(y.iloc[val_dates].to_numpy(np.float32)[...,None])

In [194]:
ohlcv = np.concatenate((ohlcv_train, ohlcv_val))
y_data = np.concatenate((y_train, y_val))

In [195]:
x = days_group(ohlcv, n_days=n_days)
y = y_data[n_days:]
l1 = len(np.split(x, [(len(ohlcv)*2)//3])[0])
l2 = len(np.split(x, [(len(ohlcv)*2)//3])[1])
half = PredefinedSplit(np.concatenate((np.ones(l1)*-1,np.ones(l2))))
half_split =  CVSplit(cv=half, stratified=False, random_state=None)

In [199]:
#modello del paper che usa attention mechanism
#TODO testa con nuova lr
#Adam lr=0.0001, 0.05 loss, 80% accuracy, amtcha il paper
#SGD lr = 0.0001, 
n_days = 5

batch_size = 20

lstm_att = NeuralNetRegressor(
    module=SequenceDoubleAtt,
    optimizer=optim.SGD,
    batch_size=batch_size,
    max_epochs=250, # trovato empiricamente
    train_split=half_split,
    callbacks=[
        callbacks.EpochScoring('neg_mean_absolute_error', lower_is_better=False),
        callbacks.EpochScoring('r2', lower_is_better=False),
        callbacks.Checkpoint(monitor='valid_loss_best', f_pickle='lstm_sa_best')        
    ],
    
    module__nb_features=ohlcv_train.shape[1],
    module__hidden_size=256,
#     module__nb_layers= 5,
    optimizer__lr=0.0001, #TODO prova con questa lr
#     optimizer__weight_decay=0,
#     optimizer__momentum=0.9
)

In [200]:
# lstm_att.fit(x, y)

  epoch    neg_mean_absolute_error      r2    train_loss    valid_loss    cp     dur
-------  -------------------------  ------  ------------  ------------  ----  ------
      1                    [36m-1.8136[0m  [32m0.0412[0m        [35m1.3316[0m        [31m4.0114[0m     +  2.8761




      2                    [36m-1.3799[0m  [32m0.3827[0m        [35m0.5683[0m        [31m2.5828[0m     +  2.0920




      3                    [36m-1.2893[0m  [32m0.4352[0m        [35m0.3533[0m        [31m2.3629[0m     +  4.5807




      4                    [36m-1.0899[0m  [32m0.6256[0m        [35m0.2617[0m        [31m1.5663[0m     +  3.3596




      5                    [36m-0.8004[0m  [32m0.7901[0m        [35m0.1742[0m        [31m0.8782[0m     +  2.3338




      6                    [36m-0.6082[0m  [32m0.8679[0m        [35m0.1202[0m        [31m0.5528[0m     +  3.7345




      7                    [36m-0.5157[0m  [32m0.9007[0m        [35m0.0920[0m        [31m0.4157[0m     +  1.9537




      8                    [36m-0.4948[0m  [32m0.9126[0m        [35m0.0788[0m        [31m0.3655[0m     +  2.1824




      9                    -0.5000  [32m0.9162[0m        [35m0.0726[0m        [31m0.3505[0m     +  2.2169




     10                    -0.5046  0.9150        [35m0.0676[0m        0.3555        2.1831
     11                    -0.5065  0.9115        [35m0.0623[0m        0.3702        2.1370
     12                    -0.5133  0.9073        [35m0.0564[0m        0.3877        1.8779
     13                    -0.5194  0.9045        [35m0.0511[0m        0.3994        2.1530
     14                    -0.5187  0.9045        [35m0.0464[0m        0.3994        1.8991
     15                    -0.5104  0.9078        [35m0.0424[0m        0.3857        1.8742
     16                    [36m-0.4942[0m  0.9140        [35m0.0391[0m        0.3600        1.4796
     17                    [36m-0.4715[0m  [32m0.9220[0m        [35m0.0364[0m        [31m0.3263[0m     +  1.5446




     18                    [36m-0.4450[0m  [32m0.9309[0m        [35m0.0342[0m        [31m0.2892[0m     +  3.5576




     19                    [36m-0.4171[0m  [32m0.9396[0m        [35m0.0326[0m        [31m0.2526[0m     +  3.0160




     20                    [36m-0.3892[0m  [32m0.9475[0m        [35m0.0313[0m        [31m0.2196[0m     +  2.7937




     21                    [36m-0.3623[0m  [32m0.9542[0m        [35m0.0303[0m        [31m0.1918[0m     +  1.5216




     22                    [36m-0.3388[0m  [32m0.9595[0m        [35m0.0296[0m        [31m0.1696[0m     +  2.7589




     23                    [36m-0.3196[0m  [32m0.9635[0m        [35m0.0292[0m        [31m0.1529[0m     +  2.6685




     24                    [36m-0.3047[0m  [32m0.9663[0m        [35m0.0288[0m        [31m0.1410[0m     +  2.8253




     25                    [36m-0.2929[0m  [32m0.9682[0m        [35m0.0286[0m        [31m0.1329[0m     +  2.7304




     26                    [36m-0.2841[0m  [32m0.9695[0m        [35m0.0284[0m        [31m0.1278[0m     +  2.6045




     27                    [36m-0.2780[0m  [32m0.9702[0m        [35m0.0283[0m        [31m0.1248[0m     +  2.8354




     28                    [36m-0.2738[0m  [32m0.9705[0m        [35m0.0281[0m        [31m0.1233[0m     +  2.4551




     29                    [36m-0.2714[0m  [32m0.9707[0m        [35m0.0280[0m        [31m0.1228[0m     +  3.0589




     30                    [36m-0.2698[0m  0.9706        [35m0.0278[0m        0.1228        4.7593
     31                    [36m-0.2689[0m  0.9705        [35m0.0275[0m        0.1233        2.2959
     32                    [36m-0.2684[0m  0.9704        [35m0.0272[0m        0.1240        2.3408
     33                    [36m-0.2683[0m  0.9701        [35m0.0269[0m        0.1250        2.7325
     34                    -0.2684  0.9699        [35m0.0266[0m        0.1261        2.2905
     35                    -0.2685  0.9696        [35m0.0262[0m        0.1273        2.7350
     36                    -0.2687  0.9693        [35m0.0259[0m        0.1284        2.0394
     37                    -0.2689  0.9690        [35m0.0256[0m        0.1296        2.3952
     38                    -0.2691  0.9688        [35m0.0252[0m        0.1307        2.4025
     39                    -0.2692  0.9686        [35m0.0249[0m        0.1316        2.4258
     40                 



     53                    [36m-0.2492[0m  [32m0.9715[0m        [35m0.0220[0m        [31m0.1191[0m     +  2.7870




     54                    [36m-0.2465[0m  [32m0.9721[0m        [35m0.0219[0m        [31m0.1168[0m     +  2.5675




     55                    [36m-0.2437[0m  [32m0.9727[0m        [35m0.0217[0m        [31m0.1144[0m     +  2.3699




     56                    [36m-0.2409[0m  [32m0.9732[0m        [35m0.0216[0m        [31m0.1119[0m     +  2.5760




     57                    [36m-0.2381[0m  [32m0.9738[0m        [35m0.0215[0m        [31m0.1095[0m     +  2.4523




     58                    [36m-0.2352[0m  [32m0.9744[0m        [35m0.0213[0m        [31m0.1070[0m     +  3.8906




     59                    [36m-0.2323[0m  [32m0.9750[0m        [35m0.0212[0m        [31m0.1045[0m     +  2.3960




     60                    [36m-0.2294[0m  [32m0.9756[0m        [35m0.0211[0m        [31m0.1020[0m     +  2.3148




     61                    [36m-0.2265[0m  [32m0.9762[0m        [35m0.0210[0m        [31m0.0996[0m     +  3.6068




     62                    [36m-0.2237[0m  [32m0.9768[0m        [35m0.0209[0m        [31m0.0972[0m     +  4.0844




     63                    [36m-0.2211[0m  [32m0.9773[0m        [35m0.0208[0m        [31m0.0949[0m     +  1.8286




     64                    [36m-0.2185[0m  [32m0.9779[0m        [35m0.0208[0m        [31m0.0926[0m     +  2.5674




     65                    [36m-0.2160[0m  [32m0.9784[0m        [35m0.0207[0m        [31m0.0905[0m     +  2.0422




     66                    [36m-0.2136[0m  [32m0.9789[0m        [35m0.0206[0m        [31m0.0884[0m     +  2.0443




     67                    [36m-0.2112[0m  [32m0.9794[0m        [35m0.0205[0m        [31m0.0864[0m     +  2.1887




     68                    [36m-0.2091[0m  [32m0.9798[0m        [35m0.0204[0m        [31m0.0845[0m     +  2.2804




     69                    [36m-0.2070[0m  [32m0.9802[0m        [35m0.0204[0m        [31m0.0827[0m     +  1.8109




     70                    [36m-0.2051[0m  [32m0.9806[0m        [35m0.0203[0m        [31m0.0810[0m     +  2.8029




     71                    [36m-0.2034[0m  [32m0.9810[0m        [35m0.0202[0m        [31m0.0795[0m     +  2.9775




     72                    [36m-0.2019[0m  [32m0.9813[0m        [35m0.0202[0m        [31m0.0781[0m     +  2.9720




     73                    [36m-0.2006[0m  [32m0.9816[0m        [35m0.0201[0m        [31m0.0768[0m     +  2.6089




     74                    [36m-0.1995[0m  [32m0.9819[0m        [35m0.0200[0m        [31m0.0756[0m     +  2.0376




     75                    [36m-0.1987[0m  [32m0.9822[0m        [35m0.0200[0m        [31m0.0746[0m     +  1.9593




     76                    [36m-0.1982[0m  [32m0.9824[0m        [35m0.0199[0m        [31m0.0737[0m     +  2.2858




     77                    [36m-0.1978[0m  [32m0.9826[0m        [35m0.0199[0m        [31m0.0730[0m     +  2.8177




     78                    [36m-0.1976[0m  [32m0.9827[0m        [35m0.0198[0m        [31m0.0723[0m     +  2.6754




     79                    -0.1977  [32m0.9828[0m        [35m0.0198[0m        [31m0.0719[0m     +  2.0431




     80                    -0.1982  [32m0.9829[0m        [35m0.0197[0m        [31m0.0715[0m     +  4.8397




     81                    -0.1988  [32m0.9830[0m        [35m0.0197[0m        [31m0.0713[0m     +  2.9353




     82                    -0.1997  [32m0.9830[0m        [35m0.0196[0m        [31m0.0712[0m     +  2.2220




     83                    -0.2007  0.9830        [35m0.0196[0m        0.0713        1.5332
     84                    -0.2019  0.9829        [35m0.0195[0m        0.0714        1.4971
     85                    -0.2033  0.9829        [35m0.0195[0m        0.0717        1.5380
     86                    -0.2048  0.9828        [35m0.0195[0m        0.0720        1.5252
     87                    -0.2063  0.9827        [35m0.0194[0m        0.0725        1.8962
     88                    -0.2079  0.9826        [35m0.0194[0m        0.0730        2.1390
     89                    -0.2095  0.9824        [35m0.0194[0m        0.0736        1.6523
     90                    -0.2111  0.9823        [35m0.0193[0m        0.0742        1.5016
     91                    -0.2128  0.9821        [35m0.0193[0m        0.0749        2.0295
     92                    -0.2145  0.9819        [35m0.0193[0m        0.0756        2.1326
     93                    -0.2161  0.9817        [35m0.019



    165                    -0.2294  0.9804        0.0161        0.0820        2.2108
    166                    -0.2271  0.9805        0.0156        0.0815        4.7106
    167                    -0.2298  0.9802        0.0145        0.0827        2.0196
    168                    -0.2279  0.9808        0.0142        0.0804        2.1398
    169                    -0.2254  0.9814        0.0143        0.0780        1.8515
    170                    -0.2246  0.9816        0.0145        0.0770        2.4233
    171                    -0.2234  0.9819        0.0146        0.0759        1.4915
    172                    -0.2226  0.9820        0.0147        0.0751        1.5031
    173                    -0.2198  0.9824        0.0147        0.0736        1.7996
    174                    -0.2194  0.9827        0.0148        0.0726        1.4817
    175                    -0.2028  0.9844        0.0147        0.0653        1.4795
    176                    -0.2169  0.9823        [35m0.0140[0m



    184                    -0.2124  0.9839        0.0134        0.0674        2.8790
    185                    [36m-0.1970[0m  [32m0.9855[0m        0.0130        [31m0.0607[0m     +  3.1058




    186                    -0.2077  0.9842        [35m0.0122[0m        0.0661        2.3996
    187                    -0.2157  0.9835        [35m0.0118[0m        0.0692        2.8043
    188                    -0.2122  0.9840        [35m0.0116[0m        0.0667        2.3323
    189                    -0.2040  0.9851        0.0117        0.0621        2.5148
    190                    -0.2017  0.9854        0.0119        0.0612        2.4084
    191                    -0.2008  0.9855        0.0121        0.0607        2.7661
    192                    -0.2023  0.9853        0.0123        0.0615        2.1119
    193                    -0.2005  [32m0.9855[0m        0.0122        [31m0.0606[0m     +  1.5033




    194                    -0.2039  0.9852        0.0120        0.0618        1.9804
    195                    [36m-0.1927[0m  [32m0.9865[0m        0.0119        [31m0.0564[0m     +  1.5449




    196                    -0.2098  0.9843        [35m0.0115[0m        0.0655        1.5781
    197                    -0.2101  0.9845        [35m0.0110[0m        0.0648        1.9747
    198                    -0.1980  0.9860        [35m0.0107[0m        0.0586        1.5188
    199                    -0.2004  0.9857        0.0109        0.0599        1.4962
    200                    -0.1992  0.9859        0.0108        0.0591        1.5061
    201                    -0.2036  0.9854        0.0108        0.0611        1.5510
    202                    -0.1968  0.9862        [35m0.0107[0m        0.0578        1.5967
    203                    -0.2064  0.9851        [35m0.0106[0m        0.0625        2.9500
    204                    -0.1990  0.9859        [35m0.0104[0m        0.0589        2.7889
    205                    -0.2043  0.9853        [35m0.0102[0m        0.0613        1.9555
    206                    -0.2077  0.9850        0.0103        0.0628        1.7464
  



    227                    -0.2101  0.9846        0.0096        0.0646        1.5138
    228                    -0.2056  0.9854        0.0095        0.0611        1.4431
    229                    -0.2007  0.9859        0.0088        0.0591        1.5150
    230                    -0.1981  0.9861        0.0088        0.0581        1.5799
    231                    -0.1956  0.9863        0.0089        0.0571        1.4608
    232                    [36m-0.1899[0m  [32m0.9869[0m        0.0091        [31m0.0546[0m     +  1.7954




    233                    [36m-0.1797[0m  [32m0.9880[0m        0.0092        [31m0.0502[0m     +  2.0562




    234                    -0.2175  0.9832        0.0106        0.0702        2.2707
    235                    -0.1931  0.9866        0.0100        0.0562        3.5976
    236                    [36m-0.1730[0m  [32m0.9886[0m        0.0093        [31m0.0476[0m     +  2.4477




    237                    -0.1956  0.9857        0.0121        0.0597        2.2144
    238                    -0.2694  0.9731        0.0102        0.1126        1.7888
    239                    -0.2531  0.9777        [35m0.0075[0m        0.0934        1.7553
    240                    -0.2186  0.9833        [35m0.0074[0m        0.0699        2.3606
    241                    -0.2073  0.9848        0.0075        0.0635        2.7710
    242                    -0.1931  0.9865        0.0076        0.0564        2.2952
    243                    -0.1965  0.9861        0.0087        0.0582        2.6888
    244                    -0.2280  0.9819        0.0089        0.0755        1.8682
    245                    -0.2272  0.9823        0.0075        0.0741        1.7731
    246                    -0.2109  0.9845        [35m0.0072[0m        0.0650        1.8359
    247                    -0.2018  0.9855        0.0075        0.0608        1.5694
    248                    -0.2039  0.

    332                    -0.2544  0.9760        0.0065        0.1004        2.7716
    333                    -0.2470  0.9769        0.0066        0.0967        1.9965
    334                    -0.2629  0.9749        0.0068        0.1052        3.3481
    335                    -0.2387  0.9781        0.0069        0.0918        2.0092
    336                    -0.2730  0.9719        0.0075        0.1174        1.8359
    337                    -0.2718  0.9723        0.0065        0.1160        2.2220
    338                    -0.2773  0.9717        0.0061        0.1183        2.3157
    339                    -0.2650  0.9743        0.0060        0.1076        1.9866
    340                    -0.2674  0.9747        [35m0.0059[0m        0.1060        1.9309
    341                    -0.2487  0.9778        0.0059        0.0927        2.4410
    342                    -0.2549  0.9769        0.0061        0.0964        5.7992
    343                    -0.2458  0.9781        0.0063

<class 'skorch.regressor.NeuralNetRegressor'>[initialized](
  module_=SequenceDoubleAtt(
    (lstm1): LSTM(4, 256, batch_first=True)
    (lstm2): LSTM(256, 512, batch_first=True)
    (lin): Linear(in_features=512, out_features=1, bias=True)
    (lin_out): Linear(in_features=512, out_features=1, bias=True)
    (softmax): Softmax(dim=1)
    (tanh): Tanh()
  ),
)

In [137]:
y = datasets['^GSPC']["target"]
ext_dataset = datasets['^GSPC']["features"]['ext']
ext_scaler = StandardScaler()
y_scaler = StandardScaler()
# sa = load('ext_encoder_better')
ext_train, ext_val = ext_scaler.fit_transform(ext_dataset.iloc[train_dates].to_numpy(np.float32)), ext_scaler.transform(ext_dataset.iloc[val_dates].to_numpy(np.float32))
y_train, y_val = y_scaler.fit_transform(y.iloc[train_dates].to_numpy(np.float32)[...,None]), y_scaler.transform(y.iloc[val_dates].to_numpy(np.float32)[...,None])

In [138]:
ext_encoder = load('ext_encoder_better')
ext_sa_train, ext_sa_val = encode(ext_train, ext_encoder), encode(ext_val, ext_encoder)
ext_sa = np.concatenate((ext_sa_train, ext_sa_val))
y_data = np.concatenate((y_train, y_val))

In [175]:
x = days_group(ext_sa, n_days=n_days)
y = y_data[n_days:]
l1 = len(np.split(x, [(len(ext_sa)*3)//4])[0])
l2 = len(np.split(x, [(len(ext_sa)*3)//4])[1])
half = PredefinedSplit(np.concatenate((np.ones(l1)*-1,np.ones(l2))))
half_split =  CVSplit(cv=half, stratified=False, random_state=None)

In [189]:
#modello del paper che usa attention mechanism
#TODO testa se funziona e trova epoche
n_days = 5

batch_size = 20

lstm_sa_att = NeuralNetRegressor(
    module=SequenceDoubleAtt,
    optimizer=optim.Adam,
    batch_size=batch_size,
    max_epochs=5000, # trovato 280
    train_split=half_split,
    callbacks=[
        callbacks.EpochScoring('neg_mean_absolute_error', lower_is_better=False),
        callbacks.EpochScoring('r2', lower_is_better=False),
        callbacks.Checkpoint(monitor='valid_loss_best', f_pickle='lstm_sa_best')        
    ],
    
    module__nb_features=ext_sa_train.shape[1],
    module__hidden_size=256,
#     module__nb_layers= 5,
    optimizer__lr=0.0001,
#     optimizer__weight_decay=0,
#     optimizer__momentum=0.9
)

In [190]:
lstm_sa_att.fit(x, y)

  epoch    neg_mean_absolute_error       r2    train_loss    valid_loss    cp     dur
-------  -------------------------  -------  ------------  ------------  ----  ------
      1                    [36m-1.6688[0m  [32m-0.0260[0m        [35m2.4180[0m        [31m3.6369[0m     +  1.7605




      2                    [36m-1.6681[0m  [32m-0.0136[0m        [35m2.3650[0m        [31m3.5932[0m     +  4.3878




      3                    -1.6691  [32m-0.0085[0m        2.3658        [31m3.5751[0m     +  1.9298




      4                    -1.6844  [32m0.0009[0m        [35m2.2221[0m        [31m3.5419[0m     +  1.9328




      5                    -1.7041  -0.0046        [35m2.1991[0m        3.5611        6.2987
      6                    -1.6997  -0.0020        2.3462        3.5522        8.8208
      7                    -1.7113  -0.0103        [35m2.1271[0m        3.5815        2.9361
      8                    -1.7168  -0.0160        [35m2.1063[0m        3.6018        3.0138
      9                    -1.7200  -0.0198        [35m2.0855[0m        3.6152        3.0639
     10                    -1.7220  -0.0225        [35m2.0695[0m        3.6246        3.1023
     11                    -1.7235  -0.0244        [35m2.0577[0m        3.6316        3.0253
     12                    -1.7245  -0.0259        [35m2.0485[0m        3.6369        2.0844
     13                    -1.7253  -0.0271        [35m2.0413[0m        3.6412        3.2779
     14                    -1.7260  -0.0281        [35m2.0353[0m        3.6446        1.6490
     15                    -1.7265  -0.0289        [35m2.0

    111                    -1.9362  -0.5254        [35m0.4420[0m        5.4075        2.6091
    112                    -1.8911  -0.4448        [35m0.4093[0m        5.1217        2.1236
    113                    -1.8576  -0.3850        [35m0.3767[0m        4.9096        3.0635
    114                    -1.8328  -0.3410        [35m0.3473[0m        4.7539        3.9407
    115                    -1.8158  -0.3111        [35m0.3224[0m        4.6479        4.3636
    116                    -1.8071  -0.2947        [35m0.3023[0m        4.5897        3.0634
    117                    -1.8041  -0.2872        [35m0.2865[0m        4.5630        5.8181
    118                    -1.8030  -0.2834        [35m0.2738[0m        4.5497        2.1681
    119                    -1.8024  -0.2813        [35m0.2627[0m        4.5421        1.9686
    120                    -1.8016  -0.2798        [35m0.2520[0m        4.5367        2.1598
    121                    -1.8009  -0.2788       



    145                    [36m-1.5254[0m  [32m0.0213[0m        [35m0.1287[0m        [31m3.4695[0m     +  2.9307




    146                    [36m-1.5083[0m  [32m0.0388[0m        [35m0.1270[0m        [31m3.4074[0m     +  2.6738




    147                    [36m-1.4907[0m  [32m0.0568[0m        [35m0.1254[0m        [31m3.3436[0m     +  3.5827




    148                    [36m-1.4725[0m  [32m0.0753[0m        [35m0.1239[0m        [31m3.2782[0m     +  3.3243




    149                    [36m-1.4539[0m  [32m0.0942[0m        [35m0.1224[0m        [31m3.2110[0m     +  3.2159




    150                    [36m-1.4348[0m  [32m0.1136[0m        [35m0.1210[0m        [31m3.1421[0m     +  8.8647




    151                    [36m-1.4153[0m  [32m0.1335[0m        [35m0.1196[0m        [31m3.0716[0m     +  4.0132




    152                    [36m-1.3953[0m  [32m0.1538[0m        [35m0.1183[0m        [31m2.9998[0m     +  2.5906




    153                    [36m-1.3753[0m  [32m0.1744[0m        [35m0.1170[0m        [31m2.9267[0m     +  5.6567




    154                    [36m-1.3551[0m  [32m0.1953[0m        [35m0.1157[0m        [31m2.8528[0m     +  2.3637




    155                    [36m-1.3346[0m  [32m0.2163[0m        [35m0.1144[0m        [31m2.7782[0m     +  4.2274




    156                    [36m-1.3140[0m  [32m0.2374[0m        [35m0.1132[0m        [31m2.7035[0m     +  5.5291




    157                    [36m-1.2935[0m  [32m0.2584[0m        [35m0.1120[0m        [31m2.6289[0m     +  2.3942




    158                    [36m-1.2729[0m  [32m0.2793[0m        [35m0.1108[0m        [31m2.5549[0m     +  2.1221




    159                    [36m-1.2527[0m  [32m0.2998[0m        [35m0.1097[0m        [31m2.4820[0m     +  1.5744




    160                    [36m-1.2331[0m  [32m0.3200[0m        [35m0.1086[0m        [31m2.4106[0m     +  3.0172




    161                    [36m-1.2139[0m  [32m0.3396[0m        [35m0.1075[0m        [31m2.3412[0m     +  2.3278




    162                    [36m-1.1953[0m  [32m0.3585[0m        [35m0.1065[0m        [31m2.2741[0m     +  7.3936




    163                    [36m-1.1774[0m  [32m0.3767[0m        [35m0.1055[0m        [31m2.2096[0m     +  3.5170




    164                    [36m-1.1602[0m  [32m0.3941[0m        [35m0.1045[0m        [31m2.1479[0m     +  5.8734




    165                    [36m-1.1435[0m  [32m0.4107[0m        [35m0.1036[0m        [31m2.0891[0m     +  3.6509




    166                    [36m-1.1275[0m  [32m0.4265[0m        [35m0.1027[0m        [31m2.0332[0m     +  2.4599




    167                    [36m-1.1121[0m  [32m0.4415[0m        [35m0.1019[0m        [31m1.9800[0m     +  4.2995




    168                    [36m-1.0975[0m  [32m0.4557[0m        [35m0.1011[0m        [31m1.9294[0m     +  2.1871




    169                    [36m-1.0837[0m  [32m0.4693[0m        [35m0.1003[0m        [31m1.8813[0m     +  4.7267




    170                    [36m-1.0704[0m  [32m0.4822[0m        [35m0.0996[0m        [31m1.8355[0m     +  2.1853




    171                    [36m-1.0578[0m  [32m0.4946[0m        [35m0.0989[0m        [31m1.7918[0m     +  2.2063




    172                    [36m-1.0456[0m  [32m0.5064[0m        [35m0.0983[0m        [31m1.7499[0m     +  1.9819




    173                    [36m-1.0341[0m  [32m0.5177[0m        [35m0.0977[0m        [31m1.7098[0m     +  2.9338




    174                    [36m-1.0233[0m  [32m0.5286[0m        [35m0.0971[0m        [31m1.6712[0m     +  2.2288




    175                    [36m-1.0128[0m  [32m0.5391[0m        [35m0.0966[0m        [31m1.6340[0m     +  2.6458




    176                    [36m-1.0026[0m  [32m0.5492[0m        [35m0.0961[0m        [31m1.5979[0m     +  2.5601




    177                    [36m-0.9928[0m  [32m0.5591[0m        [35m0.0956[0m        [31m1.5630[0m     +  2.8733




    178                    [36m-0.9835[0m  [32m0.5686[0m        [35m0.0951[0m        [31m1.5291[0m     +  2.6862




    179                    [36m-0.9746[0m  [32m0.5779[0m        [35m0.0946[0m        [31m1.4962[0m     +  3.1009




    180                    [36m-0.9658[0m  [32m0.5870[0m        [35m0.0942[0m        [31m1.4640[0m     +  3.2044




    181                    [36m-0.9571[0m  [32m0.5958[0m        [35m0.0937[0m        [31m1.4328[0m     +  4.0747




    182                    [36m-0.9483[0m  [32m0.6045[0m        [35m0.0932[0m        [31m1.4022[0m     +  2.2707




    183                    [36m-0.9399[0m  [32m0.6127[0m        [35m0.0928[0m        [31m1.3728[0m     +  3.1062




    184                    [36m-0.9312[0m  [32m0.6212[0m        [35m0.0923[0m        [31m1.3428[0m     +  3.1798




    185                    [36m-0.9232[0m  [32m0.6288[0m        [35m0.0918[0m        [31m1.3160[0m     +  2.4848




    186                    [36m-0.9141[0m  [32m0.6373[0m        [35m0.0912[0m        [31m1.2858[0m     +  2.7197




    187                    [36m-0.9062[0m  [32m0.6446[0m        [35m0.0909[0m        [31m1.2600[0m     +  2.6737




    188                    [36m-0.8975[0m  [32m0.6523[0m        [35m0.0903[0m        [31m1.2325[0m     +  2.6982




    189                    [36m-0.8892[0m  [32m0.6597[0m        [35m0.0898[0m        [31m1.2065[0m     +  2.4512




    190                    [36m-0.8807[0m  [32m0.6670[0m        [35m0.0892[0m        [31m1.1804[0m     +  2.7914




    191                    [36m-0.8723[0m  [32m0.6742[0m        [35m0.0887[0m        [31m1.1549[0m     +  2.7679




    192                    [36m-0.8640[0m  [32m0.6813[0m        [35m0.0881[0m        [31m1.1297[0m     +  2.2747




    193                    [36m-0.8557[0m  [32m0.6883[0m        [35m0.0876[0m        [31m1.1049[0m     +  2.7187




    194                    [36m-0.8473[0m  [32m0.6952[0m        [35m0.0870[0m        [31m1.0805[0m     +  2.0060




    195                    [36m-0.8389[0m  [32m0.7020[0m        [35m0.0864[0m        [31m1.0564[0m     +  3.1572




    196                    [36m-0.8304[0m  [32m0.7087[0m        [35m0.0858[0m        [31m1.0328[0m     +  2.4148




    197                    [36m-0.8219[0m  [32m0.7152[0m        [35m0.0852[0m        [31m1.0095[0m     +  3.6135




    198                    [36m-0.8135[0m  [32m0.7217[0m        [35m0.0846[0m        [31m0.9867[0m     +  8.5043




    199                    [36m-0.8051[0m  [32m0.7280[0m        [35m0.0840[0m        [31m0.9644[0m     +  2.8227




    200                    [36m-0.7969[0m  [32m0.7341[0m        [35m0.0834[0m        [31m0.9425[0m     +  2.5481




    201                    [36m-0.7887[0m  [32m0.7402[0m        [35m0.0827[0m        [31m0.9210[0m     +  2.8621




    202                    [36m-0.7805[0m  [32m0.7461[0m        [35m0.0821[0m        [31m0.9002[0m     +  2.7375




    203                    [36m-0.7724[0m  [32m0.7518[0m        [35m0.0814[0m        [31m0.8798[0m     +  2.4144




    204                    [36m-0.7643[0m  [32m0.7574[0m        [35m0.0807[0m        [31m0.8600[0m     +  2.8296




    205                    [36m-0.7562[0m  [32m0.7628[0m        [35m0.0800[0m        [31m0.8408[0m     +  3.0915




    206                    [36m-0.7483[0m  [32m0.7681[0m        [35m0.0793[0m        [31m0.8222[0m     +  2.2342




    207                    [36m-0.7405[0m  [32m0.7731[0m        [35m0.0786[0m        [31m0.8042[0m     +  3.0819




    208                    [36m-0.7328[0m  [32m0.7780[0m        [35m0.0779[0m        [31m0.7869[0m     +  3.7126




    209                    [36m-0.7253[0m  [32m0.7827[0m        [35m0.0772[0m        [31m0.7702[0m     +  2.3991




    210                    [36m-0.7180[0m  [32m0.7872[0m        [35m0.0764[0m        [31m0.7542[0m     +  4.6501




    211                    [36m-0.7110[0m  [32m0.7916[0m        [35m0.0757[0m        [31m0.7389[0m     +  4.1303




    212                    [36m-0.7041[0m  [32m0.7957[0m        [35m0.0749[0m        [31m0.7243[0m     +  3.0447




    213                    [36m-0.6977[0m  [32m0.7996[0m        [35m0.0742[0m        [31m0.7103[0m     +  4.0447




    214                    [36m-0.6915[0m  [32m0.8034[0m        [35m0.0734[0m        [31m0.6970[0m     +  2.6332




    215                    [36m-0.6855[0m  [32m0.8069[0m        [35m0.0726[0m        [31m0.6845[0m     +  2.8879




    216                    [36m-0.6798[0m  [32m0.8103[0m        [35m0.0718[0m        [31m0.6725[0m     +  3.6780




    217                    [36m-0.6743[0m  [32m0.8135[0m        [35m0.0710[0m        [31m0.6612[0m     +  2.0819




    218                    [36m-0.6689[0m  [32m0.8165[0m        [35m0.0702[0m        [31m0.6504[0m     +  2.5970




    219                    [36m-0.6639[0m  [32m0.8193[0m        [35m0.0694[0m        [31m0.6405[0m     +  2.9635




    220                    [36m-0.6590[0m  [32m0.8221[0m        [35m0.0686[0m        [31m0.6308[0m     +  2.7122




    221                    [36m-0.6544[0m  [32m0.8245[0m        [35m0.0678[0m        [31m0.6221[0m     +  2.2285




    222                    [36m-0.6499[0m  [32m0.8270[0m        [35m0.0669[0m        [31m0.6134[0m     +  2.5493




    223                    [36m-0.6458[0m  [32m0.8291[0m        [35m0.0662[0m        [31m0.6059[0m     +  2.3156




    224                    [36m-0.6416[0m  [32m0.8313[0m        [35m0.0653[0m        [31m0.5981[0m     +  4.8136




    225                    [36m-0.6379[0m  [32m0.8331[0m        [35m0.0645[0m        [31m0.5915[0m     +  2.9760




    226                    [36m-0.6340[0m  [32m0.8350[0m        [35m0.0637[0m        [31m0.5848[0m     +  3.9020




    227                    [36m-0.6306[0m  [32m0.8367[0m        [35m0.0629[0m        [31m0.5789[0m     +  2.5149




    228                    [36m-0.6272[0m  [32m0.8383[0m        [35m0.0620[0m        [31m0.5731[0m     +  3.1034




    229                    [36m-0.6241[0m  [32m0.8398[0m        [35m0.0612[0m        [31m0.5680[0m     +  4.6184




    230                    [36m-0.6210[0m  [32m0.8412[0m        [35m0.0604[0m        [31m0.5630[0m     +  2.2606




    231                    [36m-0.6182[0m  [32m0.8424[0m        [35m0.0596[0m        [31m0.5585[0m     +  1.7561




    232                    [36m-0.6155[0m  [32m0.8436[0m        [35m0.0588[0m        [31m0.5543[0m     +  2.8050




    233                    [36m-0.6130[0m  [32m0.8447[0m        [35m0.0580[0m        [31m0.5504[0m     +  2.7346




    234                    [36m-0.6106[0m  [32m0.8458[0m        [35m0.0572[0m        [31m0.5468[0m     +  8.7208




    235                    [36m-0.6084[0m  [32m0.8467[0m        [35m0.0564[0m        [31m0.5435[0m     +  4.6408




    236                    [36m-0.6063[0m  [32m0.8475[0m        [35m0.0556[0m        [31m0.5405[0m     +  2.5927




    237                    [36m-0.6044[0m  [32m0.8483[0m        [35m0.0548[0m        [31m0.5378[0m     +  2.1118




    238                    [36m-0.6026[0m  [32m0.8490[0m        [35m0.0541[0m        [31m0.5353[0m     +  2.6754




    239                    [36m-0.6010[0m  [32m0.8496[0m        [35m0.0533[0m        [31m0.5331[0m     +  3.0020




    240                    [36m-0.5995[0m  [32m0.8502[0m        [35m0.0526[0m        [31m0.5312[0m     +  4.1734




    241                    [36m-0.5982[0m  [32m0.8506[0m        [35m0.0519[0m        [31m0.5295[0m     +  3.0348




    242                    [36m-0.5971[0m  [32m0.8510[0m        [35m0.0511[0m        [31m0.5281[0m     +  2.6312




    243                    [36m-0.5960[0m  [32m0.8514[0m        [35m0.0504[0m        [31m0.5269[0m     +  3.0408




    244                    [36m-0.5952[0m  [32m0.8516[0m        [35m0.0497[0m        [31m0.5260[0m     +  2.0105




    245                    [36m-0.5944[0m  [32m0.8518[0m        [35m0.0491[0m        [31m0.5253[0m     +  2.8620




    246                    [36m-0.5939[0m  [32m0.8520[0m        [35m0.0484[0m        [31m0.5248[0m     +  2.3040




    247                    [36m-0.5934[0m  [32m0.8520[0m        [35m0.0478[0m        [31m0.5246[0m     +  2.6463




    248                    [36m-0.5931[0m  0.8520        [35m0.0471[0m        0.5246        2.6172
    249                    [36m-0.5930[0m  0.8519        [35m0.0465[0m        0.5249        2.6995
    250                    [36m-0.5929[0m  0.8518        [35m0.0459[0m        0.5253        2.6020
    251                    -0.5930  0.8516        [35m0.0454[0m        0.5260        2.6231
    252                    -0.5932  0.8514        [35m0.0448[0m        0.5268        2.7238
    253                    -0.5935  0.8511        [35m0.0442[0m        0.5279        2.2650
    254                    -0.5938  0.8508        [35m0.0437[0m        0.5289        2.2327
    255                    -0.5944  0.8504        [35m0.0432[0m        0.5304        2.1889
    256                    -0.5948  0.8500        [35m0.0426[0m        0.5317        2.6439
    257                    -0.5955  0.8495        [35m0.0422[0m        0.5334        2.8881
    258                    -0.596



    307                    [36m-0.5714[0m  [32m0.8604[0m        [35m0.0276[0m        [31m0.4948[0m     +  1.9035




    308                    [36m-0.5592[0m  [32m0.8656[0m        [35m0.0266[0m        [31m0.4763[0m     +  3.7527




    309                    [36m-0.5480[0m  [32m0.8701[0m        [35m0.0259[0m        [31m0.4604[0m     +  3.3079




    310                    [36m-0.5394[0m  [32m0.8734[0m        [35m0.0256[0m        [31m0.4489[0m     +  3.6421




    311                    [36m-0.5344[0m  [32m0.8750[0m        [35m0.0255[0m        [31m0.4432[0m     +  2.6552




    312                    [36m-0.5331[0m  [32m0.8750[0m        0.0256        [31m0.4431[0m     +  2.1433




    313                    -0.5348  0.8737        0.0257        0.4479        2.6440
    314                    -0.5389  0.8712        0.0258        0.4566        1.9430
    315                    -0.5449  0.8677        0.0258        0.4689        1.5743
    316                    -0.5525  0.8631        0.0256        0.4852        1.9964
    317                    -0.5621  0.8570        [35m0.0255[0m        0.5071        2.0458
    318                    -0.5760  0.8485        [35m0.0254[0m        0.5371        1.9788
    319                    -0.5969  0.8362        0.0254        0.5806        2.3991
    320                    -0.6284  0.8174        0.0258        0.6472        1.9810
    321                    -0.6792  0.7862        0.0269        0.7578        2.7907
    322                    -0.7659  0.7294        0.0296        0.9593        3.1143
    323                    -0.9211  0.6192        0.0358        1.3498        3.2305
    324                    -1.1121  0.4615     



    338                    [36m-0.5313[0m  [32m0.8798[0m        0.0436        [31m0.4261[0m     +  6.8136




    339                    [36m-0.5246[0m  [32m0.8819[0m        0.0401        [31m0.4188[0m     +  3.0112




    340                    [36m-0.5195[0m  [32m0.8833[0m        0.0373        [31m0.4138[0m     +  2.9849




    341                    [36m-0.5151[0m  [32m0.8844[0m        0.0349        [31m0.4098[0m     +  2.6770




    342                    [36m-0.5111[0m  [32m0.8853[0m        0.0328        [31m0.4064[0m     +  2.8915




    343                    [36m-0.5075[0m  [32m0.8861[0m        0.0311        [31m0.4038[0m     +  3.0894




    344                    [36m-0.5050[0m  [32m0.8866[0m        0.0297        [31m0.4021[0m     +  2.3491




    345                    [36m-0.5036[0m  [32m0.8867[0m        0.0286        [31m0.4016[0m     +  2.9328




    346                    [36m-0.5032[0m  0.8865        0.0277        0.4023        2.3028
    347                    -0.5037  0.8861        0.0269        0.4038        1.7697
    348                    -0.5048  0.8854        0.0260        0.4062        1.7396
    349                    -0.5064  0.8845        [35m0.0252[0m        0.4095        2.8681
    350                    -0.5086  0.8832        [35m0.0244[0m        0.4139        2.9859
    351                    -0.5116  0.8816        [35m0.0237[0m        0.4196        2.5455
    352                    -0.5156  0.8795        [35m0.0230[0m        0.4270        2.5289
    353                    -0.5207  0.8768        [35m0.0224[0m        0.4367        2.1662
    354                    -0.5276  0.8732        [35m0.0220[0m        0.4493        2.1928
    355                    -0.5365  0.8686        [35m0.0216[0m        0.4657        1.9032
    356                    -0.5478  0.8627        [35m0.0215[0m        0.486

    441                    -1.0029  0.5866        0.0336        1.4656        3.0728
    442                    -1.0646  0.5302        0.0503        1.6654        3.0357
    443                    -1.0338  0.5424        0.0616        1.6223        3.5262
    444                    -0.9419  0.6041        0.0461        1.4035        2.2696
    445                    -0.7789  0.7207        0.0235        0.9900        2.9210
    446                    -0.6543  0.8125        0.0169        0.6645        2.6691
    447                    -0.6028  0.8430        0.0219        0.5567        2.8409
    448                    -0.5950  0.8462        0.0260        0.5452        2.7441
    449                    -0.6051  0.8406        0.0281        0.5652        2.8435
    450                    -0.6153  0.8341        0.0285        0.5883        4.3698
    451                    -0.6175  0.8289        0.0280        0.6064        3.1674
    452                    -0.6175  0.8263        0.0249        0

    538                    -0.8925  0.6685        0.0349        1.1752        1.6820
    539                    -0.8456  0.6911        0.0437        1.0951        3.5848
    540                    -0.7657  0.7344        0.0422        0.9414        8.1838
    541                    -0.6778  0.7881        0.0289        0.7513        2.6563
    542                    -0.5952  0.8379        0.0183        0.5746        2.7858
    543                    -0.5474  0.8646        0.0166        0.4799        3.9686
    544                    -0.5395  0.8678        0.0197        0.4685        2.5278
    545                    -0.5517  0.8621        0.0228        0.4887        2.0284
    546                    -0.5653  0.8557        0.0234        0.5114        1.9260
    547                    -0.5765  0.8498        0.0230        0.5323        2.0647
    548                    -0.5819  0.8451        0.0222        0.5491        2.7173
    549                    -0.5830  0.8411        0.0211        0

    635                    -0.5586  0.8565        0.0179        0.5087        3.4610
    636                    -0.5705  0.8504        0.0168        0.5303        2.5888
    637                    -0.5819  0.8427        0.0177        0.5576        2.3067
    638                    -0.5965  0.8362        0.0178        0.5805        2.1759
    639                    -0.6074  0.8290        0.0186        0.6062        2.7534
    640                    -0.6162  0.8244        0.0184        0.6225        2.9563
    641                    -0.6205  0.8197        0.0182        0.6392        2.7366
    642                    -0.6199  0.8160        0.0172        0.6521        4.8080
    643                    -0.6488  0.8016        0.0169        0.7035        3.0938
    644                    -0.7129  0.7740        0.0170        0.8012        2.0109
    645                    -0.7059  0.7764        0.0200        0.7928        2.3253
    646                    -0.6777  0.7885        0.0230        0

<class 'skorch.regressor.NeuralNetRegressor'>[initialized](
  module_=SequenceDoubleAtt(
    (lstm1): LSTM(12, 256, batch_first=True)
    (lstm2): LSTM(256, 512, batch_first=True)
    (lin): Linear(in_features=512, out_features=1, bias=True)
    (lin_out): Linear(in_features=512, out_features=1, bias=True)
    (softmax): Softmax(dim=1)
    (tanh): Tanh()
  ),
)

## Valutazione del training

In [None]:
# lstm_test = lstm_sa_att_d

In [None]:
# lstm_test.history[10].keys()

In [None]:
# lstm = lstm_test
# valid_losses = lstm.history[:, 'valid_loss']
# train_losses = lstm.history[:, 'train_loss']
# plt.figure(figsize=(12,7))
# plt.plot(valid_losses, label='valid_loss')
# plt.plot(train_losses, label='train_loss')
# # plt.xticks(np.arange(len(valid_losses)+1, step=50))
# plt.legend()

# Test
Dopo aver trovato i parametri migliori i modelli vengono testati facendo training sulla metà dei dati e predizioni sull'altra metà

In [None]:
datasets.keys()

In [None]:
# models = ['lstm_sa_d', 'lstm_moro', 'lstm_sa_d_1000', 'lstm_att_d_600', 'lstm_sa_att_sa_1500']

In [None]:
models = [('lstm_moro', 'open'), ('lstm_sa_d_1000','sa_ohlcv'), ('lstm_att_d_600','ohlcv'), ('lstm_sa_att_1500', 'ohlcv')]

In [None]:
model = models[0]

In [None]:
lstm_test = load(model[0])

In [None]:
datasets["^GSPC"]["features"].keys()

In [None]:
market = "^DJI"
feature_set = model[1]

In [None]:
datasets[market]["original"]["Open"].plot()

In [None]:
set_dataset(datasets[market]["features"][feature_set])
opn = datasets[market]["original"]["Open"]
close = datasets[market]["original"]["Close"]

In [None]:
# i primi 51 giorni non vengonon considerati poichè servono per il calcolo dell'inversa dello z_index
Y_original = close.copy()
Y_preds = Y_original.copy()

for i in range(51, len(Y_val)): 
    pred = lstm_test.predict(X_val[i-10:i].to_numpy().astype(np.float32))[0] #in input vengono dati 10 gionri
    #denormalization
    previous_serie = Y_original[:i]
    Y_preds[i] = z_score_inv(previous_serie[-50:], pred) #il z score viene calcolato su 50 gionri
Y_preds = Y_preds[51:]
Y_original = Y_original.reindex_like(Y_preds, copy=False)

In [None]:
plt.figure(figsize=(15,7))
Y_preds.plot(label='close_pred')
Y_original.plot(label='close_true')
plt.legend()

In [None]:
Y_original.max()

In [None]:
sklearn.metrics.mean_squared_error(Y_original, Y_preds)

In [None]:
score = pd.Series(index=["MAPE","RMSPE", "R2", "ROI", "ROI Ideal", "ROI vs ideal"], dtype=np.float32)
score["MAPE"] = sklearn.metrics.mean_absolute_error(Y_original, Y_preds)/Y_preds.mean()
score["RMSPE"] = np.sqrt(sklearn.metrics.mean_squared_error(Y_original, Y_preds))/Y_original.mean()
score["R2"] = sklearn.metrics.r2_score(Y_original, Y_preds)
score["ROI"] = roi(Y_original,Y_preds, opn)
score["ROI Ideal"] = roi(Y_original,Y_original, opn)
score["ROI vs ideal"] = abs(roi(Y_original,Y_preds, opn) - roi(Y_original,Y_original, opn))

In [None]:
score

In [None]:
scores = pd.DataFrame(index=score.index)


In [None]:
show_transaction = False
initial_capital = 10000
capital = initial_capital
cap_history = Y_preds.copy()
holding=False
opns = opn.reindex_like(Y_preds).copy()
buy_history = opn.reindex_like(Y_preds).copy()
sell_history = opn.reindex_like(Y_preds).copy()
for date in Y_preds.index:
    buy_history[date] = None
    sell_history[date] = None
    pred = Y_preds[date]
    true = Y_original[date]
    buy = Y_preds[date] - opns[date] > 0
    delta_true = Y_original[date] - opns[date]
    if(buy and not holding): # buy
        if(show_transaction):
            capital -= opns[date]
        holding = True
        buy_history[date] = capital
    elif(not buy and holding): #sell
        if(show_transaction):
            capital += opns[date]
        holding = False
        sell_history[date] = capital
    if(holding):
        capital+= delta_true
    cap_history[date] = capital
if(holding and show_transaction):
    capital += opns[-1]
    holding = False
abs_gain = capital - initial_capital
perc_gain = (abs_gain) / initial_capital
print(f"guadagno assoluto: {abs_gain}")
print(f"guadagno percentuale: {perc_gain}")

In [None]:
plt.figure(figsize=(15,7))
plt.scatter(buy_history.index, buy_history, c='green', label='buy', linewidths=0.001)
plt.scatter(sell_history.index, sell_history, c= 'red', label='sell', linewidths=0.001)
cap_history.plot(label='capital')
plt.legend()