In [1]:
from decimal import Decimal
import numpy as np
from better_kan import KAN as BetterKAN
from kan import create_dataset,LBFGS
import torch
from torch import nn
import matplotlib.pyplot as plt
import scipy.special

def convert_func(f):
    return lambda x:f(*x.unsqueeze(-1).permute(1,0,2))

def size_of_model(model):
    size=0
    for p in model.parameters():
        size+=p.numel()*p.element_size()
    return size
def params_of_model(model):
    return size_of_model(model)/8
def ellipj(x,y):
    return scipy.special.ellipj(x,y)[0].real.float()

def ellipkinc(x,y):
    return scipy.special.ellipkinc(x,y)

def ellipeinc(x,y):
    return scipy.special.ellipeinc(x,y)
class MLP(torch.nn.Module):
    def __init__(self, in_features, out_features, hidden_dim=128, num_layers=3, activation=torch.nn.SiLU(),seed=0):
        torch.manual_seed(seed)
       
        super(MLP, self).__init__()
        self.layers = nn.Sequential()
        self.layers.append(nn.Linear(in_features, hidden_dim))
        self.layers.append(activation)
        for _ in range(num_layers - 2):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))
            self.layers.append(activation)
        self.layers.append(nn.Linear(hidden_dim, out_features))
    def forward(self, x):
        return self.layers(x)
from tqdm import tqdm
def test_kan(f,ranges,width,device="cuda",steps=50,seed=0,degrees=(5,10,20,50)):
    n_var=width[0]
    print(n_var)
    dataset=create_dataset(convert_func( f),n_var,ranges,device=device,train_num=10000)
    y=dataset["train_label"]
    print(y.min(),y.max()  )
    kan_losses={}
    mlp_losses={}
    
 
        
    for deg in degrees:
        
        print(f"BetterKAN {deg}")
        if deg ==degrees[0]:
            
            kan=BetterKAN(width,deg,device=device,bias_trainable=False,base_fn=...,symbolic_enabled=False).to(device)
        else:
            old=kan
            kan=BetterKAN(width,deg,device=device,bias_trainable=False,base_fn=...,symbolic_enabled=False).to(device)
            kan.initialize_grid_from_another_model(old,x=dataset["test_input"])
        result=kan.train(dataset,lr=0.1,grid_update_num=10,steps=steps)
        loss=torch.nn.MSELoss()(kan(dataset["test_input"]),dataset["test_label"])
        kan_losses.update({params_of_model(kan):loss.item()})
        
    
      
    for depth in [3,4,5,6]:
        mlp_loss={}
        mlp_losses[f"depth={depth}"]=mlp_loss
        for features in [8,16,32,64,128,256]:
            mlp=MLP(n_var,1,hidden_dim=features,num_layers=depth,seed=seed).to(device)
            opt=LBFGS(mlp.parameters(),lr=0.1)
            for _ in tqdm( range(steps)):
                def closure():
                    opt.zero_grad()
                    loss=torch.nn.MSELoss()(mlp(dataset["train_input"]),dataset["train_label"])
                    loss.backward()
                    return loss
                loss=opt.step(closure)
            mlp_loss[params_of_model(mlp)]=loss.item()
    plot,ax=plt.subplots()
    plt.plot(kan_losses.keys(),kan_losses.values(),label="KAN")
    for depth,mlp_loss in mlp_losses.items():
        plt.plot(mlp_loss.keys(),mlp_loss.values(),label=depth)
    plt.legend()
    ax.set_xlabel("Number of parameters")
    ax.set_ylabel("Test MSE")
    ax.set_xscale("log")
    ax.set_yscale("log")
    # ax.set_ylim(1e-10,1)

                
            
            
            


In [2]:
test_kan(ellipkinc,[[-1,1],[0,1]],[2,2,1])



2
tensor(-1.2142, device='cuda:0') tensor(1.2078, device='cuda:0')
BetterKAN 5


train loss: 5.88e-04 | test loss: 6.16e-04 | reg: 0.00e+00 : 100%|██| 50/50 [00:35<00:00,  1.42it/s]


BetterKAN 10
torch.Size([1000, 2]) acts
torch.Size([1000, 2]) acts


train loss: 2.87e-04 | test loss: 3.06e-04 | reg: 0.00e+00 : 100%|██| 50/50 [00:15<00:00,  3.19it/s]


BetterKAN 20
torch.Size([1000, 2]) acts
torch.Size([1000, 2]) acts


train loss: 2.72e-04 | test loss: 2.70e-04 | reg: 0.00e+00 : 100%|██| 50/50 [00:04<00:00, 10.32it/s]


BetterKAN 50
torch.Size([1000, 2]) acts
torch.Size([1000, 2]) acts


train loss: 1.08e-04 | test loss: 1.06e-04 | reg: 0.00e+00 : 100%|██| 50/50 [00:08<00:00,  5.66it/s]
100%|██████████| 50/50 [00:13<00:00,  3.60it/s]
 98%|█████████▊| 49/50 [00:24<00:00,  1.06it/s]