In [1]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt 
import seaborn as sns

from tqdm.auto import tqdm

from numpy.random import seed
import numpy.matlib
seed(42)
import tensorflow as tf
tf.random.set_seed(42)
from tensorflow import keras
from keras import backend as K

from sklearn import model_selection
from sklearn.preprocessing import StandardScaler, MinMaxScaler,LabelEncoder, StandardScaler
from sklearn.preprocessing import QuantileTransformer

def root_mean_squared_per_error(y_true, y_pred):
         return K.sqrt(K.mean(K.square( (y_true - y_pred)/ y_true)))
    
es = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', patience=20, verbose=0,
    mode='min',restore_best_weights=True)

plateau = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', factor=0.2, patience=7, verbose=0,
    mode='min')

In [2]:
pd.set_option('display.max_columns',None)

In [3]:
train = pd.read_pickle('./input/train_248.pkl')
test = pd.read_pickle('./input/test_247.pkl')

In [4]:
colNames = [col for col in list(train.columns)
            if col not in {"stock_id", "time_id", "target", "row_id"}]
len(colNames)

244

In [5]:
# kfold based on the knn++ algorithm
out_train = pd.read_csv('./input/train.csv')
out_train = out_train.pivot(index='time_id', columns='stock_id', values='target')
#out_train[out_train.isna().any(axis=1)]
out_train = out_train.fillna(out_train.mean())
out_train.head()

# code to add the just the read data after first execution

# data separation based on knn ++
nfolds = 5 # number of folds
index = []
totDist = []
values = []
# generates a matriz with the values of 
mat = out_train.values

scaler = MinMaxScaler(feature_range=(-1, 1))
mat = scaler.fit_transform(mat)

nind = int(mat.shape[0]/nfolds) # number of individuals

# adds index in the last column
mat = np.c_[mat,np.arange(mat.shape[0])]


lineNumber = np.random.choice(np.array(mat.shape[0]), size=nfolds, replace=False)

lineNumber = np.sort(lineNumber)[::-1]

for n in range(nfolds):
    totDist.append(np.zeros(mat.shape[0]-nfolds))

# saves index
for n in range(nfolds):
    values.append([lineNumber[n]])    

s=[]
for n in range(nfolds):
    s.append(mat[lineNumber[n],:])
    mat = np.delete(mat, obj=lineNumber[n], axis=0)

for n in range(nind-1):    

    luck = np.random.uniform(0,1,nfolds)
    
    for cycle in range(nfolds):
         # saves the values of index           

        s[cycle] = np.matlib.repmat(s[cycle], mat.shape[0], 1)

        sumDist = np.sum( (mat[:,:-1] - s[cycle][:,:-1])**2 , axis=1)   
        totDist[cycle] += sumDist        
                
        # probabilities
        f = totDist[cycle]/np.sum(totDist[cycle]) # normalizing the totdist
        j = 0
        kn = 0
        for val in f:
            j += val        
            if (j > luck[cycle]): # the column was selected
                break
            kn +=1
        lineNumber[cycle] = kn
        
        # delete line of the value added    
        for n_iter in range(nfolds):
            totDist[n_iter] = np.delete(totDist[n_iter],obj=lineNumber[cycle], axis=0)
            j= 0
        
        s[cycle] = mat[lineNumber[cycle],:]
        values[cycle].append(int(mat[lineNumber[cycle],-1]))
        mat = np.delete(mat, obj=lineNumber[cycle], axis=0)

for n_mod in range(nfolds):
    values[n_mod] = out_train.index[values[n_mod]]

In [6]:
out_train

