Skip to content

GistNoesis/FusedFourierKAN

Repository files navigation

FusedFourierKAN

C++ and Cuda ops for fused FourierKAN. See https://github.com/GistNoesis/FourierKAN for naive version explaining what it's about (and be aware that dimensions order may differ and be subject to change).

LICENSE

Code is proprietary non-commercial, for research purposes only. Contact us gistnoesis@gmail.com for commercial licenses. See LICENSE file for additional disclaimers.

Some independant (faster and simpler to install) Apache-Licensed implementation of FusedFourierKan can be found here : https://github.com/Jerry-Master/KAN-benchmarking/tree/master/extra/kernel

What is this about

Writing a custom op allow to not materialize memory. Zero extra memory needed. In addition it also allows to do some trigonometric trick to compute cos(kx) and sin(kx) more efficiently.

The core is quite simple :

void ffkan( float* x, float* coeff, float* bias, int bs, int inputdim, int outputdim, int gridsize, float* out )
{
const int s_bs_out = outputdim;
const int s_bs_x = inputdim;
//Coeff shape (2,inputdim,outputdim,gridsize)
const int s_d_coeff= inputdim*outputdim*gridsize;
const int s_i_coeff = outputdim*gridsize;
const int s_o_coeff = gridsize;
for( int i = 0 ; i < bs ; i++)
for( int j = 0 ; j < inputdim ; j++)
{
float xx = x[i*s_bs_x+j];
float c0 = cosf(xx);
float s0 = sinf(xx);
for( int l = 0 ; l < outputdim ; l++)
{
float ckm = 1.0f;
float skm = 0.0f;
for( int k = 1 ; k < gridsize+1 ; k++)
{
//float xx = x[i*s_bs_x+j];
//For better performance We use trig formula to compute ck,sk from ck-1, sk-1, cos(xx),sin(xx)
//But this form is better to obtain the bacwkard pass
//float c = cos(k*xx);
//float s = sin(k*xx);
float c = ckm*c0-skm*s0;
float s = skm*c0+ckm*s0;
ckm = c;
skm = s;
out[i*s_bs_out+l] += coeff[s_d_coeff*0 + s_i_coeff*j + s_o_coeff*l + k-1] * c;
out[i*s_bs_out+l] += coeff[s_d_coeff*1 + s_i_coeff*j + s_o_coeff*l + k-1] * s;
}
}
}
for( int i = 0 ; i < bs ; i++)
for( int l = 0 ; l < outputdim ; l++)
out[i*s_bs_out+l] += bias[l];
}

We had to write the forward and backward ops, for cpu and gpu, and some wrapper to make it available to pytorch.

In ffKANFunction.py and ffKANGPUFunction.py we verify that the functions and their gradient are approximately the same as the target.

The GPU version is not optimized, but run in parallel and is deterministic, in particular memory access are not yet coalesced or cached.

The structure and CMakeLists.txt allows for fast compilation time, that allows rapid iteration.

INSTALL

Sorry it'll probably be painful. Only tested on linux, (although some users got it to work on windows with minor modifications to the loading path see #3 )

It still has rough edges but the happy path is the following :

git clone https://github.com/GistNoesis/FusedFourierKAN.git
cd FusedFourierKAN
cd build
cmake ..
make
cd ..
pip install -e .

For it to be able to install properly you need to have nvcc working (and ideally of the same cuda version as the torch) If cmake doesn't find torch, you should set TorchDir to the repository containing "TorchConfig.cmake" locate TorchConfig.cmake export Torch_DIR=folderContainingTorchConfig.cmake then cmake .. should work

Some users are encountering some build errors with recent g++ version. using g++-10 was suggested. (See #1 )

(My current versions, where it compiles and run fine are :

g++ --version
g++ (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0
import torch as th
>>> th.__version__
'2.2.1+cu121'

)

Updating

Updating is done by

git pull
cd build
(optional) cmake .. 
make

If the python library was installed in editable mode, it should then be ok without needing to rerun pip install -e

The soft-convention for the "version" number in python library I'll try to follow is it that it won't be bumped up everytime there is a new "transparent" optimization that compute the same thing faster while giving the same results, but will be when there are significant changes, like order of dimensions.

USAGE

Once install is done successfully should be smooth. from FusedFourierKAN.FusedFourierKANLayer import FusedFourierKANLayer

You can also call demo function

from FusedFourierKAN.FusedFourierKANLayer import demo
demo()

Benchmark and performance

Some user has done an independant benchmark #4

I've pushed some performance optimization since.

Official benchmark and more optimizations coming soon

About

C++ and Cuda ops for fused FourierKAN

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published