<a href="https://colab.research.google.com/github/Krankile/npmf/blob/main/notebooks/training_loop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

## Kernel setup

In [11]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
%%capture
!pip install wandb more_itertools
!git clone https://github.com/Krankile/npmf.git

In [13]:
# https://wandb.ai/authorize
!wandb login

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


## General setup

In [14]:
%%capture
!cd npmf && git pull

import math
import multiprocessing
import os
import pickle
import random
from collections import Counter, defaultdict
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta
from operator import itemgetter
from typing import Callable, List, Tuple
from functools import partial
from glob import glob
from enum import Enum
from pathlib import Path

from more_itertools import chunked

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from npmf.utils import Problem
from npmf.utils.colors import main, main2, main3
from npmf.utils.dataset import EraDataset, EraController
from npmf.utils.dataset.utils import clamp_and_slice
from npmf.utils.dtypes import fundamental_types
from npmf.utils.eikon import column_mapping
from npmf.utils.tests.utils import pickle_df
from npmf.utils.wandb import get_datasets, put_dataset, put_nn_model, get_processed_data
from npmf.utils.training import EarlyStop, to_device, TqdmPostFix, loss_fns, get_naive_pred
from npmf.utils.models import models

from numpy.ma.core import outerproduct
from pandas.tseries.offsets import BDay, Day
from sklearn.preprocessing import MinMaxScaler, minmax_scale
from torch import nn
from torch.utils.data import DataLoader, Dataset, ConcatDataset

import wandb as wb

In [15]:
np.seterr(all="raise")

mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=[main, main2, main3, "black"])
mpl.rcParams['figure.figsize'] = (6, 4)  # (6, 4) is default and used in the paper

In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [17]:
!nvidia-smi

Wed Jun  8 12:26:41 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P0    26W / 250W |      2MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [18]:
pre_proc_data_dir = None
np.random.seed(69)


# Get some data

In [19]:
%%capture
reload_data = not True

if reload_data or not "stock_df" in vars():
    names = ["stock-data:final", "fundamental-data:final", "meta-data:final", "macro-data:final"]

    stock_df, fundamental_df, meta_df, macro_df = get_datasets(names=names, project="master")

    stock_df = stock_df.drop(columns=["close_price", "currency"]).astype({"market_cap": np.float32})
    fundamental_df = fundamental_df.drop(columns="period_end_date").astype(fundamental_types)
    macro_df.iloc[:, 1:] = macro_df.iloc[:, 1:].astype(np.float32)

# Create the loop! (Like Hans Gude Gudesen)

In [20]:
# Check if it's necessary to calculate naive loss every epoch
def get_epoch_loss(model, optimizer, dataloader, loss_fn, device, run_type, conf) -> Tuple[np.array, np.array, np.array]:
    model_losses = []
    naive_losses = []
    y_preds = []
    for data, meta_cont, meta_cat, target in to_device(dataloader, device):
        if run_type == "train":
            optimizer.zero_grad()
        
        with torch.no_grad():
            naive_pred = get_naive_pred(data, target, device, conf)
            naive_loss = loss_fn(target.clone(), naive_pred)

        y_pred: torch.Tensor = model(data, meta_cont, meta_cat)
        loss = loss_fn(target, y_pred)
        
        model_losses.append(loss.item())
        naive_losses.append(naive_loss.item())
        y_preds.append(y_pred.detach().cpu().numpy())

        if run_type == "train":
            loss.backward()
            optimizer.step()

    return model_losses, naive_losses, np.concatenate(y_preds, axis=0)

In [21]:
def eras_ahead_loss(model, data_loaders, conf):
    model_infront = []
    naive_infront = []

    with torch.no_grad():
        for loader in data_loaders:
            model_loss, naive_loss, _ = get_epoch_loss(model, None, loader, loss_fns["mape_2"], device, "inference", conf)
        
            model_infront += model_loss
            naive_infront += naive_loss
    
    return np.array(model_infront), np.array(naive_infront)

