In [1]:
import tensorflow as tf
from tfkan.layers import Conv2DKAN, DenseKAN
from keras.layers import GlobalAveragePooling2D

import numpy as np
from matplotlib import pyplot as plt

In [2]:
# load fashion-mnist dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
# normalize data
x_train = np.expand_dims(x_train / 255.0, axis=-1).astype(np.float32)
x_test = np.expand_dims(x_test / 255.0, axis=-1).astype(np.float32)

#### To call `update_grid_from_samples()` in user-define training logic

In [3]:
# KAN
kan = tf.keras.models.Sequential([
    Conv2DKAN(filters=8, kernel_size=5, strides=2, padding='valid', kan_kwargs={'grid_size': 3}),
    tf.keras.layers.LayerNormalization(),
    Conv2DKAN(filters=16, kernel_size=5, strides=2, padding='valid', kan_kwargs={'grid_size': 3}),
    GlobalAveragePooling2D(),
    DenseKAN(10, grid_size=3),
    tf.keras.layers.Softmax()
])
kan.build(input_shape=(None, 28, 28, 1))
kan.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2dkan (Conv2DKAN)       (None, 12, 12, 8)         1658      
                                                                 
 layer_normalization (Layer  (None, 12, 12, 8)         16        
 Normalization)                                                  
                                                                 
 conv2dkan_1 (Conv2DKAN)     (None, 4, 4, 16)          24416     
                                                                 
 global_average_pooling2d (  (None, 16)                0         
 GlobalAveragePooling2D)                                         
                                                                 
 dense_kan (DenseKAN)        (None, 10)                1290      
                                                                 
 softmax (Softmax)           (None, 10)                0

In [4]:
def train_kan(
    model,
    x_train,
    y_train,
    x_valid=None,
    y_valid=None,
    epochs: int=5,
    learning_rate: float=1e-3,
    batch_size: int=128,
    verbose: int=1
):  
    # build optimizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    # build dataset
    train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    train_set = train_set.batch(batch_size)
    if x_valid is not None and y_valid is not None:
        valid_set = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
        valid_set = valid_set.batch(batch_size)
    else:
        valid_set = None

    # define loss function
    loss_func = tf.keras.losses.SparseCategoricalCrossentropy()

    # define metrics
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

    step = 0
    # training loop
    for epoch in range(epochs):
        # reset metrics
        train_loss.reset_states()
        train_accuracy.reset_states()

        for x_batch, y_batch in train_set:
            with tf.GradientTape() as tape:
                y_pred = model(x_batch, training=True)
                loss = loss_func(y_batch, y_pred)
                loss = tf.reduce_mean(loss)
            # update weights
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            
            train_loss(loss)
            train_accuracy(y_batch, y_pred)
            step += 1

            if verbose > 0 and step % verbose == 0:
                # clear the output and print the updated metrics
                print(f"[EPCOH: {epoch+1:3d} / {epochs:3d}, STEP: {step:6d}]: \
train_loss: {train_loss.result():.4f}, train_accuracy: {train_accuracy.result():.4f}", end='\r')
        
        # callback after each epoch
        # call update_grid_from_samples method
        for layer in model.layers:
            if hasattr(layer, 'update_grid_from_samples'):
                layer.update_grid_from_samples(x_batch)
            x_batch = layer(x_batch)

        # eval on validation set
        if valid_set:
            valid_loss = tf.keras.metrics.Mean(name='valid_loss')
            valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy')
            for x_batch, y_batch in valid_set:
                y_pred = model(x_batch, training=False)
                loss = tf.reduce_mean(loss_func(y_batch, y_pred))
                valid_loss(loss)
                valid_accuracy(y_batch, y_pred)
            print(f"[EPCOH: {epoch+1:3d} / {epochs:3d}, STEP: {step:6d}]: \
train_loss: {train_loss.result():.4f}, train_accuracy: {train_accuracy.result():.4f}, \
valid_loss: {valid_loss.result():.4f}, valid_accuracy: {valid_accuracy.result():.4f}")
        else:
            print()
    
    return model

In [5]:
kan = train_kan(kan, x_train, y_train, x_test, y_test, epochs=5, learning_rate=1e-3, batch_size=128, verbose=1)

[EPCOH:   1 /   5, STEP:    469]: train_loss: 1.0368, train_accuracy: 0.6397, valid_loss: 0.7413, valid_accuracy: 0.7280
[EPCOH:   2 /   5, STEP:    938]: train_loss: 0.6492, train_accuracy: 0.7603, valid_loss: 0.6315, valid_accuracy: 0.7645
[EPCOH:   3 /   5, STEP:   1407]: train_loss: 0.5768, train_accuracy: 0.7883, valid_loss: 0.5872, valid_accuracy: 0.7841
[EPCOH:   4 /   5, STEP:   1876]: train_loss: 0.5373, train_accuracy: 0.8043, valid_loss: 0.5575, valid_accuracy: 0.7984
[EPCOH:   5 /   5, STEP:   2345]: train_loss: 0.5113, train_accuracy: 0.8146, valid_loss: 0.5335, valid_accuracy: 0.8068


#### To use `update_grid_from_samples()` in Tensorflow Callbacks

In [15]:
# KAN
kan = tf.keras.models.Sequential([
    Conv2DKAN(filters=8, kernel_size=5, strides=2, padding='valid', kan_kwargs={'grid_size': 3}),
    tf.keras.layers.LayerNormalization(),
    Conv2DKAN(filters=16, kernel_size=5, strides=2, padding='valid', kan_kwargs={'grid_size': 3}),
    GlobalAveragePooling2D(),
    DenseKAN(10, grid_size=3),
    tf.keras.layers.Softmax()
])
kan.build(input_shape=(None, 28, 28, 1))
kan.summary()

Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2dkan_8 (Conv2DKAN)     (None, 12, 12, 8)         1658      
                                                                 
 layer_normalization_3 (Lay  (None, 12, 12, 8)         16        
 erNormalization)                                                
                                                                 
 conv2dkan_9 (Conv2DKAN)     (None, 4, 4, 16)          24416     
                                                                 
 global_average_pooling2d_4  (None, 16)                0         
  (GlobalAveragePooling2D)                                       
                                                                 
 dense_kan_4 (DenseKAN)      (None, 10)                1290      
                                                                 
 softmax_4 (Softmax)         (None, 10)               

In [16]:
# define update grid callback
class UpdateGridCallback(tf.keras.callbacks.Callback):
    def __init__(self, training_data, batch_size: int = 128):
        super(UpdateGridCallback, self).__init__()
        self.training_data = tf.constant(training_data[:batch_size], tf.float32)
        
    def on_epoch_begin(self, epoch, logs=None):
        """
        update grid before new epoch begins
        """
        x_batch = self.training_data
        if epoch > 0:
            for layer in self.model.layers:
                if hasattr(layer, 'update_grid_from_samples'):
                    layer.update_grid_from_samples(x_batch)
                x_batch = layer(x_batch)
            print(f"Call update_grid_from_samples at epoch {epoch}")

In [17]:
kan.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)
# add callback to training
kan.fit(x_train, y_train, epochs=5, batch_size=128, 
        validation_data=(x_test, y_test), callbacks=[UpdateGridCallback(x_train)])

Epoch 1/5
Call update_grid_from_samples at epoch 1
Epoch 2/5
Call update_grid_from_samples at epoch 2
Epoch 3/5
Call update_grid_from_samples at epoch 3
Epoch 4/5
Call update_grid_from_samples at epoch 4
Epoch 5/5


<keras.src.callbacks.History at 0xffff437cc400>