In [29]:
import tensorflow as tf
from tfkan.layers import DenseKAN
# from keras import GlobalAveragePooling2D
# print(tf.keras.layers.GlobalAveragePooling2D)

import numpy as np
from sklearn.datasets import make_regression
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt

In [30]:
X, y = make_regression(n_samples=1000, n_features=10, n_informative=10, n_targets=1, noise=0.1)
scaler = MinMaxScaler()
X = scaler.fit_transform(X)
y = scaler.fit_transform(y.reshape(-1, 1)).reshape(-1)
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.4)

#### To call `update_grid_from_samples()` in user-define training logic

In [31]:
# KAN
kan = tf.keras.models.Sequential([
    DenseKAN(8, grid_size=3),
    DenseKAN(4, grid_size=3),
    DenseKAN(1, grid_size=3),
])
kan.build(input_shape=(None, 10))
kan.summary()

In [17]:
def train_kan(
    model,
    x_train,
    y_train,
    x_valid=None,
    y_valid=None,
    epochs: int=5,
    learning_rate: float=1e-3,
    batch_size: int=128,
    verbose: int=1
):  
    # build optimizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    # build dataset
    train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    train_set = train_set.batch(batch_size)
    if x_valid is not None and y_valid is not None:
        valid_set = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
        valid_set = valid_set.batch(batch_size)
    else:
        valid_set = None

    # define loss function
    loss_func = tf.keras.losses.SparseCategoricalCrossentropy()

    # define metrics
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

    step = 0
    # training loop
    for epoch in range(epochs):
        # reset metrics
        train_loss.reset_states()
        train_accuracy.reset_states()

        for x_batch, y_batch in train_set:
            with tf.GradientTape() as tape:
                y_pred = model(x_batch, training=True)
                loss = loss_func(y_batch, y_pred)
                loss = tf.reduce_mean(loss)
            # update weights
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            
            train_loss(loss)
            train_accuracy(y_batch, y_pred)
            step += 1

            if verbose > 0 and step % verbose == 0:
                # clear the output and print the updated metrics
                print(f"[EPOCH: {epoch+1:3d} / {epochs:3d}, STEP: {step:6d}]: \
train_loss: {train_loss.result():.4f}, train_accuracy: {train_accuracy.result():.4f}", end='\r')
        
        # callback after each epoch
        # call update_grid_from_samples method
        for layer in model.layers:
            if hasattr(layer, 'update_grid_from_samples'):
                layer.update_grid_from_samples(x_batch)
            x_batch = layer(x_batch)

        # eval on validation set
        if valid_set:
            valid_loss = tf.keras.metrics.Mean(name='valid_loss')
            valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy')
            for x_batch, y_batch in valid_set:
                y_pred = model(x_batch, training=False)
                loss = tf.reduce_mean(loss_func(y_batch, y_pred))
                valid_loss(loss)
                valid_accuracy(y_batch, y_pred)
            print(f"[EPOCH: {epoch+1:3d} / {epochs:3d}, STEP: {step:6d}]: \
train_loss: {train_loss.result():.4f}, train_accuracy: {train_accuracy.result():.4f}, \
valid_loss: {valid_loss.result():.4f}, valid_accuracy: {valid_accuracy.result():.4f}")
        else:
            print()
    
    return model

In [18]:
kan = train_kan(kan, x_train, y_train, x_test, y_test, epochs=5, learning_rate=1e-3, batch_size=128, verbose=1)

AttributeError: 'Mean' object has no attribute 'reset_states'

#### To use `update_grid_from_samples()` in Tensorflow Callbacks

In [15]:
# KAN
kan = tf.keras.models.Sequential([
    Conv2DKAN(filters=8, kernel_size=5, strides=2, padding='valid', kan_kwargs={'grid_size': 3}),
    tf.keras.layers.LayerNormalization(),
    Conv2DKAN(filters=16, kernel_size=5, strides=2, padding='valid', kan_kwargs={'grid_size': 3}),
    GlobalAveragePooling2D(),
    DenseKAN(10, grid_size=3),
    tf.keras.layers.Softmax()
])
kan.build(input_shape=(None, 28, 28, 1))
kan.summary()

Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2dkan_8 (Conv2DKAN)     (None, 12, 12, 8)         1658      
                                                                 
 layer_normalization_3 (Lay  (None, 12, 12, 8)         16        
 erNormalization)                                                
                                                                 
 conv2dkan_9 (Conv2DKAN)     (None, 4, 4, 16)          24416     
                                                                 
 global_average_pooling2d_4  (None, 16)                0         
  (GlobalAveragePooling2D)                                       
                                                                 
 dense_kan_4 (DenseKAN)      (None, 10)                1290      
                                                                 
 softmax_4 (Softmax)         (None, 10)               

