# fKAN-UNet: Lightweight Road Segmentation with Fractional Spectral Modeling and Directional Convolutions

In [None]:
import tensorflow as tf
physical_devices=tf.compat.v1.config.experimental.list_physical_devices('GPU')
print('GPU is available' if len(physical_devices) > 0 else 'Not available')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
import numpy as np
import tensorflow as tf
import random
# Enable deterministic behavior
# import os
# os.environ['TF_DETERMINISTIC_OPS'] = '1'
# os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

np.random.seed(65)
random.seed(65)
tf.random.set_seed(65)

# Model Architecture

## Jacobi Polynomial

In [None]:
def jacobi_polynomial(x, n, alpha, beta, gamma, a, b):
    if n == 0:
        return x / (x + 1e-7)
    elif n == 1:
        return (
            alpha - beta + (alpha + beta + 2) * (2 * x**gamma - a - b) / (b - a)
        ) / 2
    elif n == 2:
        return (
            ((alpha + 1) * (alpha + 2)) / 2
            + (
                (alpha + 2)
                * (3 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1)
            )
            / 2
            + (
                (3 + alpha + beta)
                * (4 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 2
            )
            / 8
        )
    elif n == 3:
        return (
            ((alpha + 1) * (alpha + 2) * (3 + alpha)) / 6
            + (
                (alpha + 2)
                * (3 + alpha)
                * (4 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1)
            )
            / 4
            + (
                (3 + alpha)
                * (4 + alpha + beta)
                * (5 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 2
            )
            / 8
            + (
                (4 + alpha + beta)
                * (5 + alpha + beta)
                * (6 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 3
            )
            / 48
        )
    elif n == 4:
        return (
            ((alpha + 1) * (alpha + 2) * (3 + alpha) * (4 + alpha)) / 24
            + (
                (alpha + 2)
                * (3 + alpha)
                * (4 + alpha)
                * (5 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1)
            )
            / 12
            + (
                (3 + alpha)
                * (4 + alpha)
                * (5 + alpha + beta)
                * (6 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 2
            )
            / 16
            + (
                (4 + alpha)
                * (5 + alpha + beta)
                * (6 + alpha + beta)
                * (7 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 3
            )
            / 48
            + (
                (5 + alpha + beta)
                * (6 + alpha + beta)
                * (7 + alpha + beta)
                * (8 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 4
            )
            / 384
        )
    elif n == 5:
        return (
            ((alpha + 1) * (alpha + 2) * (alpha + 3) * (alpha + 4) * (alpha + 5)) / 120
            + (
                (alpha + 2)
                * (alpha + 3)
                * (alpha + 4)
                * (alpha + 5)
                * (6 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1)
            )
            / 48
            + (
                (alpha + 3)
                * (alpha + 4)
                * (alpha + 5)
                * (6 + alpha + beta)
                * (7 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 2
            )
            / 48
            + (
                (alpha + 4)
                * (alpha + 5)
                * (6 + alpha + beta)
                * (7 + alpha + beta)
                * (8 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 3
            )
            / 96
            + (
                (alpha + 5)
                * (6 + alpha + beta)
                * (7 + alpha + beta)
                * (8 + alpha + beta)
                * (9 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 4
            )
            / 384
            + (
                (6 + alpha + beta)
                * (7 + alpha + beta)
                * (8 + alpha + beta)
                * (9 + alpha + beta)
                * (10 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 5
            )
            / 3840
        )
    elif n == 6:
        return (
            (
                (alpha + 1)
                * (alpha + 2)
                * (alpha + 3)
                * (alpha + 4)
                * (alpha + 5)
                * (6 + alpha)
            )
            / 720
            + (
                (alpha + 2)
                * (alpha + 3)
                * (alpha + 4)
                * (alpha + 5)
                * (6 + alpha)
                * (7 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1)
            )
            / 240
            + (
                (alpha + 3)
                * (alpha + 4)
                * (alpha + 5)
                * (6 + alpha)
                * (7 + alpha + beta)
                * (8 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 2
            )
            / 192
            + (
                (alpha + 4)
                * (alpha + 5)
                * (6 + alpha)
                * (7 + alpha + beta)
                * (8 + alpha + beta)
                * (9 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 3
            )
            / 288
            + (
                (alpha + 5)
                * (6 + alpha)
                * (7 + alpha + beta)
                * (8 + alpha + beta)
                * (9 + alpha + beta)
                * (10 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 4
            )
            / 768
            + (
                (6 + alpha)
                * (7 + alpha + beta)
                * (8 + alpha + beta)
                * (9 + alpha + beta)
                * (10 + alpha + beta)
                * (11 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 5
            )
            / 3840
            + (
                (7 + alpha + beta)
                * (8 + alpha + beta)
                * (9 + alpha + beta)
                * (10 + alpha + beta)
                * (11 + alpha + beta)
                * (12 + alpha + beta)
                * ((2 * x**gamma - a - b) / (b - a) - 1) ** 6
            )
            / 46080
        )
    elif n > 6:
        raise ValueError(
            f"The current implementation supports a maximum degree of 6, but you entered {n}. Higher degrees may lead to numerical instabilities, overfitting, and increased computational complexity. Please consider using a lower degree."
        )
    elif n < 0:
        raise ValueError(
            "Degrees must be non-negative. Negative degrees are not allowed."
        )

## fJNB

In [None]:
import tensorflow as tf


class FractionalJacobiNeuralBlock(tf.keras.layers.Layer):
    """
    Fractional Jacobi Neural Block layer for TensorFlow.

    This layer computes a custom transformation using the Jacobi polynomial.

    Attributes:
        degree (int): Degree of the Jacobi polynomial.
    """

    def __init__(self, degree, **kwargs):
        """
        Initialize the Fractional Jacobi Neural Block.

        Args:
            degree (int): Degree of the Jacobi polynomial.
            **kwargs: Additional keyword arguments for the parent class.
        """
        super(FractionalJacobiNeuralBlock, self).__init__(**kwargs)
        self.degree = degree
        self.jacobi_polynomial = tf.function(jacobi_polynomial)

    def build(self, input_shape):
        """
        Create the weights of the layer.

        Args:
            input_shape (TensorShape): Shape of the input tensor.
        """
        self.alpha = self.add_weight(
            name="alpha",
            initializer="ones",
            trainable=True,
            shape=(1,)
        )
        self.beta = self.add_weight(
            name="beta",
            initializer="ones",
            trainable=True,
            shape=(1,)
        )
        self.zeta = self.add_weight(
            name="gamma",
            initializer="zeros",
            trainable=True,
            shape=(1,)
        )
        super(FractionalJacobiNeuralBlock, self).build(input_shape)

    def call(self, inputs):
        """
        Forward pass of the layer.

        Args:
            inputs (Tensor): Input tensor.

        Returns:
            Tensor: Output tensor after applying the Jacobi polynomial transformation.
        """
        normalized_alpha = tf.keras.activations.elu(self.alpha, 1)
        normalized_beta = tf.keras.activations.elu(self.beta, 1)
        normalized_zeta = tf.keras.activations.sigmoid(self.zeta)
        inputs = tf.keras.activations.sigmoid(inputs)

        return self.jacobi_polynomial(
            inputs, self.degree, normalized_alpha, normalized_beta, normalized_zeta, 0, 1
        )
    
    def get_config(self):
        """
        Return the config of the layer to enable serialization.
        """
        config = super(FractionalJacobiNeuralBlock, self).get_config()
        config.update({
            "degree": self.degree,
        })
        return config

## Encode Decoder Bottleneck

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import (Input, Conv2D, Conv2DTranspose, MaxPooling2D, 
                                     Concatenate, Activation, GlobalMaxPooling2D, Reshape, Multiply, Add)
from tensorflow.keras.models import Model
from tensorflow.keras.utils import register_keras_serializable
import tensorflow.keras.layers as KL


# ------------------- SE-Net with DenseKAN and FJNB -------------------
def se_net(x, r=2):
    input_channel = x.shape[-1]
    squeeze = GlobalMaxPooling2D()(x)
    squeeze = Reshape((1, 1, input_channel))(squeeze)
    excitation = KL.Dense(input_channel // r)(squeeze)  
    excitation = Activation('relu')(excitation)
    excitation=FractionalJacobiNeuralBlock(2)(excitation)
    excitation = KL.Dense(input_channel)(excitation)
    excitation = Activation('sigmoid')(excitation)
    excitation=FractionalJacobiNeuralBlock(3)(excitation)
    scaled = Multiply()([x, excitation])
    return scaled

# ------------------- Feature Selective Fusion Block -------------------
def fsf_block(low_x, high_x):
    merged = Concatenate()([high_x, low_x])
    attention = se_net(merged)
    fused = Conv2D(int(attention.shape[-1] / 2), kernel_size=(1, 1), strides=1, padding='same')(attention)

    gate = GlobalMaxPooling2D()(fused)
    gate = Reshape((1, 1, gate.shape[-1]))(gate)
    gate = Activation('sigmoid')(gate)

    gated_low = Multiply()([low_x, gate])
    output = Add()([gated_low, high_x])
    return output


def directional_strip_conv(x, filters, kernel_size=11):
    # Horizontal (1, k)
    h = Conv2D(filters, (1, kernel_size), padding='same', activation='relu')(x)

    # Vertical (k, 1)
    v = Conv2D(filters, (kernel_size, 1), padding='same', activation='relu')(x)
   
    merged = Concatenate()([h, v])
    merged = Conv2D(filters, (1, 1), padding='same', activation='relu')(merged)  # Fuse
    return merged


def encoder_block(x, filters):
    x = Conv2D(filters, (3, 3), activation='relu', padding='same')(x)
    x = directional_strip_conv(x, filters)  # Directional conv instead of second Conv2D
    p = MaxPooling2D((2, 2))(x)
    return x, p


def decoder_block(x, skip, filters):
    x = Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same')(x)
    x = fsf_block(skip, x)
    x = Conv2D(filters, (3, 3), activation='relu', padding='same')(x)
    x = directional_strip_conv(x, filters)  # Directional conv instead of second Conv2D
    return x


# ------------------- Full U-Net Model -------------------
def build_fkan_unet(input_shape=(512, 512, 3), num_classes=1, filters=16):
    inputs = Input(shape=input_shape)

    # Encoder
    s1, p1 = encoder_block(inputs, filters)
    s2, p2 = encoder_block(p1, filters * 2)
    s3, p3 = encoder_block(p2, filters * 4)
    s4, p4 = encoder_block(p3, filters * 8)

    # Bottleneck
    b1 = Conv2D(filters * 16, (3, 3), activation='relu', padding='same')(p4)
    b1 = Conv2D(filters * 16, (3, 3), activation='relu', padding='same')(b1)

    # Decoder
    d1 = decoder_block(b1, s4, filters * 8)
    d2 = decoder_block(d1, s3, filters * 4)
    d3 = decoder_block(d2, s2, filters * 2)
    d4 = decoder_block(d3, s1, filters)

    # Output
    activation = 'sigmoid' if num_classes == 1 else 'softmax'
    outputs = Conv2D(num_classes, (1, 1), padding='same', activation=activation)(d4)

    return Model(inputs, outputs)

# ------------------- Example -------------------
model = build_fkan_unet(input_shape=(512, 512, 3), num_classes=1)
model.summary()

In [None]:
import os
import numpy as np
import cv2
from matplotlib import pyplot as plt
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf

# Define image dimensions
desired_width = 512  # Replace with your desired width
desired_height = 512  # Replace with your desired height
batch_size = 32
seed = 24

# Create an ImageDataGenerator for images and masks without augmentation
img_data_gen_args = dict(rescale=1/255.)

mask_data_gen_args = dict(rescale=1/255., )  # Binarize the output

# Generators for training, validation, and test data
image_data_generator = ImageDataGenerator(**img_data_gen_args)
mask_data_generator = ImageDataGenerator(**mask_data_gen_args)

In [None]:
def create_generator(image_dir, mask_dir, batch_size, target_size=(desired_width, desired_height)):
    image_generator = image_data_generator.flow_from_directory(
        image_dir,
        target_size=target_size,
        color_mode='rgb',
        class_mode=None,
        batch_size=batch_size,
        seed=seed)

    mask_generator = mask_data_generator.flow_from_directory(
        mask_dir,
        target_size=target_size,
        color_mode='grayscale',
        class_mode=None,
        batch_size=batch_size,
        seed=seed)

    return zip(image_generator, mask_generator)

train_generator = create_generator('/home/jayakumar/road-extraction-main/data3/train_images/', '/home/jayakumar/road-extraction-main/data3/train_masks/', batch_size)
val_generator = create_generator('/home/jayakumar/road-extraction-main/data3/val_images/', '/home/jayakumar/road-extraction-main/data3/val_masks/', batch_size)
test_generator = create_generator('/home/jayakumar/road-extraction-main/data3/test_images/', '/home/jayakumar/road-extraction-main/data3/test_masks/', batch_size)

# Model Related

In [None]:
import tensorflow as tf
checkpoint = tf.keras.callbacks.ModelCheckpoint('patchmodels/fkanunetadam32d24f6f6rsmitpatchba.h5', verbose=1, save_best_only=True)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

In [None]:
from keras.callbacks import LearningRateScheduler
def lr_scheduler(epoch,lr):
    decay_rate=1e-6
    return lr-decay_rate
lr_callback=LearningRateScheduler(lr_scheduler)

In [None]:
def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return 1 - (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

In [None]:
bce_loss = tf.keras.losses.BinaryCrossentropy()

In [None]:
#import tensorflow_addons as tfa  # if available

def boundary_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    # Gradient-based edge approximation
    dy_true, dx_true = tf.image.image_gradients(y_true)
    dy_pred, dx_pred = tf.image.image_gradients(y_pred)

    grad_true = tf.sqrt(tf.square(dx_true) + tf.square(dy_true) + 1e-7)
    grad_pred = tf.sqrt(tf.square(dx_pred) + tf.square(dy_pred) + 1e-7)

    # L1 distance between edge gradients
    return tf.reduce_mean(tf.abs(grad_true - grad_pred))


In [None]:
def combined_loss(y_true, y_pred, alpha=0.8, beta=0.1, gamma=0.1):
    dice = dice_loss(y_true, y_pred)
    bce = bce_loss(y_true, y_pred)
    boundary = boundary_loss(y_true, y_pred)
    return alpha * dice + beta * bce + gamma * boundary

In [None]:
import segmentation_models as sm
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import Precision, Recall
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9), loss=combined_loss, metrics=[sm.metrics.iou_score, sm.metrics.f1_score,sm.metrics.precision,sm.metrics.recall]) 

# Training

In [None]:
num_train_imgs = len(os.listdir('/home/jayakumar/road-extraction-main/data3/train_images/train/'))
steps_per_epoch = num_train_imgs // batch_size
num_val_imgs = len(os.listdir('/home/jayakumar/road-extraction-main/data3/val_images/val/'))
validation_steps = num_val_imgs // batch_size
epochs = 5
with tf.device("/GPU:0"):
      history = model.fit(train_generator,validation_data=val_generator, steps_per_epoch=steps_per_epoch,validation_steps=validation_steps,epochs=epochs,callbacks=[checkpoint,early_stopping,lr_callback])

In [None]:
import os
import pickle
import matplotlib.pyplot as plt

# Define folder and filename
folder_path = 'patchhistory'
filename = 'fkanunetadam32d24f6f6rsmitpatchba.pkl'  #Change filename everytime

# Ensure the folder exists
os.makedirs(folder_path, exist_ok=True)

# Construct the full file path
file_path = os.path.join(folder_path, filename)

# Save the training history
with open(file_path, 'wb') as f:
    pickle.dump(history.history, f)

# Later, load and plot the history
with open(file_path, 'rb') as f:
    loaded_history = pickle.load(f)

# Plotting
plt.plot(loaded_history['loss'], label='Training Loss')
plt.plot(loaded_history['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()


# Testing

In [None]:
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.models import load_model
import segmentation_models as sm
from tensorflow.keras.metrics import Precision, Recall
# load the saved model due to prior interuption
model = load_model('patchmodels/fkanunetadam32d24f6f6rsmitpatchba.h5', custom_objects={'FractionalJacobiNeuralBlock': FractionalJacobiNeuralBlock},compile=False)
#model = load_model('patchmodels/fkanunetadam32d24f6f6rsmitpatchba.h5',compile=False)
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9), loss=combined_loss, metrics=[sm.metrics.iou_score, sm.metrics.f1_score,sm.metrics.precision,sm.metrics.recall]) 

In [None]:
import os

# Calculate the number of test images
num_test_images = len(os.listdir('/home/jayakumar/road-extraction-main/data3/test_images/test'))  # Update 'test' with the actual folder name inside 'test_images'

# Calculate the number of steps
steps = num_test_images // batch_size

# Evaluate the model using the test generator
eval = model.evaluate(test_generator, steps=steps+1)

# Print the IoU score (or other metrics based on your model setup)
print('Test IoU score: {:.2f}'.format(eval[1]))


# Visualization

In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os

# Calculate the number of test images
num_test_images = len(os.listdir('/home/jayakumar/road-extraction-main/data3/test_images/test'))  # Update 'test' with the actual folder name inside 'test_images'

# Calculate the number of steps
steps = num_test_images // batch_size

# Get a batch of test images and masks using the test generator
test_images, test_masks = next(test_generator)
print(len(test_images))

# Select 10 random images for visualization
random_indices = random.sample(range(0, len(test_images)), 8)
test_sample = test_images[random_indices]
ground_truth_sample = test_masks[random_indices]

# Predict masks for the randomly selected images
predictions = model.predict(test_sample)
predictions = (predictions > 0.5).astype(np.uint8)

# Set up a figure with 10 rows and 3 columns for the plots
fig, axes = plt.subplots(8, 3, figsize=(8, 3*8))

# Iterate over the random samples and display them
for i in range(len(test_sample)):

    image = (test_sample[i] * 255).astype(np.uint8)  # Rescale image to 0-255
    mask = predictions[i]  # Predicted binary mask
    ground_truth = ground_truth_sample[i]  # Ground truth binary mask

    # Prepare overlay for predicted mask
    overlay = image.copy()
    mask = np.repeat(mask, 3, axis=2)  # Convert binary mask to 3 channels
    inverted_mask = 1 - mask
    yellow_mask = np.array([255, 255, 0]) * mask  # Use yellow color for mask

    # Apply the mask to the image
    result = image * inverted_mask + yellow_mask
    alpha = 0.2
    predicted_overlay = cv2.addWeighted(overlay, alpha, result.astype(overlay.dtype), 1 - alpha, 0)

    # Plot the image, ground truth, and predicted mask
    axes[i, 0].imshow(image)
    axes[i, 0].set_title('Original')
    axes[i, 0].axis('off')

    axes[i, 1].imshow(ground_truth[:, :, 0], cmap='gray')
    axes[i, 1].set_title('Ground Truth')
    axes[i, 1].axis('off')

    axes[i, 2].imshow(predicted_overlay)
    axes[i, 2].set_title('Predicted')
    axes[i, 2].axis('off')

# Adjust the spacing between subplots
plt.tight_layout()
#plt.savefig('result.png', bbox_inches='tight')  # Save the figure as a PNG image

# Show the plot
plt.show()