stock_id,0,1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,50,51,52,53,55,56,58,59,60,61,62,63,64,66,67,68,69,70,72,73,74,75,76,77,78,80,81,82,83,84,85,86,87,88,89,90,93,94,95,96,97,98,99,100,101,102,103,104,105,107,108,109,110,111,112,113,114,115,116,118,119,120,122,123,124,125,126
time_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1,Unnamed: 82_level_1,Unnamed: 83_level_1,Unnamed: 84_level_1,Unnamed: 85_level_1,Unnamed: 86_level_1,Unnamed: 87_level_1,Unnamed: 88_level_1,Unnamed: 89_level_1,Unnamed: 90_level_1,Unnamed: 91_level_1,Unnamed: 92_level_1,Unnamed: 93_level_1,Unnamed: 94_level_1,Unnamed: 95_level_1,Unnamed: 96_level_1,Unnamed: 97_level_1,Unnamed: 98_level_1,Unnamed: 99_level_1,Unnamed: 100_level_1,Unnamed: 101_level_1,Unnamed: 102_level_1,Unnamed: 103_level_1,Unnamed: 104_level_1,Unnamed: 105_level_1,Unnamed: 106_level_1,Unnamed: 107_level_1,Unnamed: 108_level_1,Unnamed: 109_level_1,Unnamed: 110_level_1,Unnamed: 111_level_1,Unnamed: 112_level_1
5,0.004136,0.006340,0.001848,0.005300,0.004468,0.006234,0.007651,0.003624,0.010036,0.007291,0.005707,0.005918,0.002148,0.002627,0.005091,0.005485,0.004615,0.016663,0.002883,0.002915,0.002218,0.006377,0.006338,0.004405,0.008170,0.003655,0.002963,0.003233,0.004113,0.002649,0.006633,0.002586,0.003092,0.004755,0.008078,0.008372,0.002707,0.004864,0.002184,0.004094,0.002090,0.004551,0.001848,0.002206,0.003171,0.003947,0.002345,0.004885,0.002594,0.007041,0.005736,0.003274,0.004072,0.003928,0.004460,0.008681,0.004841,0.001808,0.004759,0.002675,0.002242,0.002625,0.002841,0.005403,0.005920,0.002803,0.005253,0.003856,0.002477,0.006344,0.007048,0.005846,0.003187,0.006131,0.005706,0.003294,0.003858,0.006019,0.004970,0.003590,0.008780,0.002126,0.005341,0.004152,0.005130,0.006254,0.003832,0.001742,0.004580,0.006088,0.005514,0.005007,0.003218,0.002995,0.004273,0.003129,0.004710,0.007743,0.002893,0.005967,0.006204,0.003708,0.004267,0.007944,0.003336,0.002571,0.003035,0.004862,0.002942,0.004112,0.001919,0.008067
11,0.001445,0.002099,0.000806,0.002774,0.001852,0.002562,0.004670,0.002458,0.002291,0.002529,0.002352,0.001634,0.000640,0.001072,0.001826,0.003172,0.002474,0.003902,0.003411,0.001601,0.001622,0.002445,0.002464,0.001417,0.003665,0.001991,0.000663,0.001917,0.000956,0.000688,0.007245,0.000924,0.001146,0.001555,0.004299,0.002344,0.001121,0.004665,0.000652,0.001330,0.000472,0.001861,0.000728,0.000760,0.001998,0.001296,0.001610,0.001958,0.001335,0.003729,0.002891,0.001223,0.001489,0.001607,0.001718,0.003743,0.001663,0.001230,0.001405,0.001449,0.001769,0.001133,0.002418,0.001837,0.002326,0.002262,0.003308,0.002284,0.001251,0.003299,0.004387,0.003544,0.001672,0.001918,0.002037,0.001285,0.001493,0.002496,0.001209,0.001609,0.004821,0.001003,0.002341,0.001560,0.002393,0.002656,0.001619,0.001013,0.002152,0.002691,0.001486,0.001788,0.001224,0.000965,0.001204,0.001339,0.002119,0.002275,0.001195,0.003635,0.001859,0.002443,0.001382,0.002469,0.002030,0.000839,0.001271,0.002095,0.001518,0.001891,0.001123,0.003965
16,0.002168,0.002456,0.001581,0.002986,0.002213,0.003253,0.004303,0.002178,0.001841,0.003299,0.002363,0.003923,0.001702,0.002114,0.002539,0.002653,0.002831,0.003806,0.002137,0.001902,0.001629,0.002742,0.002465,0.002162,0.006047,0.002152,0.001239,0.002344,0.002127,0.002431,0.003436,0.001612,0.002205,0.001976,0.004228,0.002748,0.001477,0.004131,0.001774,0.002267,0.001275,0.002703,0.001308,0.001800,0.002026,0.002777,0.001783,0.002092,0.001784,0.003094,0.002886,0.003132,0.002563,0.002423,0.002020,0.003401,0.001461,0.001461,0.002030,0.002063,0.002289,0.001413,0.002510,0.001670,0.002804,0.003364,0.003708,0.002256,0.002166,0.002735,0.002815,0.003735,0.002981,0.001996,0.002094,0.002147,0.001559,0.002654,0.002217,0.002112,0.002367,0.001688,0.001896,0.002378,0.003085,0.003712,0.003618,0.001366,0.002965,0.002639,0.003431,0.002939,0.002458,0.001799,0.001518,0.002630,0.002738,0.002752,0.001753,0.003406,0.002687,0.002704,0.001949,0.002195,0.003410,0.002569,0.002137,0.001893,0.002131,0.002428,0.001548,0.003161
31,0.002195,0.002807,0.001599,0.004437,0.002256,0.003072,0.005401,0.002149,0.003997,0.003696,0.002341,0.003581,0.001335,0.001802,0.002007,0.001897,0.002201,0.005206,0.001827,0.002071,0.001471,0.002550,0.002127,0.001638,0.004563,0.001634,0.000883,0.001073,0.003748,0.001881,0.002581,0.001369,0.002386,0.002128,0.002666,0.002927,0.001557,0.001945,0.000841,0.001757,0.000410,0.002066,0.001285,0.001337,0.001742,0.002524,0.001936,0.002696,0.001466,0.004116,0.003129,0.002059,0.002323,0.002589,0.003027,0.002675,0.003306,0.000974,0.001914,0.001585,0.002037,0.001142,0.001467,0.002201,0.002163,0.001876,0.006501,0.003022,0.002130,0.004356,0.008159,0.005878,0.002459,0.005060,0.002561,0.002387,0.001875,0.002714,0.004667,0.006252,0.005138,0.001545,0.001810,0.002870,0.002384,0.005743,0.006027,0.001213,0.002466,0.004843,0.002187,0.004749,0.002094,0.001218,0.002750,0.001899,0.002732,0.005196,0.001412,0.003432,0.002189,0.003420,0.002035,0.002298,0.005674,0.002115,0.001734,0.003509,0.001078,0.002182,0.001251,0.003593
62,0.001747,0.004312,0.001503,0.003408,0.002102,0.002824,0.004562,0.002203,0.003923,0.003689,0.002007,0.003150,0.001354,0.001395,0.002253,0.002184,0.002090,0.006724,0.002255,0.001590,0.001128,0.002104,0.001871,0.001524,0.003896,0.001603,0.000847,0.002438,0.001573,0.001463,0.005334,0.001498,0.002589,0.002184,0.004798,0.002760,0.001070,0.001903,0.001509,0.002498,0.000713,0.001838,0.001425,0.001174,0.002942,0.001907,0.001419,0.002591,0.001443,0.003789,0.002710,0.001349,0.002398,0.001657,0.002753,0.002519,0.001678,0.001687,0.002922,0.001632,0.001529,0.001559,0.002257,0.002294,0.002310,0.002096,0.003329,0.002309,0.001980,0.002286,0.007238,0.003116,0.001957,0.002427,0.003182,0.002083,0.002310,0.002384,0.003086,0.001601,0.003013,0.001658,0.002396,0.001680,0.002936,0.005345,0.002694,0.001141,0.002891,0.004148,0.002681,0.004005,0.001908,0.001624,0.002349,0.001622,0.002263,0.002916,0.001261,0.005154,0.001984,0.003085,0.002459,0.003704,0.003914,0.001549,0.001470,0.002151,0.001253,0.002382,0.001324,0.003496
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
32751,0.002611,0.003741,0.001662,0.005943,0.002007,0.003966,0.004974,0.002287,0.005161,0.006234,0.002707,0.004193,0.001993,0.001808,0.002575,0.003100,0.002470,0.006906,0.002032,0.001679,0.001504,0.004037,0.003221,0.002402,0.005622,0.002317,0.001512,0.002650,0.001869,0.001549,0.002863,0.001807,0.002185,0.002392,0.004971,0.004327,0.001767,0.003705,0.001134,0.002662,0.001038,0.002314,0.001275,0.001659,0.002265,0.002509,0.002246,0.002696,0.001514,0.004177,0.003185,0.001368,0.001758,0.002290,0.003386,0.005161,0.002242,0.001982,0.002307,0.001757,0.002219,0.001752,0.001716,0.002560,0.003064,0.003361,0.003985,0.003605,0.001731,0.004840,0.006783,0.004629,0.001924,0.003357,0.004009,0.002155,0.002334,0.002931,0.003766,0.002547,0.002625,0.001747,0.003722,0.002132,0.004332,0.005120,0.003820,0.001279,0.002805,0.004147,0.002310,0.003832,0.002544,0.001852,0.002212,0.001770,0.003840,0.003401,0.001939,0.003830,0.003407,0.002773,0.001733,0.003795,0.002121,0.001744,0.001705,0.002850,0.001643,0.002936,0.001103,0.003461
32753,0.001190,0.012414,0.000925,0.002845,0.002449,0.001542,0.003272,0.001529,0.001301,0.002501,0.002446,0.002887,0.000891,0.001300,0.001677,0.001239,0.002465,0.003612,0.001200,0.001027,0.001392,0.001416,0.001928,0.001690,0.004967,0.000909,0.000548,0.001463,0.002934,0.001129,0.004136,0.001106,0.001838,0.002079,0.004146,0.002403,0.001579,0.002690,0.000848,0.001740,0.000424,0.001297,0.001287,0.001764,0.001682,0.002099,0.001147,0.001786,0.002067,0.002082,0.001719,0.000827,0.001002,0.002041,0.001897,0.002716,0.001218,0.001109,0.001179,0.001328,0.001589,0.001141,0.002892,0.001853,0.001361,0.002518,0.003705,0.001426,0.001503,0.002148,0.005965,0.003040,0.001077,0.001879,0.002469,0.001729,0.002954,0.002198,0.001596,0.001596,0.002077,0.000788,0.001868,0.001653,0.002848,0.003224,0.001746,0.000890,0.002324,0.002891,0.001285,0.001638,0.001365,0.001690,0.001531,0.001414,0.001756,0.003088,0.000845,0.002246,0.002715,0.001825,0.001252,0.001612,0.001842,0.001518,0.001436,0.001079,0.002507,0.001683,0.001046,0.003113
32758,0.004264,0.002868,0.001188,0.003415,0.002648,0.003377,0.004154,0.001977,0.003751,0.003578,0.001731,0.001648,0.001502,0.001971,0.002467,0.002527,0.002946,0.004717,0.001980,0.002177,0.001370,0.002761,0.002522,0.002282,0.003763,0.001716,0.001517,0.002291,0.003160,0.001858,0.004737,0.001750,0.001906,0.002081,0.002441,0.003734,0.001790,0.003325,0.001242,0.002002,0.001159,0.002897,0.001258,0.001395,0.003143,0.002681,0.001617,0.001966,0.002175,0.002689,0.003592,0.003744,0.001768,0.002059,0.002773,0.002701,0.001717,0.001651,0.003519,0.002183,0.001458,0.001792,0.002066,0.001842,0.003176,0.001646,0.003755,0.002649,0.001490,0.003044,0.003385,0.002523,0.003385,0.005232,0.002490,0.001385,0.001912,0.002937,0.003294,0.001468,0.003772,0.001541,0.002849,0.002552,0.002858,0.003167,0.002721,0.001782,0.002261,0.002803,0.002424,0.002410,0.002489,0.001622,0.002278,0.001570,0.002662,0.002956,0.001459,0.003342,0.003167,0.002963,0.002461,0.004024,0.003248,0.001840,0.001720,0.002696,0.001442,0.002811,0.001196,0.004070
32763,0.004352,0.004902,0.004879,0.003664,0.005086,0.006443,0.005483,0.002998,0.002819,0.004607,0.003916,0.004779,0.004058,0.004783,0.003679,0.004290,0.006927,0.007308,0.002909,0.005200,0.004799,0.004306,0.005759,0.003108,0.006692,0.003350,0.002094,0.004238,0.004393,0.003673,0.006403,0.002389,0.005373,0.003238,0.007360,0.006334,0.005689,0.012330,0.003100,0.003566,0.002454,0.003132,0.003897,0.003937,0.005608,0.003272,0.004716,0.003858,0.003502,0.004334,0.003989,0.006536,0.005717,0.003785,0.004153,0.005235,0.002285,0.005673,0.003226,0.002717,0.004438,0.002561,0.004907,0.002706,0.005159,0.005011,0.005374,0.003908,0.002781,0.004636,0.006254,0.003983,0.002689,0.007657,0.004479,0.004735,0.005269,0.006413,0.003932,0.004347,0.006386,0.006148,0.003562,0.004213,0.004951,0.003609,0.005338,0.003100,0.004045,0.003794,0.005150,0.006565,0.003879,0.003480,0.003955,0.002745,0.004589,0.004349,0.003218,0.005684,0.004958,0.004071,0.004302,0.003970,0.006995,0.004559,0.003190,0.002388,0.003236,0.003679,0.005127,0.003357


In [7]:
qt_train = []
train_nn=train[colNames].copy()
test_nn=test[colNames].copy()
for col in tqdm(colNames,total=len(colNames)):
    #print(col)
    qt = QuantileTransformer(random_state=21,n_quantiles=2000, output_distribution='normal')  # 将每个特征缩放在同样的范围或分布下。通过执行一个秩转换能够使异常的分布平滑化，并且能够比缩放更少的收到离群值的影响。
    train_nn[col] = qt.fit_transform(train_nn[[col]])
    test_nn[col] = qt.transform(test_nn[[col]])    
    qt_train.append(qt)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=244.0), HTML(value='')))




In [8]:
train_nn[['stock_id','time_id','target']] = train[['stock_id','time_id','target']]
test_nn[['stock_id','time_id']] = test[['stock_id','time_id']]
train_nn

