In [1]:
import numpy as np
import pandas as pd

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy

from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error 

from kan import *
import warnings

warnings.filterwarnings("ignore")

torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
def treasury_data_retrieval():
    df = pd.read_csv('../data/us_treasury_rates_large.csv')
    df['Date'] = pd.to_datetime(df['Date'])
    df.sort_values(by='Date', ascending=True, inplace=True)
    df = df.reset_index(drop=True)

    return df

df = treasury_data_retrieval()

n = len(df)
h = 5

df_flat = pd.DataFrame()
for id in range(h, n):
    row = df.iloc[(id-h):(id), 1:].stack().reset_index(drop=True).to_frame().T
    df_flat = pd.concat([df_flat, row], ignore_index=True)

for id in range(1, 13):
    df_flat[f'y_{id}'] = df.iloc[h:, id].values

df_flat['Date'] = df['Date'].iloc[h:].values
# df_flat.dropna(inplace=True)
df_flat.columns = df_flat.columns.astype(str)
df_flat.set_index('Date', inplace=True)

# df_flat = df_flat.iloc[:-10]
df_flat.tail()

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,y_3,y_4,y_5,y_6,y_7,y_8,y_9,y_10,y_11,y_12
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2024-12-02,4.72,4.67,4.63,4.46,4.42,4.37,4.32,4.3,4.35,4.41,...,4.51,4.43,4.3,4.17,4.11,4.08,4.13,4.19,4.46,4.36
2024-12-03,4.74,4.67,4.62,4.46,4.37,4.21,4.21,4.17,4.21,4.27,...,4.49,4.4,4.27,4.17,4.13,4.11,4.17,4.23,4.5,4.4
2024-12-04,4.74,4.67,4.61,4.45,4.37,4.21,4.21,4.17,4.24,4.3,...,4.47,4.38,4.23,4.13,4.09,4.07,4.13,4.19,4.45,4.35
2024-12-05,4.76,4.7,4.6,4.43,4.34,4.19,4.17,4.11,4.17,4.25,...,4.46,4.38,4.23,4.15,4.1,4.07,4.12,4.17,4.43,4.33
2024-12-06,4.76,4.69,4.58,4.42,4.3,4.13,4.1,4.05,4.1,4.18,...,4.42,4.34,4.19,4.1,4.05,4.03,4.09,4.15,4.42,4.34


In [4]:
def train_mse():
    predictions = model(dataset['train_input'])  # Model predictions
    mse = F.mse_loss(predictions, dataset['train_label'], reduction='mean')  # Compute MSE
    return mse ** 0.5  # Return scalar MSE value

def test_mse():
    predictions = model(dataset['test_input']) # Model predictions
    mse = F.mse_loss(predictions, dataset['test_label'], reduction='mean')  # Compute MSE
    return mse ** 0.5

In [None]:
import numpy as np

# Parameters for the sliding window
test_size = 10

# Store results for each fold
fold_results = {'train_mse': [], 'test_mse': []}

# Prepare data
X, y = df_flat.iloc[:, :-12], df_flat.iloc[:, -12:]
n_inputs = X.shape[1]
n_outputs = y.shape[1]

X_train, X_test = X[:-test_size], X[-test_size:]
y_train, y_test = y[:-test_size], y[-test_size:]

dataset = dict()
dtype = torch.get_default_dtype()
dataset['train_input'] = torch.from_numpy(X_train.values).type(dtype).to(device)
dataset['train_label'] = torch.from_numpy(y_train.values).type(dtype).to(device)
dataset['test_input'] = torch.from_numpy(X_test.iloc[0, :].values.reshape(1, -1)).type(dtype).to(device)
dataset['test_label'] = torch.from_numpy(y_test.iloc[0, :].values.reshape(1, -1)).type(dtype).to(device)

# Initialize the model
model = KAN(width=[n_inputs, 20, n_outputs], grid=3, k=2, seed=42, device=device)

# Train the model and compute metrics
results = model.fit(dataset, opt="Adam", lamb=0.001, lr=0.001, steps=1000, metrics=(train_mse, test_mse))


feature = dataset['test_input']
output_list = list()
for id in range(1, test_size + 1):
    new = model(feature).cpu().detach().numpy().flatten()
    output_list.append(new)

    old = feature.cpu().detach().numpy().flatten()[(n_outputs):]
    feature = torch.from_numpy(np.append(old, new).reshape(1, -1)).type(dtype).to(device)
    
    
# Store the metrics
fold_results['train_mse'].append(results['train_mse'][-1])
fold_results['test_mse'].append(results['test_mse'][-1])

