## Import library

In [1]:
import numpy as np
import pandas as pd

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy

from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error 

from kan import *
import warnings
import sys
sys.path.append('../utils')
from treasury_base import *

warnings.filterwarnings("ignore")

torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


## Retrieve data

In [2]:
WINDOW_LIST = [3, 5, 10, 20]
LAG = 1

def train_mse():
    predictions = model(dataset['train_input'])  # Model predictions
    mse = F.mse_loss(predictions, dataset['train_label'], reduction='mean')  # Compute MSE
    return mse ** 0.5  # Return scalar MSE value

def test_mse():
    predictions = model(dataset['test_input']) # Model predictions
    mse = F.mse_loss(predictions, dataset['test_label'], reduction='mean')  # Compute MSE
    return mse ** 0.5
    
df_ma = ma_data_retrieval(window_list=WINDOW_LIST, lag=LAG)
df_ma.head()

Unnamed: 0_level_0,1 Mo,2 Mo,3 Mo,6 Mo,1 Yr,2 Yr,3 Yr,5 Yr,7 Yr,10 Yr,...,10 Yr_MA10,10 Yr_MA20,20 Yr_MA3,20 Yr_MA5,20 Yr_MA10,20 Yr_MA20,30 Yr_MA3,30 Yr_MA5,30 Yr_MA10,30 Yr_MA20
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2019-01-31,2.42,2.43,2.41,2.46,2.55,2.45,2.43,2.43,2.51,2.63,...,2.742,2.718,2.906667,2.906,2.917,2.8935,3.053333,3.052,3.062,3.0375
2019-02-01,2.41,2.42,2.4,2.46,2.56,2.52,2.5,2.51,2.59,2.7,...,2.732,2.7165,2.876667,2.894,2.908,2.8935,3.03,3.042,3.054,3.0385
2019-02-04,2.41,2.41,2.42,2.49,2.57,2.53,2.52,2.53,2.62,2.73,...,2.727,2.7235,2.87,2.886,2.903,2.9,3.026667,3.036,3.05,3.044
2019-02-05,2.39,2.4,2.42,2.5,2.56,2.53,2.5,2.51,2.6,2.71,...,2.721,2.7265,2.876667,2.886,2.9,2.9045,3.026667,3.036,3.047,3.048
2019-02-06,2.4,2.41,2.42,2.5,2.56,2.52,2.5,2.5,2.59,2.7,...,2.718,2.727,2.896667,2.884,2.898,2.906,3.04,3.034,3.044,3.05


## KAN model training

In [25]:
TEST_SIZE = 1
LENGTH = len(df_ma)
TARGETS = df_ma.columns[:12]

# Store results for each fold
fold_results = {'train_mse': [], 'test_mse': [], 'naive_mse': []}

for cnt in range(0, 20, 5):
    print()
    print('WINDOW SLIDING: ', cnt)

    df_window = df_ma[(LENGTH-cnt-250):(LENGTH-cnt)]
    # Prepare data
    X, y = df_window.drop(columns=TARGETS), df_window[TARGETS]

    # scaler = StandardScaler()
    # X = pd.DataFrame(scaler.fit_transform(X))

    n_inputs = X.shape[1]
    n_outputs = y.shape[1]

    X_train, X_test = X[:-TEST_SIZE], X[-TEST_SIZE:]
    y_train, y_test = y[:-TEST_SIZE], y[-TEST_SIZE:]

    dataset = dict()
    dtype = torch.get_default_dtype()
    dataset['train_input'] = torch.from_numpy(X_train.values).type(dtype).to(device)
    dataset['train_label'] = torch.from_numpy(y_train.values).type(dtype).to(device)
    dataset['test_input'] = torch.from_numpy(X_test.values).type(dtype).to(device)
    dataset['test_label'] = torch.from_numpy(y_test.values).type(dtype).to(device)

    # Initialize the model
    model = KAN(width=[n_inputs, 48, 64, n_outputs], grid=4, k=2, seed=42, device=device)

    # Train the model and compute metrics
    results = model.fit(dataset, opt="Adam", lamb=0.0001, lr=0.001, steps=500, metrics=(train_mse, test_mse))
    df_naive = pd.DataFrame([y_train.iloc[-1]] * TEST_SIZE, columns=y_train.columns)
        
    # Store the metrics
    train_error = results['train_mse'][-1]
    test_error = results['test_mse'][-1]
    naive_error = mean_squared_error(df_naive, y_test, squared=False)

    fold_results['train_mse'].append(train_error)
    fold_results['test_mse'].append(test_error)
    fold_results['naive_mse'].append(naive_error)

    # Calculate average metrics across all windows
    print(f'Fold Train MSE: {train_error}')
    print(f'Fold Test MSE: {test_error}')
    print(f'Naive Test MSE: {naive_error}')