In [22]:
def train_one_era(run, model, optimizer, data_train, data_val, stopper, losses, device, conf, pbar):

    for epoch in range(conf.max_epochs):
        epoch_losses = dict(train=[], val=[])
        
        pbar.update_postfix({"epoch": epoch})
        for run_type, dataloader in {"train": data_train, "val": data_val}.items():
            model.train(run_type == "train")
            
            epoch_model_loss, naive_losses, y_preds = get_epoch_loss(model, optimizer, dataloader, loss_fns[conf[f"{run_type}_loss"]], device, run_type, conf)
            epoch_losses[run_type] += epoch_model_loss

            epoch_loss = np.mean(epoch_losses[run_type])
            losses[run_type].append(epoch_loss)

            run.log({f"epoch_{run_type}": epoch_loss, "epoch": epoch, "ticker_var": y_preds.std(axis=0).mean(), "self_var": y_preds.std(axis=1).mean()})

        pbar.update_postfix({"train_loss": np.mean(epoch_losses["train"]), "val_loss": np.mean(epoch_losses["val"]), "naive": np.mean(naive_losses)})

        # TODO: Implement checkpointing of the best model according to val_loss
        if run_type == "val" and stopper(epoch_losses["val"]):
            losses["epoch_lens"].append(epoch + 1)
            break

    return epoch_losses["train"], epoch_losses["val"]

In [40]:
def calculate_metrics(eras, model, train_losses, val_losses, i, run, conf, pbar):
    loaders_infront, loaders_end = eras.validation_loaders()
    model_infront, naive_infront = eras_ahead_loss(model, loaders_infront, conf)
    model_end, naive_end = eras_ahead_loss(model, loaders_end, conf)

    metric_loss = 0.5*(np.mean(model_infront/(naive_infront+1e-6)-1) + np.mean(model_end/(naive_end+1e-6)-1))

    run.log({"era_train": np.mean(train_losses), "era_val" : np.mean(val_losses),"model_infront": np.mean(model_infront),
            "naive_infront": np.mean(naive_infront), "model_end": np.mean(model_end), "naive_end": np.mean(naive_end),
            "metric_loss": metric_loss, "time": eras.date.timestamp(), "era": i})

    pbar.update_postfix(dict(metric_loss=metric_loss))

In [41]:
def train(config, project=None, entity=None, enablewb=True) -> nn.Module:
    
    mode = "online" if enablewb else "disabled"
    with wb.init(config=config, project=project, entity=entity, job_type="training", mode=mode) as run:

        conf = run.config
        print(conf)

        pre_proc_data_dir = None
        if conf.use_pre_proc_data:
            pre_proc_data_dir = get_processed_data(run, conf=conf)

        run.config.update(dict(pre_proc_data_dir=pre_proc_data_dir))
        conf = run.config
        
        model = models[conf.model](**conf).to(device)

        # Try decreasing learning rate underway
        optimizer = torch.optim.Adam(model.parameters(), lr=conf.learning_rate)

        losses = dict(train=[], val=[], epoch_lens=[])

        eras = EraController(stock_df=stock_df, fundamental_df=fundamental_df, meta_df=meta_df, macro_df=macro_df, conf=conf, **conf)
        pbar = TqdmPostFix(eras, total=eras.total)
        eras.register_pbar(pbar)

        stopper = EarlyStop(conf.patience, conf.min_delta, model=(model if conf.checkpoint else None), pbar=pbar)

        for i, (data_train, data_val) in enumerate(pbar):
            
            train_losses, val_losses = train_one_era(
                run=run, 
                model=model, 
                optimizer=optimizer, 
                data_train=data_train,
                data_val=data_val,
                stopper=stopper.reset(),
                losses=losses,
                device=device, 
                conf=conf,
                pbar=pbar,
            )

            calculate_metrics(eras, model, train_losses, val_losses, i, run, conf, pbar)

        if conf.save_model:
            put_nn_model(model, run)

    return model, losses