Unnamed: 0,wap1_sum,wap1_std,wap2_sum,wap2_std,wap3_sum,wap3_std,wap4_sum,wap4_std,log_return1_realized_volatility,log_return2_realized_volatility,log_return3_realized_volatility,log_return4_realized_volatility,wap_balance_sum,wap_balance_amax,price_spread_sum,price_spread_amax,price_spread2_sum,price_spread2_amax,bid_spread_sum,bid_spread_amax,ask_spread_sum,ask_spread_amax,total_volume_sum,total_volume_amax,volume_imbalance_sum,volume_imbalance_amax,bid_ask_spread_sum,bid_ask_spread_amax,log_return1_realized_volatility_500,log_return2_realized_volatility_500,log_return3_realized_volatility_500,log_return4_realized_volatility_500,log_return1_realized_volatility_400,log_return2_realized_volatility_400,log_return3_realized_volatility_400,log_return4_realized_volatility_400,log_return1_realized_volatility_300,log_return2_realized_volatility_300,log_return3_realized_volatility_300,log_return4_realized_volatility_300,log_return1_realized_volatility_200,log_return2_realized_volatility_200,log_return3_realized_volatility_200,log_return4_realized_volatility_200,log_return1_realized_volatility_100,log_return2_realized_volatility_100,log_return3_realized_volatility_100,log_return4_realized_volatility_100,trade_log_return_realized_volatility,trade_seconds_in_bucket_count_unique,trade_size_sum,trade_size_amax,trade_size_amin,trade_order_count_sum,trade_order_count_amax,trade_amount_sum,trade_amount_amax,trade_amount_amin,trade_tendency,trade_f_max,trade_f_min,trade_df_max,trade_df_min,trade_abs_diff,trade_energy,trade_iqr_p,trade_abs_diff_v,trade_energy_v,trade_iqr_p_v,trade_log_return_realized_volatility_500,trade_seconds_in_bucket_count_unique_500,trade_size_sum_500,trade_order_count_sum_500,trade_log_return_realized_volatility_400,trade_seconds_in_bucket_count_unique_400,trade_size_sum_400,trade_order_count_sum_400,trade_log_return_realized_volatility_300,trade_seconds_in_bucket_count_unique_300,trade_size_sum_300,trade_order_count_sum_300,trade_log_return_realized_volatility_200,trade_seconds_in_bucket_count_unique_200,trade_size_sum_200,trade_order_count_sum_200,trade_log_return_realized_volatility_100,trade_seconds_in_bucket_count_unique_100,trade_size_sum_100,trade_order_count_sum_100,size_tau,size_tau2,size_tau_200,size_tau2_200,size_tau_300,size_tau2_300,size_tau_400,size_tau2_400,size_tau2_d,log_return1_realized_volatility_mean_stock,log_return1_realized_volatility_std_stock,log_return1_realized_volatility_max_stock,log_return1_realized_volatility_min_stock,log_return2_realized_volatility_mean_stock,log_return2_realized_volatility_std_stock,log_return2_realized_volatility_max_stock,log_return2_realized_volatility_min_stock,log_return1_realized_volatility_400_mean_stock,log_return1_realized_volatility_400_std_stock,log_return1_realized_volatility_400_max_stock,log_return1_realized_volatility_400_min_stock,log_return2_realized_volatility_400_mean_stock,log_return2_realized_volatility_400_std_stock,log_return2_realized_volatility_400_max_stock,log_return2_realized_volatility_400_min_stock,log_return1_realized_volatility_300_mean_stock,log_return1_realized_volatility_300_std_stock,log_return1_realized_volatility_300_max_stock,log_return1_realized_volatility_300_min_stock,log_return2_realized_volatility_300_mean_stock,log_return2_realized_volatility_300_std_stock,log_return2_realized_volatility_300_max_stock,log_return2_realized_volatility_300_min_stock,log_return1_realized_volatility_200_mean_stock,log_return1_realized_volatility_200_std_stock,log_return1_realized_volatility_200_max_stock,log_return1_realized_volatility_200_min_stock,log_return2_realized_volatility_200_mean_stock,log_return2_realized_volatility_200_std_stock,log_return2_realized_volatility_200_max_stock,log_return2_realized_volatility_200_min_stock,trade_log_return_realized_volatility_mean_stock,trade_log_return_realized_volatility_std_stock,trade_log_return_realized_volatility_max_stock,trade_log_return_realized_volatility_min_stock,trade_log_return_realized_volatility_400_mean_stock,trade_log_return_realized_volatility_400_std_stock,trade_log_return_realized_volatility_400_max_stock,trade_log_return_realized_volatility_400_min_stock,trade_log_return_realized_volatility_300_mean_stock,trade_log_return_realized_volatility_300_std_stock,trade_log_return_realized_volatility_300_max_stock,trade_log_return_realized_volatility_300_min_stock,trade_log_return_realized_volatility_200_mean_stock,trade_log_return_realized_volatility_200_std_stock,trade_log_return_realized_volatility_200_max_stock,trade_log_return_realized_volatility_200_min_stock,log_return1_realized_volatility_mean_time,log_return1_realized_volatility_std_time,log_return1_realized_volatility_max_time,log_return1_realized_volatility_min_time,log_return2_realized_volatility_mean_time,log_return2_realized_volatility_std_time,log_return2_realized_volatility_max_time,log_return2_realized_volatility_min_time,log_return1_realized_volatility_400_mean_time,log_return1_realized_volatility_400_std_time,log_return1_realized_volatility_400_max_time,log_return1_realized_volatility_400_min_time,log_return2_realized_volatility_400_mean_time,log_return2_realized_volatility_400_std_time,log_return2_realized_volatility_400_max_time,log_return2_realized_volatility_400_min_time,log_return1_realized_volatility_300_mean_time,log_return1_realized_volatility_300_std_time,log_return1_realized_volatility_300_max_time,log_return1_realized_volatility_300_min_time,log_return2_realized_volatility_300_mean_time,log_return2_realized_volatility_300_std_time,log_return2_realized_volatility_300_max_time,log_return2_realized_volatility_300_min_time,log_return1_realized_volatility_200_mean_time,log_return1_realized_volatility_200_std_time,log_return1_realized_volatility_200_max_time,log_return1_realized_volatility_200_min_time,log_return2_realized_volatility_200_mean_time,log_return2_realized_volatility_200_std_time,log_return2_realized_volatility_200_max_time,log_return2_realized_volatility_200_min_time,trade_log_return_realized_volatility_mean_time,trade_log_return_realized_volatility_std_time,trade_log_return_realized_volatility_max_time,trade_log_return_realized_volatility_min_time,trade_log_return_realized_volatility_400_mean_time,trade_log_return_realized_volatility_400_std_time,trade_log_return_realized_volatility_400_max_time,trade_log_return_realized_volatility_400_min_time,trade_log_return_realized_volatility_300_mean_time,trade_log_return_realized_volatility_300_std_time,trade_log_return_realized_volatility_300_max_time,trade_log_return_realized_volatility_300_min_time,trade_log_return_realized_volatility_200_mean_time,trade_log_return_realized_volatility_200_std_time,trade_log_return_realized_volatility_200_max_time,trade_log_return_realized_volatility_200_min_time,log_return1_realized_volatility_0c1,log_return1_realized_volatility_1c1,log_return1_realized_volatility_3c1,log_return1_realized_volatility_4c1,log_return1_realized_volatility_6c1,price_spread_sum_0c1,price_spread_sum_1c1,price_spread_sum_3c1,price_spread_sum_4c1,price_spread_sum_6c1,bid_spread_sum_0c1,bid_spread_sum_1c1,bid_spread_sum_3c1,bid_spread_sum_4c1,bid_spread_sum_6c1,ask_spread_sum_0c1,ask_spread_sum_1c1,ask_spread_sum_3c1,ask_spread_sum_4c1,ask_spread_sum_6c1,total_volume_sum_0c1,total_volume_sum_1c1,total_volume_sum_3c1,total_volume_sum_4c1,total_volume_sum_6c1,volume_imbalance_sum_0c1,volume_imbalance_sum_1c1,volume_imbalance_sum_3c1,volume_imbalance_sum_4c1,volume_imbalance_sum_6c1,bid_ask_spread_sum_0c1,bid_ask_spread_sum_1c1,bid_ask_spread_sum_3c1,bid_ask_spread_sum_4c1,bid_ask_spread_sum_6c1,trade_size_sum_0c1,trade_size_sum_1c1,trade_size_sum_3c1,trade_size_sum_4c1,trade_size_sum_6c1,trade_order_count_sum_0c1,trade_order_count_sum_1c1,trade_order_count_sum_3c1,trade_order_count_sum_4c1,trade_order_count_sum_6c1,size_tau2_0c1,size_tau2_1c1,size_tau2_3c1,size_tau2_4c1,size_tau2_6c1,stock_id,time_id,target
0,-0.479785,-0.219804,-0.479892,-0.110657,-0.479739,-0.321303,-0.479685,-0.324980,0.505711,0.680072,0.700777,0.520365,0.742454,0.758581,0.651479,0.461469,0.311180,0.310652,-0.114086,0.439657,0.368319,0.326069,-0.873988,-1.143994,-0.736256,-0.906479,-0.248976,0.482859,0.294839,0.849250,0.716336,0.784525,0.407584,0.923770,0.716127,0.810281,0.457110,0.706116,0.753604,0.618593,0.435218,0.734521,0.778648,0.661683,0.502468,0.737631,0.783113,0.597991,-0.024862,-0.604106,-1.045595,-0.955807,-5.199338,-0.828825,-0.601099,-1.042691,-0.925428,0.953833,0.363876,-0.332688,-0.877546,-0.172007,-0.581683,-0.528907,1.516474,-0.715711,-0.796557,-1.045387,-0.827941,0.286640,-0.059597,-0.642074,-0.304325,-0.027664,-0.318145,-0.945954,-0.649124,-0.130354,-0.501089,-0.973379,-0.780294,0.013567,-0.561015,-1.082254,-0.858357,-0.011454,-0.561749,-1.093946,-0.946935,0.603354,0.828825,0.561015,0.858357,0.501089,0.780294,0.318145,0.649124,-0.510342,0.259308,0.659999,0.774358,0.281419,0.575753,0.806057,1.265750,0.937164,0.191129,0.685947,1.86251,-1.227566,0.523969,0.685947,0.807793,-5.199337,0.191129,0.685947,1.747560,-0.659219,0.549312,0.731077,0.716412,-0.167555,0.191129,0.67331,1.565103,-0.053946,0.549312,0.774358,0.869283,0.499668,-0.452616,-0.440839,0.809532,-5.199337,-0.428435,-0.294489,-0.282071,-5.199337,-0.306295,-0.193045,1.047981,-5.199337,-0.306295,-0.176463,1.268548,-5.199337,0.589870,0.646526,0.646636,0.431098,0.513983,0.686741,0.899880,0.720277,0.533738,0.629129,0.728431,0.526848,0.467251,0.577974,0.799933,0.632189,0.522531,0.583912,0.603261,0.420203,0.454701,0.571182,0.709133,0.745072,0.567337,0.697295,0.689919,0.456223,0.489748,0.689919,0.903646,0.739189,0.438077,0.289219,0.055829,0.370036,0.405456,0.198797,0.200076,0.732023,0.397030,0.251535,0.102375,0.467229,0.414731,0.360854,0.112464,0.357858,0.730299,0.580940,0.582009,0.355305,0.782411,0.802591,0.646031,0.619670,0.245069,1.133204,0.243807,0.113425,-0.050806,-0.057314,0.293583,-0.433940,-0.126501,-0.070907,-0.020899,-0.242485,-0.959775,-1.623877,-0.683569,-0.618159,-0.807793,-0.535166,-1.068838,-0.418834,-0.836862,-0.088665,0.348635,0.122565,0.005670,-0.021946,0.273600,-0.243084,0.453311,0.697895,0.282724,-0.318145,0.331363,0.868369,0.752541,0.793969,0.132047,-0.878663,-0.975798,-0.868369,-0.941656,-0.293180,0,5,0.004136
1,-1.369969,-1.739819,-1.370100,-1.816258,-1.369949,-1.505664,-1.370019,-1.662101,-1.630305,-0.886665,-1.168769,-0.861189,-0.696660,-0.149143,-1.204533,-0.051335,-1.412538,-0.310450,-1.126616,0.404881,1.237316,0.349967,-1.084623,-0.870198,-1.275850,-1.097012,-1.257148,0.099350,-0.473672,-0.187421,-0.399954,-0.251077,-0.973595,-0.338007,-0.756131,-0.665947,-1.286706,-0.601002,-1.011896,-0.807379,-1.514421,-0.780846,-1.102057,-0.906151,-1.541719,-0.748941,-1.076243,-0.769406,-1.522349,-1.009664,-1.748348,-1.814929,-5.199338,-1.674425,-2.080074,-1.748419,-1.814776,0.038393,0.146441,-0.965754,-0.966754,-0.751709,-1.042568,-1.633203,0.121011,-1.388829,-1.858090,-1.761807,-2.074650,-0.834770,-0.471451,-0.844842,-0.915975,-1.355651,-0.789680,-1.115509,-1.236944,-1.562764,-0.858357,-1.398107,-1.260182,-1.308271,-0.839479,-1.482669,-1.376767,-1.526561,-0.984922,-1.647966,-1.565103,1.009664,1.674425,0.839479,1.365532,0.858357,1.260182,0.789680,1.216980,-1.401555,0.259308,0.659999,0.774358,0.281419,0.575753,0.806057,1.265750,0.937164,0.191129,0.685947,1.86251,-1.227566,0.523969,0.685947,0.807793,-5.199337,0.191129,0.685947,1.747560,-0.659219,0.549312,0.731077,0.716412,-0.167555,0.191129,0.67331,1.565103,-0.053946,0.549312,0.774358,0.869283,0.499668,-0.452616,-0.440839,0.809532,-5.199337,-0.428435,-0.294489,-0.282071,-5.199337,-0.306295,-0.193045,1.047981,-5.199337,-0.306295,-0.176463,1.268548,-5.199337,-1.051887,-0.347532,-0.652223,-1.344091,-0.981872,-0.525408,-0.990027,-1.281619,-1.251646,-0.949851,-1.119020,-1.368158,-1.179701,-0.936192,-1.369098,-0.833867,-1.248849,-0.926264,-1.620837,-1.479265,-1.175696,-0.873699,-1.664911,-1.202643,-1.185344,-0.621505,-1.214412,-1.417533,-1.070956,-0.633371,-1.010708,-1.147652,-1.223580,-0.212886,-0.151680,-5.199337,-1.335326,-0.682661,-0.646306,-5.199337,-1.316238,-0.752241,-1.080207,-5.199337,-1.282933,-0.587722,-1.414966,-5.199337,-1.126084,-1.169702,-0.705295,-1.407014,0.121302,-0.816512,-0.898756,-0.844822,-1.401447,0.109462,-1.139247,-1.918659,-1.373630,-1.770181,-0.969760,1.393620,1.680182,1.417720,1.727160,1.529813,-1.247795,-0.667417,-0.126357,-0.446559,-0.842842,-0.573534,-1.554295,0.303668,-0.474018,0.072164,-1.275588,-1.798890,-1.399747,-1.758799,-1.289191,-1.049063,-1.569029,-1.620837,-1.549222,-0.284626,-0.990709,-1.093585,-0.834140,-1.094497,0.324747,0.906572,0.773512,1.811684,0.864567,-0.569105,0,11,0.001445
2,-1.512491,0.096645,-1.511633,0.035216,-1.512495,-0.247944,-1.512774,-0.066527,-0.449452,0.147527,-0.437740,0.151895,-0.103191,0.497737,-0.243376,0.229884,-0.608536,0.465349,-0.689092,0.435433,0.693930,0.397030,-1.148337,-1.152478,-1.376622,-0.722911,-0.728153,0.551440,-0.916039,0.599256,-0.245755,0.072285,-0.593709,0.344751,-0.118708,-0.216573,-0.805420,0.107267,-0.418876,-0.287384,-0.406467,0.172598,-0.355853,0.147060,-0.398653,0.067391,-0.343040,0.050546,-0.061966,-1.258796,-1.353468,-1.322226,-5.199338,-1.457159,-1.458974,-1.354236,-1.322623,0.149890,-0.636085,-0.965754,-1.484906,-1.527796,-0.503223,0.422826,-0.473848,0.450533,-0.656967,-1.334084,-0.628365,-0.127998,-1.084509,-0.717223,-0.915975,-0.135987,-0.909319,-0.918203,-0.805189,-0.365172,-1.224907,-1.193008,-1.198775,-0.079369,-1.103904,-1.174687,-1.231572,0.007866,-1.095868,-1.228455,-1.281267,1.258796,1.457159,1.103904,1.231572,1.224907,1.188549,0.879390,0.805189,-2.028286,0.259308,0.659999,0.774358,0.281419,0.575753,0.806057,1.265750,0.937164,0.191129,0.685947,1.86251,-1.227566,0.523969,0.685947,0.807793,-5.199337,0.191129,0.685947,1.747560,-0.659219,0.549312,0.731077,0.716412,-0.167555,0.191129,0.67331,1.565103,-0.053946,0.549312,0.774358,0.869283,0.499668,-0.452616,-0.440839,0.809532,-5.199337,-0.428435,-0.294489,-0.282071,-5.199337,-0.306295,-0.193045,1.047981,-5.199337,-0.306295,-0.176463,1.268548,-5.199337,-0.923477,-2.168153,-1.684702,-0.091871,-0.825294,-1.467650,-0.261902,-0.272196,-0.773512,-1.672394,-1.630237,-0.235422,-0.686741,-1.705807,-1.014897,0.076278,-0.812025,-1.705807,-1.532664,0.034490,-0.760978,-1.550338,-0.482082,-0.112625,-0.832366,-1.887973,-1.245069,-0.078698,-0.781400,-1.420129,-0.278617,-0.165648,-0.737529,-1.102752,-1.258796,-0.025676,-0.621505,-0.768210,-0.531244,-0.198321,-0.674096,-1.005570,-0.785405,0.001571,-0.730258,-1.058739,-0.869555,0.128984,-0.985801,-0.722807,-1.011460,-0.427491,-1.831449,-1.349774,-0.814764,-0.658441,-0.536526,-2.103224,-0.810425,-0.525408,-0.229218,-0.408947,-2.102035,0.932312,0.625319,0.380800,0.362300,1.327987,-0.021449,-0.049551,0.336667,-0.485511,-0.825498,-1.284121,-0.284351,0.839479,-0.790536,-1.625519,-0.879843,-0.573534,-0.303668,-0.378483,-1.744680,-1.012800,-1.348216,-0.970605,-0.850700,-1.256031,-1.091308,-1.597197,-0.879390,-0.680591,-1.381636,0.649124,1.643447,1.486797,0.580940,1.145201,0,16,0.002168
3,-2.512967,-0.092468,-2.513428,-0.374156,-2.513126,-0.395255,-2.512831,-0.154242,-0.315367,-0.267597,-0.634425,-0.550625,-0.574670,0.443066,-0.735929,0.649542,-1.354331,0.546872,-1.489460,0.476072,2.476487,0.436007,-1.647228,-0.786259,-1.982346,-0.732716,-2.079534,0.223882,-0.263330,-0.263269,-1.716243,-1.687618,-0.853643,-0.699594,-2.372781,-1.738624,-0.299860,-0.136062,-1.006348,-0.660233,-0.478514,-0.281543,-1.109378,-0.720067,-0.169176,-0.135083,-0.466386,-0.449299,-0.468174,-1.891713,-1.430649,-1.061100,1.695160,-1.632609,-0.167555,-1.431458,-1.063092,1.669261,-0.230812,-1.753363,-1.891713,-1.811684,-1.393126,0.016114,-0.571119,-0.108235,-0.624236,-1.272493,-0.531897,-1.299290,-1.724918,-0.892386,-1.214355,-0.581827,-2.119870,-1.438090,-1.910818,-0.440365,-1.562973,-0.987982,-0.972775,-0.728770,-1.805250,-1.235405,-1.326747,-0.303971,-1.808458,-1.270352,-1.406486,1.891713,1.632609,1.805250,1.326747,1.562973,0.972775,2.119870,1.910818,1.829157,0.259308,0.659999,0.774358,0.281419,0.575753,0.806057,1.265750,0.937164,0.191129,0.685947,1.86251,-1.227566,0.523969,0.685947,0.807793,-5.199337,0.191129,0.685947,1.747560,-0.659219,0.549312,0.731077,0.716412,-0.167555,0.191129,0.67331,1.565103,-0.053946,0.549312,0.774358,0.869283,0.499668,-0.452616,-0.440839,0.809532,-5.199337,-0.428435,-0.294489,-0.282071,-5.199337,-0.306295,-0.193045,1.047981,-5.199337,-0.306295,-0.176463,1.268548,-5.199337,-0.543983,-0.414731,-0.898002,-2.063981,-0.563803,-0.552961,-0.973781,-1.605251,-0.703266,-0.501089,-0.928445,-1.331502,-0.694700,-0.615432,-0.733265,-0.736822,-0.637871,-0.536959,-0.957789,-1.679542,-0.656691,-0.734516,-0.996186,-1.051244,-0.552923,-0.310240,-0.715602,-1.986088,-0.597348,-0.592857,-0.947917,-1.320136,-0.705076,-0.224445,-0.237323,-1.207828,-0.909319,-0.058341,-0.016406,-5.199337,-0.820144,-0.030726,-0.062110,-5.199337,-0.756708,-0.172007,-0.244125,-1.010225,-0.449146,-0.583912,-0.478469,-0.807793,0.000627,-0.381453,-0.548583,-0.655329,-0.992076,0.400925,-1.322934,-1.605369,-1.025447,-1.774152,-0.441835,1.154918,1.343564,0.873670,1.858969,0.083484,-2.542524,-0.481283,-0.806057,-0.341980,-0.015286,-2.307796,-0.785405,-1.093585,-0.392963,0.424316,-1.264354,-1.468122,-0.946521,-1.818194,-0.269943,-2.074650,-1.851957,-0.926565,-1.883309,0.163575,-2.432198,-1.537945,-1.152478,-1.803831,0.230879,2.651901,1.821632,0.821774,2.053542,-0.004389,0,31,0.002195
4,-1.655595,-1.763856,-1.655506,-1.556704,-1.655869,-2.054141,-1.656406,-1.476157,-0.822700,-0.442796,-0.834547,-0.408571,-0.603956,-0.017347,-1.394893,-0.214018,-1.559752,-0.231745,-0.849947,0.012958,1.799172,0.427748,-1.470327,-1.205231,-1.670588,-1.323730,-1.349296,-0.090407,-0.077046,-1.093439,-0.840173,-0.178580,-0.286791,-0.697117,-0.313285,-0.167865,-0.547862,-0.467183,-0.612772,-0.233695,-0.751232,-0.528004,-0.831686,-0.359296,-0.917321,-0.423320,-0.958251,-0.390179,-1.587159,-1.425277,-1.502506,-1.460794,-5.199338,-1.105058,0.055201,-1.502126,-1.462958,-0.213819,-0.157571,-1.618510,-1.153697,-1.272765,-1.042568,-2.262688,-0.243514,-2.092934,-0.719911,-1.408547,-0.749216,-1.214314,-1.368724,-1.887973,-1.058901,-1.774393,-1.492501,-2.023501,-1.285552,-2.016343,-1.331296,-1.174687,-0.780294,-2.184380,-1.416674,-1.288182,-0.995157,-2.275026,-1.416674,-1.430241,-1.061100,1.425277,1.105058,1.416674,0.995157,1.331296,0.780294,1.492501,1.285552,1.154918,0.259308,0.659999,0.774358,0.281419,0.575753,0.806057,1.265750,0.937164,0.191129,0.685947,1.86251,-1.227566,0.523969,0.685947,0.807793,-5.199337,0.191129,0.685947,1.747560,-0.659219,0.549312,0.731077,0.716412,-0.167555,0.191129,0.67331,1.565103,-0.053946,0.549312,0.774358,0.869283,0.499668,-0.452616,-0.440839,0.809532,-5.199337,-0.428435,-0.294489,-0.282071,-5.199337,-0.306295,-0.193045,1.047981,-5.199337,-0.306295,-0.176463,1.268548,-5.199337,-1.292749,-1.242351,-1.432037,-1.475530,-1.145201,-0.947917,-1.140382,-1.256032,-1.384900,-1.384523,-1.768134,-1.004459,-1.286986,-1.399128,-1.316394,-0.683569,-1.376982,-1.408174,-1.584186,-1.167220,-1.245069,-1.157364,-1.639782,-0.911217,-1.378386,-1.498255,-1.494414,-1.428748,-1.219559,-1.116678,-1.584603,-1.159818,-1.502118,-1.174687,-1.127074,-1.046968,-1.639782,-1.744680,-1.352898,-0.234744,-1.658965,-1.549878,-1.411562,-0.635254,-1.620837,-1.357457,-1.107370,-0.907951,-1.213683,-1.278422,-1.634696,-1.256032,0.030456,-1.008055,-1.218296,-1.506004,-1.644611,0.229592,-0.802591,-1.517801,-2.256937,-1.750454,0.354156,1.223580,1.778815,2.015664,1.744681,-0.211830,-1.758311,-0.025709,-0.398387,-0.109941,-0.374067,-1.134579,0.591363,-0.414342,-0.212886,-0.190670,-1.004058,-1.646661,-2.156874,-1.756286,0.293180,-2.457084,-2.264451,-1.408174,-2.096724,-0.184748,-2.004444,-1.824781,-1.767584,-1.630236,0.278159,1.494830,1.756286,1.275588,1.485556,-0.528379,0,62,0.001747
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
428927,-0.436581,-0.750628,-0.436602,-0.481173,-0.437142,-0.419296,-0.437025,-0.463638,0.233381,0.439020,0.505459,0.517498,0.686638,0.714084,0.719769,0.609426,0.349341,0.517147,-1.139961,0.462845,-0.119253,0.835917,-0.577144,-0.870198,-0.446553,-0.402463,-0.379694,0.578852,0.384616,0.116513,0.750115,0.418297,0.056682,0.100894,0.432482,0.180409,0.430773,0.349837,0.601866,0.537204,0.422682,0.498432,0.603957,0.596393,0.300621,0.430896,0.556850,0.519492,0.105624,-0.713983,-1.214209,-0.815638,-5.199338,-0.915022,-0.295799,-1.214811,-0.814898,-0.506577,0.092200,-0.880314,-0.516074,-0.580198,-0.360653,-0.613485,-0.427895,-0.675041,-1.089902,-1.263728,-1.257413,0.340847,-0.471451,-1.575862,-1.058901,-0.106741,-0.577974,-1.451742,-1.145201,0.035853,-0.702703,-1.484645,-1.228899,0.191918,-0.720470,-1.317944,-1.095868,0.094414,-0.781145,-1.502118,-1.281267,0.713983,0.915022,0.720470,1.086770,0.702703,1.218296,0.577974,1.145201,1.147620,1.175938,0.938137,0.401783,1.744681,1.087902,0.839479,0.776899,1.423548,1.175938,0.872031,-0.26320,1.175938,1.087902,0.836806,-0.013794,0.057713,1.175938,0.870198,-0.263849,1.651928,1.087902,0.835028,0.400424,1.271357,1.175938,0.90459,0.187938,0.802591,1.087902,0.836806,0.142173,1.173438,0.941061,0.479173,-0.332026,0.774358,0.941061,0.634487,-0.690715,-0.235389,0.941061,0.607119,-0.662339,0.747557,0.941061,0.553692,-0.057713,0.550771,-0.227017,0.057085,0.418834,-0.409270,-0.259013,-0.142222,0.122565,-0.101114,-0.175077,-0.432562,-0.378492,-0.113455,-0.198797,-0.293150,-0.118776,0.254124,-0.241194,-0.423943,-0.460270,-0.378105,-0.253954,-0.308925,-0.210160,0.163106,-0.192315,0.019437,0.384058,-0.263200,-0.229592,-0.113103,-0.024145,0.040764,-0.325547,-0.271932,-0.348390,-0.281222,-0.228900,-0.585399,0.031091,0.360653,-0.314549,-0.602644,-0.145915,0.102628,-0.303668,-0.350305,-0.151264,-0.039509,-0.246361,-0.193684,-0.468650,-0.308925,0.498248,-0.049551,-0.130619,-0.357879,-0.281419,0.633721,-0.499307,-0.594159,-0.607873,-0.414805,0.342527,0.463060,0.613918,0.448598,0.502512,0.165648,-1.289862,-0.271064,-0.210322,-0.554423,-0.858454,-0.420830,0.458876,-0.791056,-0.666248,-0.472591,-0.475659,-0.606365,-0.530142,-0.457166,0.125093,-0.987982,-0.484101,0.010659,-0.582425,0.450533,-0.179648,0.052564,0.105527,-0.064623,0.636789,-0.467251,-0.491614,-0.209628,-0.428435,-0.856710,126,32751,0.003461
428928,-1.122658,0.608881,-1.122269,0.605204,-1.122819,0.694037,-1.123003,0.671823,0.380137,0.205818,0.366310,0.207760,-0.023160,0.483311,-0.017705,0.442008,-0.554702,0.240573,-1.136850,0.521405,0.962743,0.762565,-1.603299,-1.414966,-1.022406,-0.693902,-1.104399,0.501326,1.030901,0.576624,0.937082,0.575642,0.698478,0.249939,0.561387,0.209590,0.672788,0.198273,0.596155,0.196986,0.574794,0.270686,0.565352,0.247296,0.498428,0.279093,0.473752,0.287680,0.112391,-0.506072,-1.294679,-1.959750,-5.199338,-0.456787,-0.050806,-1.292572,-1.958245,0.491749,0.123711,-0.575013,-0.394317,0.175190,-0.939111,0.616292,1.134860,0.670469,-1.541062,-1.557321,-1.169702,0.680736,-0.059597,-1.167220,-0.623789,0.302808,-0.577974,-1.696921,-1.236944,0.366572,-0.563953,-1.247249,-0.444296,0.239540,-0.511781,-1.354242,-0.565423,0.228844,-0.410634,-1.276650,-0.495411,0.506072,0.456787,0.511781,0.565423,0.536235,0.439457,0.577974,1.216980,1.726466,1.175938,0.938137,0.401783,1.744681,1.087902,0.839479,0.776899,1.423548,1.175938,0.872031,-0.26320,1.175938,1.087902,0.836806,-0.013794,0.057713,1.175938,0.870198,-0.263849,1.651928,1.087902,0.835028,0.400424,1.271357,1.175938,0.90459,0.187938,0.802591,1.087902,0.836806,0.142173,1.173438,0.941061,0.479173,-0.332026,0.774358,0.941061,0.634487,-0.690715,-0.235389,0.941061,0.607119,-0.662339,0.747557,0.941061,0.553692,-0.057713,0.550771,-1.068838,-0.272298,0.187396,-1.315899,-1.234253,-0.884942,-0.300445,-1.114295,-1.062839,-0.179423,0.274902,-1.674425,-1.275467,-0.814018,-0.164271,-0.496829,-1.053425,-0.079710,0.352635,-1.116678,-1.200063,-0.619657,-0.003135,-0.728204,-1.045319,-0.099082,0.398387,-1.250531,-1.179701,-0.669721,-0.088518,-0.928445,-1.004459,0.135210,0.563932,-0.907880,-1.042829,-0.088518,0.143946,-0.455475,-0.960695,0.234744,0.638991,-0.474204,-0.963757,0.285335,0.709137,-0.735178,-0.606236,-1.028776,-1.223580,-1.159818,-0.557350,-0.521094,-1.097404,-0.946463,-1.087167,-0.311565,-0.882297,-2.102682,-0.901761,-1.216222,-0.962125,1.053425,1.362259,1.404802,1.121368,1.452046,-0.491651,-0.080968,0.402463,3.102617,-0.815609,-0.038255,-0.119585,0.756708,3.182016,-0.554423,-0.977819,-1.675104,-1.130824,-1.174687,-1.246929,-0.538733,-0.623027,-0.868369,0.027308,-0.680405,-0.135210,0.106157,0.251535,0.109941,-0.555885,0.282812,-0.354350,-0.481283,-0.366011,0.338489,126,32753,0.003113
428929,-0.825735,-0.817728,-0.825895,-0.516954,-0.826946,-0.621454,-0.826781,-0.556967,-0.019567,0.472930,-0.055978,0.323615,0.548748,0.562767,0.249911,0.362731,-0.063906,0.565258,-0.253336,0.623122,0.248583,0.364670,-0.985379,-0.725355,-0.108394,0.000627,-0.266129,0.429270,-0.025273,0.718598,-0.014169,0.779860,0.290299,0.595485,0.311612,0.602158,0.378240,0.847207,0.234871,0.614217,0.161950,0.647327,0.072633,0.430958,0.120536,0.542149,0.040578,0.357320,-0.097166,-0.793110,-0.908823,-0.957789,-5.199338,-0.979843,-1.458974,-0.908194,-0.955019,0.722950,-0.411933,-0.798273,-0.719658,-0.845738,-0.360653,-1.350120,0.401965,-1.134639,-0.441956,-1.006286,-0.514642,-0.085474,-1.724918,-1.580218,-1.490594,0.330943,-0.484806,-0.900830,-0.947917,0.173953,-0.327392,-0.545670,-0.629893,0.020503,-0.561015,-0.704844,-0.753373,0.047104,-0.561749,-0.763963,-0.770977,0.793110,0.973781,0.561015,0.753373,0.327392,0.624550,0.484806,0.947917,0.353970,1.175938,0.938137,0.401783,1.744681,1.087902,0.839479,0.776899,1.423548,1.175938,0.872031,-0.26320,1.175938,1.087902,0.836806,-0.013794,0.057713,1.175938,0.870198,-0.263849,1.651928,1.087902,0.835028,0.400424,1.271357,1.175938,0.90459,0.187938,0.802591,1.087902,0.836806,0.142173,1.173438,0.941061,0.479173,-0.332026,0.774358,0.941061,0.634487,-0.690715,-0.235389,0.941061,0.607119,-0.662339,0.747557,0.941061,0.553692,-0.057713,0.550771,-0.751709,-1.824781,-2.169892,-0.334014,-0.653775,-0.968350,-0.691506,-0.286803,-0.691246,-1.337405,-1.267148,-0.345971,-0.586744,-1.005885,-1.292749,-0.061801,-0.740116,-1.452156,-1.590475,-0.399745,-0.644487,-0.982437,-1.102751,-0.201356,-0.737987,-1.590288,-1.951265,-0.577974,-0.598217,-0.693104,-0.312873,-0.366011,-0.829977,-1.744680,-2.056985,-5.199337,-0.775204,-1.251101,-1.895479,-5.199337,-0.812363,-1.502118,-2.123056,-5.199337,-0.837481,-1.644626,-1.788791,-5.199337,-0.804165,-0.541044,-0.702703,-0.586888,-1.197490,-0.458876,-0.405396,-0.633721,-0.320395,-0.738468,-1.428748,-0.725355,-0.562140,-0.552961,-1.359188,1.304392,0.879390,0.472853,0.489316,2.878006,-1.091449,0.272844,0.079710,0.598981,-1.968378,-0.999793,0.410634,1.066621,0.743186,-1.630694,-1.403099,-0.813017,-0.516033,-0.528289,-2.215979,-1.574784,-0.949884,-1.043572,-1.646022,-2.611883,-1.986088,-1.245069,-1.349774,-1.558733,-2.461383,2.438397,2.105586,1.936453,1.750454,2.457084,126,32758,0.004070
428930,0.065064,-0.850647,0.065031,-0.777759,0.065374,-0.941580,0.065428,-0.968170,0.221262,0.308269,0.102809,0.292941,0.429431,0.306341,0.395440,0.195586,0.135515,0.208054,-0.010608,0.623313,0.119775,0.161200,-0.244012,-0.579457,-0.251461,-0.742592,-0.074981,0.189164,-0.248204,-0.143408,-0.522667,-0.095461,0.058474,-0.164311,-0.405836,-0.028355,0.142786,0.099968,-0.030923,0.082925,0.244634,0.220916,0.042849,0.200720,0.198659,0.289720,0.018540,0.237801,0.010135,0.304982,-0.163853,-0.376207,-5.199338,0.105527,-0.601099,-0.162541,-0.373930,0.678170,0.815318,0.212886,0.399066,0.332688,0.583168,-1.188549,0.961790,-1.177597,-0.307196,-0.313435,0.055829,-0.053550,0.334014,-0.245715,0.094185,-0.032365,0.322765,-0.140780,0.160565,0.110172,0.403823,-0.057399,0.166919,-0.002346,0.330701,-0.119523,0.129517,-0.029897,0.371379,-0.104502,0.166919,-0.308925,-0.105527,-0.331363,-0.130782,-0.405864,-0.168191,-0.328053,-0.165648,-0.252567,1.175938,0.938137,0.401783,1.744681,1.087902,0.839479,0.776899,1.423548,1.175938,0.872031,-0.26320,1.175938,1.087902,0.836806,-0.013794,0.057713,1.175938,0.870198,-0.263849,1.651928,1.087902,0.835028,0.400424,1.271357,1.175938,0.90459,0.187938,0.802591,1.087902,0.836806,0.142173,1.173438,0.941061,0.479173,-0.332026,0.774358,0.941061,0.634487,-0.690715,-0.235389,0.941061,0.607119,-0.662339,0.747557,0.941061,0.553692,-0.057713,0.550771,-0.133727,-0.349967,-0.598627,0.229131,-0.137741,-0.401103,-0.444988,0.348635,-0.063366,-0.232167,-0.688329,0.409270,-0.088131,-0.209215,-0.204944,0.427061,-0.107418,-0.240801,-0.603354,0.310128,-0.123089,-0.353970,-0.363653,0.260605,-0.101114,-0.254124,-0.555938,0.359315,-0.125738,-0.441483,-0.504356,0.384753,0.007446,-0.081321,0.344640,0.386199,0.091686,-0.106157,0.336320,0.708074,0.013167,-0.069650,0.353312,0.422944,0.037000,0.043274,0.458876,0.514688,-0.373370,-0.151680,0.132679,0.063366,-0.233456,-0.673764,-0.289243,0.090908,-0.053537,-0.217402,0.122477,0.446715,0.413364,0.697094,0.546653,-0.218117,-0.478469,-0.574614,-0.666248,-0.401103,2.181149,0.905534,1.616191,0.495411,1.408174,1.435743,0.743418,1.896223,0.581521,1.468122,0.175827,0.468650,0.489748,0.681986,0.473548,0.943994,0.582351,1.308940,0.680405,1.091308,0.883459,0.512496,0.977819,0.707526,0.770134,-1.457918,-1.243224,-1.620001,-1.209615,-0.891588,126,32763,0.003357


