In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import root_mean_squared_error as rmse

# Load the data
train_data = pd.read_parquet("data/train_data.parquet")
test_data = pd.read_parquet("data/test_data.parquet")

# Convert expiry to datetime if it's not already
if train_data["expiry"].dtype != "datetime64[ns]":
    train_data["expiry"] = pd.to_datetime(train_data["expiry"])

# Get the target expiry date
target_date = pd.Timestamp("2025-05-08").date()

In [3]:
# Filter rows with the specified expiry date
expiry_filter = train_data['expiry'].dt.date == target_date
target_rows = train_data[expiry_filter]

# Get indices of rows to be used for validation (50% of the rows with target expiry)
validation_indices = target_rows.sample(frac=0.5, random_state=43).index

# Create validation set
val_data = train_data.loc[validation_indices].copy()

# Remove validation data from training set
train_data = train_data.drop(validation_indices)

# Print shapes to confirm
print(f"Original training data shape: {len(train_data) + len(val_data)}")
print(f"New training data shape: {train_data.shape}")
print(f"Validation data shape: {val_data.shape}")
print(f"Test data shape: {test_data.shape}")

Original training data shape: 178340
New training data shape: (138752, 97)
Validation data shape: (39588, 97)
Test data shape: (12065, 96)


In [4]:
test_data.columns

