In [1]:
import numpy as np
import pandas as pd

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy

from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error 

from kan import *
torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [28]:
def treasury_data_retrieval():
    df = pd.read_csv('../data/us_treasury_rates.csv')
    df['Date'] = pd.to_datetime(df['Date'])
    df.sort_values(by='Date', ascending=True, inplace=True)
    df = df.reset_index(drop=True)

    return df

df = treasury_data_retrieval()

n = len(df)
h = 5

df_flat = pd.DataFrame()
for id in range(h, n):
    row = df.iloc[(id-h):(id), 1:].stack().reset_index(drop=True).to_frame().T
    df_flat = pd.concat([df_flat, row], ignore_index=True)

for id in range(1, 13):
    df_flat[f'y_{id}'] = df.iloc[h:, id].values

df_flat['Date'] = df['Date'].iloc[h:].values
# df_flat.dropna(inplace=True)
df_flat.columns = df_flat.columns.astype(str)
df_flat.set_index('Date', inplace=True)

# df_flat = df_flat.iloc[:-10]
df_flat.tail()

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,y_3,y_4,y_5,y_6,y_7,y_8,y_9,y_10,y_11,y_12
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2024-12-02,4.72,4.67,4.63,4.46,4.42,4.37,4.32,4.3,4.35,4.41,...,4.51,4.43,4.3,4.17,4.11,4.08,4.13,4.19,4.46,4.36
2024-12-03,4.74,4.67,4.62,4.46,4.37,4.21,4.21,4.17,4.21,4.27,...,4.49,4.4,4.27,4.17,4.13,4.11,4.17,4.23,4.5,4.4
2024-12-04,4.74,4.67,4.61,4.45,4.37,4.21,4.21,4.17,4.24,4.3,...,4.47,4.38,4.23,4.13,4.09,4.07,4.13,4.19,4.45,4.35
2024-12-05,4.76,4.7,4.6,4.43,4.34,4.19,4.17,4.11,4.17,4.25,...,4.46,4.38,4.23,4.15,4.1,4.07,4.12,4.17,4.43,4.33
2024-12-06,4.76,4.69,4.58,4.42,4.3,4.13,4.1,4.05,4.1,4.18,...,4.42,4.34,4.19,4.1,4.05,4.03,4.09,4.15,4.42,4.34


In [29]:
df.head(10)

Unnamed: 0,Date,1 Mo,2 Mo,3 Mo,6 Mo,1 Yr,2 Yr,3 Yr,5 Yr,7 Yr,10 Yr,20 Yr,30 Yr
0,2022-01-03,0.05,0.06,0.08,0.22,0.4,0.78,1.04,1.37,1.55,1.63,2.05,2.01
1,2022-01-04,0.06,0.05,0.08,0.22,0.38,0.77,1.02,1.37,1.57,1.66,2.1,2.07
2,2022-01-05,0.05,0.06,0.09,0.22,0.41,0.83,1.1,1.43,1.62,1.71,2.12,2.09
3,2022-01-06,0.04,0.05,0.1,0.23,0.45,0.88,1.15,1.47,1.66,1.73,2.12,2.09
4,2022-01-07,0.05,0.05,0.1,0.24,0.43,0.87,1.17,1.5,1.69,1.76,2.15,2.11
5,2022-01-10,0.05,0.06,0.13,0.28,0.46,0.92,1.21,1.53,1.71,1.78,2.15,2.11
6,2022-01-11,0.04,0.05,0.11,0.28,0.46,0.9,1.22,1.51,1.69,1.75,2.13,2.08
7,2022-01-12,0.04,0.06,0.12,0.27,0.48,0.92,1.21,1.5,1.67,1.74,2.13,2.08
8,2022-01-13,0.05,0.05,0.12,0.28,0.47,0.91,1.18,1.47,1.64,1.7,2.1,2.05
9,2022-01-14,0.05,0.05,0.13,0.3,0.51,0.99,1.26,1.55,1.72,1.78,2.18,2.12


In [30]:
def train_mse():
    predictions = model(dataset['train_input'])  # Model predictions
    mse = F.mse_loss(predictions, dataset['train_label'], reduction='mean')  # Compute MSE
    return mse  # Return scalar MSE value

def test_mse():
    predictions = model(dataset['test_input']) # Model predictions
    mse = F.mse_loss(predictions, dataset['test_label'], reduction='mean')  # Compute MSE
    return mse

