In [1]:
pip install pytorch-lightning  

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.6.4-py3-none-any.whl (585 kB)
[K     |████████████████████████████████| 585 kB 9.8 MB/s 
[?25hCollecting PyYAML>=5.4
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 57.3 MB/s 
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)
[K     |████████████████████████████████| 140 kB 73.2 MB/s 
Collecting pyDeprecate>=0.3.1
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Collecting torchmetrics>=0.4.1
  Downloading torchmetrics-0.9.1-py3-none-any.whl (419 kB)
[K     |████████████████████████████████| 419 kB 68.3 MB/s 
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manyli

In [2]:
pip install pytorch-forecasting

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-forecasting
  Downloading pytorch_forecasting-0.10.1-py3-none-any.whl (127 kB)
[K     |████████████████████████████████| 127 kB 8.1 MB/s 
Collecting optuna<3.0.0,>=2.3.0
  Downloading optuna-2.10.0-py3-none-any.whl (308 kB)
[K     |████████████████████████████████| 308 kB 40.8 MB/s 
Collecting cliff
  Downloading cliff-3.10.1-py3-none-any.whl (81 kB)
[K     |████████████████████████████████| 81 kB 11.2 MB/s 
[?25hCollecting alembic
  Downloading alembic-1.8.0-py3-none-any.whl (209 kB)
[K     |████████████████████████████████| 209 kB 70.0 MB/s 
Collecting cmaes>=0.8.2
  Downloading cmaes-0.8.2-py3-none-any.whl (15 kB)
Collecting colorlog
  Downloading colorlog-6.6.0-py2.py3-none-any.whl (11 kB)
Collecting Mako
  Downloading Mako-1.2.0-py3-none-any.whl (78 kB)
[K     |████████████████████████████████| 78 kB 8.8 MB/s 
Collecting cmd2>=1.0.0
  Downloading cmd2-2.4.1-p

In [3]:
import os

import warnings
warnings.filterwarnings('ignore')

import pickle

import numpy as np
import pandas as pd
from typing import Dict, Callable, List, Optional, Tuple, Union

from sklearn.preprocessing import MinMaxScaler

import matplotlib.pyplot as plt

import torch

import pytorch_lightning as pl

In [4]:
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_forecasting import Baseline, DeepAR, TimeSeriesDataSet

In [5]:
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_forecasting.metrics import SMAPE, DistributionLoss, distributions, BaseEstimator

In [6]:
class MultivariateDistributionLoss(DistributionLoss):
    """Base class for multivariate distribution losses.
    Class should be inherited for all multivariate distribution losses, i.e. if a batch of values
    is predicted in one go and the batch dimension is not independent, but the time dimension still
    remains independent.
    """

    def sample(self, y_pred, n_samples: int) -> torch.Tensor:
        """
        Sample from distribution.
        Args:
            y_pred: prediction output of network (shape batch_size x n_timesteps x n_paramters)
            n_samples (int): number of samples to draw
        Returns:
            torch.Tensor: tensor with samples  (shape batch_size x n_timesteps x n_samples)
        """
        dist = self.map_x_to_distribution(y_pred)
        samples = dist.sample((n_samples,)).permute(
            2, 1, 0
        )  # returned as (n_samples, n_timesteps, batch_size), so reshape to (batch_size, n_timesteps, n_samples)
        return samples

    def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
        """
        Convert network prediction into a point prediction.
        Args:
            y_pred: prediction output of network
        Returns:
            torch.Tensor: mean prediction
        """
        distribution = self.map_x_to_distribution(y_pred)

        return distribution.mean.transpose(0, 1)  # switch to batch_size x n_timesteps

    def loss(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor:
        """
        Calculate negative likelihood
        Args:
            y_pred: network output
            y_actual: actual values
        Returns:
            torch.Tensor: metric value on which backpropagation can be applied
        """
        distribution = self.map_x_to_distribution(y_pred)
        # calculate one number and scale with batch size
        loss = -distribution.log_prob(y_actual.transpose(0, 1)).sum() * y_actual.size(0)
        return loss


In [7]:
import torch.nn.functional as F

class MultivariateNormalDistributionLoss(MultivariateDistributionLoss):
    """
    Multivariate low-rank normal distribution loss.
    Use this loss to make out of a DeepAR model a DeepVAR network.
    Requirements for original target normalizer:
        * not normalized in log space (use :py:class:`~LogNormalDistributionLoss`)
        * not coerced to be positive
    """

    distribution_class = distributions.LowRankMultivariateNormal

    def __init__(
        self,
        name: str = None,
        quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98],
        reduction: str = "mean",
        rank: int = 10,
        sigma_init: float = 1.0,
        sigma_minimum: float = 1e-3,
    ):
        """
        Initialize metric
        Args:
            name (str): metric name. Defaults to class name.
            quantiles (List[float], optional): quantiles for probability range.
                Defaults to [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98].
            reduction (str, optional): Reduction, "none", "mean" or "sqrt-mean". Defaults to "mean".
            rank (int): rank of low-rank approximation for covariance matrix. Defaults to 10.
            sigma_init (float, optional): default value for diagonal covariance. Defaults to 1.0.
            sigma_minimum (float, optional): minimum value for diagonal covariance. Defaults to 1e-3.
        """
        super().__init__(name=name, quantiles=quantiles, reduction=reduction)
        self.rank = rank
        self.sigma_minimum = sigma_minimum
        self.sigma_init = sigma_init
        self.distribution_arguments = list(range(2 + rank))

        # determine bias
        self._diag_bias: float = self.inv_softplus(self.sigma_init**2) if self.sigma_init > 0.0 else 0.0

    def map_x_to_distribution(self, x: torch.Tensor) -> distributions.Normal:
        x = x.permute(1, 0, 2)
        return self.distribution_class(
            loc=x[..., 0],
            cov_factor=x[..., 2:],
            cov_diag=x[..., 1],
        )

    @staticmethod
    def validate_encoder(encoder: BaseEstimator):
        assert encoder.transformation not in [
            "log",
            "log1p",
        ], "Use MultivariateLogNormalDistributionLoss for log scaled data"  # todo: implement
        assert encoder.transformation not in [
            "softplus",
            "relu",
        ], "Cannot use NormalDistributionLoss for positive data"
        assert encoder.transformation not in ["logit"], "Cannot use bound transformation such as 'logit'"

    def rescale_parameters(
        self, parameters: torch.Tensor, target_scale: torch.Tensor, encoder: BaseEstimator
    ) -> torch.Tensor:
        self.validate_encoder(encoder)

        # scale
        loc = encoder(dict(prediction=parameters[..., 0], target_scale=target_scale)).unsqueeze(-1)
        scale = (
            F.softplus(parameters[..., 1].unsqueeze(-1) + self._diag_bias) + self.sigma_minimum**2
        ) * target_scale[..., 1, None, None] ** 2

        cov_factor = parameters[..., 2:] * target_scale[..., 1, None, None]
        return torch.concat([loc, scale, cov_factor], dim=-1)

    def inv_softplus(self, y):
        if y < 20.0:
            return np.log(np.exp(y) - 1.0)
        else:
            return y

In [8]:
df=pd.read_excel('/content/drive/MyDrive/교육/AI실무인증과정/학회 발표/data/data_full.xlsx')
df.head()

Unnamed: 0,Date,Account DOW,REV OBD,OBD NET+FSC_KRW,OBD A/R_KRW,REV CPN,CPN NET+FSC_KRW,CPN A/R_KRW,REV TKT,TKT NET+FSC_KRW,TKT A/R_KRW,WTI,exchanges,kospi,rates,stock_a,stock_k,stock_kkj
0,2016-01-01,FRI,25196,9923237000.0,393841.7444,5517,2410238000.0,436874.6714,3188,2236210000.0,701446.0183,0.0,0.0,0.0,0.0,0,0,0
1,2016-01-02,SAT,27495,11066710000.0,402499.1654,6320,2500112000.0,395587.3387,3577,2405358000.0,672451.2796,0.0,0.0,0.0,0.0,0,0,0
2,2016-01-03,SUN,31843,12708000000.0,399083.0044,4292,1813688000.0,422573.9862,2521,1697046000.0,673163.7338,0.0,0.0,0.0,0.0,0,0,0
3,2016-01-04,MON,28000,11541730000.0,412204.5697,36263,14656230000.0,404164.9715,19528,14544510000.0,744803.0601,36.76,1189.5,1918.76,1.63,21429,16459,108918
4,2016-01-05,TUE,24657,10087910000.0,409129.4946,38432,15386670000.0,400360.8701,20461,15551750000.0,760067.7182,35.97,1189.5,1930.53,1.64,20881,16186,108918


In [9]:
df.tail()

Unnamed: 0,Date,Account DOW,REV OBD,OBD NET+FSC_KRW,OBD A/R_KRW,REV CPN,CPN NET+FSC_KRW,CPN A/R_KRW,REV TKT,TKT NET+FSC_KRW,TKT A/R_KRW,WTI,exchanges,kospi,rates,stock_a,stock_k,stock_kkj
1853,2021-01-27,WED,1099,872488809.0,793893.4,557,482987700.0,867123.4,446,568712000.0,1275139.0,52.85,1105.0,3122.56,0.98,17302,30400,59587
1854,2021-01-28,THU,369,346194072.0,938195.3,686,599869500.0,874445.3,533,690605800.0,1295696.0,52.34,1118.0,3069.05,0.97,17023,29650,58220
1855,2021-01-29,FRI,929,697137374.0,750417.0,682,603589800.0,885029.0,540,679201500.0,1257780.0,52.2,1117.5,2976.21,0.97,16353,28700,56657
1856,2021-01-30,SAT,504,394144488.0,782032.7,261,219166300.0,839717.4,186,212856600.0,1144391.0,52.2,1117.5,2976.21,0.97,16353,28700,56657
1857,2021-01-31,SUN,461,463553317.0,1005539.0,302,364514000.0,1207000.0,215,-133093200.0,-619038.0,52.2,1117.5,2976.21,0.97,16353,28700,56657


In [10]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1858 entries, 0 to 1857
Data columns (total 18 columns):
 #   Column           Non-Null Count  Dtype         
---  ------           --------------  -----         
 0   Date             1858 non-null   datetime64[ns]
 1   Account DOW      1858 non-null   object        
 2   REV OBD          1858 non-null   int64         
 3   OBD NET+FSC_KRW  1858 non-null   float64       
 4   OBD A/R_KRW      1858 non-null   float64       
 5   REV CPN          1858 non-null   int64         
 6   CPN NET+FSC_KRW  1858 non-null   float64       
 7   CPN A/R_KRW      1858 non-null   float64       
 8   REV TKT          1858 non-null   int64         
 9   TKT NET+FSC_KRW  1858 non-null   float64       
 10  TKT A/R_KRW      1858 non-null   float64       
 11  WTI              1858 non-null   float64       
 12  exchanges        1858 non-null   float64       
 13  kospi            1858 non-null   float64       
 14  rates            1858 non-null   float64

In [11]:
df = df[df["Date"].isin(pd.date_range('2016-01-4', '2019-11-30'))]
df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1427 entries, 3 to 1429
Data columns (total 18 columns):
 #   Column           Non-Null Count  Dtype         
---  ------           --------------  -----         
 0   Date             1427 non-null   datetime64[ns]
 1   Account DOW      1427 non-null   object        
 2   REV OBD          1427 non-null   int64         
 3   OBD NET+FSC_KRW  1427 non-null   float64       
 4   OBD A/R_KRW      1427 non-null   float64       
 5   REV CPN          1427 non-null   int64         
 6   CPN NET+FSC_KRW  1427 non-null   float64       
 7   CPN A/R_KRW      1427 non-null   float64       
 8   REV TKT          1427 non-null   int64         
 9   TKT NET+FSC_KRW  1427 non-null   float64       
 10  TKT A/R_KRW      1427 non-null   float64       
 11  WTI              1427 non-null   float64       
 12  exchanges        1427 non-null   float64       
 13  kospi            1427 non-null   float64       
 14  rates            1427 non-null   float64

In [12]:
df.reset_index(drop=True, inplace=True)

In [13]:
df

Unnamed: 0,Date,Account DOW,REV OBD,OBD NET+FSC_KRW,OBD A/R_KRW,REV CPN,CPN NET+FSC_KRW,CPN A/R_KRW,REV TKT,TKT NET+FSC_KRW,TKT A/R_KRW,WTI,exchanges,kospi,rates,stock_a,stock_k,stock_kkj
0,2016-01-04,MON,28000,1.154173e+10,412204.5697,36263,1.465623e+10,404164.9715,19528,1.454451e+10,744803.0601,36.76,1189.5,1918.76,1.63,21429,16459,108918
1,2016-01-05,TUE,24657,1.008791e+10,409129.4946,38432,1.538667e+10,400360.8701,20461,1.555175e+10,760067.7182,35.97,1189.5,1930.53,1.64,20881,16186,108918
2,2016-01-06,WED,26920,1.009247e+10,374905.8415,41478,1.597113e+10,385050.5604,21953,1.647945e+10,750669.3999,33.97,1200.0,1925.43,1.64,20785,16063,110383
3,2016-01-07,THU,26624,1.060220e+10,398219.6596,49006,1.798140e+10,366922.4713,25917,1.872443e+10,722476.6906,33.27,1200.2,1904.33,1.64,20809,16003,104522
4,2016-01-08,FRI,28879,1.077601e+10,373143.3549,63847,2.480722e+10,388541.6089,33748,2.598923e+10,770096.9727,33.16,1199.5,1917.62,1.67,20523,15637,105499
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1422,2019-11-26,TUE,22295,7.106925e+09,318767.6722,33929,1.501046e+10,442407.9889,18388,1.618648e+10,880274.0396,58.41,1177.5,2121.35,1.48,26679,16567,47670
1423,2019-11-27,WED,25134,8.100108e+09,322276.9318,35642,1.487445e+10,417329.2423,19304,1.609138e+10,833577.7006,58.11,1177.5,2127.85,1.46,26250,16438,46986
1424,2019-11-28,THU,24742,8.116231e+09,328034.5456,35459,1.549214e+10,436902.7653,19087,1.674230e+10,877157.3571,58.11,1179.0,2118.60,1.43,26059,16470,46937
1425,2019-11-29,FRI,29624,9.924881e+09,335028.3726,41190,1.766483e+10,428862.0799,22167,1.915980e+10,864338.9685,55.17,1180.0,2087.96,1.39,25391,16048,47426


In [14]:
df['time_index'] = np.arange(len(df))
df

Unnamed: 0,Date,Account DOW,REV OBD,OBD NET+FSC_KRW,OBD A/R_KRW,REV CPN,CPN NET+FSC_KRW,CPN A/R_KRW,REV TKT,TKT NET+FSC_KRW,TKT A/R_KRW,WTI,exchanges,kospi,rates,stock_a,stock_k,stock_kkj,time_index
0,2016-01-04,MON,28000,1.154173e+10,412204.5697,36263,1.465623e+10,404164.9715,19528,1.454451e+10,744803.0601,36.76,1189.5,1918.76,1.63,21429,16459,108918,0
1,2016-01-05,TUE,24657,1.008791e+10,409129.4946,38432,1.538667e+10,400360.8701,20461,1.555175e+10,760067.7182,35.97,1189.5,1930.53,1.64,20881,16186,108918,1
2,2016-01-06,WED,26920,1.009247e+10,374905.8415,41478,1.597113e+10,385050.5604,21953,1.647945e+10,750669.3999,33.97,1200.0,1925.43,1.64,20785,16063,110383,2
3,2016-01-07,THU,26624,1.060220e+10,398219.6596,49006,1.798140e+10,366922.4713,25917,1.872443e+10,722476.6906,33.27,1200.2,1904.33,1.64,20809,16003,104522,3
4,2016-01-08,FRI,28879,1.077601e+10,373143.3549,63847,2.480722e+10,388541.6089,33748,2.598923e+10,770096.9727,33.16,1199.5,1917.62,1.67,20523,15637,105499,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1422,2019-11-26,TUE,22295,7.106925e+09,318767.6722,33929,1.501046e+10,442407.9889,18388,1.618648e+10,880274.0396,58.41,1177.5,2121.35,1.48,26679,16567,47670,1422
1423,2019-11-27,WED,25134,8.100108e+09,322276.9318,35642,1.487445e+10,417329.2423,19304,1.609138e+10,833577.7006,58.11,1177.5,2127.85,1.46,26250,16438,46986,1423
1424,2019-11-28,THU,24742,8.116231e+09,328034.5456,35459,1.549214e+10,436902.7653,19087,1.674230e+10,877157.3571,58.11,1179.0,2118.60,1.43,26059,16470,46937,1424
1425,2019-11-29,FRI,29624,9.924881e+09,335028.3726,41190,1.766483e+10,428862.0799,22167,1.915980e+10,864338.9685,55.17,1180.0,2087.96,1.39,25391,16048,47426,1425


In [15]:
df.keys()

Index(['Date', 'Account DOW', 'REV OBD', 'OBD NET+FSC_KRW', 'OBD A/R_KRW',
       'REV CPN', 'CPN NET+FSC_KRW', 'CPN A/R_KRW', 'REV TKT',
       'TKT NET+FSC_KRW', 'TKT A/R_KRW', 'WTI', 'exchanges', 'kospi', 'rates',
       'stock_a', 'stock_k', 'stock_kkj', 'time_index'],
      dtype='object')

In [16]:
# Scaling
scaler = MinMaxScaler()
scale_col = ['REV OBD', 'OBD NET+FSC_KRW', 'OBD A/R_KRW', 'REV CPN',
             'CPN NET+FSC_KRW', 'CPN A/R_KRW', 'REV TKT', 'TKT NET+FSC_KRW',
             'TKT A/R_KRW', 'WTI', 'exchanges', 'kospi', 'rates',
             'stock_a', 'stock_k', 'stock_kkj']
scaled = scaler.fit_transform(df[scale_col])

In [17]:
scaled.shape

(1427, 16)

In [18]:
tmp_df_1 = df[['time_index', 'Date', 'Account DOW']]
columns = ['REV OBD', 'OBD NET+FSC_KRW', 'OBD A/R_KRW', 'REV CPN',
           'CPN NET+FSC_KRW', 'CPN A/R_KRW', 'REV TKT', 'TKT NET+FSC_KRW',
           'TKT A/R_KRW', 'WTI', 'exchanges', 'kospi', 'rates',
           'stock_a', 'stock_k', 'stock_kkj']
tmp_df_2 = pd.DataFrame(scaled, columns=columns)
res_data = pd.concat([tmp_df_1, tmp_df_2], axis=1)
res_data

Unnamed: 0,time_index,Date,Account DOW,REV OBD,OBD NET+FSC_KRW,OBD A/R_KRW,REV CPN,CPN NET+FSC_KRW,CPN A/R_KRW,REV TKT,TKT NET+FSC_KRW,TKT A/R_KRW,WTI,exchanges,kospi,rates,stock_a,stock_k,stock_kkj
0,0,2016-01-04,MON,0.346259,0.326659,0.374062,0.382432,0.410254,0.477330,0.390738,0.500458,0.532091,0.210159,0.722372,0.109423,0.439024,0.223500,0.211446,0.847607
1,1,2016-01-05,TUE,0.218042,0.241049,0.366526,0.407230,0.432465,0.456675,0.411128,0.506018,0.533700,0.194422,0.722372,0.124851,0.447154,0.200986,0.186607,0.847607
2,2,2016-01-06,WED,0.304836,0.241318,0.282653,0.442056,0.450237,0.373548,0.443735,0.511139,0.532709,0.154582,0.778976,0.118166,0.447154,0.197042,0.175416,0.865189
3,3,2016-01-07,THU,0.293484,0.271334,0.339789,0.528126,0.511364,0.275121,0.530367,0.523532,0.529736,0.140637,0.780054,0.090509,0.447154,0.198028,0.169957,0.794849
4,4,2016-01-08,FRI,0.379972,0.281569,0.278334,0.697807,0.718921,0.392503,0.701510,0.563636,0.534758,0.138446,0.776280,0.107929,0.471545,0.186278,0.136657,0.806574
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1422,1422,2019-11-26,TUE,0.127450,0.065511,0.145074,0.355746,0.421025,0.684971,0.365824,0.509522,0.546376,0.641434,0.657682,0.374972,0.317073,0.439195,0.221272,0.112549
1423,1423,2019-11-27,WED,0.236336,0.123995,0.153675,0.375332,0.416889,0.548806,0.385843,0.508997,0.541452,0.635458,0.657682,0.383492,0.300813,0.421569,0.209535,0.104340
1424,1424,2019-11-28,THU,0.221302,0.124945,0.167785,0.373239,0.435672,0.655080,0.381100,0.512590,0.546047,0.635458,0.665768,0.371368,0.276423,0.413722,0.212447,0.103752
1425,1425,2019-11-29,FRI,0.408545,0.231449,0.184925,0.438763,0.501738,0.611423,0.448412,0.525936,0.544695,0.576892,0.671159,0.331206,0.243902,0.386278,0.174051,0.109620


In [19]:
res_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1427 entries, 0 to 1426
Data columns (total 19 columns):
 #   Column           Non-Null Count  Dtype         
---  ------           --------------  -----         
 0   time_index       1427 non-null   int64         
 1   Date             1427 non-null   datetime64[ns]
 2   Account DOW      1427 non-null   object        
 3   REV OBD          1427 non-null   float64       
 4   OBD NET+FSC_KRW  1427 non-null   float64       
 5   OBD A/R_KRW      1427 non-null   float64       
 6   REV CPN          1427 non-null   float64       
 7   CPN NET+FSC_KRW  1427 non-null   float64       
 8   CPN A/R_KRW      1427 non-null   float64       
 9   REV TKT          1427 non-null   float64       
 10  TKT NET+FSC_KRW  1427 non-null   float64       
 11  TKT A/R_KRW      1427 non-null   float64       
 12  WTI              1427 non-null   float64       
 13  exchanges        1427 non-null   float64       
 14  kospi            1427 non-null   float64

In [20]:
res_data.describe()

Unnamed: 0,time_index,REV OBD,OBD NET+FSC_KRW,OBD A/R_KRW,REV CPN,CPN NET+FSC_KRW,CPN A/R_KRW,REV TKT,TKT NET+FSC_KRW,TKT A/R_KRW,WTI,exchanges,kospi,rates,stock_a,stock_k,stock_kkj
count,1427.0,1427.0,1427.0,1427.0,1427.0,1427.0,1427.0,1427.0,1427.0,1427.0,1427.0,1427.0,1427.0,1427.0,1427.0,1427.0,1427.0
mean,713.0,0.360195,0.282904,0.29238,0.314459,0.33552,0.418474,0.320825,0.492223,0.534719,0.552898,0.452095,0.456563,0.516457,0.282264,0.489028,0.46703
std,412.083729,0.149388,0.133402,0.106898,0.198184,0.217978,0.174246,0.20016,0.04634,0.020388,0.19152,0.210885,0.246101,0.253865,0.12856,0.236769,0.22652
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,356.5,0.252905,0.189293,0.218024,0.074551,0.064994,0.293565,0.078622,0.439682,0.528946,0.433267,0.339623,0.259166,0.308943,0.187223,0.311255,0.331777
50%,713.0,0.348752,0.267715,0.282006,0.374474,0.404387,0.426462,0.383002,0.505914,0.53583,0.539841,0.423181,0.383492,0.504065,0.244125,0.489673,0.441974
75%,1069.5,0.456392,0.355028,0.355314,0.44324,0.480351,0.545623,0.452379,0.520757,0.540755,0.683367,0.613208,0.688987,0.707317,0.358833,0.663725,0.623686
max,1426.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [21]:
res_data[['time_index']].isna().sum() 

time_index    0
dtype: int64

In [22]:
res_data['time_index'] = res_data['time_index'].astype(int)

In [23]:
res_data.shape

(1427, 19)

In [24]:
max_encoder_length = 60  # encoder LSTM에 들어가는 길이
max_prediction_length = 20  # decoder의 길이
training_cutoff = res_data['time_index'].max() - max_prediction_length

In [25]:
training_cutoff

1406

In [26]:
res_data

Unnamed: 0,time_index,Date,Account DOW,REV OBD,OBD NET+FSC_KRW,OBD A/R_KRW,REV CPN,CPN NET+FSC_KRW,CPN A/R_KRW,REV TKT,TKT NET+FSC_KRW,TKT A/R_KRW,WTI,exchanges,kospi,rates,stock_a,stock_k,stock_kkj
0,0,2016-01-04,MON,0.346259,0.326659,0.374062,0.382432,0.410254,0.477330,0.390738,0.500458,0.532091,0.210159,0.722372,0.109423,0.439024,0.223500,0.211446,0.847607
1,1,2016-01-05,TUE,0.218042,0.241049,0.366526,0.407230,0.432465,0.456675,0.411128,0.506018,0.533700,0.194422,0.722372,0.124851,0.447154,0.200986,0.186607,0.847607
2,2,2016-01-06,WED,0.304836,0.241318,0.282653,0.442056,0.450237,0.373548,0.443735,0.511139,0.532709,0.154582,0.778976,0.118166,0.447154,0.197042,0.175416,0.865189
3,3,2016-01-07,THU,0.293484,0.271334,0.339789,0.528126,0.511364,0.275121,0.530367,0.523532,0.529736,0.140637,0.780054,0.090509,0.447154,0.198028,0.169957,0.794849
4,4,2016-01-08,FRI,0.379972,0.281569,0.278334,0.697807,0.718921,0.392503,0.701510,0.563636,0.534758,0.138446,0.776280,0.107929,0.471545,0.186278,0.136657,0.806574
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1422,1422,2019-11-26,TUE,0.127450,0.065511,0.145074,0.355746,0.421025,0.684971,0.365824,0.509522,0.546376,0.641434,0.657682,0.374972,0.317073,0.439195,0.221272,0.112549
1423,1423,2019-11-27,WED,0.236336,0.123995,0.153675,0.375332,0.416889,0.548806,0.385843,0.508997,0.541452,0.635458,0.657682,0.383492,0.300813,0.421569,0.209535,0.104340
1424,1424,2019-11-28,THU,0.221302,0.124945,0.167785,0.373239,0.435672,0.655080,0.381100,0.512590,0.546047,0.635458,0.665768,0.371368,0.276423,0.413722,0.212447,0.103752
1425,1425,2019-11-29,FRI,0.408545,0.231449,0.184925,0.438763,0.501738,0.611423,0.448412,0.525936,0.544695,0.576892,0.671159,0.331206,0.243902,0.386278,0.174051,0.109620


In [27]:
res_data['market'] = 'OBD'
res_data

Unnamed: 0,time_index,Date,Account DOW,REV OBD,OBD NET+FSC_KRW,OBD A/R_KRW,REV CPN,CPN NET+FSC_KRW,CPN A/R_KRW,REV TKT,TKT NET+FSC_KRW,TKT A/R_KRW,WTI,exchanges,kospi,rates,stock_a,stock_k,stock_kkj,market
0,0,2016-01-04,MON,0.346259,0.326659,0.374062,0.382432,0.410254,0.477330,0.390738,0.500458,0.532091,0.210159,0.722372,0.109423,0.439024,0.223500,0.211446,0.847607,OBD
1,1,2016-01-05,TUE,0.218042,0.241049,0.366526,0.407230,0.432465,0.456675,0.411128,0.506018,0.533700,0.194422,0.722372,0.124851,0.447154,0.200986,0.186607,0.847607,OBD
2,2,2016-01-06,WED,0.304836,0.241318,0.282653,0.442056,0.450237,0.373548,0.443735,0.511139,0.532709,0.154582,0.778976,0.118166,0.447154,0.197042,0.175416,0.865189,OBD
3,3,2016-01-07,THU,0.293484,0.271334,0.339789,0.528126,0.511364,0.275121,0.530367,0.523532,0.529736,0.140637,0.780054,0.090509,0.447154,0.198028,0.169957,0.794849,OBD
4,4,2016-01-08,FRI,0.379972,0.281569,0.278334,0.697807,0.718921,0.392503,0.701510,0.563636,0.534758,0.138446,0.776280,0.107929,0.471545,0.186278,0.136657,0.806574,OBD
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1422,1422,2019-11-26,TUE,0.127450,0.065511,0.145074,0.355746,0.421025,0.684971,0.365824,0.509522,0.546376,0.641434,0.657682,0.374972,0.317073,0.439195,0.221272,0.112549,OBD
1423,1423,2019-11-27,WED,0.236336,0.123995,0.153675,0.375332,0.416889,0.548806,0.385843,0.508997,0.541452,0.635458,0.657682,0.383492,0.300813,0.421569,0.209535,0.104340,OBD
1424,1424,2019-11-28,THU,0.221302,0.124945,0.167785,0.373239,0.435672,0.655080,0.381100,0.512590,0.546047,0.635458,0.665768,0.371368,0.276423,0.413722,0.212447,0.103752,OBD
1425,1425,2019-11-29,FRI,0.408545,0.231449,0.184925,0.438763,0.501738,0.611423,0.448412,0.525936,0.544695,0.576892,0.671159,0.331206,0.243902,0.386278,0.174051,0.109620,OBD


In [29]:
training = TimeSeriesDataSet(
    res_data[lambda x: x.time_index <= training_cutoff],
    time_idx = 'time_index',
    target = 'REV OBD',
    categorical_encoders = {'Account DOW':NaNLabelEncoder().fit(res_data['Account DOW']), 
                            'market':NaNLabelEncoder().fit(res_data['market'])},  # market 컬럼에 nan 값이 있으면 NaNLabelEncoder()로 처리하라는 의미임.
    group_ids = ['market'],
    time_varying_unknown_reals = ['REV OBD'],
    # time_varying_known_reals = ['candle_acc_trade_volume'],
    max_encoder_length = max_encoder_length,
    max_prediction_length = max_prediction_length,
    allow_missing_timesteps=True,
)

In [30]:
validation = TimeSeriesDataSet.from_dataset(   # from_dataseet을 이용하면 기존 dataset의 기 입력 속성 이용 가능
    training,
    res_data,  # 실제 적용 데이터셋
    min_prediction_idx = training_cutoff + 1
)

In [31]:
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)

### Calculate baseline error

In [55]:
val_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f71a47fca50>

In [38]:
actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])
[y[0].shape for x, y in iter(val_dataloader)]

[torch.Size([1, 20])]

In [39]:
baseline_predictions = Baseline().predict(val_dataloader)
baseline_predictions

tensor([[0.3446, 0.3446, 0.3446, 0.3446, 0.3446, 0.3446, 0.3446, 0.3446, 0.3446,
         0.3446, 0.3446, 0.3446, 0.3446, 0.3446, 0.3446, 0.3446, 0.3446, 0.3446,
         0.3446, 0.3446]])

In [40]:
SMAPE()(baseline_predictions, actuals)

tensor(0.2962)

In [41]:
pl.seed_everything(42)
trainer = pl.Trainer(gpus=1, gradient_clip_val=.1)  # gradient가 폭발하는 것을 막기 위해 0.1이 넘어서면 잘라버림, gpus는 사용하는 gpu의 개수
net = DeepAR.from_dataset(   # 데이터셋에 적용한 설정을 그대로 가져와 자동으로 DeepAR 네트워크에 적용
    training,
    learning_rate = 3e-2,
    hidden_size = 200,
    rnn_layers =2,
    loss = MultivariateNormalDistributionLoss()   # class가 아닌 함수를 전달해 줘야 함
)

Global seed set to 42
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [42]:
early_stop_callback = EarlyStopping(monitor='val_loss', min_delta=1e-4, patience=10, verbose=True, mode='min')
lr_logger = LearningRateMonitor()

In [43]:
trainer = pl.Trainer(
    max_epochs = 10,
    gpus = 1,
    weights_summary = 'top',  # 'top'은 무슨 의미?    
    gradient_clip_val = .01,
    callbacks = [lr_logger, early_stop_callback],
    limit_train_batches = 30,
    enable_checkpointing = True,
    # auto_lr_find = True
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [44]:
trainer.fit(
    net,
    train_dataloaders = train_dataloader,
    val_dataloaders = val_dataloader,
)

Missing logger folder: /content/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                   | Type                               | Params
------------------------------------------------------------------------------
0 | loss                   | MultivariateNormalDistributionLoss | 0     
1 | logging_metrics        | ModuleList                         | 0     
2 | embeddings             | MultiEmbedding                     | 0     
3 | rnn                    | LSTM                               | 484 K 
4 | distribution_projector | Linear                             | 2.4 K 
------------------------------------------------------------------------------
486 K     Trainable params
0         Non-trainable params
486 K     Total params
1.946     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: -0.834


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [45]:
best_model_path = trainer.checkpoint_callback.best_model_path
best_model = net.load_from_checkpoint(best_model_path)

In [46]:
best_model_path

'/content/lightning_logs/version_0/checkpoints/epoch=9-step=100.ckpt'

In [47]:
best_model

DeepAR(
  (loss): MultivariateNormalDistributionLoss()
  (logging_metrics): ModuleList(
    (0): SMAPE()
    (1): MAE()
    (2): RMSE()
    (3): MAPE()
    (4): MASE()
  )
  (embeddings): MultiEmbedding(
    (embeddings): ModuleDict()
  )
  (rnn): LSTM(1, 200, num_layers=2, batch_first=True, dropout=0.1)
  (distribution_projector): Linear(in_features=200, out_features=12, bias=True)
)

In [48]:
actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])
predictions = net.predict(val_dataloader)
(actuals - predictions).abs().mean()

tensor(0.0824)

In [49]:
actuals.shape

torch.Size([1, 20])

In [50]:
predictions.shape

torch.Size([1, 20])

In [51]:
SMAPE()(predictions, actuals)

tensor(0.2814)

In [63]:
predictions[0]

tensor([0.3249, 0.3274, 0.3091, 0.3029, 0.3142, 0.3131, 0.3218, 0.3188, 0.3195,
        0.3070, 0.3165, 0.3097, 0.3452, 0.3548, 0.3301, 0.3162, 0.3103, 0.3163,
        0.3369, 0.3387])

In [54]:
for i in np.arange(10, 20):
    plt.figure(figsize=(4, 3))
    plt.plot(predictions[i], label='prediction')
    plt.plot(actuals[i], label='actual')
    plt.suptitle('Timeseries Prediction')
    plt.legend()
    plt.show()

IndexError: ignored

<Figure size 288x216 with 0 Axes>

In [None]:
raw_predictions, x = net.predict(val_dataloader, return_x=True, n_samples=100)
for i in np.arange(60, 70):
  plt.figure(figsize=(4, 3))
  plt.plot(raw_predictions[i], label='prediction')
  plt.plot(actuals[i], label='actual')
  plt.suptitle('Timeseries Prediction')
  plt.legend()
  plt.show()

IndexError: ignored

<Figure size 288x216 with 0 Axes>