In [9]:
#https://bignerdranch.com/blog/implementing-swish-activation-function-in-keras/
from keras.backend import sigmoid,tanh,log,exp
def swish(x, beta = 1):
    return (x * sigmoid(beta * x))

def swish_2(x, beta = 0.95):
    return (x * sigmoid(beta * x))

def swish_3(x, beta = 1.05):
    return (x * sigmoid(beta * x))

def mish(x):
    return (x * tanh(log(1+exp(x))))

from keras.utils.generic_utils import get_custom_objects
from keras.layers import Activation
get_custom_objects().update({'swish': Activation(swish)})
# get_custom_objects().update({'mish': Activation(mish)})

# hidden_units = (128,64,32)
# hidden_units = (512,128,32)
# hidden_units = (256,128,64,16)
# hidden_units = (512,128,64,16)
stock_embedding_size = 24

cat_data = train_nn['stock_id']

def base_model():
    
    # Each instance will consist of two inputs: a single user id, and a single movie id
    stock_id_input = keras.Input(shape=(1,), name='stock_id')
    num_input = keras.Input(shape=(244,), name='num_data') # 247 - 1(stock_id) - 2(time_id, target) =244  # 258


    #embedding, flatenning and concatenating
    stock_embedded = keras.layers.Embedding(max(cat_data)+1, stock_embedding_size, 
                                           input_length=1, name='stock_embedding')(stock_id_input)
    stock_flattened = keras.layers.Flatten()(stock_embedded)
    out = keras.layers.Concatenate()([stock_flattened, num_input])
    
    # Add one or more hidden layers
    hidden_units = (128,64,32)
    for n_hidden in hidden_units:
        out = keras.layers.Dense(n_hidden, activation='swish')(out)
