In [85]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [86]:
class KANLayer2(tf.Module):
    def __init__(self, inpDim, outDim, GInterval, p, name = "", seed = 100):
        super().__init__(name=name)
        self.inpDim = tf.constant(inpDim)
        self.outDim = tf.constant(outDim)
        self.G = tf.constant(GInterval)
        self.p = tf.constant(p)
        self.seed = seed
        stddev = tf.math.sqrt(2 / (inpDim + outDim))
        self.W = tf.Variable(tf.random.normal([1, self.outDim], mean = 0.0, stddev = stddev, seed = self.seed), name = name+"_W")
        self.Phi = tf.Variable(tf.random.normal([self.inpDim, self.outDim, self.G], 
                               mean = 0.0, stddev = stddev, seed = self.seed), name = self.name+"_Phi")
        self.BSplineFunc = [self.BSpline0, self.BSpline1, self.BSpline2, self.BSpline3][self.p]

    @tf.function
    def dot_matmul(self, A, B):
        A0 = tf.squeeze(tf.slice(A, [0, 0, 0], [self.inpDimN, self.inpDim, 1]), axis = 2)
        B0 = tf.squeeze(tf.slice(B, [0, 0, 0], [self.inpDim, self.outDim, 1]), axis = 2)
        Y = tf.matmul(A0, B0)
        i = tf.constant(1)
        while(i < self.G):
            A0 = tf.squeeze(tf.slice(A, [0, 0, i], [self.inpDimN, self.inpDim, 1]), axis = 2)
            B0 = tf.squeeze(tf.slice(B, [0, 0, i], [self.inpDim, self.outDim, 1]), axis = 2)
            Y = Y + tf.matmul(A0, B0)
            i+=1
        return Y
    
    @tf.function
    def BSpline0(self, T, Ti):
        Gdelta = tf.cast(self.G-self.p, dtype = tf.float32)
        Ginter = 1/Gdelta
        temp1 = tf.cast(tf.where(((Ti <= T) & (T < (Ti + Ginter))), 1, 0), tf.float32)
        return temp1

    @tf.function
    def BSpline1(self, T, Ti):
        Gdelta = tf.cast(self.G-self.p, dtype = tf.float32)
        Ginter = 1/Gdelta
        temp0 = 1 * Ginter
        diff0 = (T - Ti) / temp0
        diff1 = (Ti + temp0 + Ginter - T)/temp0
        temp2 = ( diff0 * self.BSpline0(T, Ti) ) + ( diff1 * self.BSpline0(T, Ti+Ginter) )
        return temp2

    @tf.function
    def BSpline2(self, T, Ti):
        Gdelta = tf.cast(self.G-self.p, dtype = tf.float32)
        Ginter = 1/Gdelta
        temp0 = 2 * Ginter
        diff0 = (T - Ti) / temp0
        diff1 = (Ti + temp0 + Ginter - T)/temp0
        temp2 = ( diff0 * self.BSpline1(T, Ti) ) + ( diff1 * self.BSpline1(T, Ti+Ginter) )
        return temp2
        
    @tf.function
    def BSpline3(self, T, Ti):
        Gdelta = tf.cast(self.G-self.p, dtype = tf.float32)
        Ginter = 1/Gdelta
        temp0 = 3 * Ginter
        diff0 = (T - Ti) / temp0
        diff1 = (Ti + temp0 + Ginter - T)/temp0
        temp2 = ( diff0 * self.BSpline2(T, Ti) ) + ( diff1 * self.BSpline2(T, Ti+Ginter) )
        return temp2
    
    @tf.function(input_signature=(tf.TensorSpec(shape=[None, None], dtype=tf.float32),))
    def Spline(self, X):
        Gdelta = tf.cast(self.G-self.p, dtype = tf.float32)
        Ginter = 1/Gdelta

        I = tf.floor(X * Gdelta)
        I = tf.maximum(I, 0)
        I = tf.minimum(I, Gdelta-1)
        #I = I + self.p
        Ti = I / Gdelta
        I = tf.cast(I, tf.int32) + self.p

        Res = tf.zeros(tf.concat([[self.inpDimN], [self.inpDim], [self.G]], axis = 0), dtype=tf.float32)
        Ti = Ti + Ginter
        i = tf.constant(0)
        J = tf.expand_dims(tf.repeat(tf.expand_dims(tf.range(self.inpDimN), axis = 1), self.inpDim, axis = 1), axis = 2)
        K = tf.expand_dims(tf.repeat(tf.expand_dims(tf.range(self.inpDim), axis = 0), self.inpDimN, axis = 0), axis = 2)
        I = tf.expand_dims(I, axis = 2)
        while(i <= self.p):
            Ti = Ti - Ginter
            Y = self.BSplineFunc(X, Ti)
            temp0 = tf.concat([J, K, (I-i)], axis = 2)
            Res = tf.tensor_scatter_nd_update(Res, temp0, Y)
            i+=1
        Ans = self.dot_matmul(Res, self.Phi)
        return Ans

    @tf.function(input_signature=(tf.TensorSpec(shape=[None, None], dtype=tf.float32),))
    def __call__(self, X):
        self.inpDimN = tf.gather_nd(tf.shape(X), [0])
        spline = self.Spline(X)
        silu = X / (1 + tf.exp(-X))
        silu_exp = tf.matmul(silu, tf.ones([self.inpDim, self.outDim], dtype = tf.float32))
        W = tf.matmul(tf.ones([self.inpDimN, 1], dtype = tf.float32), self.W)
        return W * (silu_exp + spline)
        
        

In [128]:
kan0 = KANLayer2(764, 20, GInterval = 2, p = 1, name = "layer0", seed = 100)

In [131]:
N = 10000
X = tf.random.uniform([N, 764], minval = -0.01, maxval = 1.02)
X = tf.sort(X, axis = 0)

time0 = tf.timestamp()
with tf.GradientTape() as tape:
    Y = kan0(X)
timedelt = tf.timestamp() - time0

dY_dK = tape.gradient(Y, kan0.trainable_variables)

In [132]:
print(timedelt)

tf.Tensor(0.8690340518951416, shape=(), dtype=float64)
