In [None]:
import torch
from kan import *


In [None]:
from sklearn.datasets import make_friedman1

def get_dataset(n_samples, n_features, noise):
    X_train, y_train = make_friedman1(n_samples=int(0.8 * n_samples), n_features=n_features, noise = noise)
    X_test,y_test = make_friedman1(n_samples=int(0.2 * n_samples), n_features=n_features, noise = 0.)
    # Store the dataset in the desired dictionary format
    dataset_friedmann = {
        'train_input': torch.tensor(X_train, dtype = torch.float32),
        'test_input': torch.tensor(X_test, dtype = torch.float32),
        'train_label': torch.tensor(y_train, dtype = torch.float32).unsqueeze(1),
        'test_label': torch.tensor(y_test, dtype = torch.float32).unsqueeze(1)
    }
    return dataset_friedmann



In [None]:
model = KAN(width=[15,6,2,1], grid=3, k=3)
dataset = get_dataset(n_samples = 20000, n_features = 15 , noise = 1.)

In [None]:
model.fit(dataset, opt="LBFGS", steps=20, batch = 1024, lamb = 0., lamb_l1 = 0, lamb_coeffdiff = 0, lamb_entropy = 0)
grids = [10,20,50]
models = []
train_rmse = []
test_rmse = []
for i in range(len(grids)):
    #model = KAN(width=[4,9,1], grid=grids[i], k=3, seed=0).initialize_from_another_model(model, dataset['train_input'])
    model = model.refine(new_grid = grids[i])
    results = model.fit(dataset, opt="LBFGS", steps=50, stop_grid_update_step=30, batch = 1024)
    train_rmse.append(results['train_loss'][-1].item())
    test_rmse.append(results['test_loss'][-1].item())
    models.append(model)

In [None]:
models[0]

In [None]:
import random
mod = models[1]
criterion = torch.nn.MSELoss()
X_train, y_train = make_friedman1(n_samples=20000, n_features=15, random_state=random.randint(1,100), noise = 0.)
out = mod(torch.tensor(X_train, dtype = torch.float32))#
print(mod.grid)
criterion(out, torch.tensor(y_train, dtype = torch.float32).unsqueeze(1))


In [12]:
import time
import dill
adjust = True 
widths = [10,20] ## Grids
n_samples = 20000
in_dims = [5,10,15,100]
in_dims = [5,10,15,100]
noises = [0.,0.2,0.5,1.]
train_losses = []
test_losses = []
for in_dim in in_dims:
    for noise in noises:
        models = []
        for width in widths:
            model = KAN(width=[in_dim,6,2,1], grid=width, k=3)
            dataset = get_dataset(n_samples = 20000, n_features = in_dim , noise = noise)
            print(in_dim, noise, width)
            results = model.fit(dataset, opt="Adam", steps= 100, stop_grid_update_step=30, batch = -1)
            models.append(model)
        with open(f"models/Friedmann_1_KAN_spline_Adam{noise}_{in_dim}.dill", "wb") as f:
            dill.dump(models, f)

checkpoint directory created: ./model
saving model version 0.0
5 0.0 10