In [31]:
import numpy as np

# Parameters for the sliding window
test_size = 10

# Store results for each fold
fold_results = {'train_mse': [], 'test_mse': []}

# Prepare data
X, y = df_flat.iloc[:, :-12], df_flat.iloc[:, -12:]
n_inputs = X.shape[1]
n_outputs = y.shape[1]

X_train, X_test = X[:-test_size], X[-test_size:]
y_train, y_test = y[:-test_size], y[-test_size:]

dataset = dict()
dtype = torch.get_default_dtype()
dataset['train_input'] = torch.from_numpy(X_train.values).type(dtype).to(device)
dataset['train_label'] = torch.from_numpy(y_train.values).type(dtype).to(device)
dataset['test_input'] = torch.from_numpy(X_test.iloc[0, :].values.reshape(1, -1)).type(dtype).to(device)
dataset['test_label'] = torch.from_numpy(y_test.iloc[0, :].values.reshape(1, -1)).type(dtype).to(device)

# Initialize the model
model = KAN(width=[n_inputs, 20, n_outputs], grid=3, k=2, seed=42, device=device)

# Train the model and compute metrics
results = model.fit(dataset, opt="Adam", lamb=0.001, lr=0.001, steps=1000, metrics=(train_mse, test_mse))


feature = dataset['test_input']
output_list = list()
for id in range(1, test_size + 1):
    new = model(feature).cpu().detach().numpy().flatten()
    output_list.append(new)

    old = feature.cpu().detach().numpy().flatten()[(n_outputs):]
    feature = torch.from_numpy(np.append(old, new).reshape(1, -1)).type(dtype).to(device)
    
    
# Store the metrics
fold_results['train_mse'].append(results['train_mse'][-1])
fold_results['test_mse'].append(results['test_mse'][-1])

# Calculate average metrics across all windows
avg_train_mse = np.mean(fold_results['train_mse'])
avg_test_mse = np.mean(fold_results['test_mse'])

print("Sliding Window Cross-Validation Results")
print(f"Average Train MSE: {avg_train_mse}")
print(f"Average Test MSE: {mean_squared_error(output_list, y_test)}")

checkpoint directory created: ./model
saving model version 0.0


description:   0%|                                                         | 0/1000 [00:00<?, ?it/s]

  self.subnode_actscale.append(torch.std(x, dim=0).detach())
  input_range = torch.std(preacts, dim=0) + 0.1
  output_range_spline = torch.std(postacts_numerical, dim=0) # for training, only penalize the spline part
  output_range = torch.std(postacts, dim=0) # for visualization, include the contribution from both spline + symbolic