In [32]:
# define update grid callback
class UpdateGridCallback(tf.keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        """
        update grid before new epoch begins
        """
        global x_train
        x_batch = x_train[:128]
        
        if epoch > 0:
            for layer in self.model.layers:
                if hasattr(layer, 'update_grid_from_samples'):
                    layer.update_grid_from_samples(x_batch)
                x_batch = layer(x_batch)
            print(f"Call update_grid_from_samples at epoch {epoch}")

In [27]:
kan.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)
# add callback to training
kan.fit(x_train, y_train, epochs=5, batch_size=128, 
        validation_data=(x_test, y_test), callbacks=[UpdateGridCallback()])

Epoch 1/5
[1m1/5[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m18s[0m 5s/step - accuracy: 0.0000e+00 - loss: 0.0000e+00

InvalidArgumentError: Graph execution error:

Detected at node compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits defined at (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "c:\Users\matte\documents\KAN\Lib\site-packages\ipykernel_launcher.py", line 18, in <module>

  File "c:\Users\matte\documents\KAN\Lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance

  File "c:\Users\matte\documents\KAN\Lib\site-packages\ipykernel\kernelapp.py", line 739, in start

  File "c:\Users\matte\documents\KAN\Lib\site-packages\tornado\platform\asyncio.py", line 205, in start

  File "C:\Program Files\Python312\Lib\asyncio\base_events.py", line 641, in run_forever

  File "C:\Program Files\Python312\Lib\asyncio\base_events.py", line 1987, in _run_once

  File "C:\Program Files\Python312\Lib\asyncio\events.py", line 88, in _run

  File "c:\Users\matte\documents\KAN\Lib\site-packages\ipykernel\kernelbase.py", line 545, in dispatch_queue

  File "c:\Users\matte\documents\KAN\Lib\site-packages\ipykernel\kernelbase.py", line 534, in process_one

  File "c:\Users\matte\documents\KAN\Lib\site-packages\ipykernel\kernelbase.py", line 437, in dispatch_shell

  File "c:\Users\matte\documents\KAN\Lib\site-packages\ipykernel\ipkernel.py", line 362, in execute_request

  File "c:\Users\matte\documents\KAN\Lib\site-packages\ipykernel\kernelbase.py", line 778, in execute_request

  File "c:\Users\matte\documents\KAN\Lib\site-packages\ipykernel\ipkernel.py", line 449, in do_execute

  File "c:\Users\matte\documents\KAN\Lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell

  File "c:\Users\matte\documents\KAN\Lib\site-packages\IPython\core\interactiveshell.py", line 3075, in run_cell

  File "c:\Users\matte\documents\KAN\Lib\site-packages\IPython\core\interactiveshell.py", line 3130, in _run_cell

  File "c:\Users\matte\documents\KAN\Lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner

  File "c:\Users\matte\documents\KAN\Lib\site-packages\IPython\core\interactiveshell.py", line 3334, in run_cell_async

  File "c:\Users\matte\documents\KAN\Lib\site-packages\IPython\core\interactiveshell.py", line 3517, in run_ast_nodes

  File "c:\Users\matte\documents\KAN\Lib\site-packages\IPython\core\interactiveshell.py", line 3577, in run_code

  File "C:\Users\matte\AppData\Local\Temp\ipykernel_11804\1783448292.py", line 7, in <module>

  File "c:\Users\matte\documents\KAN\Lib\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler

  File "c:\Users\matte\documents\KAN\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 318, in fit

  File "c:\Users\matte\documents\KAN\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 121, in one_step_on_iterator

  File "c:\Users\matte\documents\KAN\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 108, in one_step_on_data

  File "c:\Users\matte\documents\KAN\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 54, in train_step

  File "c:\Users\matte\documents\KAN\Lib\site-packages\keras\src\trainers\trainer.py", line 357, in _compute_loss

  File "c:\Users\matte\documents\KAN\Lib\site-packages\keras\src\trainers\trainer.py", line 325, in compute_loss

  File "c:\Users\matte\documents\KAN\Lib\site-packages\keras\src\trainers\compile_utils.py", line 609, in __call__

  File "c:\Users\matte\documents\KAN\Lib\site-packages\keras\src\trainers\compile_utils.py", line 645, in call

  File "c:\Users\matte\documents\KAN\Lib\site-packages\keras\src\losses\loss.py", line 43, in __call__

  File "c:\Users\matte\documents\KAN\Lib\site-packages\keras\src\losses\losses.py", line 27, in call

  File "c:\Users\matte\documents\KAN\Lib\site-packages\keras\src\losses\losses.py", line 1853, in sparse_categorical_crossentropy

  File "c:\Users\matte\documents\KAN\Lib\site-packages\keras\src\ops\nn.py", line 1567, in sparse_categorical_crossentropy

  File "c:\Users\matte\documents\KAN\Lib\site-packages\keras\src\backend\tensorflow\nn.py", line 645, in sparse_categorical_crossentropy

Received a label value of 1 which is outside the valid range of [0, 1).  Label values: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
	 [[{{node compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits}}]] [Op:__inference_one_step_on_iterator_12942]