avg_train_mse = np.mean(fold_results['train_mse'])
avg_test_mse = np.mean(fold_results['test_mse'])
avg_naive_mse = np.mean(fold_results['naive_mse'])

print()
print("Sliding Window Cross-Validation Results")
print(f"Average Train MSE: {avg_train_mse}")
print(f"Average Test MSE: {avg_test_mse}")
print(f"Average Naive MSE: {avg_naive_mse}")

In [8]:
WINDOW_LIST = [3, 5, 10, 15, 20]
TEST_SIZE = 20
TARGETS = df_ma.columns[:12]

# Store results for each fold
fold_results = {'train_mse': [], 'test_mse': [], 'naive_mse': []}

for LAG in range(1, 6): # steps into the future
    df_ma = ma_data_retrieval(window_list=WINDOW_LIST, lag=LAG)

    for cnt in range(0, 40, 10): # sliding window
        print()
        print(f'WINDOW SLIDING: {cnt}, LAG: {LAG}')

        df_window = df_ma[(len(df_ma)-cnt-500):(len(df_ma)-cnt)]
        # Prepare data
        X, y = df_window.drop(columns=TARGETS), df_window[TARGETS]

        # scaler = StandardScaler()
        # X = pd.DataFrame(scaler.fit_transform(X))

        n_inputs = X.shape[1]
        n_outputs = y.shape[1]

        X_train, X_test = X[:-TEST_SIZE], X[-TEST_SIZE:]
        y_train, y_test = y[:-TEST_SIZE], y[-TEST_SIZE:]

        dataset = dict()
        dtype = torch.get_default_dtype()
        dataset['train_input'] = torch.from_numpy(X_train.values).type(dtype).to(device)
        dataset['train_label'] = torch.from_numpy(y_train.values).type(dtype).to(device)
        dataset['test_input'] = torch.from_numpy(X_test.values).type(dtype).to(device)
        dataset['test_label'] = torch.from_numpy(y_test.values).type(dtype).to(device)

        # Initialize the model
        model = KAN(width=[n_inputs, 32, n_outputs], grid=4, k=2, seed=42, device=device)

        # Train the model and compute metrics
        results = model.fit(dataset, opt="Adam", lamb=0.0001, lr=0.0015, steps=500, metrics=(train_mse, test_mse))
        df_naive = pd.DataFrame([y_train.iloc[-LAG]] * TEST_SIZE, columns=y_train.columns)
            
        # Store the metrics
        train_error = results['train_mse'][-1]
        test_error = results['test_mse'][-1]
        naive_error = mean_squared_error(df_naive.values.flatten(), y_test.values.flatten(), squared=False)

        fold_results['train_mse'].append(train_error)
        fold_results['test_mse'].append(test_error)
        fold_results['naive_mse'].append(naive_error)

        # Calculate average metrics across all windows
        print(f'Fold Train MSE: {train_error}')
        print(f'Fold Test MSE: {test_error}')
        print(f'Naive Test MSE: {naive_error}')

avg_train_mse = np.mean(fold_results['train_mse'])
avg_test_mse = np.mean(fold_results['test_mse'])
avg_naive_mse = np.mean(fold_results['naive_mse'])

print()
print("Sliding Window Cross-Validation Results")
print(f"Average Train MSE: {avg_train_mse}")
print(f"Average Test MSE: {avg_test_mse}")
print(f"Average Naive MSE: {avg_naive_mse}")