In [42]:
def get_params_from_data(stock_df, fundamental_df, meta_df, macro_df, params):
    meta_cont_len = 1
    meta_cat_len = np.array([len(meta_df[col].unique()) for col in meta_df.iloc[:,1:] if col != "founding_year"]) + 1
    
    stock_feats = 1
    macro_feats = (macro_df.shape[1]-1)
    funda_feats = (fundamental_df.loc[:,"revenue":].shape[1] - 1) + 2

    n_features = stock_feats + macro_feats + funda_feats

    if params.get("feature_subset") is not None:
        n_features = len(params["feature_subset"])
    
    data_given_params = dict(
        meta_cont_lens=(meta_cont_len, 1),
        meta_cat_lens=list(map(lambda x: (x, int(math.ceil(x**0.25))), meta_cat_len)),
        out_len=1 if params["forecast_problem"] == Problem.volatility.name else params["forecast_w"] if params["forecast_problem"] == Problem.market_cap.name else funda_feats,
        input_size=n_features,
    )

    return data_given_params

# Run the loop! (Like Odd-Geir Lademo)

In [44]:
def validate(config):
    assert pd.date_range(start=config["end_date"], periods=config["forecast_w"], freq="B")[-1] < pd.to_datetime("2019-01-01"), "Training overlaps with test set"

In [50]:
forecast_problem = Problem.fundamentals

params_human = dict(
    forecast_problem=forecast_problem.name,

    cpus=1,
    training_w=240,
    forecast_w=forecast_problem.forecast_w.h60,
    start_date="2000-12-31",
    end_date="2018-06-30",
    save_model=True,
    batch_size=512,
    use_pre_proc_data=True,
    clamp=2.5,
    dtype="float32",

    checkpoint=False,
    feature_subset=None,

    fundamental_targets=None,
)

era_controller_params = dict(
    sequential=dict(
        mode="sequential",
        include_past=True,
        queue_length=3,
    ),
    random=dict(
        mode="random",
        sample_size=10,
        distribution=["uniform"][0],
        max_samplings=100,
    ),
)[EraController.Mode.sequential]

params_wb = dict(
    max_epochs=100,
    patience=10,
    min_delta=0.0001,
    learning_rate=1e-5,

    hd=128,
    dropout=0.1,
    num_layers=6,
    channels=128,
    kernel_size=7,

    meta_hd=16,

    model="TcnV2",
    activation="leaky",

    train_loss=forecast_problem.loss.smape,
    val_loss=forecast_problem.loss.mape,
)

params_from_data = get_params_from_data(stock_df, fundamental_df, meta_df, macro_df, {**params_human, **params_wb})

config = {  
    **params_human,
    "era_controller": era_controller_params,
    **params_wb,
    **params_from_data,
}

validate(config)

AssertionError: ignored

In [48]:
enablewb = True
sweepid = None  # "krankile/master/q8hau0w8"

if sweepid:
    count = 500 # number of runs to execute
    wb.agent(sweepid, partial(train,config=config, enablewb=enablewb), count=count)

else:
    model, losses = train(config=config, project="master-test", entity="krankile", enablewb=enablewb)

{'forecast_problem': 'fundamentals', 'cpus': 1, 'training_w': 240, 'forecast_w': 60, 'start_date': '2000-12-31', 'end_date': '2018-06-30', 'save_model': True, 'batch_size': 512, 'use_pre_proc_data': True, 'clamp': 2.5, 'dtype': 'float32', 'checkpoint': False, 'feature_subset': None, 'fundamental_targets': None, 'era_controller': {'mode': 'random', 'sample_size': 10, 'distribution': 'uniform', 'max_samplings': 100}, 'max_epochs': 100, 'patience': 10, 'min_delta': 0.0001, 'learning_rate': 1e-05, 'hd': 128, 'dropout': 0.1, 'num_layers': 6, 'channels': 128, 'kernel_size': 7, 'meta_hd': 16, 'model': 'TcnV2', 'activation': 'leaky', 'train_loss': 'smape', 'val_loss': 'mape_2', 'meta_cont_lens': [1, 1], 'meta_cat_lens': [[110, 4], [6, 2], [91, 4], [285, 5], [3, 2], [5, 2], [7, 2], [14, 2], [58, 3]], 'out_len': 18, 'input_size': 37}


Sampling before 2018-06-30 00:00:00 [2/100]:   0%|          | 0/100 [00:00<?, ?it/s, epoch=0]