#         out = keras.layers.Dense(n_hidden, activation='mish')(out)
    #out = keras.layers.Concatenate()([out, num_input])

    # A single output: our predicted rating
    out = keras.layers.Dense(1, activation='linear', name='prediction')(out)
    
    model = keras.Model(
    inputs = [stock_id_input, num_input],
    outputs = out,
    )
    
    return model


def base_model_2(num_columns,  # 244
#                num_labels,  # output_dim
               hidden_units,   # (128,64,32)
               dropout_rates,  # list of dropout_rates
               stock_embedding_size = 24, 
               # ls = 1e-2,  # label_smoothing = ls
               lr = 1e-3):
    
    # input 
    stock_id_input = keras.Input(shape=(1,), name='stock_id')
    num_input = keras.Input(shape=(num_columns,), name='num_data')
    
    #embedding, flatenning and concatenating
    stock_embedded = keras.layers.Embedding(max(cat_data)+1, 
                                            stock_embedding_size,
                                            input_length=1, 
                                            name='stock_embedding')(stock_id_input)
    stock_flattened = keras.layers.Flatten()(stock_embedded)
    x = keras.layers.Concatenate()([stock_flattened, num_input])
    
    i=0
    # Add one or more hidden layers
    for n_hidden in hidden_units:
#         out = keras.layers.Dense(n_hidden, activation='swish')(out)
        x = keras.layers.Dense(n_hidden)(x)
