# Example 2: Speeding up

Major concerns about KANs are their slow running speed and huge memory. This is mainly due to the naive implementation of the first version. We have done a few efficiency updates in the new release. 

* We update the spline evaluation method, inspired from the efficientKAN repo.
* We provide a method to speed up training, simply call model = model.speed(). In this speed mode, the symbolic front is skipped (which will save computation time), and intermediate activations are not saved (which save memory).

### Below we compare the normal mode and the speed mode

The Normal mode without speeding

In [1]:
from kan import *

seed = 1
torch.manual_seed(seed)

# initialize KAN with G=3
model = KAN(width=[2,30,30,1], grid=3, k=3, seed=1)

# create dataset
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
model.fit(dataset, opt="Adam", steps=10, update_grid=False);

Directory already exists: ./model


train loss: 1.45e+01 | test loss: 2.91e+01 | reg: 6.95e+03 : 100%|██| 10/10 [00:06<00:00,  1.66it/s]


The speed mode

In [2]:
from kan import *

seed = 1
torch.manual_seed(seed)

# initialize KAN with G=3
model = KAN(width=[2,30,30,1], grid=3, k=3, seed=1)
model = model.speed()

# create dataset
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
model.fit(dataset, opt="Adam", steps=10, update_grid=False);

Directory already exists: ./model


train loss: 1.53e+01 | test loss: 2.52e+01 | reg: 0.00e+00 : 100%|██| 10/10 [00:00<00:00, 10.03it/s]


However, the speed mode does not save intermediate activations.

In [3]:
model.acts_scale

[]

In [4]:
model.save_plot_data = True
model.get_act(dataset)
model.acts_scale

[tensor([[ 7.6856, 16.2691],
         [15.3356,  4.7036],
         [ 1.7350,  0.5762],
         [13.7567,  0.6547],
         [ 2.0027,  1.1051],
         [ 1.3028, 28.8967],
         [ 3.1434,  0.9702],
         [20.5310,  0.9579],
         [10.5522,  3.1435],
         [ 1.8792,  2.7931],
         [27.2825,  5.3805],
         [ 5.1843,  0.9814],
         [ 5.2057,  5.5816],
         [ 9.4054,  0.3300],
         [ 8.7830,  1.9473],
         [ 1.6670,  5.7267],
         [ 6.4220, 28.1405],
         [18.4623, 10.7246],
         [33.4021,  0.9486],
         [ 1.6931,  3.2448],
         [ 7.3774, 16.6039],
         [ 1.5200,  1.4180],
         [ 6.8054,  0.9019],
         [ 1.9008,  1.0136],
         [ 0.4024,  4.3498],
         [ 4.3653,  5.3940],
         [ 4.2989,  3.5657],
         [ 4.5522,  0.4301],
         [ 0.5495,  0.2769],
         [24.9288,  4.5964]]),
 tensor([[4.4354e+00, 5.6637e-02, 4.2377e+00, 8.8434e-01, 3.7175e+00, 8.9901e-01,
          2.0012e+00, 7.9254e-01, 1.5717e+00, 