| train_loss: 9.94e-02 | test_loss: 5.54e-02 | reg: 5.99e+01 | : 100%|█| 1000/1000 [01:39<00:00, 10.


saving model version 0.1
Sliding Window Cross-Validation Results
Average Train MSE: 0.009963009842744069
Average Test MSE: 0.012292689068750355


In [32]:
df_naive = pd.DataFrame([y_train.iloc[-1]] * test_size, columns=y_train.columns)
mean_squared_error(df_naive, y_test)
# df_naive

np.float64(0.025754166666666644)

In [33]:
pd.DataFrame(output_list)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,4.669116,4.669942,4.645418,4.503704,4.409211,4.290708,4.251724,4.260768,4.311587,4.334164,4.564315,4.567775
1,4.68519,4.66724,4.649655,4.518799,4.416094,4.280356,4.236471,4.234249,4.291882,4.296359,4.541456,4.533931
2,4.680865,4.666222,4.657031,4.524111,4.420663,4.282137,4.225081,4.216844,4.280601,4.277091,4.519531,4.513392
3,4.690283,4.671531,4.666153,4.540819,4.429579,4.273446,4.215519,4.201335,4.262719,4.25329,4.511799,4.495831
4,4.712254,4.689164,4.674207,4.559401,4.439812,4.267363,4.205964,4.185953,4.238068,4.234494,4.504456,4.478531
5,4.713086,4.708333,4.686869,4.576765,4.437749,4.261368,4.190318,4.167906,4.21367,4.222714,4.485957,4.450861
6,4.727198,4.722802,4.69823,4.58965,4.439564,4.25525,4.1772,4.15037,4.195618,4.205947,4.471452,4.431866
7,4.743026,4.737357,4.710787,4.601393,4.442246,4.248975,4.165376,4.134289,4.179736,4.19007,4.458708,4.41578
8,4.761212,4.75505,4.725983,4.614339,4.442118,4.240278,4.152352,4.118102,4.163934,4.176356,4.448405,4.400631
9,4.783884,4.774825,4.74234,4.62667,4.441358,4.231372,4.139438,4.102664,4.148236,4.163547,4.439978,4.386126


In [34]:
y_test

Unnamed: 0_level_0,y_1,y_2,y_3,y_4,y_5,y_6,y_7,y_8,y_9,y_10,y_11,y_12
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
2024-11-22,4.72,4.67,4.63,4.46,4.42,4.37,4.32,4.3,4.35,4.41,4.67,4.6
2024-11-25,4.74,4.67,4.62,4.46,4.37,4.21,4.21,4.17,4.21,4.27,4.53,4.45
2024-11-26,4.74,4.67,4.61,4.45,4.37,4.21,4.21,4.17,4.24,4.3,4.56,4.48
2024-11-27,4.76,4.7,4.6,4.43,4.34,4.19,4.17,4.11,4.17,4.25,4.52,4.44
2024-11-29,4.76,4.69,4.58,4.42,4.3,4.13,4.1,4.05,4.1,4.18,4.45,4.36
2024-12-02,4.75,4.63,4.51,4.43,4.3,4.17,4.11,4.08,4.13,4.19,4.46,4.36
2024-12-03,4.66,4.56,4.49,4.4,4.27,4.17,4.13,4.11,4.17,4.23,4.5,4.4
2024-12-04,4.65,4.54,4.47,4.38,4.23,4.13,4.09,4.07,4.13,4.19,4.45,4.35
2024-12-05,4.59,4.53,4.46,4.38,4.23,4.15,4.1,4.07,4.12,4.17,4.43,4.33
2024-12-06,4.57,4.5,4.42,4.34,4.19,4.1,4.05,4.03,4.09,4.15,4.42,4.34


In [35]:
df_naive

Unnamed: 0,y_1,y_2,y_3,y_4,y_5,y_6,y_7,y_8,y_9,y_10,y_11,y_12
2024-11-21,4.72,4.67,4.63,4.45,4.39,4.34,4.3,4.3,4.36,4.43,4.68,4.61
2024-11-21,4.72,4.67,4.63,4.45,4.39,4.34,4.3,4.3,4.36,4.43,4.68,4.61
2024-11-21,4.72,4.67,4.63,4.45,4.39,4.34,4.3,4.3,4.36,4.43,4.68,4.61
2024-11-21,4.72,4.67,4.63,4.45,4.39,4.34,4.3,4.3,4.36,4.43,4.68,4.61
2024-11-21,4.72,4.67,4.63,4.45,4.39,4.34,4.3,4.3,4.36,4.43,4.68,4.61
2024-11-21,4.72,4.67,4.63,4.45,4.39,4.34,4.3,4.3,4.36,4.43,4.68,4.61
2024-11-21,4.72,4.67,4.63,4.45,4.39,4.34,4.3,4.3,4.36,4.43,4.68,4.61
2024-11-21,4.72,4.67,4.63,4.45,4.39,4.34,4.3,4.3,4.36,4.43,4.68,4.61
2024-11-21,4.72,4.67,4.63,4.45,4.39,4.34,4.3,4.3,4.36,4.43,4.68,4.61
2024-11-21,4.72,4.67,4.63,4.45,4.39,4.34,4.3,4.3,4.36,4.43,4.68,4.61


In [88]:
help(model.fit)

Help on method fit in module kan.MultKAN:

fit(dataset, opt='LBFGS', steps=100, log=1, lamb=0.0, lamb_l1=1.0, lamb_entropy=2.0, lamb_coef=0.0, lamb_coefdiff=0.0, update_grid=True, grid_update_num=10, loss_fn=None, lr=1.0, start_grid_update_step=-1, stop_grid_update_step=50, batch=-1, metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000.0, reg_metric='edge_forward_spline_n', display_metrics=None) method of kan.MultKAN.MultKAN instance
    training

    Args:
    -----
        dataset : dic
            contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label']
        opt : str
            "LBFGS" or "Adam"
        steps : int
            training steps
        log : int
            logging frequency
        lamb : float
            overall penalty strength
        lamb_l1 : float
            l1 penalty strength
        lamb_entropy : float
            ent