#         if i%2 ==1:
#             x = keras.layers.BatchNormalization()(x)
        if i==0 or i==4:
            x = tf.keras.layers.GaussianNoise(0.1)(x)  # std of noise
        x = keras.layers.Activation('swish')(x)
#         x = keras.layers.Activation('swish')(x)
        x = keras.layers.Dropout(dropout_rates[i])(x)
        i+=1

    out = keras.layers.Dense(1, activation='linear', name='prediction')(x)
    
    model = keras.Model(
    inputs = [stock_id_input, num_input],
    outputs = out,
    )
    
    return model

In [10]:
# Function to calculate the root mean squared percentage error
def rmspe(y_true, y_pred):
    return np.sqrt(np.mean(np.square((y_true - y_pred) / y_true)))

# Function to early stop with root mean squared percentage error
def feval_rmspe(y_pred, lgb_train):
    y_true = lgb_train.get_label()
    return 'RMSPE', rmspe(y_true, y_pred), False

In [11]:
# %%time
target_name='target'
scores_folds = {}
model_name = 'NN'
pred_name = 'pred_{}'.format(model_name)

n_folds = 5
kf = model_selection.KFold(n_splits=n_folds, shuffle=True, random_state=2020)
scores_folds[model_name] = []
counter = 1

features_to_consider = list(train_nn)  # len -> 247