| train_loss: 7.27e-02 | test_loss: 8.19e-02 | reg: 5.27e+01 | : 100%|█| 100/100 [2:44:17<00:00, 98.


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
5 0.0 20


| train_loss: 1.15e-01 | test_loss: 1.86e-01 | reg: 5.89e+01 | : 100%|█| 100/100 [2:59:58<00:00, 107


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
5 0.2 10


| train_loss: 2.01e-01 | test_loss: 1.02e-01 | reg: 5.73e+01 | : 100%|█| 100/100 [2:43:28<00:00, 98.


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
5 0.2 20


| train_loss: 1.99e-01 | test_loss: 3.64e-01 | reg: 4.04e+01 | : 100%|█| 100/100 [2:53:22<00:00, 104


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
5 0.5 10


| train_loss: 5.24e-01 | test_loss: 2.69e-01 | reg: 5.30e+01 | : 100%|█| 100/100 [2:03:23<00:00, 74.


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
5 0.5 20


| train_loss: 4.79e-01 | test_loss: 1.58e-01 | reg: 5.80e+01 | : 100%|█| 100/100 [41:21<00:00, 24.81


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
5 1.0 10


| train_loss: 9.78e-01 | test_loss: 3.33e-01 | reg: 5.87e+01 | : 100%|█| 100/100 [22:50<00:00, 13.70


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
5 1.0 20


| train_loss: 9.29e-01 | test_loss: 4.93e-01 | reg: 5.70e+01 | : 100%|█| 100/100 [28:44<00:00, 17.24


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
10 0.0 10


| train_loss: 1.00e-02 | test_loss: 1.16e-02 | reg: 3.88e+01 | : 100%|█| 100/100 [22:49<00:00, 13.70


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
10 0.0 20


| train_loss: 7.02e-03 | test_loss: 7.96e-03 | reg: 4.59e+01 | : 100%|█| 100/100 [32:47<00:00, 19.68


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
10 0.2 10


| train_loss: 1.90e-01 | test_loss: 6.37e-02 | reg: 4.03e+01 | : 100%|█| 100/100 [23:31<00:00, 14.12


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
10 0.2 20


| train_loss: 1.80e-01 | test_loss: 9.83e-02 | reg: 4.65e+01 | : 100%|█| 100/100 [33:14<00:00, 19.94


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
10 0.5 10


| train_loss: 4.67e-01 | test_loss: 1.78e-01 | reg: 4.41e+01 | : 100%|█| 100/100 [23:09<00:00, 13.89


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
10 0.5 20


| train_loss: 4.37e-01 | test_loss: 2.98e-01 | reg: 4.88e+01 | : 100%|█| 100/100 [32:22<00:00, 19.42


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
10 1.0 10


| train_loss: 9.12e-01 | test_loss: 4.52e-01 | reg: 4.73e+01 | : 100%|█| 100/100 [22:50<00:00, 13.70


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
10 1.0 20


| train_loss: 1.00e+00 | test_loss: 9.58e-01 | reg: 5.35e+01 | : 100%|█| 100/100 [31:43<00:00, 19.03


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
15 0.0 10


| train_loss: 2.25e-01 | test_loss: 2.45e-01 | reg: 6.17e+01 | : 100%|█| 100/100 [26:33<00:00, 15.93


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
15 0.0 20


| train_loss: 2.29e-01 | test_loss: 3.06e-01 | reg: 5.95e+01 | : 100%|█| 100/100 [37:53<00:00, 22.73


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
15 0.2 10


| train_loss: 5.60e-01 | test_loss: 5.60e-01 | reg: 6.44e+01 | : 100%|█| 100/100 [27:06<00:00, 16.27


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
15 0.2 20


| train_loss: 2.62e-01 | test_loss: 2.76e-01 | reg: 5.73e+01 | : 100%|█| 100/100 [37:39<00:00, 22.59


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
15 0.5 10


| train_loss: 6.13e-01 | test_loss: 4.33e-01 | reg: 6.43e+01 | : 100%|█| 100/100 [27:13<00:00, 16.33


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
15 0.5 20


| train_loss: 4.68e-01 | test_loss: 2.85e-01 | reg: 6.56e+01 | : 100%|█| 100/100 [37:43<00:00, 22.63


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
15 1.0 10


| train_loss: 1.27e+00 | test_loss: 9.30e-01 | reg: 7.08e+01 | : 100%|█| 100/100 [25:55<00:00, 15.56


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
15 1.0 20


| train_loss: 1.32e+00 | test_loss: 1.66e+00 | reg: 6.20e+01 | : 100%|█| 100/100 [36:07<00:00, 21.67


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
100 0.0 10


| train_loss: 1.27e+00 | test_loss: 1.38e+00 | reg: 1.47e+02 | : 100%|█| 100/100 [1:18:41<00:00, 47.


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
100 0.0 20


| train_loss: 2.09e-01 | test_loss: 2.22e+00 | reg: 1.33e+02 | : 100%|█| 100/100 [1:57:41<00:00, 70.


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
100 0.2 10


| train_loss: 1.60e+00 | test_loss: 1.68e+00 | reg: 1.10e+02 | : 100%|█| 100/100 [1:19:40<00:00, 47.


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
100 0.2 20


| train_loss: 1.70e-01 | test_loss: 1.99e+00 | reg: 1.12e+02 | : 100%|█| 100/100 [2:35:56<00:00, 93.


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
100 0.5 10


| train_loss: 1.56e+00 | test_loss: 1.55e+00 | reg: 1.10e+02 | : 100%|█| 100/100 [2:06:17<00:00, 75.


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
100 0.5 20


| train_loss: 1.11e-01 | test_loss: 9.22e-01 | reg: 9.42e+01 | : 100%|█| 100/100 [3:11:20<00:00, 114


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
100 1.0 10


| train_loss: 2.02e+00 | test_loss: 1.93e+00 | reg: 1.15e+02 | : 100%|█| 100/100 [2:58:03<00:00, 106


saving model version 0.1
checkpoint directory created: ./model
saving model version 0.0
100 1.0 20


| train_loss: 3.35e-01 | test_loss: 2.52e+00 | reg: 9.38e+01 | : 100%|█| 100/100 [3:46:51<00:00, 136


saving model version 0.1
