In [2]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

  _warn(("h5py is running against HDF5 {0} when it was built against {1}, "


In [7]:
class KANLayer(tf.Module):
    def __init__(self, layerSize, GInterval = 10, p = 3, seed = None, name = ""):
        super().__init__(name=name)
        assert(GInterval > p)
        self.seed = seed
        self.layerSize = layerSize
        self.G = GInterval
        self.p = p
        self.W = tf.Variable(tf.random.normal((1, self.layerSize), mean = 0.0, stddev = 1.0, seed = self.seed), name = name+"_W")
        self.Phi = None
        self.Initialized = tf.constant(0, tf.int32)
        self.BSplineFunc = [self.BSpline0, self.BSpline1, self.BSpline2, self.BSpline3][self.p]

    @tf.function
    def expand_dims_with_index_values(self, indices):
        expanded_dims = [self.inpDimN, self.inpdim, self.p+1]
        result_dims = expanded_dims + [3]
        result = tf.zeros(result_dims, dtype = np.int32)
        i = tf.constant(0)
        j = tf.constant(0)
        k = tf.constant(0)
        while(i < expanded_dims[0]):
            j = 0
            while(j < expanded_dims[1]):
                k = 0
                while(k < expanded_dims[2]):
                    index = [i, j, k]
                    temp0 = tf.tensor_scatter_nd_update(index, [[2]], [tf.gather_nd(indices, index)])
                    result = tf.tensor_scatter_nd_update(result, [index], [temp0])
                    k+=1
                j+=1
            i+=1
        return result
    
    @tf.function
    def dot_matmul(self, A, B):
        A0 = tf.squeeze(tf.slice(A, [0, 0, 0], [self.inpDimN, self.inpdim, 1]), axis = 2)
        B0 = tf.squeeze(tf.slice(B, [0, 0, 0], [self.inpdim, self.layerSize, 1]), axis = 2)
        Y = tf.matmul(A0, B0)
        i = tf.constant(1)
        while(i < self.G):
            A0 = tf.squeeze(tf.slice(A, [0, 0, i], [self.inpDimN, self.inpdim, 1]), axis = 2)
            B0 = tf.squeeze(tf.slice(B, [0, 0, i], [self.inpdim, self.layerSize, 1]), axis = 2)
            Y = Y + tf.matmul(A0, B0)
            i+=1
        return Y
    
    @tf.function
    def BSpline(self, T, Ti, p):
        Ginter = 1/(self.G-self.p)
        if(p <= 0):
            temp1 = tf.cast(tf.where(((Ti <= T) & (T < (Ti + Ginter))), 1, 0), tf.float32)
            return temp1
        else:
            temp0 = tf.cast(p, tf.float32) * Ginter
            diff0 = (T - Ti) / temp0
            diff1 = (Ti + temp0 + Ginter - T)/temp0
            temp2 = ( diff0 * self.BSpline(T, Ti, p-1) ) + ( diff1 * self.BSpline(T, Ti+Ginter, p-1) )
            
            return temp2
            
    @tf.function
    def BSpline0(self, T, Ti):
        Ginter = 1/(self.G-self.p)
        temp1 = tf.cast(tf.where(((Ti <= T) & (T < (Ti + Ginter))), 1, 0), tf.float32)
        return temp1

    @tf.function
    def BSpline1(self, T, Ti):
        Ginter = 1/(self.G-self.p)
        temp0 = 1 * Ginter
        diff0 = (T - Ti) / temp0
        diff1 = (Ti + temp0 + Ginter - T)/temp0
        temp2 = ( diff0 * self.BSpline0(T, Ti) ) + ( diff1 * self.BSpline0(T, Ti+Ginter) )
        return temp2

    @tf.function
    def BSpline2(self, T, Ti):
        Ginter = 1/(self.G-self.p)
        temp0 = 2 * Ginter
        diff0 = (T - Ti) / temp0
        diff1 = (Ti + temp0 + Ginter - T)/temp0
        temp2 = ( diff0 * self.BSpline1(T, Ti) ) + ( diff1 * self.BSpline1(T, Ti+Ginter) )
        return temp2
        
    @tf.function
    def BSpline3(self, T, Ti):
        Ginter = 1/(self.G-self.p)
        temp0 = 3 * Ginter
        diff0 = (T - Ti) / temp0
        diff1 = (Ti + temp0 + Ginter - T)/temp0
        temp2 = ( diff0 * self.BSpline2(T, Ti) ) + ( diff1 * self.BSpline2(T, Ti+Ginter) )
        return temp2
        
    
    @tf.function
    def Spline(self, X):
        Gdelta = self.G-self.p
        Ginter = 1/(self.G-self.p)

        I = tf.floor(X * Gdelta)
        I = tf.maximum(I, 0)
        I = tf.minimum(I, self.G-self.p-1)
        I = I + self.p
        Ti = (I - self.p) / Gdelta
        I = tf.cast(I, tf.int32)

        expanded_dim = [self.inpDimN, self.inpdim, self.p+1]
        
        I_indices = tf.transpose(tf.zeros(expanded_dim, dtype = tf.int32))
        Y = tf.transpose(tf.zeros(expanded_dim, dtype = tf.float32))
        Ti = Ti + Ginter
        
        i = tf.constant(0)
        while(i <= self.p):
            Ti = Ti - Ginter
            I_indices = tf.tensor_scatter_nd_update(I_indices, [[i]], [tf.transpose(I-i)] )
            Y = tf.tensor_scatter_nd_update(Y, [[i]], [ tf.transpose(self.BSplineFunc(X, Ti) )] )
            i += 1
            
        Y = tf.transpose(Y)
        I_indices = tf.transpose(I_indices)
        I_indices = self.expand_dims_with_index_values(I_indices)

        expanded_dim[-1] = self.G
        Res = tf.zeros(expanded_dim, dtype = tf.float32)
        
        Res = tf.tensor_scatter_nd_update(Res, I_indices, Y)
        Ans = self.dot_matmul(Res, self.Phi)
        
        return Ans
    
    def __call__(self, X):
        self.inpDimN = X.shape[0]
        if(self.Initialized == 0):
            #inpdim = tf.shape(X).eval(session=tf.compat.v1.Session())[-1]
            self.inpdim = X.shape[-1]
            self.Initialized = 1
            stddev = tf.math.sqrt(2 / (self.inpdim + self.inpDimN))
            self.Phi = tf.Variable(tf.random.normal([self.inpdim, self.layerSize, self.G], 
                                                    mean = 0.0, stddev = stddev, seed = self.seed), name = self.name+"_Phi")
            
        spline = self.Spline(X)
        silu = X / (1 + tf.exp(-X))
        silu_exp = tf.matmul(silu, tf.ones([self.inpdim, self.layerSize], dtype = tf.float32))
        W = tf.matmul(tf.ones([X.shape[0], 1], dtype = tf.float32), self.W)
        return W * (silu_exp + spline)
        

In [8]:
kan0 = KANLayer(1, GInterval = 20, p = 3, seed = 100, name = "KAN_test")

In [None]:
Kan0 = KANLayer(20, GInterval = 30, p = 3, seed = 100, name = "KAN_test")
#Kan1 = KANLayer(1, GInterval = 30, p = 2, seed = 100, name = "KAN_test")
#Kan2 = KANLayer(1, GInterval = 40, p = 3, seed = 100, name = "KAN_test")

#tf.compat.v1.enable_eager_execution()
N = 10000
X = tf.constant(tf.random.uniform([N, 20], minval = 0, maxval = 1))
X = tf.sort(X, axis = 0)

time0 = tf.timestamp()
with tf.GradientTape() as tape:
    #Y = Kan2(Kan1(Kan0(X)))
    Y = Kan0(X)
print((tf.timestamp() - time0) / N)

[dY_dK0, dY_dK1, dY_dK2] = tape.gradient(Y, [Kan0.trainable_variables, Kan1.trainable_variables, Kan2.trainable_variables])

In [9]:
def FuncCall(N):
    X = tf.random.uniform([N, 1], minval = -0.01, maxval = 1.02)
    X = tf.sort(X, axis = 0)
    
    time0 = tf.timestamp()
    with tf.GradientTape() as tape:
        Y = kan0(X)
    timedelt = tf.timestamp() - time0
    
    dY_dK = tape.gradient(Y, kan0.trainable_variables)
    return timedelt

In [10]:
N = [0, 10, 100, 200, 300, 500, 1000, 10000]
Ny = [FuncCall(n) for n in N]

ValueError: in user code:

    File "C:\Users\prath\AppData\Local\Temp\ipykernel_27460\1465817995.py", line 117, in Spline  *
        I_indices = tf.tensor_scatter_nd_update(I_indices, [[i]], [tf.transpose(I-i)] )

    ValueError: Indices and updates specified for empty input for '{{node while/TensorScatterUpdate}} = TensorScatterUpdate[T=DT_INT32, Tindices=DT_INT32](while/Placeholder, while/TensorScatterUpdate/indices, while/TensorScatterUpdate/updates)' with input shapes: [4,1,0], [1,1], [1,1,0].


In [None]:
plt.plot(N, Ny)