features_to_consider.remove('time_id')
features_to_consider.remove('target')
try:
    features_to_consider.remove('pred_NN')
except:
    pass


train_nn[features_to_consider] = train_nn[features_to_consider].fillna(train_nn[features_to_consider].mean())
test_nn[features_to_consider] = test_nn[features_to_consider].fillna(train_nn[features_to_consider].mean())

train_nn[pred_name] = 0
test_nn[target_name] = 0
test_predictions_nn1 = np.zeros(test_nn.shape[0])
valid_predictions_nn1 = np.zeros(train_nn.shape[0])

for n_count in range(n_folds):
    print('CV {}/{}'.format(counter, n_folds))
    
    indexes = np.arange(nfolds).astype(int)    
    indexes = np.delete(indexes,obj=n_count, axis=0)
    
    indexes = np.r_[values[indexes[0]],values[indexes[1]],values[indexes[2]],values[indexes[3]]] # 按列连接两个矩阵，就是把两矩阵上下相加，要求列数相等。
    
    X_train = train_nn.loc[train_nn.time_id.isin(indexes), features_to_consider]
    y_train = train_nn.loc[train_nn.time_id.isin(indexes), target_name]
    X_test = train_nn.loc[train_nn.time_id.isin(values[n_count]), features_to_consider]
    y_test = train_nn.loc[train_nn.time_id.isin(values[n_count]), target_name]
    
    #############################################################################################
    # NN
    #############################################################################################
#     model = base_model()
    model = base_model()
    
    model.compile(
        keras.optimizers.Adam(learning_rate=0.006),
#         keras.optimizers.Adam(learning_rate=0.006,decay=3e-5),
        loss=root_mean_squared_per_error
    )
    try:
        features_to_consider.remove('stock_id')
    except:
        pass
    
    num_data = X_train[features_to_consider]
    
    scaler = MinMaxScaler(feature_range=(-1, 1))
    num_data = scaler.fit_transform(num_data.values)
    cat_data = X_train['stock_id']
    target =  y_train
    num_data_test = X_test[features_to_consider]
    num_data_test = scaler.transform(num_data_test.values)
    cat_data_test = X_test['stock_id']
    
    model.fit([cat_data, num_data],
              target, 
              batch_size=2048, 
              epochs=1000, 
#               sample_weight= 1/np.square(y_train),
              validation_data=([cat_data_test, num_data_test], y_test), 
              callbacks=[es, plateau], 
              validation_batch_size=len(y_test), 
              shuffle=True, 
              verbose = 1)
#     preds = model.predict([cat_data_test, num_data_test]).reshape(1,-1)[0]
    valid_predictions_nn1[train_nn.time_id.isin(values[n_count])] = model.predict([cat_data_test, num_data_test]).reshape(1,-1)[0]
        
    
#     score = round(rmspe(y_true = y_test, y_pred = preds),5)
    score = round(rmspe(y_true = y_test, y_pred = valid_predictions_nn1[train_nn.time_id.isin(values[n_count])]),5)
    print('Fold {} {}: {}'.format(counter, model_name, score))
    scores_folds[model_name].append(score)

    

    tt =scaler.transform(test_nn[features_to_consider].values)
    #test_nn[target_name] += model.predict([test_nn['stock_id'], tt]).reshape(1,-1)[0].clip(0,1e10)
    test_predictions_nn1 += model.predict([test_nn['stock_id'], tt]).reshape(1,-1)[0].clip(0,1e10)/n_folds
    #test[target_name] += model.predict([test['stock_id'], test[features_to_consider]]).reshape(1,-1)[0].clip(0,1e10)
    
    counter += 1
    features_to_consider.append('stock_id')
    
print('avg val loss: ', round(np.mean(scores_folds['NN']), 5))

CV 1/5
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72

Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Fold 2 NN: 0.2117
CV 3/5
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6

Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Fold 3 NN: 0.20936
CV 4/5
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epo

Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Fold 4 NN: 0.21619
CV 5/5
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epo

Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72/1000
Epoch 73/1000
Epoch 74/1000
Epoch 75/1000
Epoch 76/1000
Epoch 77/1000
Epoch 78/1000
Epoch 79/1000
Epoch 80/1000
Epoch 81/1000
Epoch 82/1000
Epoch 83/1000
Epoch 84/1000
Fold 5 NN: 0.21488
avg val loss:  0.21214


In [12]:
seed(41)
tf.random.set_seed(41)

target_name='target'
scores_folds = {}
model_name = 'NN2'
pred_name = 'pred_{}'.format(model_name)

n_folds = 5
kf = model_selection.KFold(n_splits=n_folds, shuffle=True, random_state=2021)
scores_folds[model_name] = []
counter = 1

# features_to_consider = list(train1)
features_to_consider = list(train_nn)

features_to_consider.remove('time_id')
features_to_consider.remove('target')
try:
    features_to_consider.remove('pred_NN')
except:
    pass


# train1[features_to_consider] = train1[features_to_consider].fillna(train1[features_to_consider].mean())
# test1[features_to_consider] = test1[features_to_consider].fillna(train1[features_to_consider].mean())

# train1[pred_name] = 0
# test1[target_name] = 0
# test_predictions_nn2 = np.zeros(test_nn.shape[0])

train_nn[features_to_consider] = train_nn[features_to_consider].fillna(train_nn[features_to_consider].mean())
test_nn[features_to_consider] = test_nn[features_to_consider].fillna(train_nn[features_to_consider].mean())

train_nn[pred_name] = 0
test_nn[target_name] = 0
test_predictions_nn2 = np.zeros(test_nn.shape[0])
valid_predictions_nn2 = np.zeros(train_nn.shape[0])

for n_count in range(n_folds):
    print('CV {}/{}'.format(counter, n_folds))
    indexes = np.arange(nfolds).astype(int)    
    indexes = np.delete(indexes,obj=n_count, axis=0)
    
    indexes = np.r_[values[indexes[0]],values[indexes[1]],values[indexes[2]],values[indexes[3]]] # 按列连接两个矩阵，就是把两矩阵上下相加，要求列数相等。
    
    X_train = train_nn.loc[train_nn.time_id.isin(indexes), features_to_consider]
    y_train = train_nn.loc[train_nn.time_id.isin(indexes), target_name]
    X_test = train_nn.loc[train_nn.time_id.isin(values[n_count]), features_to_consider]
    y_test = train_nn.loc[train_nn.time_id.isin(values[n_count]), target_name]
    
#     X_train = train1.loc[train1.time_id.isin(indexes), features_to_consider]
#     y_train = train1.loc[train1.time_id.isin(indexes), target_name]
#     X_test = train1.loc[train1.time_id.isin(values[n_count]), features_to_consider]
#     y_test = train1.loc[train1.time_id.isin(values[n_count]), target_name]
    
    #############################################################################################
    # NN
    #############################################################################################
    
    model = base_model_2(244,
                       hidden_units=[512,256,128,128,64,64,32,16],
                       # hidden_units=[128,128,128,128,32],   #[256,128,64]
                       dropout_rates=[0.02,0.02,0.01,0.01,0.01,0.02,0.02,0.02]
                       # dropout_rates=[0.2,0.2,0.2,0.2,0.2,0.2,]
                      )
    
    model.compile(
        keras.optimizers.Adam(learning_rate=0.006),
        loss=root_mean_squared_per_error
    )
    
    try:
        features_to_consider.remove('stock_id')
    except:
        pass
    
    num_data = X_train[features_to_consider]
    
    scaler = MinMaxScaler(feature_range=(-1, 1))         
    num_data = scaler.fit_transform(num_data.values)    
    
    cat_data = X_train['stock_id']    
    target =  y_train
    
    num_data_test = X_test[features_to_consider]
    num_data_test = scaler.transform(num_data_test.values)
    cat_data_test = X_test['stock_id']
    model.fit([cat_data, num_data], 
              target,               
              batch_size=2048,
              epochs=1000,
              validation_data=([cat_data_test, num_data_test], y_test),
              callbacks=[es, plateau],
              validation_batch_size=len(y_test),
              shuffle=True,
              verbose = 1
             )
    
#     preds = model.predict([cat_data_test, num_data_test]).reshape(1,-1)[0]
    
    valid_predictions_nn2[train_nn.time_id.isin(values[n_count])] = model.predict([cat_data_test, num_data_test]).reshape(1,-1)[0]

#     score = round(rmspe(y_true = y_test, y_pred = preds),5)
    score = round(rmspe(y_true = y_test, y_pred = valid_predictions_nn2[train_nn.time_id.isin(values[n_count])]),5)
    print('Fold {} {}: {}'.format(counter, model_name, score))
    scores_folds[model_name].append(score)
    
    
    
    tt =scaler.transform(test_nn[features_to_consider].values)
    #test_nn[target_name] += model.predict([test_nn['stock_id'], tt]).reshape(1,-1)[0].clip(0,1e10)
#     test_predictions_nn2 += model.predict([test1['stock_id'], tt]).reshape(1,-1)[0].clip(0,1e10)/n_folds
    test_predictions_nn2 += model.predict([test_nn['stock_id'], tt]).reshape(1,-1)[0].clip(0,1e10)/n_folds
    #test[target_name] += model.predict([test['stock_id'], test[features_to_consider]]).reshape(1,-1)[0].clip(0,1e10)
    
    counter += 1
    features_to_consider.append('stock_id')
    
print('avg val loss: ', round(np.mean(scores_folds['NN2']), 5))

CV 1/5
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72

Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Fold 2 NN2: 0.21169
CV 3/5
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 2

Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Fold 3 NN2: 0.20778
CV 4/5
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Ep

Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Fold 4 NN2: 0.21194
CV 5/5
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Ep

In [13]:
valid_predictions_nn1, valid_predictions_nn2

(array([0.00372068, 0.00139473, 0.00230175, ..., 0.00327022, 0.00460468,
        0.002109  ]),
 array([0.0037984 , 0.00149692, 0.00218644, ..., 0.00323129, 0.0054737 ,
        0.00209903]))

In [14]:
test_predictions_nn1, test_predictions_nn2

(array([0.00130807, 0.00250945, 0.00250945]),
 array([0.00171968, 0.00283447, 0.00283447]))

In [15]:
# pd.Series(valid_predictions_nn1).to_csv('nn1_valid_pred_knn.csv',index=0)
# pd.Series(valid_predictions_nn2).to_csv('nn2_valid_pred_knn.csv',index=0)

In [16]:
import gc
del train_nn, test_nn
del X_train, X_test
del cat_data, num_data, cat_data_test, num_data_test
gc.collect()

9343