In [1]:
import pandas as pd
import numpy as np
import datetime
import time

from typing import Any, List

import logging
LOGGER = logging.getLogger(__name__)

Example Data

In [2]:
df = pd.read_parquet("cleaned_final_data_post_2010_small.parquet")

In [3]:
data = df[df.datadate >= "2010-01-01"]

In [4]:
data = data.set_index(["datadate", "tic"]).sort_index()

In [5]:
data

Unnamed: 0_level_0,Unnamed: 1_level_0,gvkey,iid,cusip,conm,ajexdi,cshoc,cshtrd,dvi,eps,epsmo,prccd,prchd,prcld,prcod,prcstd,trfd
datadate,tic,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
2010-01-04,A,126554,01,00846U101,AGILENT TECHNOLOGIES INC,1.000000,3.488310e+08,2729240.0,,-0.09,10.0,31.30,31.63,31.1314,31.39,3.0,1.059619
2010-01-04,AA.3,001356,01,013817507,ALCOA INC,0.333333,9.743780e+08,25872570.0,0.12,-1.91,9.0,16.65,16.89,16.4000,16.47,3.0,1.894268
2010-01-04,AABA,062634,01,021346101,ALTABA INC,1.000000,1.401056e+09,16479600.0,,0.10,9.0,17.10,17.20,16.8800,16.94,3.0,1.000000
2010-01-04,AAPL,001690,01,037833100,APPLE INC,28.000000,9.053490e+08,17540960.0,0.00,6.39,9.0,214.01,214.50,212.3800,213.43,3.0,1.095663
2010-01-04,ABC,031673,01,03073E105,AMERISOURCEBERGEN CORP,1.000000,2.848920e+08,2455833.0,0.32,1.70,9.0,26.63,26.69,26.1400,26.29,3.0,1.067085
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2022-12-30,YUM,065417,01,988498101,YUM BRANDS INC,1.000000,2.816880e+08,1401328.0,2.28,4.46,9.0,128.08,129.61,127.3450,129.61,3.0,1.965736
2022-12-30,ZBH,144559,01,98956P102,ZIMMER BIOMET HOLDINGS INC,1.000000,2.098520e+08,785116.0,0.96,1.54,9.0,127.50,127.73,126.2100,127.15,3.0,1.125085
2022-12-30,ZBRA,024405,01,989207105,ZEBRA TECHNOLOGIES CP -CL A,1.000000,5.163000e+07,228188.0,0.00,8.87,9.0,256.41,256.54,251.5200,254.20,3.0,
2022-12-30,ZION,011687,01,989701107,ZIONS BANCORPORATION NA,1.000000,1.496180e+08,609715.0,1.64,5.32,9.0,49.16,49.41,48.5000,48.61,3.0,2.407089


In [6]:
data.index

MultiIndex([('2010-01-04',    'A'),
            ('2010-01-04', 'AA.3'),
            ('2010-01-04', 'AABA'),
            ('2010-01-04', 'AAPL'),
            ('2010-01-04',  'ABC'),
            ('2010-01-04',  'ABT'),
            ('2010-01-04',  'ACS'),
            ('2010-01-04', 'ADBE'),
            ('2010-01-04',  'ADI'),
            ('2010-01-04',  'ADM'),
            ...
            ('2022-12-30', 'WYNN'),
            ('2022-12-30',  'XEL'),
            ('2022-12-30',  'XOM'),
            ('2022-12-30', 'XRAY'),
            ('2022-12-30',  'XYL'),
            ('2022-12-30',  'YUM'),
            ('2022-12-30',  'ZBH'),
            ('2022-12-30', 'ZBRA'),
            ('2022-12-30', 'ZION'),
            ('2022-12-30',  'ZTS')],
           names=['datadate', 'tic'], length=1637823)

In [7]:
DATE_IDX_COL = "datadate"

# These are unused in training, just for debugging
# DATA_DATE_START: datetime.datetime = data.index.min()[0]
# DATA_DATE_END: datetime.datetime = data.index.max()[0]

# These are all in the same units as the dataframe row interval.
# That is, we don't actually use any units, but count each dataframe row as 1 "day".
LOOKBACK_LENGTH   = 252 * 8  # 8 years
VALIDATION_LENGTH = 126      # 6 months, taken out of lookback_length
RETRAIN_FREQUENCY = 252      # annual retrain, i.e. each split will be moved forward by 1Y.


############################
# Utility Functions
############################

def get_df_index_rows(
    data: pd.DataFrame, start: int, end: int, index_name: str = DATE_IDX_COL
) -> pd.DataFrame:
    """
    Query rows from index start to end-1 on index_name. This is useful for multiindex dfs.
    
    For example if I want to get all the rows for the first 4 dates, I can do:
    `get_df_index_rows(data, start=0, end=4)` which gets date indices 0, 1, 2, 3.
    """
    assert data.index.is_monotonic_increasing, "index must be sorted"
    result = data.loc[data.index.get_level_values(index_name).unique()[start:end]]
    return result
    
