## Enabling Deterministic Behaviour 

In [None]:
import numpy as np
import tensorflow as tf
import random

# Enable deterministic behavior
import os
os.environ['TF_DETERMINISTIC_OPS'] = '1'
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

np.random.seed(65)
random.seed(65)
tf.random.set_seed(65)

# GADC-KANNet

## Model

In [None]:
import tensorflow as tf
from tfkan.layers import DenseKAN, Conv2DKAN
from tensorflow.keras.layers import Conv2D, MaxPooling2D, concatenate, Conv2DTranspose, BatchNormalization, Input, Activation
from tensorflow.keras.models import Model
import tensorflow.keras.layers as KL
import numpy as np

# Dilated Residual Activation Path (DRAP) function
def DRAP(x, filters, size):
    outputs = Conv2D(kernel_size=size, filters=filters, strides=1, dilation_rate=(2, 2), padding='same', use_bias=False)(x)
    outputs = BatchNormalization()(outputs)
    x = Conv2D(kernel_size=(1, 1), filters=filters, strides=1, padding='same')(x)
    outputs = KL.add([outputs, x])
    outputs = Activation('relu')(outputs)
    return outputs

# Kolmogorov Arnold Network based Squeeze and Excitation (KAN-SE)
def kan_se(x, r=4):
    input_channel = x.shape[-1]
    outputs = KL.GlobalMaxPooling2D()(x)
    outputs = KL.Reshape((1, 1, input_channel))(outputs)
    #outputs = KL.Dense(input_channel // r)(outputs)
    outputs = DenseKAN(input_channel // r)(outputs)
    print(input_channel/r)
    #print(input_channel // r)
    outputs = Activation('relu')(outputs)
    #outputs = KL.Dense(input_channel)(outputs)
    outputs = DenseKAN(input_channel)(outputs)
    print(input_channel)
    outputs = Activation('sigmoid')(outputs)
    outputs = KL.Multiply()([x, outputs])
    return outputs

# Kolmogorov Arnold Network based Feature Selective Fusion (KAN-FSF)
def kan_fsf(low_x, high_x):
    outputs = concatenate([high_x, low_x], axis=3)
    outputs = kan_se(outputs)
    outputs = Conv2D(kernel_size=(1, 1), filters=int(outputs.shape[-1] / 2), strides=1, padding='same')(outputs)
    outputs = KL.GlobalMaxPooling2D()(outputs)
    outputs = KL.Reshape((1, 1, outputs.shape[-1]))(outputs)
    outputs = Activation('sigmoid')(outputs)
    outputs = KL.Multiply()([low_x, outputs])
    outputs = KL.add([outputs, high_x])
    return outputs

#Gradient Aware Directional Convolution (GADC)
class GADC(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size=(3, 3), strides=(1, 1), padding='same', **kwargs):
        super(GADC, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding

        # Define standard convolutions for all directions
        self.horizontal_conv = Conv2D(filters, kernel_size=(1, kernel_size[1]), strides=strides, padding=padding)
        self.vertical_conv = Conv2D(filters, kernel_size=(kernel_size[0], 1), strides=strides, padding=padding)
        self.diagonal_conv = Conv2D(filters, kernel_size=kernel_size, strides=strides, padding=padding)


    def call(self, inputs):
        sobel_x = tf.image.sobel_edges(inputs)[..., 0]
        sobel_y = tf.image.sobel_edges(inputs)[..., 1]
        gradient_magnitude = tf.sqrt(tf.square(sobel_x) + tf.square(sobel_y))
        gradient_direction = tf.atan2(sobel_y, sobel_x)
        # Horizontal and vertical convolutions
        horizontal_output = self.horizontal_conv(inputs)
        vertical_output = self.vertical_conv(inputs)

        # Rotate input for diagonal convolutions
        rotated_45 = tf.image.rot90(inputs, k=1)  # Rotate 90 degrees counter-clockwise
        diagonal_45_output = self.diagonal_conv(rotated_45)
        diagonal_45_output = tf.image.rot90(diagonal_45_output, k=3)  # Rotate back 90 degrees clockwise

        rotated_135 = tf.image.rot90(inputs, k=3)  # Rotate 90 degrees clockwise
        diagonal_135_output = self.diagonal_conv(rotated_135)
        diagonal_135_output = tf.image.rot90(diagonal_135_output, k=1)  # Rotate back 90 degrees counter-clockwise
  
      
        weights_horizontal = tf.cast(
            tf.logical_or(
                tf.logical_or(
                    tf.logical_and(gradient_direction >= -0.5, gradient_direction <= 0.5),
                    tf.logical_and(gradient_direction >= 2.64, gradient_direction <= np.pi)
                ),
                tf.logical_and(gradient_direction >= -np.pi, gradient_direction <= -2.64)
            ),
            dtype=tf.float32
        )

        # Vertical weights (orthogonal to horizontal)
        weights_vertical = 1.0 - weights_horizontal

        # Diagonal 45° weights
        weights_diagonal_45 = tf.cast(
            tf.logical_or(
                tf.logical_and(gradient_direction >= 0.2854, gradient_direction <= 1.2854),
                tf.logical_and(gradient_direction >= -2.8554, gradient_direction <= -1.8554)
            ),
            dtype=tf.float32
        )

        # Diagonal 135° weights
        weights_diagonal_135 = tf.cast(
            tf.logical_or(
                tf.logical_and(gradient_direction >= 1.8554, gradient_direction <= 2.8554),
                tf.logical_and(gradient_direction >= -1.2854, gradient_direction <= -0.2854)
            ),
            dtype=tf.float32
        )

        # Combine the directional outputs using weights
        output = (
            horizontal_output * weights_horizontal +
            vertical_output * weights_vertical +
            diagonal_45_output * weights_diagonal_45 +
            diagonal_135_output * weights_diagonal_135
        )
         # Combine outputs
        #output = horizontal_output + vertical_output + diagonal_45_output + diagonal_135_output

        return output

    def get_config(self):
        config = super(GADC, self).get_config()
        config.update({
            "filters": self.filters,
            "kernel_size": self.kernel_size,
            "strides": self.strides,
            "padding": self.padding,
        })
        return config

# U-Net model with ar_path in skip connections
def base_unet_with_snake_kan_fsf_ar_path(filters, output_channels, width=None, height=None, input_channels=1, conv_layers=3):
    def conv2d_with_snake(layer_input, filters, conv_layers=3):
        d = Conv2D(filters, kernel_size=(3, 3), strides=(1, 1), padding='same')(layer_input)
        d = BatchNormalization()(d)
        d = Activation('relu')(d)
        
        d = GADC(filters, kernel_size=(3, 3), strides=(1, 1), padding='same')(d)
        d = BatchNormalization()(d)
        d = Activation('relu')(d)
        
        for i in range(conv_layers - 2):
            d = Conv2D(filters, kernel_size=(3, 3), strides=(1, 1), padding='same')(d)
            d = BatchNormalization()(d)
            d = Activation('relu')(d)

        return d

    def deconv2d(layer_input, filters):
        u = Conv2DTranspose(filters, 2, strides=(2, 2), padding='same')(layer_input)
        u = BatchNormalization()(u)
        u = Activation('relu')(u)
        return u

    inputs = Input(shape=(width, height, input_channels))

    conv1 = conv2d_with_snake(inputs, filters, conv_layers=conv_layers)
    pool1 = MaxPooling2D((2, 2))(conv1)

    conv2 = conv2d_with_snake(pool1, filters * 2, conv_layers=conv_layers)
    pool2 = MaxPooling2D((2, 2))(conv2)

    conv3 = conv2d_with_snake(pool2, filters * 4, conv_layers=conv_layers)
    pool3 = MaxPooling2D((2, 2))(conv3)

    conv5 = conv2d_with_snake(pool3, filters * 8, conv_layers=conv_layers)

    up7 = deconv2d(conv5, filters * 4)
    conv2_ar = DRAP(conv3, filters * 4, (3, 3))  # Apply ar_path on conv2
    fsf7 = kan_fsf(conv2_ar, up7)
    conv7 = conv2d_with_snake(fsf7, filters * 4, conv_layers=conv_layers)

    up8 = deconv2d(conv7, filters *2)
    conv1_ar = DRAP(conv2, filters*2, (3, 3))  # Apply ar_path on conv1
    fsf8 = kan_fsf(conv1_ar, up8)
    conv8 = conv2d_with_snake(fsf8, filters*2, conv_layers=conv_layers)

    up9 = deconv2d(conv8, filters)
    conv0_ar = DRAP(conv1, filters, (3, 3))  # Apply ar_path on conv1
    fsf9 = kan_fsf(conv0_ar, up9)
    conv9 = conv2d_with_snake(fsf9, filters, conv_layers=conv_layers)

    outputs = Conv2D(output_channels, kernel_size=(1, 1), strides=(1, 1), activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=outputs)
    return model

# Example usage:
filters = 16  # Base number of filters
output_channels = 1  # For binary segmentation
width, height = 1024, 1024  # Example input size
model = base_unet_with_snake_kan_fsf_ar_path(filters, output_channels, width=width, height=height, input_channels=3,conv_layers=2)
model.summary()

## Importing necessary packages

In [None]:
import os
import glob
import tensorflow as tf
from keras.callbacks import LearningRateScheduler
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model
from tfkan.layers import DenseKAN
from tensorflow.keras.metrics import Precision, Recall
import segmentation_models as sm 
import random
import numpy as np
import cv2
import matplotlib.pyplot as plt

## Preprocessing

In [None]:
desired_width = 1024
desired_height = 1024

In [None]:
train_images = []

for directory_path in glob.glob("massachusetts-roads-dataset/tiff/train/"):
    for img_path in sorted(glob.glob(os.path.join(directory_path, "*.tiff"))):
        #print(img_path)
        img = cv2.imread(img_path, 1)  
        img = cv2.resize(img, (desired_width, desired_height))   
        img = img / 255.0  
        train_images.append(img)
       
train_images = np.array(train_images)
print(train_images.shape)
#plt.imshow(train_images[0])

In [None]:
train_masks = [] 
for directory_path in glob.glob("massachusetts-roads-dataset/tiff/train_labels/"):
    for mask_path in sorted(glob.glob(os.path.join(directory_path, "*.tif"))):
        #print(mask_path)
        mask = cv2.imread(mask_path, 0)     
        mask = cv2.resize(mask, (desired_width, desired_height))  
        mask = (mask > 0)
        #mask = (mask > 0).astype(np.uint8) 
        mask=mask.astype(np.float64)
        train_masks.append(mask)
          
train_masks = np.array(train_masks)
print(train_masks.shape)
#plt.imshow(train_masks[0],cmap='gray')

In [None]:
val_images = []
for directory_path in glob.glob("massachusetts-roads-dataset/tiff/val/"):
    for img_path in sorted(glob.glob(os.path.join(directory_path, "*.tiff"))):
        #print(img_path)
        img = cv2.imread(img_path, 1)     
        img = cv2.resize(img, (desired_width, desired_height))  
        img = img / 255.0   
        val_images.append(img)
              
val_images = np.array(val_images)
print(val_images.shape)
#plt.imshow(val_images[0])

In [None]:
val_masks = [] 
for directory_path in glob.glob("massachusetts-roads-dataset/tiff/val_labels/"):
    for mask_path in sorted(glob.glob(os.path.join(directory_path, "*.tif"))):
        #print(mask_path)
        mask = cv2.imread(mask_path, 0)  
        mask = cv2.resize(mask, (desired_width, desired_height))     
        mask = (mask > 0)
        mask=mask.astype(np.float64)
        val_masks.append(mask)
                
val_masks = np.array(val_masks)
print(val_masks.shape)
#plt.imshow(val_masks[0],cmap='gray')

In [None]:
test_images = []
for directory_path in glob.glob("massachusetts-roads-dataset/tiff/test/"):
    for img_path in sorted(glob.glob(os.path.join(directory_path, "*.tiff"))):
        #print(img_path)
        img = cv2.imread(img_path, 1)     
        img = cv2.resize(img, (desired_width, desired_height))   
        img = img / 255.0  
        test_images.append(img)
            
test_images = np.array(test_images)
print(test_images.shape)
#plt.imshow(test_images[0])

In [None]:
test_masks = [] 
for directory_path in glob.glob("massachusetts-roads-dataset/tiff/test_labels/"):
    for mask_path in sorted(glob.glob(os.path.join(directory_path, "*.tif"))):
        #print(mask_path)
        mask = cv2.imread(mask_path, 0) 
        mask = cv2.resize(mask, (desired_width, desired_height)) 
        mask = (mask > 0)  
        mask=mask.astype(np.float64)
        test_masks.append(mask)
                 
test_masks = np.array(test_masks)
print(test_masks.shape)
#plt.imshow(test_masks[0],cmap='gray')

## Training

In [None]:
checkpoint = tf.keras.callbacks.ModelCheckpoint('models/improvedkanunet.h5', verbose=1, save_best_only=True)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

In [None]:
def lr_scheduler(epoch,lr):
    decay_rate=1e-6
    return lr-decay_rate
lr_callback=LearningRateScheduler(lr_scheduler)

In [None]:
physical_devices=tf.compat.v1.config.experimental.list_physical_devices('GPU')
print('GPU is available' if len(physical_devices) > 0 else 'Not available')

In [None]:
model.compile(optimizer=Adam(), loss=sm.losses.dice_loss, metrics=[sm.metrics.iou_score, sm.metrics.f1_score,sm.metrics.precision,sm.metrics.recall])

In [None]:
epochs = 100
with tf.device("/GPU:0"):
    history = model.fit(train_images, train_masks, batch_size=4, epochs=epochs, validation_data=(val_images, val_masks), callbacks=[checkpoint,early_stopping,lr_callback])

## Testing

In [None]:
 custom_objects = {'GADC': GADC}

model = load_model('models/improvedkanunet.h5', compile=False, custom_objects=custom_objects)

# Compile the model with the appropriate loss function and metrics
model.compile(optimizer=Adam(),
              loss=sm.losses.dice_loss,
              metrics=[sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5), 
                       'accuracy', Precision(), Recall()])

In [None]:
eval = model.evaluate(test_images, test_masks)
print('Test accuracy: ' + "{:.2f}".format(eval[1]))

## Visualize 10 images in test set

In [None]:
random_indices = random.sample(range(0, len(test_images)), 10)
test_sample = test_images[random_indices]

predictions = model.predict(test_sample)
predictions = (predictions > 0.5).astype(np.uint8)

fig, axes = plt.subplots(10, 3, figsize=(10, 3*10))

# Iterate over the image and mask pairs and display them in subplots
for i in range(len(test_sample)):

    image = (test_sample[i] * 255).astype(np.uint8)
    mask = predictions[i]
    #print(mask.shape)
    ground_truth = test_masks[random_indices][i] #* np.array([255, 255, 255]) # convert the forground into yellow color to achieve the desired aesthetic
    overlay = image.copy()

    mask = np.repeat(mask, 3, axis=2) # matching the size of the channel of the mask and the image to perform an overlay
    #print(mask.shape)
    inverted_mask = 1 - mask

    yellow_mask = np.array([255, 255, 255]) * mask

    # Apply the mask on the image
    result = image * inverted_mask + yellow_mask
    alpha = 0.2
    predicted_overlay = cv2.addWeighted(overlay, alpha, result.astype(overlay.dtype), 1 - alpha, 0)

    # Plot the image and mask in the corresponding subplot
    axes[i, 0].imshow(image)
    axes[i, 0].set_title('Original')
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(ground_truth)
    axes[i, 1].set_title('Ground Truth')
    axes[i, 1].axis('off')
    
    axes[i, 2].imshow(yellow_mask)
    axes[i, 2].set_title('Predicted')
    axes[i, 2].axis('off')
    
    
    
# Adjust the spacing between subplots
plt.tight_layout()
#plt.savefig('result.png', bbox_inches='tight')  # Save as PNG image

# Show the plot
plt.show()

## Visualize all images in test set

In [None]:
test_sample = test_images[:49]

predictions = model.predict(test_sample)
predictions = (predictions > 0.5).astype(np.uint8)

# Create a 49x3 grid of images (49 samples, each with 3 images)
fig, axes = plt.subplots(49, 3, figsize=(15, 49*3))  

# Iterate over the image and mask pairs and display them in subplots
for i in range(len(test_sample)):

    image = (test_sample[i] * 255).astype(np.uint8)
    mask = predictions[i]
    ground_truth = test_masks[i]  # No need to use random indices anymore

    overlay = image.copy()

    mask = np.repeat(mask, 3, axis=2)  # Match the size of the channel of the mask and the image to perform an overlay
    inverted_mask = 1 - mask
    yellow_mask = np.array([255, 255, 255]) * mask

    # Apply the mask on the image
    result = image * inverted_mask + yellow_mask
    alpha = 0.2
    predicted_overlay = cv2.addWeighted(overlay, alpha, result.astype(overlay.dtype), 1 - alpha, 0)

    # Plot the image and mask in the corresponding subplot
    axes[i, 0].imshow(image)
    axes[i, 0].set_title('Original')
    axes[i, 0].axis('off')

    axes[i, 1].imshow(ground_truth)
    axes[i, 1].set_title('Ground Truth')
    axes[i, 1].axis('off')

    axes[i, 2].imshow(yellow_mask)
    axes[i, 2].set_title('Predicted')
    axes[i, 2].axis('off')

# Adjust the spacing between subplots
plt.tight_layout()
#plt.savefig('unetSalArFSF1024dice1mlp1kan.png', bbox_inches='tight')  # Save as PNG image

# Show the plot
plt.show()
