<a href="https://colab.research.google.com/github/Krankile/npmf/blob/main/notebooks/training_loop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Setup

##Kernel setup

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
%%capture
!pip install wandb
!git clone https://github.com/Krankile/npmf.git

In [5]:
%%capture
!cd npmf && git pull

In [6]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mkjartan[0m ([33mkrankile[0m). Use [1m`wandb login --relogin`[0m to force relogin


##General setup

In [7]:
import os
from collections import defaultdict
from collections import Counter
from datetime import datetime
from datetime import timedelta
from operator import itemgetter

import numpy as np
from numpy.ma.core import outerproduct
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from tqdm import tqdm

import wandb as wb

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from sklearn.preprocessing import MinMaxScaler

from npmf.utils.colors import main, main2, main3
from npmf.utils.wandb import get_dataset, put_dataset
from npmf.utils.eikon import column_mapping

In [8]:
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=[main, main2, main3, "black"])
mpl.rcParams['figure.figsize'] = (6, 4)  # (6, 4) is default and used in the paper

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


In [10]:
np.random.seed(420)

# Create a Neural network class

In [11]:
class MultivariateNetwork(nn.Module):
    def __init__(self, lag_len, cat_len, out_len, hidden_dim):
        super().__init__()

        self.pre = nn.Sequential(
            nn.Linear(lag_len, hidden_dim),
            nn.ReLU(),
        )

        self.predict = nn.Sequential(
            nn.Linear(hidden_dim + cat_len, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, out_len),
        )


    def forward(self, lags, cats):

        x = self.pre(lags)
        x = torch.cat((x, cats), dim=1)
        x = self.predict(x)

        return x

In [None]:
stock_df

Unnamed: 0,ticker,date,market_cap,close_price,currency
0,000096.SZ,2000-07-24,461931447.619507,0.87487,USD
1,000096.SZ,2000-07-25,487445197.289756,0.923192,USD
2,000096.SZ,2000-07-26,476291125.175129,0.902067,USD
3,000096.SZ,2000-07-27,481888565.458432,0.912668,USD
4,000096.SZ,2000-07-28,490252916.878034,0.928509,USD
...,...,...,...,...,...
3206175,ZHEN.SI,2022-04-14,434277288.716814,0.297291,USD
3206176,ZHEN.SI,2022-04-18,432522955.196474,0.297291,USD
3206177,ZHEN.SI,2022-04-19,409329854.855221,0.281515,USD
3206178,ZHEN.SI,2022-04-20,415892653.465346,0.286029,USD


In [75]:
macro_df

Unnamed: 0,date,BRT-,CLc1,WTCLc1,LNG-AS,.VIX,EUR=,GBP=,CNY=
0,2000-01-03,,,,,24.21,1.0262,1.6368,8.2798
1,2000-01-04,24.40,25.55,,,27.01,1.0308,1.6365,8.2799
2,2000-01-05,23.65,24.91,,,26.41,1.0314,1.6415,8.2798
3,2000-01-06,23.54,24.78,,,25.73,1.0319,1.6463,8.2797
4,2000-01-07,22.95,24.22,,,21.72,1.0289,1.6383,8.2794
...,...,...,...,...,...,...,...,...,...
5813,2022-04-15,,,,,,1.0806,1.3058,6.3705
5814,2022-04-18,,108.21,108.21,,22.17,1.0780,1.3008,6.3630
5815,2022-04-19,105.67,102.56,102.56,,21.37,1.0786,1.2996,6.3930
5816,2022-04-20,104.78,102.75,102.19,,20.32,1.0850,1.3066,6.4188


In [None]:
meta_df

Unnamed: 0,ticker,exchange_code,region_hq,country_hq,state_province_hq,founding_year,economic_sector,business_sector,industry_group,industry,activity
0,OMVV.VI,WBAH,Europe,Austria,WIEN,1956,Energy,Energy - Fossil Fuels,Oil & Gas,Oil & Gas Refining and Marketing,Oil & Gas Refining and Marketing (NEC)
1,MDINp.TA,XTAE,Asia,Israel,,1992,Energy,Energy - Fossil Fuels,Oil & Gas,Oil & Gas Exploration and Production,Oil & Gas Exploration and Production (NEC)
2,000440.KQ,XKOS,Asia,Korea; Republic (S. Korea),SEOUL,1946,Energy,Energy - Fossil Fuels,Oil & Gas,Oil & Gas Refining and Marketing,Petroleum Product Wholesale
3,603507.SS,XSHG,Asia,China,JIANGSU,2004,Energy,Renewable Energy,Renewable Energy,Renewable Energy Equipment & Services,Renewable Energy Equipment & Services (NEC)
4,ATS.AX,XASX,Oceania,Australia,WESTERN AUSTRALIA,2015,Energy,Energy - Fossil Fuels,Oil & Gas,Oil & Gas Exploration and Production,Oil & Gas Exploration and Production (NEC)
...,...,...,...,...,...,...,...,...,...,...,...
898,WOWS.JK,XIDX,Asia,Indonesia,SUMATERA SELATAN,,Energy,Energy - Fossil Fuels,Oil & Gas Related Equipment and Services,Oil Related Services and Equipment,Oil Related Services and Equipment (NEC)
899,PRSO.OL,XOSL,Europe,Norway,ROGALAND,2019,Energy,Energy - Fossil Fuels,Oil & Gas Related Equipment and Services,Oil Related Services and Equipment,Oil Related Services and Equipment (NEC)
900,BROG.OQ,XNCM,Asia,United Arab Emirates,FUJAIRAH,,Energy,Energy - Fossil Fuels,Oil & Gas Related Equipment and Services,Oil & Gas Transportation Services,Oil & Gas Storage
901,336260.KS,XKRX,Asia,Korea; Republic (S. Korea),JEOLLABUK-DO,2019,Energy,Renewable Energy,Renewable Energy,Renewable Energy Equipment & Services,Stationary Fuel Cells


In [None]:
fundamentals_df

Unnamed: 0,ticker,date,period_end_date,announce_date,revenue,gross_profit,ebitda,ebit,net_income,fcf,total_assets,total_current_assets,total_liabilites,total_current_liabilities,long_term_debt_p_assets,short_term_debt_p_assets,gross_profit_p,ebitda_p,ebit_p,net_income_p
0,OMVV.VI,2000-06-30T00:00:00Z,2000-06-30,2000-10-23,1591395023.28372,230190745.555143,181045434.200878,105530002.190288,,,,,,,,,0.144647,0.113765,0.066313,
1,OMVV.VI,2000-12-31T00:00:00Z,2000-12-31,2001-04-30,1938098647.94837,372273990.672257,228378009.139304,132398360.578508,,,,,,,,,0.192082,0.117836,0.068314,
2,OMVV.VI,2001-03-31T00:00:00Z,2001-03-31,2001-11-15,1675754784.0278,270127135.374168,246253937.336036,174937923.894256,,,,,,,,,0.161197,0.146951,0.104394,
3,OMVV.VI,2001-06-30T00:00:00Z,2001-06-30,2001-11-15,1704178457.61632,261593371.574251,199438283.407691,139758657.318887,,,,,,,,,0.153501,0.117029,0.082009,
4,OMVV.VI,2001-09-30T00:00:00Z,2001-09-30,2001-11-08,1692966706.46458,247021992.14371,223213423.136865,133206040.77615,,,,,,,,,0.145911,0.131847,0.078682,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
41575,ARACA.NFF,2008-12-31T00:00:00Z,2008-12-31,2011-03-03,92904.554159,-2117171.324493,-4619760.125553,-7909109.757678,-2137357.171855,,41863684.938016,5109582.451442,21447887.060314,9690425.035636,18.99665,,-22.788671,-49.725874,-85.131561,-23.005946
41576,ARACA.NFF,2009-03-31T00:00:00Z,2009-03-31,2009-07-02,10549.150125,-144567.930584,-1106769.285629,-1137376.678949,-4988326.377089,-6881572.556783,50059580.411268,6308391.774634,26654730.773803,11597379.056223,16.66563,,-13.704225,-104.915493,-107.816901,-472.865237
41577,ARACA.NFF,2009-06-30T00:00:00Z,2009-06-30,2009-09-03,37654.234545,-51969.067513,-1521106.598827,-1548491.496678,-1570316.767768,-2599751.794234,49956277.521044,6098274.440243,27039941.4959,12361947.439668,17.77714,,-1.380165,-40.396694,-41.123967,-41.703590
41578,ARACA.NFF,2009-09-30T00:00:00Z,2009-09-30,2010-01-07,478926.658462,345792.09208,-961160.706547,-969172.632564,-4784882.925321,-8585810.204493,51412614.536567,2744454.228709,29735080.691293,5689130.510149,34.19498,,0.722015,-2.006906,-2.023635,-9.990847


# Define a dataloader which can iterate through time 

We want to train our neural network like the person experiences the world. I.e. we have a window of time and look at recent financial rapports, macro variables to predict future market capitalization. We want to train multiple epochs over one time window and validation period, in this manner we will not have any "learned future" effect which could occur if one were to have epochs run over the all time windows. 

In [50]:
class TimeDeltaDataset(Dataset):
    def _get_last_q_fundamentals(self, target, fundamental_df, q):
        for key, df in tqdm(fundamental_df.groupby(by="ticker")):
            
            padding = pd.DataFrame(np.empty((q, fundamental_df.loc[:,"revenue":].shape[1])),columns=df.loc[:,"revenue":].columns)
            padding[:] = np.nan
            padded_df = pd.concat([padding, df.loc[:,"revenue":]] , axis=0)
            
            target = pd.concat([target, padded_df.iloc[-q:,:]],axis=0)
        return target         
    
    def _get_global_local_column(self, stock_df):
        last_market_cap_col = pd.Series()
        for ticker, df in stock_df.groupby(by="ticker"):
            last_market_cap_col[ticker] = stock_df.market_cap.iloc[df.market_cap.last_valid_index()]

        min_max_scaler = MinMaxScaler()
        
        #Add column to learn relative values
        apple_market_cap = 2.687*(10**12) #ish as of may 2022 (USD)
        
        relative_to_global_market_column = last_market_cap_col / apple_market_cap
       
        relative_to_current_market_column = min_max_scaler.fit_transform(last_market_cap_col.to_numpy().reshape((-1,1)))
        relative_to_current_market_column = pd.Series(relative_to_current_market_column[:,0], index=last_market_cap_col.index) 

        return relative_to_global_market_column, relative_to_current_market_column, last_market_cap_col
    
    def _get_stocks_in_timeframe(self, stock_df, stock_dates, min_max_scaler=True):
        #TODO this needs a speedup when getting forecasts 
        to_be_transposed = pd.DataFrame(index=stock_dates)

        for i, (ticker, df) in enumerate(tqdm(stock_df.groupby(by="ticker"))):
            
            if min_max_scaler:
                min_max_scaler = MinMaxScaler()
                
                df.market_cap = min_max_scaler.fit_transform(df.market_cap.to_numpy().reshape((-1,1)))
            
            ticker_df = pd.DataFrame(df[["date", "market_cap"]]).set_index("date",drop=True).rename(columns={"market_cap":ticker})
            
            to_be_transposed = to_be_transposed.join(ticker_df)
            if i % 50 and i != 0:
                to_be_transposed.loc[~to_be_transposed.T.columns.duplicated(),:]
        
        return to_be_transposed.T.loc[:,~to_be_transposed.T.columns.duplicated()]    


    def __init__(self, current_time, forecast_window, number_of_trading_days, q_quarterly_rapports, stock_df, fundamental_df, meta_df, macro_df):
        back_in_time = timedelta(number_of_trading_days)
        forecast_horizon = timedelta(forecast_window)

        legal_stock_df = stock_df[(stock_df.date >= current_time - back_in_time) & (stock_df.date < current_time)] #TODO change to current_time - stock__macro_days_lookback_days
        legal_fundamental_df = fundamental_df[fundamental_df.announce_date < current_time]
        legal_macro_df = macro_df[(macro_df.date >= current_time - back_in_time) & (stock_df.date < current_time)] #TODO change to current_time - stock__macro_days_lookback_days
        legal_meta_df = meta_df.set_index("ticker")
        
        #Important dimensions
        n_companies_with_fundamentals = len(legal_fundamental_df.ticker.unique())
        m_fundamentals = legal_fundamental_df.loc[:,"revenue":].shape[1]

        #Get last q fundamentals and return NA rows if they are still missing
        fundamental_df_all_quarters = pd.DataFrame(data=np.empty((0,m_fundamentals)),columns=fundamental_df.loc[:,"revenue":].columns)   
        fundamental_df_all_quarters = self._get_last_q_fundamentals(fundamental_df_all_quarters, legal_fundamental_df, q_quarterly_rapports)
        fundamentals = fundamental_df_all_quarters.to_numpy().reshape((n_companies_with_fundamentals, q_quarterly_rapports*m_fundamentals))
        
        #Construct columns for relative information 
        relative_to_global_market_column, relative_to_current_market_column, last_market_cap_col = self._get_global_local_column(stock_df)
        
        #Create dataframe
        historic_dates = pd.date_range(start=current_time-back_in_time, end=current_time, freq="D")
        forecast_dates = pd.date_range(start=current_time+timedelta(1), end=current_time+forecast_horizon, freq="D")

        fund_columns = []
        for i in range(q_quarterly_rapports):
            fund_columns.extend(fundamental_df.loc[0,"revenue":].index.to_series().map(lambda title: f"{title}_q=-{q_quarterly_rapports-i}"))    
        columns =  ["global_relative"] + ["peers_relative"] + fund_columns
        fundamental_df = pd.DataFrame(index=legal_fundamental_df.ticker.unique(), columns=columns)
        
        #Load data
        fundamental_df["peers_relative"]  = relative_to_current_market_column.loc[legal_fundamental_df.ticker.unique()]
        fundamental_df["global_relative"] = relative_to_global_market_column.loc[legal_fundamental_df.ticker.unique()]

        formated_stocks = self._get_stocks_in_timeframe(legal_stock_df, historic_dates)
         
        fundamental_df.loc[:,f"revenue_q={-q_quarterly_rapports}":"net_income_p_q=-1"] = fundamentals
        for q in range(q_quarterly_rapports,0,-1):
            fundamental_df.loc[:,f"revenue_q={-q}":f"total_current_liabilities_q={-q}"] = fundamental_df.loc[:,f"revenue_q={-q}":f"total_current_liabilities_q={-q}"].div(last_market_cap_col, axis=0)

        self.stocks_and_fundamentals = formated_stocks.join(fundamental_df)
        
        # Get forecasts
        
        forecasts = stock_df[(stock_df.date > current_time) & (stock_df.date <= current_time + forecast_horizon)]
        
        forecasts_unormalized = self._get_stocks_in_timeframe(forecasts, forecast_dates, min_max_scaler=False)
        forecasts_normalized = forecasts_unormalized.div(last_market_cap_col, axis=0)
        self.forecast = forecasts_normalized.loc[self.stocks_and_fundamentals.index,:]


    def __len__(self):
        return self.stocks_and_fundamentals.shape[0]

    def __getitem__(self, idx):
      
        return self.stocks_and_fundamentals.iloc[idx,:], self.forecast.iloc[idx,:]

In [52]:
delta_set.forecast

Unnamed: 0,2020-01-02,2020-01-03,2020-01-04,2020-01-05,2020-01-06,2020-01-07,2020-01-08,2020-01-09,2020-01-10,2020-01-11,...,2020-01-22,2020-01-23,2020-01-24,2020-01-25,2020-01-26,2020-01-27,2020-01-28,2020-01-29,2020-01-30,2020-01-31
000096.SZ,1.199137,1.227389,,,1.248715,1.23103,1.219366,1.193282,1.188136,,...,1.179909,1.141688,,,,,,,,
000159.SZ,0.940755,1.013752,,,1.059899,1.02154,1.035823,0.993177,0.974261,,...,0.924913,0.882301,,,,,,,,
000440.KQ,0.361043,0.432133,,,0.501306,0.427853,0.426288,0.39496,0.388343,,...,0.379245,0.378019,,,,,0.362174,0.354509,0.333405,0.329695
000552.SZ,0.588216,0.583633,,,0.587314,0.601246,0.58533,0.591042,0.583051,,...,0.570524,0.551959,,,,,,,,
000554.SZ,0.954534,0.977622,,,1.074614,1.044469,1.106415,1.038723,1.015402,,...,0.964806,0.923439,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
YGR.TO,0.429669,0.428975,,,0.475109,0.483249,0.453261,0.45267,0.446431,,...,0.405555,0.396412,0.383143,,,0.36599,0.360692,0.365879,0.359244,0.352222
YPFD.BA,1.651625,1.660961,,,1.701532,1.656488,1.66209,1.685335,1.696267,,...,1.622298,1.59153,1.564683,,,1.560875,1.560672,1.567437,1.582027,1.539679
ZENZ.L,0.472281,0.471518,,,0.472828,0.538712,0.537307,0.536607,0.603868,,...,0.505884,0.506346,0.505576,,,0.503659,0.505076,0.503507,0.503049,0.477954
ZEST.OQ,1.134144,1.165648,,,1.159347,1.172264,1.159347,1.171949,1.152992,,...,1.165662,1.140321,1.133986,,,1.102311,1.05163,1.03896,1.05163,1.035792


In [51]:
current_time = pd.to_datetime("2020-01-01")
number_of_quarterly_rapports = 4
number_of_trading_days = 365
forecast_window = 30

delta_set = TimeDeltaDataset(current_time, forecast_window, number_of_trading_days, number_of_quarterly_rapports, stock_df, fundamentals_df, meta_df, macro_df)
delta_set.stocks_and_fundamentals

100%|██████████| 862/862 [00:07<00:00, 108.49it/s]
  del sys.path[0]
100%|██████████| 866/866 [00:15<00:00, 56.49it/s]
100%|██████████| 864/864 [01:42<00:00,  8.46it/s]


Unnamed: 0,2019-01-01 00:00:00,2019-01-02 00:00:00,2019-01-03 00:00:00,2019-01-04 00:00:00,2019-01-05 00:00:00,2019-01-06 00:00:00,2019-01-07 00:00:00,2019-01-08 00:00:00,2019-01-09 00:00:00,2019-01-10 00:00:00,...,total_assets_q=-1,total_current_assets_q=-1,total_liabilites_q=-1,total_current_liabilities_q=-1,long_term_debt_p_assets_q=-1,short_term_debt_p_assets_q=-1,gross_profit_p_q=-1,ebitda_p_q=-1,ebit_p_q=-1,net_income_p_q=-1
000096.SZ,,0.189644,0.188477,0.212934,,,0.228653,0.253593,0.267543,0.252803,...,0.090104,0.051523,0.021687,0.016558,1.85106,8.08959,0.589092,0.517639,0.478217,0.756598
000159.SZ,,0.008950,0.000000,0.008583,,,0.016915,0.018669,0.021915,0.028010,...,1.764741,1.362191,1.092201,1.090297,0.0,16.03614,0.032042,-0.206393,-0.206393,-0.286165
000440.KQ,,0.354101,0.316018,0.268291,,,0.363935,0.357102,0.365441,0.383530,...,0.8253,0.058079,0.154947,0.036267,8.16624,0.93235,0.933919,0.206851,0.107994,0.078372
000552.SZ,,0.137192,0.133927,0.175464,,,0.220644,0.220091,0.232506,0.256323,...,0.118507,0.032626,0.075441,0.046749,23.04012,9.3166,0.148582,0.119381,0.053168,0.020464
000554.SZ,,0.190758,0.212284,0.246290,,,0.274068,0.303706,0.289149,0.304020,...,274.674707,53.566872,152.578401,42.931305,28.81831,,0.120675,0.092285,0.061403,0.040201
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
YGR.TO,,0.663561,0.720056,0.798677,,,0.770812,0.830866,0.860202,0.849888,...,10.070025,3.454649,7.077476,2.562134,37.44828,0.0,0.096552,0.048599,0.032789,0.015657
YPFD.BA,,0.482517,0.506972,0.565648,,,0.569768,0.553067,0.566394,0.576155,...,0.019617,0.018328,0.02101,0.013966,35.2919,0.9022,0.014560,0.088876,0.088586,0.139669
ZENZ.L,,0.001793,0.005541,0.010404,,,0.013665,0.014702,0.017530,0.016352,...,6.688994,3.677693,5.008538,4.428105,7.13896,8.35383,0.103992,0.039762,0.039762,0.026250
ZEST.OQ,,0.230949,0.266802,0.346476,,,0.314606,0.266802,0.226965,0.266802,...,4.243796,3.232654,1.65337,1.631358,0.0,23.29476,0.247284,0.119436,0.119436,0.115098


In [12]:
def train_multivar(model, optimizer, loss_fn, data_train, data_val, one_hot_encoding, batch_number, forecast_window, epochs, device):

# print(model)
train_losses = []
val_losses = []
it = tqdm(range(epochs), disable=True)
for epoch in it: 
    for run_type in ["train", "val"]:
        model.train(run_type == "train")

        if run_type == "train":
            optimizer.zero_grad()

            mu, sigma = 0, 0.1
            noise = np.random.normal(mu, sigma, data_train.shape)
            noise[:,-forecast_window:] = 0

            data_train = data_train + noise
            data_encoded = pd.concat([one_hot_encoding, data_train], axis=1, join="inner")

        else:
            data_encoded = pd.concat([one_hot_encoding, data_val], axis=1, join="inner")

        data_shuffled = torch.tensor(data_encoded.sample(frac=1).values, dtype=torch.float32)

        for batch in torch.split(data_shuffled, batch_number, dim=0):

            inputs = batch[:,:-forecast_window].to(device)
            actuals = batch[:,-forecast_window:].to(device)

            pred = model(inputs[:, 3:], inputs[:, 0:3])
            loss = loss_fn(pred, actuals)

            if run_type == "train":
                train_losses.append(loss.item())
                loss.backward()

                optimizer.step()
            else:
                val_losses.append(loss.item())

    it.set_postfix({"train_loss": np.mean(train_losses), "val_loss": np.mean(val_losses)})

return train_losses, val_losses

IndentationError: ignored

# Get some data

In [13]:
stock_df = get_dataset("stock-oil-final:latest", project="master-test")
meta_df = get_dataset("meta-oil-final:latest", project="master-test")
fundamentals_df = get_dataset("fundamentals-oil-final:latest", project="master-test")
macro_df = get_dataset("macro-oil-final:latest", project="master-test")

[34m[1mwandb[0m: Currently logged in as: [33mkjartan[0m ([33mkrankile[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact stock-oil-final:latest, 77.63MB. 1 files... Done. 0:0:0


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

# Run the loop! (Like Odd-Geir Lademo)

In [None]:
def run_multivar(hidden_dim):
    #Training loop params
    forecast_window = 16
    amount_of_time_series = 999
    length_of_time_series = 500
    epochs = 200
    batch_size = 111
    


    #_, TS_signal, one_hot_encoding  = time_series_df(amount_of_time_series,length_of_time_series,periods, horisontal_shift, vertical_shift, forecast_window)

    df_train = TS_signal.iloc[-int(amount_of_time_series/3):,:]
    df_val = TS_signal.iloc[:-int(amount_of_time_series/3),:]

    loss_fn = nn.L1Loss()
    
    model = MultivariateNetwork(length_of_time_series, one_hot_encoding.shape[1], forecast_window, hidden_dim)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    return train_multivar(model, optimizer, loss_fn, df_train, df_val, one_hot_encoding, batch_size, forecast_window, epochs, device)