def get_unique_index_values(
    data: pd.DataFrame, index_name: str = DATE_IDX_COL
) -> List[Any]:
    assert data.index.is_monotonic_increasing, "index must be sorted"
    return data.index.get_level_values(index_name).unique()

In [8]:
data_length = len(get_unique_index_values(data))
data_length

3274

In [9]:
import math

In [10]:
# Need to start the backtest sufficiently in the future so we have enough lookback to train the first model.

# TODO: we should actually make backtest_start_idx get the idx of a set date.
# The rest of this code will work correctly (with the assert if not enough lookback data).
backtest_start_idx = LOOKBACK_LENGTH
backtest_end_idx = data_length - 1

train_start_idx = backtest_start_idx - LOOKBACK_LENGTH
backtest_length = backtest_end_idx - backtest_start_idx
num_models = math.ceil(backtest_length / RETRAIN_FREQUENCY)

assert train_start_idx >= 0

backtest_start_idx, backtest_end_idx, train_start_idx, backtest_length, num_models

(2016, 3273, 0, 1257, 5)

In [11]:
# i = 0, ... num_models - 1
train_partitions = []
for i in range(num_models):
    # These are inclusive
    # Note that the validation set comes out of the training set.
    train_start_partition = train_start_idx + i * RETRAIN_FREQUENCY
    validation_end_partition = train_start_partition + LOOKBACK_LENGTH - 1
    train_end_partition = validation_end_partition - VALIDATION_LENGTH
    validation_start_partition = train_end_partition + 1
    
    train_partitions.append((train_start_partition, train_end_partition, validation_start_partition, validation_end_partition))

# These are all inclusive
train_partitions

[(0, 1889, 1890, 2015),
 (252, 2141, 2142, 2267),
 (504, 2393, 2394, 2519),
 (756, 2645, 2646, 2771),
 (1008, 2897, 2898, 3023)]

In [12]:
# sanity check to make sure our first train + validate indices match the length specified
assert train_partitions[0][3] - train_partitions[0][0] + 1 == LOOKBACK_LENGTH

In [13]:
for partition in train_partitions:
    # Add one to the end bound because our df partition function treats the 
    # end bound as exclusive.
    ts, te = partition[0], partition[1] + 1  # train start, train end
    vs, ve = partition[2], partition[3] + 1  # validate start, validate end
    
    # TODO: save these in some object or parse to files (to save RAM)
    train_df = get_df_index_rows(data, start=ts, end=te)
    val_df = get_df_index_rows(data, start=vs, end=ve)

In [14]:
val_df

Unnamed: 0_level_0,Unnamed: 1_level_0,gvkey,iid,cusip,conm,ajexdi,cshoc,cshtrd,dvi,eps,epsmo,prccd,prchd,prcld,prcod,prcstd,trfd
datadate,tic,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
2021-07-07,A,126554,01,00846U101,AGILENT TECHNOLOGIES INC,1.0,3.034430e+08,2286307.0,0.776,3.02,3.0,149.49,149.620,148.090,149.590,3.0,1.591278
2021-07-07,AAL,001045,04,02376R102,AMERICAN AIRLINES GROUP INC,1.0,6.413830e+08,30698600.0,0.000,-15.07,3.0,20.31,21.050,20.170,20.790,3.0,1.060905
2021-07-07,AAP,145977,01,00751Y106,ADVANCE AUTO PARTS INC,1.0,6.543900e+07,689644.0,4.000,9.37,3.0,209.44,209.765,206.600,207.430,3.0,1.066679
2021-07-07,AAPL,001690,01,037833100,APPLE INC,1.0,1.668763e+10,104688200.0,0.880,4.51,3.0,144.57,144.890,142.660,143.535,3.0,1.274322
2021-07-07,ABBV,016101,01,00287Y109,ABBVIE INC,1.0,1.766222e+09,6693420.0,5.200,2.71,3.0,116.75,116.970,115.310,115.910,3.0,1.417216
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2022-01-03,YUM,065417,01,988498101,YUM BRANDS INC,1.0,2.931330e+08,1251350.0,2.000,5.27,9.0,136.53,138.770,134.850,138.380,3.0,1.928702
2022-01-03,ZBH,144559,01,98956P102,ZIMMER BIOMET HOLDINGS INC,1.0,2.089080e+08,1150267.0,0.960,3.95,9.0,129.13,129.940,126.615,127.350,3.0,1.088195
2022-01-03,ZBRA,024405,01,989207105,ZEBRA TECHNOLOGIES CP -CL A,1.0,5.344100e+07,272580.0,0.000,15.81,9.0,583.90,599.730,578.490,592.080,3.0,
2022-01-03,ZION,011687,01,989701107,ZIONS BANCORPORATION NA,1.0,1.564630e+08,1218018.0,1.520,7.10,9.0,64.24,65.060,63.730,63.880,3.0,2.341912