# Calculate average metrics across all windows
avg_train_mse = np.mean(fold_results['train_mse'])
avg_test_mse = np.mean(fold_results['test_mse'])

print("Sliding Window Cross-Validation Results")
print(f"Average Train MSE: {avg_train_mse}")
print(f"Average Test MSE: {mean_squared_error(output_list, y_test)}")

In [None]:
import numpy as np

# Parameters for the sliding window
test_size = 5
df_length = len(df_flat)

# Store results for each fold
fold_results = {'train_mse': [], 'test_mse': [], 'naive_mse': []}

for cnt in range(0, 20, 5):
    print()
    print('WINDOW SLIDING: ', cnt)

    df_window = df_flat[(df_length-cnt-250):(df_length-cnt)]
    # Prepare data
    X, y = df_window.iloc[:, :-12], df_window.iloc[:, -12:]

    n_inputs = X.shape[1]
    n_outputs = y.shape[1]

    X_train, X_test = X[:-test_size], X[-test_size:]
    y_train, y_test = y[:-test_size], y[-test_size:]

    dataset = dict()
    dtype = torch.get_default_dtype()
    dataset['train_input'] = torch.from_numpy(X_train.values).type(dtype).to(device)
    dataset['train_label'] = torch.from_numpy(y_train.values).type(dtype).to(device)
    dataset['test_input'] = torch.from_numpy(X_test.iloc[0, :].values.reshape(1, -1)).type(dtype).to(device)
    dataset['test_label'] = torch.from_numpy(y_test.iloc[0, :].values.reshape(1, -1)).type(dtype).to(device)

    # Initialize the model
    model = KAN(width=[n_inputs, 20, n_outputs], grid=3, k=2, seed=42, device=device)

    # Train the model and compute metrics
    results = model.fit(dataset, opt="Adam", lamb=0.001, lr=0.001, steps=1000, metrics=(train_mse, test_mse))

    feature = dataset['test_input']
    output_list = list()
    for id in range(1, test_size + 1):
        new = model(feature).cpu().detach().numpy().flatten()
        output_list.append(new)

        old = feature.cpu().detach().numpy().flatten()[(n_outputs):]
        feature = torch.from_numpy(np.append(old, new).reshape(1, -1)).type(dtype).to(device)
    
    df_naive = pd.DataFrame([y_train.iloc[-1]] * test_size, columns=y_train.columns)
        
    # Store the metrics
    train_error = results['train_mse'][-1]
    test_error = mean_squared_error(output_list, y_test, squared=False)
    naive_error = mean_squared_error(df_naive, y_test, squared=False)

    fold_results['train_mse'].append(train_error)
    fold_results['test_mse'].append(test_error)
    fold_results['naive_mse'].append(naive_error)

    # Calculate average metrics across all windows
    print(f'Fold Train MSE: {train_error}')
    print(f'Fold Test MSE: {test_error}')
    print(f'Naive Test MSE: {naive_error}')

avg_train_mse = np.mean(fold_results['train_mse'])
avg_test_mse = np.mean(fold_results['test_mse'])
avg_naive_mse = np.mean(fold_results['naive_mse'])

print("Sliding Window Cross-Validation Results")
print(f"Average Train MSE: {avg_train_mse}")
print(f"Average Test MSE: {avg_test_mse}")
print(f"Average Naive MSE: {avg_naive_mse}")


WINDOW SLIDING:  0
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 4.68e+00 | test_loss: 4.24e+00 | reg: 7.46e+01 | :   0%|     | 0/1000 [00:00<?, ?it/s]

