Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FastKAN approximates KANs with RBFs achieving 3+ times acceleration #141

Closed
ZiyaoLi opened this issue May 9, 2024 · 22 comments
Closed

FastKAN approximates KANs with RBFs achieving 3+ times acceleration #141

ZiyaoLi opened this issue May 9, 2024 · 22 comments

Comments

@ZiyaoLi
Copy link

ZiyaoLi commented May 9, 2024

As a 3-order BSpline is used (most commonly), they can be well approximated by Gaussian RBF functions.

LayerNorm is useful to avoid grid re-scaling.

Two summing up to a much faster implementation (approximation) of KAN: FastKAN. See here

@ZiyaoLi
Copy link
Author

ZiyaoLi commented May 10, 2024

Now FastKAN is 3+ times faster (fwd) compared with efficient_kan. Believe you all want to try this.

https://github.com/ZiyaoLi/fast-kan

@1ssb
Copy link

1ssb commented May 10, 2024

Very interesting, what is your benchmark and validation performance?

@ZiyaoLi
Copy link
Author

ZiyaoLi commented May 10, 2024

Very interesting, what is your benchmark and validation performance?

Results are shown in the provided repo. I tested the forward speed of my FastKANLayer and efficient KAN's KANLinear, and the results are 740us -> 220us.

@AthanasiosDelis
Copy link

AthanasiosDelis commented May 10, 2024

I think it is even faster:

https://github.com/AthanasiosDelis/fast-kan-playground

I run benchmarks similar to https://github.com/Jerry-Master/KAN-benchmarking for uniformity of comparisons.

For me, the most important thing is to test if Pykan indeed has the continuous learning capabilities that it promises and if Fast-Kan inherits these capabilities as well as the ability for pruning and symbolic regression.

@1ssb
Copy link

1ssb commented May 10, 2024

My version gets around:

Forward pass took 0.000297 - 0.0016 (Best-Worst) seconds

while hitting ~98% accuracy on MNIST, what is the comparison of time vs accuracy for FastKAN?

Edit: In the training dynamics, its supposed to be variable, depending on what the loss function, data dimensionality and a host of other things, what is the benchmark exactly?

@AthanasiosDelis
Copy link

AthanasiosDelis commented May 10, 2024

I inspected your github and adjusted my
lr=1e-3,
weight_decay=1e-5,
gamma=0.85,
with yours, @1ssb.

So for the results I got after 15 epochs:

-with FastKAN([28 * 28, 64, 10], grid_min = -3., grid_max = 3., num_grids = 4, exponent = 2, denominator = 1.7)

Total parameters: 255858
Trainable parameters: 255850

