### Try DeepAR
* PytorchForecasting [Get Started](https://pytorch-forecasting.readthedocs.io/en/stable/getting-started.html)
* DeepAR [doc](https://pytorch-forecasting.readthedocs.io/en/stable/api/pytorch_forecasting.models.deepar.DeepAR.html)

In [1]:
import os

import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_forecasting import TimeSeriesDataSet, DeepAR
from pytorch_forecasting.data import GroupNormalizer

import warnings
warnings.filterwarnings('ignore')

# Used Yujie's cleaned version
DIR_BYID = '/media/user/12TB1/HanLi/GitHub/CMU11785-project/local_data/content/databyid'

ls_all_invest_ids = sorted([int(fn.split('.')[0]) for fn in os.listdir(os.path.join(DIR_BYID, 'target'))])

In [2]:
f_cols = [f"f_{i}" for i in range(300)]
# Read a subset for testing
n = 3
ls_dfs = []
for id in ls_all_invest_ids[:n]:
    df_f_id = pd.DataFrame(np.load(os.path.join(DIR_BYID, f'feats/{id}.npy')), columns=f_cols)
    df_t_id = pd.DataFrame(np.load(os.path.join(DIR_BYID, f'target/{id}.npy')), columns=['target'])
    df_f_id['investment_id'] = id
    ls_dfs.append(pd.concat([df_t_id, df_f_id], axis=1))

df = pd.concat(ls_dfs).reset_index().rename(columns={'index': 'time_id'})
df = df.sort_values(by=['time_id']) # sort by time before splitting
df_train, df_test = train_test_split(df, test_size=0.1, shuffle=False)
# df_train, df_val = train_test_split(df_train, test_size=2/9, shuffle=False)

In [3]:
df_train

Unnamed: 0,time_id,target,f_0,f_1,f_2,f_3,f_4,f_5,f_6,f_7,f_8,f_9,f_10,f_11,f_12,f_13,f_14,f_15,f_16,f_17,f_18,f_19,f_20,f_21,f_22,f_23,f_24,f_25,f_26,f_27,f_28,f_29,f_30,f_31,f_32,f_33,f_34,f_35,f_36,f_37,...,f_261,f_262,f_263,f_264,f_265,f_266,f_267,f_268,f_269,f_270,f_271,f_272,f_273,f_274,f_275,f_276,f_277,f_278,f_279,f_280,f_281,f_282,f_283,f_284,f_285,f_286,f_287,f_288,f_289,f_290,f_291,f_292,f_293,f_294,f_295,f_296,f_297,f_298,f_299,investment_id
0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0
1220,0,-0.300875,0.932573,0.113691,-0.402206,0.378386,-0.203938,-0.413469,0.965623,1.230508,0.114809,-2.012777,0.004936,0.284220,0.502155,-0.287932,-1.169338,-0.267310,-0.574423,-0.771869,1.012212,-1.230507,1.785726,-2.090686,0.325659,-0.877769,1.048786,0.131774,-0.349609,-1.813385,0.099226,-0.241020,1.604571,0.003637,-0.902062,0.221581,0.610063,-0.738558,2.097248,-0.913877,...,0.709600,-0.031878,-1.020150,-1.291206,0.038669,0.187159,-0.680358,0.900593,-0.924766,-1.057890,-0.167062,0.000000,1.281245,0.258715,-0.237964,-0.742125,-0.324677,0.992547,0.961355,-0.025610,-0.006259,0.473603,0.040136,0.453711,-1.597790,0.301659,0.157470,0.416631,1.506131,0.366028,-1.095620,0.200075,0.819155,0.941183,-0.086764,-1.087009,-1.044826,-0.287605,0.321566,1
2440,0,-0.231040,0.810802,-0.514115,0.742368,-0.616673,-0.194255,1.771210,1.428127,1.134144,0.114809,-0.219201,-0.351726,0.846882,0.440299,0.499824,0.893144,-0.010217,-0.681523,1.254092,-1.026969,-1.690156,0.011152,0.875251,0.325659,-0.458305,-1.797581,-0.300364,0.584786,0.551460,0.806422,1.235012,-0.984701,-1.084491,3.161929,0.211016,-2.656093,-0.176984,0.486530,1.237427,...,-0.015459,-0.158329,0.980246,0.799270,0.798399,-0.633207,0.779735,0.171233,1.165891,0.590802,0.118520,0.000000,-0.650803,0.851905,0.086198,1.135668,0.298990,-1.583445,-0.481945,0.532229,0.226693,-0.894744,-0.514552,-1.000073,0.884377,-0.557502,-0.875265,-0.156106,0.537055,-0.154193,0.912726,-0.734579,0.819155,0.941183,-0.387617,-1.087009,-0.929529,-0.974060,-0.343624,2
1,1,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0
1221,1,-0.917045,0.373575,0.296349,0.019102,-0.031842,-0.222027,-0.199950,1.325165,1.267433,0.165132,-1.809698,-0.004046,0.500912,0.270602,-0.251901,-1.447138,-0.726851,-0.693402,-0.404121,1.370337,-0.217417,1.985823,-1.789867,0.321832,-1.061878,1.235261,0.092415,-0.922268,-0.776751,-0.202838,0.621944,0.826967,1.040406,-0.348309,0.222908,0.748655,-0.592812,1.718307,-1.127791,...,0.700823,-0.055185,-1.005235,-0.825481,0.085568,-0.386882,-0.103534,0.902356,-0.797575,-0.626853,-0.060273,0.999999,0.216916,0.010894,-0.156528,-0.831907,-0.205951,0.902887,0.775394,-0.181355,-0.187022,-0.238478,0.150551,0.844083,-0.755152,-0.479676,0.128678,0.407574,1.357091,0.235071,-1.081652,0.645607,0.581035,0.839101,-0.362388,1.229239,-1.301037,-0.391490,0.330331,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2316,1096,0.235795,0.793866,-1.436199,0.495416,0.341805,-0.139172,-1.502401,0.675525,2.113461,0.210705,0.958565,1.007658,0.078338,1.308932,0.266498,0.045223,0.967561,-1.593580,0.673034,-1.819969,-0.799343,1.212783,-0.301537,1.008629,-0.635095,-4.600780,0.569642,0.911000,-1.464542,-1.152730,-1.155847,0.646057,-1.263747,0.872482,0.245150,0.245420,-1.026173,-0.783827,-0.055586,...,-1.832061,-0.590258,1.075936,0.496549,-0.186512,0.208193,0.634965,0.840185,0.431335,1.105372,0.033106,0.267282,-0.477303,-0.276528,3.381330,2.106128,-2.046720,-0.964806,-0.096877,0.246545,-0.885334,0.348636,-0.927359,0.791813,0.778852,-0.014592,-1.239978,-0.625722,-0.713983,-2.427725,0.335188,-0.188790,0.684264,0.862043,0.854590,-1.840698,0.641495,0.861893,-0.701756,1
3536,1096,-0.161215,-1.651199,-0.402627,0.365231,-0.060228,-0.233239,-1.540840,-0.068755,1.524139,0.084803,0.120906,-0.115822,0.652399,0.520354,0.064511,0.065284,0.124667,-0.464520,-0.609141,0.029550,-1.387422,0.194509,-0.052304,1.601354,-0.061263,-0.219311,-1.173846,-1.206285,1.788630,-1.226825,-0.280440,0.256537,-0.498572,-0.341884,0.094889,0.827402,0.703937,0.194871,0.528937,...,-0.180936,-0.319193,1.477943,1.653123,-0.715702,0.222869,0.952676,-1.224426,-0.614755,1.584475,-0.707195,-0.050350,-0.793896,-0.611213,0.317939,0.278965,-1.675660,-1.148505,-1.518317,-0.969804,-0.436997,0.361642,0.235858,-0.439595,0.473661,-0.230062,0.952682,-0.397151,-0.208232,-0.729604,0.298263,-0.130472,-0.246768,-0.216665,-0.489377,-2.487978,0.918601,0.724660,-0.519444,2
1097,1097,1.393395,-0.960512,0.799016,-5.037099,1.802017,0.203709,3.742825,1.333749,-0.717466,0.639435,-0.077403,-1.073687,1.664010,0.006667,0.225587,-2.120850,1.509332,-0.723253,0.669966,0.785729,-0.841904,-0.509171,0.270334,-1.219625,0.296941,0.841471,1.852109,-1.235564,1.951918,1.119238,-0.371044,-0.049922,0.893931,-0.888359,0.201307,0.171159,1.284370,1.438194,-0.667012,...,1.163838,0.488736,-1.465145,2.840193,0.674275,-0.158068,-3.591777,0.195900,-1.101994,-0.021928,0.733843,-0.537392,-0.467608,1.843937,-0.202686,-1.024272,0.568953,0.609179,0.082057,0.733061,-1.215226,-2.016319,-1.553871,-0.380017,-1.506010,1.565906,0.332653,0.504419,0.326811,-1.105340,-2.062027,-1.923397,0.140001,0.160493,0.177020,1.963074,1.205313,-1.217377,0.348695,0
3537,1097,0.303290,-2.041984,0.000178,0.231410,-0.244404,-0.204807,-1.164735,-1.212279,1.787682,-0.033794,2.137438,-0.037219,0.185490,0.732767,0.067431,-1.138284,-0.741437,-0.556347,-0.656311,0.627950,-1.358971,0.484626,0.778144,1.269949,-0.392998,-0.310960,-1.009415,0.062687,1.968434,-1.637262,-0.043629,0.210474,-0.749165,-0.129112,-0.162086,1.879023,0.446851,-0.543166,0.138576,...,-0.690420,-0.305783,1.471808,2.139212,-0.298103,0.398183,1.617495,-0.505975,-0.996070,1.082209,-0.709393,-1.567418,-0.775145,-0.854719,0.221138,0.329437,-1.133425,-0.841096,-0.703182,-0.837147,0.630092,0.842397,0.227419,-0.139813,0.679410,-0.304674,1.161658,-0.352890,-0.406201,0.477849,1.075185,-0.419914,-1.507225,-0.861868,-0.564768,-1.650619,0.589593,0.849424,-0.611488,2


### Create dataset and dataloaders
* Ref: https://pytorch-forecasting.readthedocs.io/en/stable/data.html

In [4]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.metrics import QuantileLoss

# load data
# define dataset
max_encoder_length = 6
max_prediction_length = 1

# create validation and training dataset
batch_size = 128
max_prediction_length = 3
max_encoder_length = 24

# create the dataset from the pandas dataframe
train_dataset = TimeSeriesDataSet(
    df_train,
    group_ids=["investment_id"],
    target="target",
    time_idx="time_id",
    min_encoder_length=max_encoder_length // 2,
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    # static_reals=[],
    time_varying_unknown_reals=['target'] + [f"f_{i}" for i in range(300)],
    target_normalizer=GroupNormalizer(
        groups=["investment_id"], transformation="softplus"
    ),  # use softplus and normalize by group
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

val_dataset = TimeSeriesDataSet.from_dataset(train_dataset, df_train, predict=True, stop_randomization=True)

# create dataloaders for model
batch_size = 64  # set this between 32 to 128
train_dataloader = train_dataset.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = val_dataset.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0) # Check this


### TODO: configure DeepAR model

In [8]:
# configure network and trainer




### TODO: Train Model

* Note: use tensorboard to check the logs: run ```tensorboard --logdir=<logging_folder>```
* To visualize tensorboard in Jupyter Notebook: 
    ```
    %reload_ext tensorboard
    %tensorboard --logdir=<logging_folder>
    ```

### TODO: test model and calculate performance metrics on test data