random_dates [Timestamp('2005-06-30 00:00:00', freq='M')
 Timestamp('2017-11-30 00:00:00', freq='M')
 Timestamp('2017-09-30 00:00:00', freq='M')
 Timestamp('2009-11-30 00:00:00', freq='M')
 Timestamp('2008-06-30 00:00:00', freq='M')
 Timestamp('2016-03-31 00:00:00', freq='M')
 Timestamp('2013-04-30 00:00:00', freq='M')
 Timestamp('2005-01-31 00:00:00', freq='M')
 Timestamp('2002-10-31 00:00:00', freq='M')
 Timestamp('2001-09-30 00:00:00', freq='M')]


Sampling before 2018-06-30 00:00:00 [2/100]:   1%|          | 1/100 [00:22<36:47, 22.29s/it, epoch=10, train_loss=1.37, val_loss=1e+4, naive=570, triggers=10/10, best_loss=1e+4, metric_loss=162]

random_dates [Timestamp('2016-04-30 00:00:00', freq='M')
 Timestamp('2004-02-29 00:00:00', freq='M')
 Timestamp('2008-12-31 00:00:00', freq='M')
 Timestamp('2008-04-30 00:00:00', freq='M')
 Timestamp('2012-08-31 00:00:00', freq='M')
 Timestamp('2014-09-30 00:00:00', freq='M')
 Timestamp('2009-10-31 00:00:00', freq='M')
 Timestamp('2007-03-31 00:00:00', freq='M')
 Timestamp('2005-05-31 00:00:00', freq='M')
 Timestamp('2001-03-31 00:00:00', freq='M')]


Sampling before 2018-06-30 00:00:00 [3/100]:   2%|▏         | 2/100 [00:49<40:45, 24.95s/it, epoch=10, train_loss=1.16, val_loss=1e+4, naive=570, triggers=10/10, best_loss=1e+4, metric_loss=162]

random_dates [Timestamp('2007-08-31 00:00:00', freq='M')
 Timestamp('2014-11-30 00:00:00', freq='M')
 Timestamp('2002-12-31 00:00:00', freq='M')
 Timestamp('2001-02-28 00:00:00', freq='M')
 Timestamp('2004-03-31 00:00:00', freq='M')
 Timestamp('2013-08-31 00:00:00', freq='M')
 Timestamp('2015-07-31 00:00:00', freq='M')
 Timestamp('2011-01-31 00:00:00', freq='M')
 Timestamp('2008-06-30 00:00:00', freq='M')
 Timestamp('2007-03-31 00:00:00', freq='M')]


Sampling before 2018-06-30 00:00:00 [5/100]:   3%|▎         | 3/100 [01:19<43:58, 27.20s/it, epoch=0, train_loss=1.14, val_loss=1e+4, naive=570, triggers=10/10, best_loss=1e+4, metric_loss=156] 

random_dates [Timestamp('2004-08-31 00:00:00', freq='M')
 Timestamp('2013-07-31 00:00:00', freq='M')
 Timestamp('2003-11-30 00:00:00', freq='M')
 Timestamp('2015-03-31 00:00:00', freq='M')
 Timestamp('2012-10-31 00:00:00', freq='M')
 Timestamp('2013-05-31 00:00:00', freq='M')
 Timestamp('2015-09-30 00:00:00', freq='M')
 Timestamp('2017-06-30 00:00:00', freq='M')
 Timestamp('2006-09-30 00:00:00', freq='M')
 Timestamp('2007-04-30 00:00:00', freq='M')]


Sampling before 2018-06-30 00:00:00 [6/100]:   4%|▍         | 4/100 [01:44<42:34, 26.61s/it, epoch=0, train_loss=1.11, val_loss=1e+4, naive=570, triggers=10/10, best_loss=1e+4, metric_loss=151] 

random_dates [Timestamp('2012-10-31 00:00:00', freq='M')
 Timestamp('2010-03-31 00:00:00', freq='M')
 Timestamp('2002-02-28 00:00:00', freq='M')
 Timestamp('2008-07-31 00:00:00', freq='M')
 Timestamp('2016-09-30 00:00:00', freq='M')
 Timestamp('2010-11-30 00:00:00', freq='M')
 Timestamp('2015-08-31 00:00:00', freq='M')
 Timestamp('2015-05-31 00:00:00', freq='M')
 Timestamp('2002-04-30 00:00:00', freq='M')
 Timestamp('2008-09-30 00:00:00', freq='M')]


Sampling before 2018-06-30 00:00:00 [7/100]:   5%|▌         | 5/100 [02:15<44:17, 27.97s/it, epoch=0, train_loss=1.12, val_loss=1e+4, naive=570, triggers=10/10, best_loss=1e+4, metric_loss=151] 

random_dates [Timestamp('2009-10-31 00:00:00', freq='M')
 Timestamp('2009-10-31 00:00:00', freq='M')
 Timestamp('2006-03-31 00:00:00', freq='M')
 Timestamp('2010-06-30 00:00:00', freq='M')
 Timestamp('2001-11-30 00:00:00', freq='M')
 Timestamp('2018-02-28 00:00:00', freq='M')
 Timestamp('2010-05-31 00:00:00', freq='M')
 Timestamp('2002-06-30 00:00:00', freq='M')
 Timestamp('2004-08-31 00:00:00', freq='M')
 Timestamp('2011-07-31 00:00:00', freq='M')]


Sampling before 2018-06-30 00:00:00 [8/100]:   6%|▌         | 6/100 [02:45<44:56, 28.68s/it, epoch=0, train_loss=1.15, val_loss=1e+4, naive=570, triggers=10/10, best_loss=1e+4, metric_loss=161] 

random_dates [Timestamp('2002-12-31 00:00:00', freq='M')
 Timestamp('2018-02-28 00:00:00', freq='M')
 Timestamp('2014-01-31 00:00:00', freq='M')
 Timestamp('2011-08-31 00:00:00', freq='M')
 Timestamp('2004-09-30 00:00:00', freq='M')
 Timestamp('2016-10-31 00:00:00', freq='M')
 Timestamp('2004-05-31 00:00:00', freq='M')
 Timestamp('2010-06-30 00:00:00', freq='M')
 Timestamp('2015-06-30 00:00:00', freq='M')
 Timestamp('2007-11-30 00:00:00', freq='M')]


Sampling before 2018-06-30 00:00:00 [9/100]:   7%|▋         | 7/100 [03:17<46:09, 29.78s/it, epoch=0, train_loss=1.12, val_loss=1e+4, naive=570, triggers=10/10, best_loss=1e+4, metric_loss=152] 

random_dates [Timestamp('2013-11-30 00:00:00', freq='M')
 Timestamp('2016-01-31 00:00:00', freq='M')
 Timestamp('2004-12-31 00:00:00', freq='M')
 Timestamp('2016-12-31 00:00:00', freq='M')
 Timestamp('2005-01-31 00:00:00', freq='M')
 Timestamp('2008-12-31 00:00:00', freq='M')
 Timestamp('2014-12-31 00:00:00', freq='M')
 Timestamp('2011-01-31 00:00:00', freq='M')
 Timestamp('2015-06-30 00:00:00', freq='M')
 Timestamp('2001-07-31 00:00:00', freq='M')]


Sampling before 2018-06-30 00:00:00 [9/100]:   7%|▋         | 7/100 [03:22<44:52, 28.95s/it, epoch=2, train_loss=1.12, val_loss=1e+4, naive=570, triggers=1/10, best_loss=1e+4, metric_loss=152]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▂▄▅▇█▂▃▅▆▇▁▂▄▅▇█▂▃▅▆▇▁▂▄▅▇█▂▃▅▆▇▁▂▄▅▇█▂
epoch_train,███▇▇▆▅▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▂▁▂▁▁▁▂▁▁▂▂▂▂▁▁▁
epoch_val,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
era,▁▂▃▅▆▇█
era_train,█▂▂▁▁▂▁
era_val,▁▁▁▁▁▁▁
metric_loss,██▅▁▁█▂
model_end,██▅▁▁█▂
model_infront,██▅▁▁█▂
naive_end,▁▁▁▁▁▁▁

0,1
epoch,1.0
epoch_train,1.11651
epoch_val,10000.0
era,6.0
era_train,1.124
era_val,10000.0
metric_loss,152.0412
model_end,8444.28247
model_infront,8316.06406
naive_end,178.58788


KeyboardInterrupt: ignored