WINDOW SLIDING: 0, LAG: 1
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 7.27e-02 | test_loss: 1.19e-01 | reg: 1.16e+02 | : 100%|█| 750/750 [01:21<00:00,  9.24


saving model version 0.1
Fold Train MSE: 0.07256363261841549
Fold Test MSE: 0.11935505068443238
Naive Test MSE: 0.11094668389215898

WINDOW SLIDING: 10, LAG: 1
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 7.34e-02 | test_loss: 2.00e-01 | reg: 1.15e+02 | : 100%|█| 750/750 [01:23<00:00,  9.03


saving model version 0.1
Fold Train MSE: 0.07320320509691446
Fold Test MSE: 0.2001554887713319
Naive Test MSE: 0.12736921396737383

WINDOW SLIDING: 20, LAG: 1
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 7.50e-02 | test_loss: 1.77e-01 | reg: 1.20e+02 | : 100%|█| 750/750 [01:19<00:00,  9.39


saving model version 0.1
Fold Train MSE: 0.07493622462960524
Fold Test MSE: 0.1767447413707237
Naive Test MSE: 0.14997221964972937

WINDOW SLIDING: 30, LAG: 1
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 7.56e-02 | test_loss: 2.09e-01 | reg: 1.20e+02 | : 100%|█| 750/750 [01:18<00:00,  9.56


saving model version 0.1
Fold Train MSE: 0.07561153172117846
Fold Test MSE: 0.20884070705014865
Naive Test MSE: 0.26592213772706724

WINDOW SLIDING: 0, LAG: 2
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 7.84e-02 | test_loss: 1.26e-01 | reg: 1.18e+02 | : 100%|█| 750/750 [01:18<00:00,  9.60


saving model version 0.1
Fold Train MSE: 0.07876269390699238
Fold Test MSE: 0.1256567377237061
Naive Test MSE: 0.09790726905257505

WINDOW SLIDING: 10, LAG: 2
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 7.97e-02 | test_loss: 2.11e-01 | reg: 1.17e+02 | : 100%|█| 750/750 [01:18<00:00,  9.51


saving model version 0.1
Fold Train MSE: 0.07959999060822966
Fold Test MSE: 0.21064221537398026
Naive Test MSE: 0.15235785725280676

WINDOW SLIDING: 20, LAG: 2
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 8.21e-02 | test_loss: 2.02e-01 | reg: 1.21e+02 | : 100%|█| 750/750 [01:19<00:00,  9.41


saving model version 0.1
Fold Train MSE: 0.08204245117727815
Fold Test MSE: 0.2017624642923682
Naive Test MSE: 0.156276997667603

WINDOW SLIDING: 30, LAG: 2
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 8.30e-02 | test_loss: 2.31e-01 | reg: 1.22e+02 | : 100%|█| 750/750 [01:19<00:00,  9.43


saving model version 0.1
Fold Train MSE: 0.08298026063768746
Fold Test MSE: 0.23130260539096742
Naive Test MSE: 0.24277647469774882

WINDOW SLIDING: 0, LAG: 3
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 7.93e-02 | test_loss: 1.22e-01 | reg: 1.21e+02 | : 100%|█| 750/750 [01:20<00:00,  9.36


saving model version 0.1
Fold Train MSE: 0.07959389667048104
Fold Test MSE: 0.1222689244265049
Naive Test MSE: 0.10017068766194369

WINDOW SLIDING: 10, LAG: 3
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 8.26e-02 | test_loss: 2.00e-01 | reg: 1.20e+02 | : 100%|█| 750/750 [01:20<00:00,  9.33


saving model version 0.1
Fold Train MSE: 0.08264071911118447
Fold Test MSE: 0.20008421091520867
Naive Test MSE: 0.16670707843400048

WINDOW SLIDING: 20, LAG: 3
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 8.65e-02 | test_loss: 2.24e-01 | reg: 1.21e+02 | : 100%|█| 750/750 [01:20<00:00,  9.30


saving model version 0.1
Fold Train MSE: 0.08641593899672162
Fold Test MSE: 0.2235664177510921
Naive Test MSE: 0.18431856480922734

WINDOW SLIDING: 30, LAG: 3
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 8.50e-02 | test_loss: 2.80e-01 | reg: 1.23e+02 | : 100%|█| 750/750 [01:18<00:00,  9.52


saving model version 0.1
Fold Train MSE: 0.08487054103439665
Fold Test MSE: 0.2799849784960966
Naive Test MSE: 0.2550187901573791

WINDOW SLIDING: 0, LAG: 4
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 8.08e-02 | test_loss: 1.30e-01 | reg: 1.21e+02 | : 100%|█| 750/750 [01:18<00:00,  9.53


saving model version 0.1
Fold Train MSE: 0.08083823213505514
Fold Test MSE: 0.12973704182167167
Naive Test MSE: 0.09919677414109794

WINDOW SLIDING: 10, LAG: 4
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 8.10e-02 | test_loss: 2.39e-01 | reg: 1.23e+02 | : 100%|█| 750/750 [01:19<00:00,  9.47


saving model version 0.1
Fold Train MSE: 0.08075559279961135
Fold Test MSE: 0.23888998925420454
Naive Test MSE: 0.2294676810940195

WINDOW SLIDING: 20, LAG: 4
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 8.88e-02 | test_loss: 2.28e-01 | reg: 1.22e+02 | : 100%|█| 750/750 [01:19<00:00,  9.43


saving model version 0.1
Fold Train MSE: 0.08876650676045206
Fold Test MSE: 0.22781008200546632
Naive Test MSE: 0.29922956627534875

WINDOW SLIDING: 30, LAG: 4
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 8.68e-02 | test_loss: 2.73e-01 | reg: 1.23e+02 | : 100%|█| 750/750 [01:19<00:00,  9.40


saving model version 0.1
Fold Train MSE: 0.0866625856752807
Fold Test MSE: 0.2729952600881827
Naive Test MSE: 0.24907913735731993

WINDOW SLIDING: 0, LAG: 5
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 7.75e-02 | test_loss: 1.48e-01 | reg: 1.23e+02 | : 100%|█| 750/750 [01:20<00:00,  9.37


saving model version 0.1
Fold Train MSE: 0.07743177785799693
Fold Test MSE: 0.1478257142420265
Naive Test MSE: 0.10156278846112872

WINDOW SLIDING: 10, LAG: 5
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 7.78e-02 | test_loss: 2.50e-01 | reg: 1.26e+02 | : 100%|█| 750/750 [01:19<00:00,  9.40


saving model version 0.1
Fold Train MSE: 0.07771074430551927
Fold Test MSE: 0.25033751094265566
Naive Test MSE: 0.2195630129750151

WINDOW SLIDING: 20, LAG: 5
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 8.63e-02 | test_loss: 1.87e-01 | reg: 1.25e+02 | : 100%|█| 750/750 [01:19<00:00,  9.41


saving model version 0.1
Fold Train MSE: 0.08629425919955788
Fold Test MSE: 0.186722485442877
Naive Test MSE: 0.34475836948989846

WINDOW SLIDING: 30, LAG: 5
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 8.67e-02 | test_loss: 3.03e-01 | reg: 1.25e+02 | : 100%|█| 750/750 [01:20<00:00,  9.35

saving model version 0.1
Fold Train MSE: 0.0863870247650075
Fold Test MSE: 0.3034451026913566
Naive Test MSE: 0.2574805817921032

Sliding Window Cross-Validation Results
Average Train MSE: 0.08090339048537831
Average Test MSE: 0.2029063864367501
Average Naive MSE: 0.19050409432777726





In [None]:
WINDOW_LIST = [1]
TEST_SIZE = 20
TARGETS = df_ma.columns[:12]

# Store results for each fold
fold_results = {'train_mse': [], 'test_mse': [], 'naive_mse': []}

for LAG in range(1, 2): # steps into the future
    df_ma = ma_data_retrieval(window_list=WINDOW_LIST, lag=LAG)

    for cnt in range(0, 20, 20): # sliding window
        print()
        print(f'WINDOW SLIDING: {cnt}, LAG: {LAG}')

        df_window = df_ma[(len(df_ma)-cnt-250):(len(df_ma)-cnt)]

X, y = df_window.drop(columns=TARGETS), df_window[TARGETS]

X_train, X_test = X[:-TEST_SIZE], X[-TEST_SIZE:]
y_train, y_test = y[:-TEST_SIZE], y[-TEST_SIZE:]

df_ma

In [10]:
pred = model(dataset['test_input']).cpu().detach().numpy().flatten()
pred

array([5.21237847, 5.1619604 , 5.04709044, 4.67967744, 4.15193884,
       3.71704813, 3.56014822, 3.47621561, 3.59151907, 3.69388815,
       4.05337747, 3.97061296, 5.17205815, 5.12886437, 5.01424317,
       4.66327922, 4.14612473, 3.72125655, 3.5663183 , 3.46866433,
       3.58347315, 3.67848346, 4.03141349, 3.94227142, 5.13233578,
       5.09728157, 4.98248951, 4.64939049, 4.14524081, 3.72981473,
       3.5777566 , 3.4666548 , 3.57973041, 3.6675905 , 4.01431426,
       3.91733525, 5.07160511, 5.04674751, 4.93306038, 4.62497583,
       4.14623009, 3.74496875, 3.59395559, 3.47161179, 3.57972479,
       3.66129636, 3.99697002, 3.893777  , 4.99595397, 4.9818709 ,
       4.87144346, 4.59243931, 4.15340526, 3.7638003 , 3.61302775,
       3.48571821, 3.58438334, 3.65794603, 3.9813363 , 3.87494375,
       4.92350077, 4.91865864, 4.81355449, 4.56344598, 4.16675478,
       3.79046157, 3.63740181, 3.50891515, 3.59565802, 3.65637864,
       3.97176596, 3.86572954, 4.85906877, 4.85814391, 4.76263

In [11]:
df_naive

Unnamed: 0,1 Mo,2 Mo,3 Mo,6 Mo,1 Yr,2 Yr,3 Yr,5 Yr,7 Yr,10 Yr,20 Yr,30 Yr
2024-09-10,5.18,5.18,5.06,4.65,4.07,3.59,3.42,3.43,3.53,3.65,4.04,3.97
2024-09-10,5.18,5.18,5.06,4.65,4.07,3.59,3.42,3.43,3.53,3.65,4.04,3.97
2024-09-10,5.18,5.18,5.06,4.65,4.07,3.59,3.42,3.43,3.53,3.65,4.04,3.97
2024-09-10,5.18,5.18,5.06,4.65,4.07,3.59,3.42,3.43,3.53,3.65,4.04,3.97
2024-09-10,5.18,5.18,5.06,4.65,4.07,3.59,3.42,3.43,3.53,3.65,4.04,3.97
2024-09-10,5.18,5.18,5.06,4.65,4.07,3.59,3.42,3.43,3.53,3.65,4.04,3.97
2024-09-10,5.18,5.18,5.06,4.65,4.07,3.59,3.42,3.43,3.53,3.65,4.04,3.97
2024-09-10,5.18,5.18,5.06,4.65,4.07,3.59,3.42,3.43,3.53,3.65,4.04,3.97
2024-09-10,5.18,5.18,5.06,4.65,4.07,3.59,3.42,3.43,3.53,3.65,4.04,3.97
2024-09-10,5.18,5.18,5.06,4.65,4.07,3.59,3.42,3.43,3.53,3.65,4.04,3.97


In [12]:
y_test

Unnamed: 0_level_0,1 Mo,2 Mo,3 Mo,6 Mo,1 Yr,2 Yr,3 Yr,5 Yr,7 Yr,10 Yr,20 Yr,30 Yr
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
2024-09-11,5.21,5.19,5.1,4.72,4.12,3.62,3.45,3.45,3.54,3.65,4.03,3.96
2024-09-12,5.18,5.22,5.06,4.68,4.09,3.64,3.47,3.47,3.57,3.68,4.07,4.0
2024-09-13,5.15,5.17,4.97,4.6,4.0,3.57,3.42,3.43,3.53,3.66,4.05,3.98
2024-09-16,5.11,5.1,4.96,4.55,3.96,3.56,3.42,3.41,3.51,3.63,4.01,3.94
2024-09-17,5.05,5.05,4.95,4.55,3.99,3.59,3.45,3.44,3.53,3.65,4.02,3.96
2024-09-18,4.91,4.91,4.84,4.5,3.95,3.61,3.49,3.47,3.58,3.7,4.08,4.03
2024-09-19,4.89,4.91,4.8,4.46,3.93,3.59,3.47,3.49,3.6,3.73,4.11,4.06
2024-09-20,4.87,4.88,4.75,4.43,3.92,3.55,3.46,3.48,3.59,3.73,4.1,4.07
2024-09-23,4.85,4.84,4.72,4.4,3.91,3.57,3.47,3.51,3.62,3.75,4.12,4.09
2024-09-24,4.78,4.78,4.69,4.36,3.88,3.49,3.44,3.47,3.6,3.74,4.13,4.09


In [90]:
mean_squared_error(df_naive.values.flatten(), y_test.values.flatten(), squared=False)

np.float64(0.28945350806879727)

In [92]:
x1 = df_naive.values.flatten()
x2 = y_test.values.flatten()
print(x1)
print(x2)

[4.96 4.85 4.75 4.44 4.21 3.98 3.86 3.86 3.94 4.04 4.38 4.32]
[4.68 4.71 4.64 4.41 4.31 4.27 4.2  4.27 4.37 4.42 4.71 4.6 ]


## Optuna training

In [None]:
import optuna

def objective(trial):
    x = trial.suggest_float('x', -10, 10)
    return (x - 2) ** 2

study = optuna.create_study()
study.optimize(objective, n_trials=100)

study.best_params  # E.g. {'x': 2.002108042}

In [None]:
import optuna
import torch

def train_mse(model, dataset):
    predictions = model(dataset['train_input'])  # Model predictions
    loss = torch.nn.functional.mse_loss(predictions, dataset['train_label'])
    return loss

def test_mse(model, dataset):
    predictions = model(dataset['test_input'])  # Model predictions
    loss = torch.nn.functional.mse_loss(predictions, dataset['test_label'])
    return loss

# Define the objective function for Optuna
def objective(trial):
    # Define the hyperparameter search space
    n_layers = trial.suggest_int('n_layers', 1, 2)  # Number of layers in the network
    layer_sizes = [trial.suggest_int(f'n_units_l{i}', 16, 64, step=16) for i in range(n_layers)]
    grid = trial.suggest_int('grid', 2, 4)          # Example parameter for KAN
    lamb = trial.suggest_float('lamb', 1e-4, 1e-2, log=True)  # Regularization rate
    lr = trial.suggest_float('lr', 1e-4, 1e-2, log=True)       # Learning rate
    steps = trial.suggest_int('steps', 500, 2000, step=500)   # Training steps

    # Model architecture
    width = [n_inputs] + layer_sizes + [n_outputs]

    # Initialize dataset
    dataset = dict()
    dtype = torch.get_default_dtype()
    dataset['train_input'] = torch.from_numpy(X_train.values).type(dtype).to(device)
    dataset['train_label'] = torch.from_numpy(y_train.values).type(dtype).to(device)
    dataset['test_input'] = torch.from_numpy(X_test.values).type(dtype).to(device)
    dataset['test_label'] = torch.from_numpy(y_test.values).type(dtype).to(device)

    # Initialize the model
    model = KAN(width=width, grid=grid, k=2, seed=42, device=device)

    # Train the model
    results = model.fit(
        dataset, 
        opt="Adam", 
        lamb=lamb, 
        lr=lr, 
        steps=steps, 
        metrics=(lambda: train_mse(model, dataset), lambda: test_mse(model, dataset))
    )

    # Retrieve the metric (e.g., test MSE) from the results
    test_mse_value = results['test_loss'][-1]
    return test_mse_value  # Minimize test MSE

# Create an Optuna study
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=50)

# Best parameters and results
print("Best parameters:", study.best_params)
print("Best test MSE:", study.best_value)