Index(['timestamp', 'underlying', 'call_iv_24000', 'call_iv_24100',
       'call_iv_24200', 'call_iv_24300', 'call_iv_24400', 'call_iv_24500',
       'call_iv_24600', 'call_iv_24700', 'call_iv_24800', 'call_iv_24900',
       'call_iv_25000', 'call_iv_25100', 'call_iv_25200', 'call_iv_25300',
       'call_iv_25400', 'call_iv_25500', 'call_iv_25600', 'call_iv_25700',
       'call_iv_25800', 'call_iv_25900', 'call_iv_26000', 'call_iv_26100',
       'call_iv_26200', 'call_iv_26300', 'call_iv_26400', 'call_iv_26500',
       'put_iv_23000', 'put_iv_23100', 'put_iv_23200', 'put_iv_23300',
       'put_iv_23400', 'put_iv_23500', 'put_iv_23600', 'put_iv_23700',
       'put_iv_23800', 'put_iv_23900', 'put_iv_24000', 'put_iv_24100',
       'put_iv_24200', 'put_iv_24300', 'put_iv_24400', 'put_iv_24500',
       'put_iv_24600', 'put_iv_24700', 'put_iv_24800', 'put_iv_24900',
       'put_iv_25000', 'put_iv_25100', 'put_iv_25200', 'put_iv_25300',
       'put_iv_25400', 'put_iv_25500', 'X0', 'X1', 'X2',

In [5]:
train_data.columns

Index(['timestamp', 'underlying', 'expiry', 'call_iv_23500', 'call_iv_23600',
       'call_iv_23700', 'call_iv_23800', 'call_iv_23900', 'call_iv_24000',
       'call_iv_24100', 'call_iv_24200', 'call_iv_24300', 'call_iv_24400',
       'call_iv_24500', 'call_iv_24600', 'call_iv_24700', 'call_iv_24800',
       'call_iv_24900', 'call_iv_25000', 'call_iv_25100', 'call_iv_25200',
       'call_iv_25300', 'call_iv_25400', 'call_iv_25500', 'call_iv_25600',
       'call_iv_25700', 'call_iv_25800', 'call_iv_25900', 'call_iv_26000',
       'put_iv_22500', 'put_iv_22600', 'put_iv_22700', 'put_iv_22800',
       'put_iv_22900', 'put_iv_23000', 'put_iv_23100', 'put_iv_23200',
       'put_iv_23300', 'put_iv_23400', 'put_iv_23500', 'put_iv_23600',
       'put_iv_23700', 'put_iv_23800', 'put_iv_23900', 'put_iv_24000',
       'put_iv_24100', 'put_iv_24200', 'put_iv_24300', 'put_iv_24400',
       'put_iv_24500', 'put_iv_24600', 'put_iv_24700', 'put_iv_24800',
       'put_iv_24900', 'put_iv_25000', 'X0', '

In [6]:
pred_cols = list(
    filter(
        lambda x: x.startswith("call") or x.startswith("put"), val_data.columns.tolist()
    )
)
len(pred_cols)

52

In [7]:
val_Y = val_data[pred_cols]
val_X = val_data.drop(columns=pred_cols)

## Setting everything to 0.2

In [8]:
preds = val_Y.copy()
preds.loc[:,:] = 0.2
print(rmse(preds, val_Y))
preds

0.09687563762995885


Unnamed: 0,call_iv_23500,call_iv_23600,call_iv_23700,call_iv_23800,call_iv_23900,call_iv_24000,call_iv_24100,call_iv_24200,call_iv_24300,call_iv_24400,...,put_iv_24100,put_iv_24200,put_iv_24300,put_iv_24400,put_iv_24500,put_iv_24600,put_iv_24700,put_iv_24800,put_iv_24900,put_iv_25000
158881,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,...,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2
117518,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,...,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2
155403,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,...,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2
159498,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,...,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2
106693,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,...,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
150469,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,...,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2
178012,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,...,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2
114201,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,...,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2
133136,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,...,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2


## Average of given data from that week

In [9]:
median = train_data[train_data["expiry"].dt.date == target_date][pred_cols].median()
print(median)

call_iv_23500    0.273244
call_iv_23600    0.255100
call_iv_23700    0.236308
call_iv_23800    0.217708
call_iv_23900    0.198384
call_iv_24000    0.177905
call_iv_24100    0.164981
call_iv_24200    0.154688
call_iv_24300    0.134063
call_iv_24400    0.122894
call_iv_24500    0.130239
call_iv_24600    0.133252
call_iv_24700    0.133231
call_iv_24800    0.139514
call_iv_24900    0.151388
call_iv_25000    0.164070
call_iv_25100    0.176350
call_iv_25200    0.191076
call_iv_25300    0.206294
call_iv_25400    0.222072
call_iv_25500    0.239783
call_iv_25600    0.256224
call_iv_25700    0.271383
call_iv_25800    0.286871
call_iv_25900    0.302576
call_iv_26000    0.316601
put_iv_22500     0.457202
put_iv_22600     0.441771
put_iv_22700     0.424398
put_iv_22800     0.406865
put_iv_22900     0.390018
put_iv_23000     0.372748
put_iv_23100     0.354954
put_iv_23200     0.336409
put_iv_23300     0.316864
put_iv_23400     0.298291
put_iv_23500     0.278934
put_iv_23600     0.259421
put_iv_23700

In [10]:
preds2= val_Y.copy()
for col in preds2.columns:
    preds2[col].values[:] = median[col]
print(rmse(preds2, val_Y))
preds2

0.05287821631609408


Unnamed: 0,call_iv_23500,call_iv_23600,call_iv_23700,call_iv_23800,call_iv_23900,call_iv_24000,call_iv_24100,call_iv_24200,call_iv_24300,call_iv_24400,...,put_iv_24100,put_iv_24200,put_iv_24300,put_iv_24400,put_iv_24500,put_iv_24600,put_iv_24700,put_iv_24800,put_iv_24900,put_iv_25000
158881,0.273244,0.2551,0.236308,0.217708,0.198384,0.177905,0.164981,0.154688,0.134063,0.122894,...,0.165398,0.154417,0.133818,0.122695,0.128697,0.133168,0.13311,0.138793,0.14983,0.162431
117518,0.273244,0.2551,0.236308,0.217708,0.198384,0.177905,0.164981,0.154688,0.134063,0.122894,...,0.165398,0.154417,0.133818,0.122695,0.128697,0.133168,0.13311,0.138793,0.14983,0.162431
155403,0.273244,0.2551,0.236308,0.217708,0.198384,0.177905,0.164981,0.154688,0.134063,0.122894,...,0.165398,0.154417,0.133818,0.122695,0.128697,0.133168,0.13311,0.138793,0.14983,0.162431
159498,0.273244,0.2551,0.236308,0.217708,0.198384,0.177905,0.164981,0.154688,0.134063,0.122894,...,0.165398,0.154417,0.133818,0.122695,0.128697,0.133168,0.13311,0.138793,0.14983,0.162431
106693,0.273244,0.2551,0.236308,0.217708,0.198384,0.177905,0.164981,0.154688,0.134063,0.122894,...,0.165398,0.154417,0.133818,0.122695,0.128697,0.133168,0.13311,0.138793,0.14983,0.162431
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
150469,0.273244,0.2551,0.236308,0.217708,0.198384,0.177905,0.164981,0.154688,0.134063,0.122894,...,0.165398,0.154417,0.133818,0.122695,0.128697,0.133168,0.13311,0.138793,0.14983,0.162431
178012,0.273244,0.2551,0.236308,0.217708,0.198384,0.177905,0.164981,0.154688,0.134063,0.122894,...,0.165398,0.154417,0.133818,0.122695,0.128697,0.133168,0.13311,0.138793,0.14983,0.162431
114201,0.273244,0.2551,0.236308,0.217708,0.198384,0.177905,0.164981,0.154688,0.134063,0.122894,...,0.165398,0.154417,0.133818,0.122695,0.128697,0.133168,0.13311,0.138793,0.14983,0.162431
133136,0.273244,0.2551,0.236308,0.217708,0.198384,0.177905,0.164981,0.154688,0.134063,0.122894,...,0.165398,0.154417,0.133818,0.122695,0.128697,0.133168,0.13311,0.138793,0.14983,0.162431


In [11]:
mean = train_data[train_data["expiry"].dt.date == target_date][pred_cols].mean()
print(mean)

call_iv_23500    0.290099
call_iv_23600    0.271857
call_iv_23700    0.253358
call_iv_23800    0.234379
call_iv_23900    0.213979
call_iv_24000    0.192710
call_iv_24100    0.172379
call_iv_24200    0.152369
call_iv_24300    0.134634
call_iv_24400    0.124403
call_iv_24500    0.124010
call_iv_24600    0.129707
call_iv_24700    0.138047
call_iv_24800    0.148269
call_iv_24900    0.159281
call_iv_25000    0.171683
call_iv_25100    0.183612
call_iv_25200    0.196548
call_iv_25300    0.210207
call_iv_25400    0.224275
call_iv_25500    0.239199
call_iv_25600    0.254046
call_iv_25700    0.268822
call_iv_25800    0.283601
call_iv_25900    0.298621
call_iv_26000    0.312240
put_iv_22500     0.461152
put_iv_22600     0.446929
put_iv_22700     0.431413
put_iv_22800     0.415661
put_iv_22900     0.399775
put_iv_23000     0.383283
put_iv_23100     0.367232
put_iv_23200     0.350082
put_iv_23300     0.332396
put_iv_23400     0.314205
put_iv_23500     0.295229
put_iv_23600     0.275906
put_iv_23700

In [12]:
preds3 = val_Y.copy()
for col in preds3.columns:
    preds3[col].values[:] = mean[col]
print(rmse(preds3, val_Y))
preds3

0.05218188290844201


Unnamed: 0,call_iv_23500,call_iv_23600,call_iv_23700,call_iv_23800,call_iv_23900,call_iv_24000,call_iv_24100,call_iv_24200,call_iv_24300,call_iv_24400,...,put_iv_24100,put_iv_24200,put_iv_24300,put_iv_24400,put_iv_24500,put_iv_24600,put_iv_24700,put_iv_24800,put_iv_24900,put_iv_25000
158881,0.290099,0.271857,0.253358,0.234379,0.213979,0.19271,0.172379,0.152369,0.134634,0.124403,...,0.172863,0.152323,0.1346,0.124191,0.123508,0.129496,0.137637,0.147084,0.156532,0.136339
117518,0.290099,0.271857,0.253358,0.234379,0.213979,0.19271,0.172379,0.152369,0.134634,0.124403,...,0.172863,0.152323,0.1346,0.124191,0.123508,0.129496,0.137637,0.147084,0.156532,0.136339
155403,0.290099,0.271857,0.253358,0.234379,0.213979,0.19271,0.172379,0.152369,0.134634,0.124403,...,0.172863,0.152323,0.1346,0.124191,0.123508,0.129496,0.137637,0.147084,0.156532,0.136339
159498,0.290099,0.271857,0.253358,0.234379,0.213979,0.19271,0.172379,0.152369,0.134634,0.124403,...,0.172863,0.152323,0.1346,0.124191,0.123508,0.129496,0.137637,0.147084,0.156532,0.136339
106693,0.290099,0.271857,0.253358,0.234379,0.213979,0.19271,0.172379,0.152369,0.134634,0.124403,...,0.172863,0.152323,0.1346,0.124191,0.123508,0.129496,0.137637,0.147084,0.156532,0.136339
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
150469,0.290099,0.271857,0.253358,0.234379,0.213979,0.19271,0.172379,0.152369,0.134634,0.124403,...,0.172863,0.152323,0.1346,0.124191,0.123508,0.129496,0.137637,0.147084,0.156532,0.136339
178012,0.290099,0.271857,0.253358,0.234379,0.213979,0.19271,0.172379,0.152369,0.134634,0.124403,...,0.172863,0.152323,0.1346,0.124191,0.123508,0.129496,0.137637,0.147084,0.156532,0.136339
114201,0.290099,0.271857,0.253358,0.234379,0.213979,0.19271,0.172379,0.152369,0.134634,0.124403,...,0.172863,0.152323,0.1346,0.124191,0.123508,0.129496,0.137637,0.147084,0.156532,0.136339
133136,0.290099,0.271857,0.253358,0.234379,0.213979,0.19271,0.172379,0.152369,0.134634,0.124403,...,0.172863,0.152323,0.1346,0.124191,0.123508,0.129496,0.137637,0.147084,0.156532,0.136339