| train_loss: 7.62e-02 | test_loss: 1.17e-01 | reg: 3.05e+01 | : 100%|█| 1000/1000 [00:36<00:00, 27.


saving model version 0.1
Fold Train MSE: 0.07700548542964239
Fold Test MSE: 0.08429906089615313
Naive Test MSE: 0.031666666666666655

WINDOW SLIDING:  1
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 1.78e-01 | test_loss: 2.83e-01 | reg: 6.63e+01 | :  13%|▏| 133/1000 [00:05<00:33, 25.7


KeyboardInterrupt: 

In [32]:
df_naive = pd.DataFrame([y_train.iloc[-1]] * test_size, columns=y_train.columns)
mean_squared_error(df_naive, y_test)
# df_naive

np.float64(0.025754166666666644)

In [8]:
pd.DataFrame(output_list)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,4.808254,4.722639,4.730283,4.386444,4.244575,4.121847,4.052503,4.05409,4.125195,4.168106,4.521195,4.460731
1,7.425504,7.759623,7.564955,6.729259,7.20968,6.702704,6.475952,6.409134,5.845021,6.518092,7.001816,6.102906
2,12.68833,12.75475,11.603767,11.206963,11.77373,10.609055,9.988644,9.223608,7.769438,9.584902,10.734335,9.246078
3,20.595896,20.126341,18.872407,18.177391,19.066755,17.503062,15.908176,14.37798,12.52115,15.198067,16.661805,14.342094
4,33.693696,32.786344,30.992696,29.877886,31.374907,28.604893,25.897076,22.886119,20.45854,24.664812,26.611204,23.094519


In [16]:
y_test

Unnamed: 0_level_0,y_1,y_2,y_3,y_4,y_5,y_6,y_7,y_8,y_9,y_10,y_11,y_12
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
2024-10-09,4.93,4.84,4.75,4.46,4.24,3.99,3.89,3.91,3.97,4.06,4.41,4.34
2024-10-10,4.98,4.84,4.75,4.45,4.22,3.98,3.88,3.91,3.99,4.09,4.44,4.38
2024-10-11,4.97,4.82,4.73,4.44,4.18,3.95,3.85,3.88,3.97,4.08,4.44,4.39
2024-10-15,4.93,4.82,4.73,4.42,4.18,3.95,3.86,3.86,3.93,4.03,4.37,4.32
2024-10-16,4.91,4.8,4.72,4.42,4.17,3.93,3.84,3.84,3.92,4.02,4.36,4.3
2024-10-17,4.93,4.83,4.74,4.45,4.21,3.96,3.89,3.9,3.99,4.09,4.44,4.39
2024-10-18,4.92,4.82,4.73,4.45,4.19,3.95,3.86,3.88,3.97,4.08,4.44,4.38
2024-10-21,4.92,4.82,4.73,4.47,4.24,4.02,3.95,3.98,4.07,4.19,4.54,4.49
2024-10-22,4.89,4.81,4.72,4.47,4.24,4.03,3.98,4.0,4.1,4.2,4.55,4.49
2024-10-23,4.88,4.8,4.73,4.48,4.27,4.07,4.03,4.05,4.14,4.24,4.58,4.51


In [17]:
df_naive

Unnamed: 0,y_1,y_2,y_3,y_4,y_5,y_6,y_7,y_8,y_9,y_10,y_11,y_12
2024-10-08,4.96,4.85,4.75,4.44,4.21,3.98,3.86,3.86,3.94,4.04,4.38,4.32
2024-10-08,4.96,4.85,4.75,4.44,4.21,3.98,3.86,3.86,3.94,4.04,4.38,4.32
2024-10-08,4.96,4.85,4.75,4.44,4.21,3.98,3.86,3.86,3.94,4.04,4.38,4.32
2024-10-08,4.96,4.85,4.75,4.44,4.21,3.98,3.86,3.86,3.94,4.04,4.38,4.32
2024-10-08,4.96,4.85,4.75,4.44,4.21,3.98,3.86,3.86,3.94,4.04,4.38,4.32
2024-10-08,4.96,4.85,4.75,4.44,4.21,3.98,3.86,3.86,3.94,4.04,4.38,4.32
2024-10-08,4.96,4.85,4.75,4.44,4.21,3.98,3.86,3.86,3.94,4.04,4.38,4.32
2024-10-08,4.96,4.85,4.75,4.44,4.21,3.98,3.86,3.86,3.94,4.04,4.38,4.32
2024-10-08,4.96,4.85,4.75,4.44,4.21,3.98,3.86,3.86,3.94,4.04,4.38,4.32
2024-10-08,4.96,4.85,4.75,4.44,4.21,3.98,3.86,3.86,3.94,4.04,4.38,4.32


In [88]:
help(model.fit)

Help on method fit in module kan.MultKAN:

fit(dataset, opt='LBFGS', steps=100, log=1, lamb=0.0, lamb_l1=1.0, lamb_entropy=2.0, lamb_coef=0.0, lamb_coefdiff=0.0, update_grid=True, grid_update_num=10, loss_fn=None, lr=1.0, start_grid_update_step=-1, stop_grid_update_step=50, batch=-1, metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000.0, reg_metric='edge_forward_spline_n', display_metrics=None) method of kan.MultKAN.MultKAN instance
    training

    Args:
    -----
        dataset : dic
            contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label']
        opt : str
            "LBFGS" or "Adam"
        steps : int
            training steps
        log : int
            logging frequency
        lamb : float
            overall penalty strength
        lamb_l1 : float
            l1 penalty strength
        lamb_entropy : float
            ent