In [None]:
from kan import *
import torch
from kan.utils import create_dataset
import matplotlib.pyplot as plt

device = ('cuda' if torch.cuda.is_available() else 'cpu')
device

### CREATE A kan KAN

In [None]:
kan = KAN([2,5,1], device=device, grid = 3, k = 3, seed = 2)

### CREATE A DATASET FROM FUNCTION DEFINED AS BELOW

In [None]:
#create dataset
#f(x,y) = e^(sin(pi*x)+y^2) + noise
f = lambda x : torch.exp(torch.cos(torch.pi * x[:, [0]]) + x[:, [1]]**2) + torch.randn_like(x[:,[1]])*0.1
dataset = create_dataset(f, n_var=2, device = device, train_num = 1000, test_num = 800)

train = dataset['train_input']
train_target = dataset['train_label']
print(train.shape)
print(train_target.shape)

### LET THE kan SEE THE DATA AND PLOT THE INITIAL kan


In [None]:
kan(dataset['train_input'])
kan.plot()

### TRAIN WITH L1 REGULARIZATION FOR LATER SPARSIFICATION
#### L1 norm of an activation function $\phi$ over its $N_p$ input : $$|\phi|_1 = \frac{1}{N_p} \times \sum_{s=1}^{N_p}{\phi (x_s)} $$
#### For a KAN layer $ \ Phi $ with $n_{in}$ inputs and $n_{out}$ outputs, L1 norm of the layer is the sum of L1 norms of all activation function in that layer : $$|\Phi|_1 = \sum_{i=1}^{n_{in}}{\sum_{j=1}^{n_{j}}{|\phi_{i,j}|_1}} $$
#### Entropy of KAN Layer $\Phi$ : $$ S(\Phi) = \sum_{i=1}^{n_{in}}{\sum_{j=1}^{n_{j}}}{\frac{|\phi_{i,j}|_1}{|\Phi|_1} \times log(\frac{|\phi_{i,j}|_1}{|\Phi|_1})} $$

#### Total loss of KAN with L layers : $$l_{total} = l_{pred} + \lambda \times (\lambda_1 \sum_{l=0}^{L-1}{|\Phi|_1} +\lambda_2 \sum_{l=0}^{L-1}{S(\Phi_l)}) $$

In [None]:
res = kan.fit(dataset, opt="LBFGS", steps=40, lamb=0.01, lamb_l1=10., lamb_entropy=10.)
plt.plot(res['train_loss'])
plt.show()
kan.plot()

### PRUNING (SPARSIFICATION)
##### For a node of layer $l^{th}$ , $i^{th}$ neuron : 
##### $$ I_{l,i} = max_k(|\phi_{l-1, i, k}|1)$$ 
##### $$ O_{l,i} = max_j(|\phi_{l+1, i, j}|1)$$
##### A node is considered "important" if both score are greater than a threshold $\theta$ (set = 0.01)

In [None]:
kan = kan.prune(node_th=0.01)
kan.plot()

### CONTINUE TRAINING ON SPARSIFIED KAN (AND EXTEND GRID)

In [None]:
kan.fit(dataset, opt= "LBFGS", steps = 10)
kan = kan.refine(10)    
res2 = kan.fit(dataset, opt="LBFGS", steps=10)
plt.plot(res2['train_loss'])
plt.show()
kan.plot()

### SET ACTIVATION FUNCTIONS (SPLINES -> SYMBOLIC)

In [None]:
mode = "auto" # "manual"
from kan.utils import SYMBOLIC_LIB
if mode == "manual":
    # manual mode
    kan.fix_symbolic(0,0,0,'sin')
    kan.fix_symbolic(0,1,0,'x^2')
    kan.fix_symbolic(1,0,0,'exp')
elif mode == "auto":
    # automatic mode
    lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
    kan.auto_symbolic(lib=lib)

### CONTINUE TRAINING 

In [None]:

res3 = kan.fit(dataset, opt="LBFGS", steps=15)
plt.plot(res3['train_loss'])
plt.show()

### OPTAIN SYMBOLIC 

In [None]:
from kan.utils import ex_round
ex_round(kan.symbolic_formula()[0][0],4)