100%|█| 938/938 [00:16<00:00, 58.10it/s, accuracy=0.969, loss=0.045, lr=0.0
Epoch 15, Val Loss: 0.07097885620257162, Val Accuracy: 0.9798964968152867

-with MLP(layers=[28 * 28, 320, 10], device='cuda')

Total parameters: 254410
Trainable parameters: 254410

100%|█| 938/938 [00:15<00:00, 59.52it/s, accuracy=0.969, loss=1.47, lr=0.00
Epoch 15, Val Loss: 1.4862790791092404, Val Accuracy: 0.9756170382165605

Results from comparison of these networks in a dataset with comparable scale to MNIST, generated with the create_dataset:

forward backward forward backward num params num trainable params
fastkan-gpu 0.83 ms 1.27 ms 0.02 GB 0.02 GB 255858 255850
mlp-gpu 0.25 ms 0.62 ms 0.02 GB 0.02 GB 254410 254410
effkan-gpu 2.39 ms 2.18 ms 0.03 GB 0.03 GB 508160 508160

Result accuracy comparaple, network width smaller, Fast-KAN remains slower, memory comparable (have not tested with the original Fast-KAN yet).

Still, the big questions are how you adapt FastKAN to perform symbolic regression, testing for continuous learning, and also the relationship between RBF parameters and the B-Spline grid parameters. Are you working currently in any of those 3 @ZiyaoLi ?

@1ssb
Copy link

1ssb commented May 10, 2024

Thanks, kindly take a look at KAL Net that I have just released if you get time.

@AthanasiosDelis
Copy link

image
@1ssb

@1ssb
Copy link

1ssb commented May 10, 2024

Thanks a lot, I think it indicates that my model is a bit bulkier as expected because of the recursive stacking.

@AthanasiosDelis
Copy link

I also put the original FastKAN in the game. For some reason, I cannot easily minimize the trainable parameters of effecient-kan using only the grid_size and spline_order.

image

@1ssb
Copy link

1ssb commented May 10, 2024

I think its a bit incomplete to not have a measure on the expressivity or performances, I think FastKAN outperforms overall?

@AthanasiosDelis
Copy link

I think yes, FastKAN-like implementations that use RBF approximations are the Fastest. I am aware of 3 implementations so far, that are all more or less extremly similar:

RBF-KAN
fast-kan-playground
fast-kan

Later tonight, I will compare also with RBF-KAN.

@LiZhenzhuBlog
Copy link

wonderful

@ZiyaoLi
Copy link
Author

ZiyaoLi commented May 11, 2024

I think yes, FastKAN-like implementations that use RBF approximations are the Fastest. I am aware of 3 implementations so far, that are all more or less extremly similar:

RBF-KAN fast-kan-playground fast-kan

Later tonight, I will compare also with RBF-KAN.

This RBF-KAN is simply a copy of my FastKAN code without acknowledgement, even with the same variable names.

@ZiyaoLi
Copy link
Author

ZiyaoLi commented May 11, 2024

I inspected your github and adjusted my lr=1e-3, weight_decay=1e-5, gamma=0.85, with yours, @1ssb.

So for the results I got after 15 epochs:

-with FastKAN([28 * 28, 64, 10], grid_min = -3., grid_max = 3., num_grids = 4, exponent = 2, denominator = 1.7)

Total parameters: 255858 Trainable parameters: 255850

100%|█| 938/938 [00:16<00:00, 58.10it/s, accuracy=0.969, loss=0.045, lr=0.0 Epoch 15, Val Loss: 0.07097885620257162, Val Accuracy: 0.9798964968152867

-with MLP(layers=[28 * 28, 320, 10], device='cuda')

Total parameters: 254410 Trainable parameters: 254410

100%|█| 938/938 [00:15<00:00, 59.52it/s, accuracy=0.969, loss=1.47, lr=0.00 Epoch 15, Val Loss: 1.4862790791092404, Val Accuracy: 0.9756170382165605

Results from comparison of these networks in a dataset with comparable scale to MNIST, generated with the create_dataset:

forward backward forward backward num params num trainable params
fastkan-gpu 0.83 ms 1.27 ms 0.02 GB 0.02 GB 255858 255850
mlp-gpu 0.25 ms 0.62 ms 0.02 GB 0.02 GB 254410 254410
effkan-gpu 2.39 ms 2.18 ms 0.03 GB 0.03 GB 508160 508160
Result accuracy comparaple, network width smaller, Fast-KAN remains slower, memory comparable (have not tested with the original Fast-KAN yet).

Still, the big questions are how you adapt FastKAN to perform symbolic regression, testing for continuous learning, and also the relationship between RBF parameters and the B-Spline grid parameters. Are you working currently in any of those 3 @ZiyaoLi ?

Not exactly. What FastKAN found is that KANs are essentially RBF networks. If you check the history of RBF networks you'll see that it's been widely inspected. Efficiency wouldn't be the most important problem. The problem would now be: if KANs are really that good.

@1ssb
Copy link

1ssb commented May 11, 2024 via email

@ZiyaoLi
Copy link
Author

ZiyaoLi commented May 11, 2024

I don't think that always holds: KANs are RBFs. If you approximate splines with RBFs that may be true but I do not agree with this blanket generalisation. This is why I think one should enquire deeper into other approximations, and if these start exhibiting a variety of properties that just proves thatt KANs are dominated by the approximate basis functions, which is intuitive. First can you clearly outline why you think KANs are essentially RBFs? From the history of RBFs we know they are not very scalable nor expressive as their affine counterparts.

@1ssb This would be an interesting discussion that is far beyond this issue lol.

The claim isn't steady indeed. My claim that KANs are RBFs should be narrowed as "3-order B-Spline KANs as implemented in pykan are very much the same as FastKAN, which is a univariate RBF network". This is because that "3-order B-Spline basis can be numerically approximated by univariate Gaussian RBFs", as you've concluded.

@ZiyaoLi ZiyaoLi changed the title A fast approximation of KAN wrt BSpline & grid scaling FastKAN approximates KANs with RBFs achieving 3+ times acceleration May 11, 2024
@1ssb
Copy link

1ssb commented May 11, 2024 via email

@ZiyaoLi
Copy link
Author

ZiyaoLi commented May 11, 2024

Haha relax, it's just about bridging KANs with the theories that you mentioned. It's not going to be something as important as MLPs or say Transformers anyway :p

@1ssb
Copy link

1ssb commented May 11, 2024

I was really hopeful 😭

@AthanasiosDelis
Copy link

I think yes, FastKAN-like implementations that use RBF approximations are the Fastest. I am aware of 3 implementations so far, that are all more or less extremly similar:
RBF-KAN fast-kan-playground fast-kan
Later tonight, I will compare also with RBF-KAN.

This RBF-KAN is simply a copy of my FastKAN code without acknowledgement, even with the same variable names.

Yeah it was my sad realisation also.

@AthanasiosDelis
Copy link

I have updates. I switched from RBF to approximation with the RSWAF approximation function:
image
I also zeroed the SiLu part. The results are even faster, and MNIST still yields 97.7% accuracy:
image
image
Now that I have significantly distanced myself from FastKAN's implementation in terms of mathematics, I thought it proper to rename it FasterKAN, but I still keep the original references because, nevertheless, I am based on @ZiyaoLi's code base.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants