In [None]:
!pip install -U tensorflow==2.15.0
import tensorflow as tf
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os


print("TensorFlow version:", tf.__version__)

In [None]:
import os
import subprocess
import nibabel as nib
import tensorflow as tf
from skimage import exposure
from math import ceil
import matplotlib.pyplot as plt
import random

class DataLoader:
    def __init__(self, data_path, split_ratio, batch_size=None):
        self.data_path = data_path
        self.split_ratio = split_ratio
        self.all_subjects = None
        self.subjects_lists = []
        self.labels = {'train': 0, 'test': 1, 'validation': 2}
        self.size = [0, 0, 0]
        self.batch_size = batch_size
        self.slices_number = None

    def list_subjects(self):
        subjects = os.listdir(self.data_path)
        subjects = [item for item in subjects if item.startswith('sub')]
        self.all_subjects = subjects

    def get_nifti_path(self, subject, number_of_motion='1'):
        ref_path_stand = f'{self.data_path}/{subject}/anat/{subject}_acq-standard_T1w.nii'
        ref_path_motion = f'{self.data_path}/{subject}/anat/{subject}_acq-headmotion{number_of_motion}_T1w.nii'

        return [ref_path_stand, ref_path_motion]

    def get_paired_volumes(self, path):
        if os.path.exists(path[0]) and os.path.exists(path[1]):
            free_data = nib.load(path[0]).get_fdata()
#             free_data = exposure.rescale_intensity(free_data, out_range=(0.0, 1.0))
            free_data = exposure.rescale_intensity(free_data[37:-37], out_range=(-1.0, 1.0))

            motion_data = nib.load(path[1]).get_fdata()
#             motion_data = exposure.rescale_intensity(motion_data, out_range=(0.0, 1.0))
            motion_data = exposure.rescale_intensity(motion_data[37:-37], out_range=(-1.0, 1.0))
            return tf.convert_to_tensor(free_data), tf.convert_to_tensor(motion_data)
        else:
            return None, None

    def split_data(self):
        self.list_subjects()
        if ceil(sum(self.split_ratio)) == 1 and len(self.split_ratio) <= 3:
            self.split_ratio.insert(0, 0)
            cumulative_sum = [sum(self.split_ratio[:i + 1]) for i in range(len(self.split_ratio))]
            number_of_subjects = len(self.all_subjects)

            for i in range(1, len(self.split_ratio)):
                self.subjects_lists.append(
                    self.all_subjects[int(round(cumulative_sum[i - 1] * number_of_subjects)):int(
                        round(cumulative_sum[i] * number_of_subjects))])

                self.size[i - 1] = len(self.subjects_lists[i - 1])  * 2 * 190

                if i - 1 == 0:
                    self.size[i - 1] -= 8  * 2 * 190
        else:
            print("The Summation of ratios is not equal to 1")
       
    def generator(self, mode):
        subjects = self.subjects_lists[self.labels[mode]]

        def data_gen():
            for subject in subjects:
                for i in range(2):
                    pathes = self.get_nifti_path(subject, str(i + 1))
                    free, motion = self.get_paired_volumes(pathes)
                    if (free is not None) and (motion is not None):
                        self.slices_number = motion.shape[0]

                        for slice_id in range(0, self.slices_number):
                            start_idx = slice_id + 1
                            end_idx = (slice_id + 1) + 1
                            if (end_idx < self.slices_number-1):
                                free_slice = free[start_idx:end_idx]
                                free_slice = tf.transpose(free_slice, perm=[1, 2, 0])
                                
                                motion_slice = motion[start_idx:end_idx]
                                motion_slice = tf.transpose(motion_slice, perm=[1, 2, 0])
                                
                                motion_before_slice = motion[start_idx-1:end_idx-1]
                                motion_before_slice = tf.transpose(motion_before_slice, perm=[1, 2, 0])
                                
                                motion_after_slice = motion[start_idx+1:end_idx+1]
                                motion_after_slice = tf.transpose(motion_after_slice, perm=[1, 2, 0])

                                yield (
                                (motion_before_slice, motion_slice, motion_after_slice),
                                free_slice
                                )

        input_signature = (
            (tf.TensorSpec(shape=(256, 256, 1), dtype=tf.float32),
             tf.TensorSpec(shape=(256, 256, 1), dtype=tf.float32),
             tf.TensorSpec(shape=(256, 256, 1), dtype=tf.float32)),
            tf.TensorSpec(shape=(256, 256, 1), dtype=tf.float32)
        )

        dataset = tf.data.Dataset.from_generator(data_gen, output_signature=input_signature)
        dataset = dataset.batch(self.batch_size)

        return dataset

In [None]:
import keras
import tensorflow as tf
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras import layers
import tensorflow.keras.backend as K
import numpy as np
import pywt

IMAGE_ORDERING_CHANNELS_LAST = "channels_last"
IMAGE_ORDERING_CHANNELS_FIRST = "channels_first"

# Default IMAGE_ORDERING = channels_last
IMAGE_ORDERING = IMAGE_ORDERING_CHANNELS_LAST

if IMAGE_ORDERING == 'channels_first':
	MERGE_AXIS = 1
elif IMAGE_ORDERING == 'channels_last':
	MERGE_AXIS = -1
#Define Adaptive normalization layer
def expand_moments_dim(moment):
    return tf.reshape(moment, [-1, 1, 1, tf.shape(moment)[-1]])
class AdaptiveInstanceNorm(Layer):
    def __init__(self, epsilon=1e-5, **kwargs):
        super(AdaptiveInstanceNorm, self).__init__(**kwargs)
        self.epsilon = epsilon

    def call(self, content, gamma, beta):
        c_mean, c_var = tf.nn.moments(content, axes=[1, 2], keepdims=True)
        c_std = tf.sqrt(c_var + self.epsilon)
        normalized = (content - c_mean) / c_std
        gamma = expand_moments_dim(gamma)
        beta = expand_moments_dim(beta)
        return multiply([gamma, normalized]) + beta

    def get_config(self):
        config = super(AdaptiveInstanceNorm, self).get_config()
        config.update({"epsilon": self.epsilon})

# Define the Haar filters
def haar_filter():
    low_pass = tf.constant([1, 1], dtype=tf.float32) / tf.math.sqrt(2.0)
    high_pass = tf.constant([1, -1], dtype=tf.float32) / tf.math.sqrt(2.0)
    return low_pass, high_pass

# Function to apply 1D filter across the width dimension
def apply_filter_1d(inputs, filter_kernel, stride):
    """
    Applies a 1D filter across the width dimension.
    """
    in_channels = inputs.shape[-1]
    out_channels = in_channels  # Ensure output channels match input channels
    
    # Adjust the reshape operation to match filter_kernel's shape
    filter_kernel = tf.reshape(filter_kernel, [2, 1, 1, 1])  # [filter_height, filter_width, in_channels, out_channels]
    filter_kernel = tf.tile(filter_kernel, [1, 1, in_channels, out_channels])  # Tile to match input channels
    
    return tf.nn.conv2d(inputs, filter_kernel, strides=[1, 1, stride, 1], padding='SAME', data_format='NHWC')

# Function to perform 1D wavelet transform
def wavelet_transform_1d(inputs):
    low_pass, high_pass = haar_filter()
    
    # Apply filters manually
    low = apply_filter_1d(inputs, low_pass, stride=2)
    high = apply_filter_1d(inputs, high_pass, stride=2)
    
    return low, high

# Function to perform 2D wavelet transform
def wavelet_transform_2d(inputs):
    low, high = wavelet_transform_1d(inputs)
    low_low, low_high = wavelet_transform_1d(tf.transpose(low, [0, 2, 1, 3]))
    high_low, high_high = wavelet_transform_1d(tf.transpose(high, [0, 2, 1, 3]))
    return tf.transpose(low_low, [0, 2, 1, 3]), tf.transpose(low_high, [0, 2, 1, 3]), tf.transpose(high_low, [0, 2, 1, 3]), tf.transpose(high_high, [0, 2, 1, 3])

# Function to perform multi-level wavelet transform
def multi_level_wavelet_transform(inputs, levels=3):
    wavelet_coeffs_1 = []
    wavelet_coeffs_2 = []
    wavelet_coeffs_3 = []
    
    for level in range(levels):
        low_low, low_high, high_low, high_high = wavelet_transform_2d(inputs)
        if level == 0:
            wavelet_coeffs_1.append([low_high, high_low, high_high])
        elif level == 1:
            wavelet_coeffs_2.append([low_high, high_low, high_high])
        elif level == 2:
            wavelet_coeffs_3.append([low_low, low_high, high_low, high_high])

        inputs = low_low  # Proceed to the next level with the low-frequency component
    
    wat1 = tf.convert_to_tensor(wavelet_coeffs_1)
    wat2 = tf.convert_to_tensor(wavelet_coeffs_2)
    wat3 = tf.convert_to_tensor(wavelet_coeffs_3)
    
#     print("wat1.shape: ", wat1.shape)
#     print("wat2.shape: ", wat2.shape)
#     print("wat3.shape: ", wat3.shape)

    
    wat1 = tf.squeeze(wat1, axis=0)
#     print("wat1.shape after squeeze: ", wat1.shape)
    wat1 = tf.transpose(wat1, perm=[1, 2, 3, 4, 0])
#     print("wat1.shape after trans: ", wat1.shape)

    wat2 = tf.squeeze(wat2, axis=0)
    wat2 = tf.transpose(wat2, perm=[1, 2, 3, 4, 0])

    wat3 = tf.squeeze(wat3, axis=0)
    wat3 = tf.transpose(wat3, perm=[1, 2, 3, 4, 0])
    
    # Get the shape of the input tensor
    wat1_shape = tf.shape(wat1)
    # Calculate the new shape for reshaping
    wat1_shape = tf.concat([wat1_shape[:-2], [wat1_shape[-2] * wat1_shape[-1]]], axis=0)
    # Reshape the tensor
    wat1 = tf.reshape(wat1, wat1_shape)
    
    # Get the shape of the input tensor
    wat2_shape = tf.shape(wat2)
    # Calculate the new shape for reshaping
    wat2_shape = tf.concat([wat2_shape[:-2], [wat2_shape[-2] * wat2_shape[-1]]], axis=0)
    # Reshape the tensor
    wat2 = tf.reshape(wat2, wat2_shape)
    
    
    # Get the shape of the input tensor
    wat3_shape = tf.shape(wat3)
    # Calculate the new shape for reshaping
    wat3_shape = tf.concat([wat3_shape[:-2], [wat3_shape[-2] * wat3_shape[-1]]], axis=0)
    # Reshape the tensor
    wat3 = tf.reshape(wat3, wat3_shape)
    
#     print("wat1.shape: ", wat1.shape)
#     print("wat2.shape: ", wat2.shape)
#     print("wat3.shape: ", wat3.shape)
    
    return wat1, wat2, wat3

    
def wat_3(inputs):
#     print("inputs shape: ", inputs.shape)
    wat1, wat2, wat3 = multi_level_wavelet_transform(inputs) 
    
    # Explicitly set the shapes
    batch_size = inputs.shape[0]
    height, width = inputs.shape[1], inputs.shape[2]
    channels = inputs.shape[3]

    wat1.set_shape((batch_size, height // 2, width // 2, channels * 3))
    wat2.set_shape((batch_size, height // 4, width // 4, channels * 3))
    wat3.set_shape((batch_size, height // 8, width // 8, channels * 4))
        
#     print("wat1.shape: ", wat1.shape)
#     print("wat2.shape: ", wat2.shape)
#     print("wat3.shape: ", wat3.shape)
        
    return wat1, wat2, wat3


def wat_layer_1(x, wat):
    watp_prod = Conv2D(16*(2**1), (3, 3), data_format=IMAGE_ORDERING, padding='same')(wat)
    watp_prod = Activation('relu')(watp_prod)
    watp_prod = Conv2D(32*(2**1), (3, 3), data_format=IMAGE_ORDERING, padding='same')(watp_prod)
    
    watp_sum = Conv2D(16*(2**1), (3, 3), data_format=IMAGE_ORDERING, padding='same')(wat)
    watp_sum = Activation('relu')(watp_sum)
    watp_sum = Conv2D(32*(2**1), (3, 3), data_format=IMAGE_ORDERING, padding='same')(watp_sum)
    
    x = multiply([x, watp_prod])
    x = Add()([x, watp_sum])    
    x = Activation('relu')(x)
    return x
    

def wat_layer_2(x, wat):
    watp_prod = Conv2D(16*(2**2), (3, 3), data_format=IMAGE_ORDERING, padding='same')(wat)
    watp_prod = Activation('relu')(watp_prod)
    watp_prod = Conv2D(32*(2**2), (3, 3), data_format=IMAGE_ORDERING, padding='same')(watp_prod)
    
    watp_sum = Conv2D(16*(2**2), (3, 3), data_format=IMAGE_ORDERING, padding='same')(wat)
    watp_sum = Activation('relu')(watp_sum)
    watp_sum = Conv2D(32*(2**2), (3, 3), data_format=IMAGE_ORDERING, padding='same')(watp_sum)
    
    x = multiply([x, watp_prod])
    x = Add()([x, watp_sum])    
    x = Activation('relu')(x)
    return x


def wat_layer_3(x, wat):
    watp_prod = Conv2D(16*(2**3), (3, 3), data_format=IMAGE_ORDERING, padding='same')(wat)
    watp_prod = Activation('relu')(watp_prod)
    watp_prod = Conv2D(32*(2**3), (3, 3), data_format=IMAGE_ORDERING, padding='same')(watp_prod)
    
    watp_sum = Conv2D(16*(2**3), (3, 3), data_format=IMAGE_ORDERING, padding='same')(wat)
    watp_sum = Activation('relu')(watp_sum)
    watp_sum = Conv2D(32*(2**3), (3, 3), data_format=IMAGE_ORDERING, padding='same')(watp_sum)
    
    x = multiply([x, watp_prod])
    x = Add()([x, watp_sum])    
    x = Activation('relu')(x)
    return x


# CBAM --------------------------------------------
# Convolutional Block Attention Module(CBAM) block
def cbam_block(cbam_feature, ratio=8):
	cbam_feature = channel_attention(cbam_feature, ratio)
	cbam_feature = spatial_attention(cbam_feature)
	return cbam_feature

def channel_attention(input_feature, ratio=8):

	channel_axis = 1 if K.image_data_format() == "channels_first" else -1
	channel = input_feature.shape[channel_axis]  # input_feature._keras_shape[channel_axis]

	shared_layer_one = Dense(channel//ratio,
							 activation='relu',
							 kernel_initializer='he_normal',
							 use_bias=True,
							 bias_initializer='zeros')
	shared_layer_two = Dense(channel,
							 kernel_initializer='he_normal',
							 use_bias=True,
							 bias_initializer='zeros')

	avg_pool = GlobalAveragePooling2D()(input_feature)
	avg_pool = Reshape((1,1,channel))(avg_pool)
	assert avg_pool.shape[1:] == (1,1,channel)
	avg_pool = shared_layer_one(avg_pool)
	assert avg_pool.shape[1:] == (1,1,channel//ratio)
	avg_pool = shared_layer_two(avg_pool)
	assert avg_pool.shape[1:] == (1,1,channel)

	max_pool = GlobalMaxPooling2D()(input_feature)
	max_pool = Reshape((1,1,channel))(max_pool)
	assert max_pool.shape[1:] == (1,1,channel)
	max_pool = shared_layer_one(max_pool)
	assert max_pool.shape[1:] == (1,1,channel//ratio)
	max_pool = shared_layer_two(max_pool)
	assert max_pool.shape[1:] == (1,1,channel)

	cbam_feature = Add()([avg_pool,max_pool])
	cbam_feature = Activation('sigmoid')(cbam_feature)

	if K.image_data_format() == "channels_first":
		cbam_feature = Permute((3, 1, 2))(cbam_feature)

	return multiply([input_feature, cbam_feature])

def spatial_attention(input_feature):
	kernel_size = 7

	if K.image_data_format() == "channels_first":
		channel = input_feature._keras_shape[1]
		cbam_feature = Permute((2,3,1))(input_feature)
	else:
		channel = input_feature.shape[-1]
		cbam_feature = input_feature

	avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)
	assert avg_pool.shape[-1] == 1
	max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)
	assert max_pool.shape[-1] == 1
	concat = Concatenate(axis=3)([avg_pool, max_pool])
	assert concat.shape[-1] == 2
	cbam_feature = Conv2D(filters = 1,
					kernel_size=kernel_size,
					strides=1,
					padding='same',
					activation='sigmoid',
					kernel_initializer='he_normal',
					use_bias=False)(concat)
	assert cbam_feature.shape[-1] == 1

	if K.image_data_format() == "channels_first":
		cbam_feature = Permute((3, 1, 2))(cbam_feature)

	return multiply([input_feature, cbam_feature])
    
def UNet(img_input,norm_list):
    k1 = 32
    k2 = 64
    k3 = 128
    k4 = 256
    
    watp1, watp2, watp3 = wat_3(img_input)
    
    # Block 1 in Contracting Path
    conv1 = Conv2D(k1, (3, 3), data_format=IMAGE_ORDERING,padding='same', dilation_rate=1)(img_input)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation(tf.nn.leaky_relu)(conv1)
    #conv1 = Dropout(0.2)(conv1)
    conv1 = Conv2D(k1, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=1)(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation(tf.nn.leaky_relu)(conv1)

    conv1 = cbam_block(conv1)    # Convolutional Block Attention Module(CBAM) block

    o = AveragePooling2D((2, 2), strides=(2, 2))(conv1)

    # Block 2 in Contracting Path
    conv2 = Conv2D(k2, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=1)(o)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation(tf.nn.leaky_relu)(conv2)
#     conv2 = Dropout(0.2)(conv2)
    conv2 = Conv2D(k2, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=1)(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation(tf.nn.leaky_relu)(conv2)
    
    conv2 = wat_layer_1(conv2, watp1)

    conv2 = cbam_block(conv2)    # Convolutional Block Attention Module(CBAM) block

    o = AveragePooling2D((2, 2), strides=(2, 2))(conv2)

    # Block 3 in Contracting Path
    conv3 = Conv2D(k3, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=1)(o)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation(tf.nn.leaky_relu)(conv3)
    #conv3 = Dropout(0.2)(conv3)
    conv3 = Conv2D(k3, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=1)(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation(tf.nn.leaky_relu)(conv3)

    conv3 = wat_layer_2(conv3, watp2)

    conv3 = cbam_block(conv3)    # Convolutional Block Attention Module(CBAM) block

    o = AveragePooling2D((2, 2), strides=(2, 2))(conv3)

     # Transition layer between contracting and expansive paths:
    conv4 = Conv2D(k4, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=1)(o)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation(tf.nn.leaky_relu)(conv4)
    #conv4 = Dropout(0.2)(conv4)
    conv4 = Conv2D(k4, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=1)(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 =Activation(tf.nn.leaky_relu)(conv4)

    conv4 = wat_layer_3(conv4, watp3)

    conv4 = cbam_block(conv4)    # Convolutional Block Attention Module(CBAM) block
    res1 = Conv2D(k4, (3, 3),padding='same',activation='relu')(conv4)
    AdaptiveInstanceNorm()(res1,norm_list[:,:,:,:k4], norm_list[:,:,:,k4:2*k4])
    res1 = Conv2D(k4, (3, 3),padding='same')(res1)
    res1 += conv4
    
    res2 = Conv2D(k4, (3, 3),padding='same',activation='relu')(res1)
    AdaptiveInstanceNorm()(res2,norm_list[:,:,:,2*k4:3*k4], norm_list[:,:,:,3*k4:4*k4])
    res2 = Conv2D(k4, (3, 3),padding='same')(res2)
    res2 += res1
    
    res3 = Conv2D(k4, (3, 3),padding='same',activation='relu')(res2)
    AdaptiveInstanceNorm()(res3,norm_list[:,:,:,4*k4:5*k4], norm_list[:,:,:,4*k4:5*k4])
    res3 = Conv2D(k4, (3, 3),padding='same')(res3)
    res3 += res2
    
    res4 = Conv2D(k4, (3, 3),padding='same',activation='relu')(res3)
    AdaptiveInstanceNorm()(res4,norm_list[:,:,:,5*k4:6*k4], norm_list[:,:,:,5*k4:6*k4])
    res4 = Conv2D(k4, (3, 3),padding='same')(res4)
    res4 += res3
    conv4 = res4

    # Block 1 in Expansive Path
    up1 = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(conv4)
    up1 = concatenate([up1, conv3], axis=MERGE_AXIS)
    deconv1 =  Conv2D(k3, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=1)(up1)
    deconv1 = BatchNormalization()(deconv1)
    deconv1 = Activation(tf.nn.leaky_relu)(deconv1)
    #deconv1 = Dropout(0.2)(deconv1)
    deconv1 =  Conv2D(k3, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=1)(deconv1)
    deconv1 = BatchNormalization()(deconv1)
    deconv1 = Activation(tf.nn.leaky_relu)(deconv1)

    deconv1 = cbam_block(deconv1)    # Convolutional Block Attention Module(CBAM) block

    # Block 2 in Expansive Path
    up2 = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(deconv1)
    up2 = concatenate([up2, conv2], axis=MERGE_AXIS)
    deconv2 = Conv2D(k2, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=1)(up2)
    deconv2 = BatchNormalization()(deconv2)
    deconv2 = Activation(tf.nn.leaky_relu)(deconv2)
    #deconv2 = Dropout(0.2)(deconv2)
    deconv2 = Conv2D(k2, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=1)(deconv2)
    deconv2 = BatchNormalization()(deconv2)
    deconv2 = Activation(tf.nn.leaky_relu)(deconv2)

    deconv2 = cbam_block(deconv2)    # Convolutional Block Attention Module(CBAM) block

    # Block 3 in Expansive Path
    up3 = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(deconv2)
    up3 = concatenate([up3, conv1], axis=MERGE_AXIS)
    deconv3 = Conv2D(k1, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=1)(up3)
    deconv3 = BatchNormalization()(deconv3)
    deconv3 = Activation(tf.nn.leaky_relu)(deconv3)
    #deconv3 = Dropout(0.2)(deconv3)
    deconv3 = Conv2D(k1, (3, 3), data_format=IMAGE_ORDERING, padding='same', dilation_rate=1)(deconv3)
    deconv3 = BatchNormalization()(deconv3)
    deconv3 = Activation(tf.nn.leaky_relu)(deconv3)

    deconv3 = cbam_block(deconv3)    # Convolutional Block Attention Module(CBAM) block

    output = Conv2D(1, (3, 3), data_format=IMAGE_ORDERING, padding='same')(deconv3)
    # 	output = Activation('sigmoid')(output)
    output = Activation('tanh')(output)
    return output


def Correction_Multi_input(input_height, input_width):
	assert input_height % 32 == 0
	assert input_width % 32 == 0

#   UNET
	img_input_1 = Input(shape=(input_height, input_width, 1))
	img_input_2 = Input(shape=(input_height, input_width, 1))
	img_input_3 = Input(shape=(input_height, input_width, 1))
	norm_list = Input(shape=(1, 1, 1536))    
# 	kk = 32
	kk = 64
	conv1 = Conv2D(kk, (3, 3), data_format=IMAGE_ORDERING,padding='same', dilation_rate=1)(img_input_1) # dilation_rate=6
	conv1 = BatchNormalization()(conv1)
	conv1 = Activation('relu')(conv1)
	conv2 = Conv2D(kk, (3, 3), data_format=IMAGE_ORDERING,padding='same', dilation_rate=1)(img_input_2) # dilation_rate=6
	conv2 = BatchNormalization()(conv2)
	conv2 = Activation('relu')(conv2)
	conv3 = Conv2D(kk, (3, 3), data_format=IMAGE_ORDERING,padding='same', dilation_rate=1)(img_input_3) # dilation_rate=6
	conv3 = BatchNormalization()(conv3)
	conv3 = Activation('relu')(conv3)

	input_concat = concatenate([conv1, conv2, conv3], axis=MERGE_AXIS)  #conv4
	# dataset = tf.data.Dataset.from_tensor_slices((img_input_1, img_input_2, img_input_3)

	## Two Stacked Nets:
	pred_1  = UNet(input_concat,norm_list)
	input_2 = concatenate([input_concat, pred_1], axis=MERGE_AXIS)
	pred_2  = UNet(input_2,norm_list) #

	model = Model(inputs=[img_input_1,img_input_2,img_input_3,norm_list], outputs=pred_2)

	return model
def Style_Encoder(style_dim, img_input):
    initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    k1, k2, k3 = 32, 64, 128
    conv_1 = Conv2D(k1, (5, 5), padding="same", strides=1, activation='relu', kernel_initializer=initializer)(img_input)
    conv_2 = Conv2D(k2, (2, 2), padding="same", strides=2, activation='relu', kernel_initializer=initializer)(conv_1)
    conv_3 = Conv2D(k3, (2, 2), padding="same", strides=2, activation='relu', kernel_initializer=initializer)(conv_2)
    conv_4 = Conv2D(k3, (2, 2), padding="same", strides=2, activation='relu', kernel_initializer=initializer)(conv_3)
    conv_5 = Conv2D(k3, (2, 2), padding="same", strides=2, activation='relu', kernel_initializer=initializer)(conv_4)
    avg = GlobalAveragePooling2D(keepdims=True)(conv_5)
    conv_6 = Conv2D(style_dim, (1, 1), padding="valid", strides=1, kernel_initializer=initializer)(avg)
    return conv_6

def MLP(input_block, output_dim, dim):
    layer1 = Dense(dim, activation='relu', use_bias=True)(input_block)
    layer2 = Dense(dim, activation='relu', use_bias=True)(layer1)
    layer3 = Dense(output_dim, activation='relu', use_bias=True)(layer2)
    return layer3

In [None]:
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Model

def ssim_score(y_true, y_pred):
    score = tf.image.ssim(
    y_true,
    y_pred,
    max_val=2.0,
    filter_size=5,
    filter_sigma=1.5,
    k1=0.01,
    k2=0.03,
    )
    return score

def ssim_loss(y_true, y_pred):
    score = tf.image.ssim(
    y_true,
    y_pred,
    max_val=2.0,
    filter_size=5,
    filter_sigma=1.5,
    k1=0.01,
    k2=0.03,
    )
    
    loss = (1-score)/2
    return loss

def create_center_rectangle_mask(mask_shape, rect_height, rect_width):
    mask_shape = mask_shape[1:]
    # Create a mask with a central rectangle of zeros
    mask = np.ones(mask_shape, dtype=np.float32)

    # Calculate the position of the top-left corner of the rectangle
    rect_top = (mask_shape[0] - rect_height) // 2
    rect_left = (mask_shape[1] - rect_width) // 2

    # Update the mask with the rectangle at the center
    mask[rect_top:rect_top+rect_height, rect_left:rect_left+rect_width] = 0

    # Convert the NumPy array to a TensorFlow tensor
    mask_tensor = tf.convert_to_tensor(mask, dtype=tf.float32)

    return mask_tensor

def crop_center_rectangle_mask(tensor, rect_height=128, rect_width=128):
    mask = create_center_rectangle_mask(tensor.shape, rect_height, rect_height)
    return tf.multiply(tensor, mask)

def fft_loss(y_true, y_pred, crop=False):
    """
    Custom loss function for spatial loss with FFT features.

    Args:
        y_true: Ground truth image(s).
        y_pred: Predicted image(s).

    Returns:
        Normalized reduced spatial loss.

    This function defines a custom loss for training neural networks. It applies a Fourier Transform
    to the true and predicted images, extracts the real and imaginary parts of the transformed
    features, and calculates the mean squared error between them. The loss is then normalized and
    reduced to a single scalar value.

    """
    # Apply Fourier Transform to the true and predicted images
    true_fft = tf.signal.fft2d(tf.cast(y_true, dtype=tf.complex64))
    pred_fft = tf.signal.fft2d(tf.cast(y_pred, dtype=tf.complex64))

    # Extract Real & Imaginary parts
    true_fft_real = tf.math.real(true_fft)
    true_fft_imag = tf.math.imag(true_fft)
    pred_fft_real = tf.math.real(pred_fft)
    pred_fft_imag = tf.math.imag(pred_fft)
    
    if crop:
        # Crop center rectangles for real and imag
        true_fft_real_cropped = crop_center_rectangle_mask(true_fft_real)
        true_fft_imag_cropped = crop_center_rectangle_mask(true_fft_imag)
        pred_fft_real_cropped = crop_center_rectangle_mask(pred_fft_real)
        pred_fft_imag_cropped = crop_center_rectangle_mask(pred_fft_imag)

        # Calculate Total loss
        loss_real = tf.reduce_mean(tf.keras.losses.mse(true_fft_real_cropped, pred_fft_real_cropped))
        loss_imag = tf.reduce_mean(tf.keras.losses.mse(true_fft_imag_cropped, pred_fft_imag_cropped))
        total_loss = 0.5*(loss_real+loss_imag)
    else:
        # Calculate Total loss
        loss_real = tf.reduce_mean(tf.keras.losses.mse(true_fft_real, pred_fft_real))
        loss_imag = tf.reduce_mean(tf.keras.losses.mse(true_fft_imag, pred_fft_imag))
        total_loss = 0.5*(loss_real+loss_imag)
    
    return total_loss

def init_vgg16_model():
    """
    Initialize a pre-trained VGG16 model for feature extraction.

    Args:
        perceptual_layer_name: Name of the layer to extract features from.

    Returns:
        Pre-trained VGG16 model with specified layer for feature extraction.

    This function loads a pre-trained VGG16 model with ImageNet weights and removes the top
    classification layers. It then extracts the specified layer for feature extraction and
    freezes the model's layers to prevent further training.

    """
    # Load pre-trained VGG16 model without the top classification layers
    vgg_model = VGG16(include_top=False, weights='imagenet', input_shape=(256, 256, 3))

    # Extract the specified layer from the VGG16 model
    perceptual_model_conv1 = Model(inputs=vgg_model.input, outputs=vgg_model.get_layer('block1_conv1').output)
    perceptual_model_conv2 = Model(inputs=vgg_model.input, outputs=vgg_model.get_layer('block2_conv1').output)
    perceptual_model_conv3 = Model(inputs=vgg_model.input, outputs=vgg_model.get_layer('block3_conv1').output)

    # Freeze the layers in the perceptual model so they are not trained further
    for perceptual_model in [perceptual_model_conv1,perceptual_model_conv2,perceptual_model_conv3]:
        for layer in perceptual_model.layers:
            layer.trainable = False
        
    print("VGG16 Model Initialized")
    return perceptual_model_conv1, perceptual_model_conv2, perceptual_model_conv3

# Initialize VGG16 model for feature extraction
perceptual_models = init_vgg16_model()


def perceptual_loss(y_true, y_pred):
    """
    Custom loss function for perceptual loss.

    Args:
        y_true: Ground truth image(s).
        y_pred: Predicted image(s).

    Returns:
        Normalized reduced perceptual loss.

    This function defines a custom loss for training neural networks. It converts single-channel
    images to RGB, preprocesses them for VGG16, and extracts features from a specified layer
    using a pre-trained VGG16 model. It then calculates the mean squared error between the features
    of the true and predicted images. The loss is normalized and reduced to a single scalar value.

    """
    # Extract perceptual models
    perceptual_model_conv1, perceptual_model_conv2, perceptual_model_conv3 = perceptual_models

    # Convert single-channel images to RGB
    y_true_rgb = tf.repeat(y_true, 3, axis=-1)
    y_pred_rgb = tf.repeat(y_pred, 3, axis=-1)

    # Preprocess images for VGG16
    y_true_processed = tf.keras.applications.vgg16.preprocess_input(y_true_rgb)
    y_pred_processed = tf.keras.applications.vgg16.preprocess_input(y_pred_rgb)

    # Extract features from specified layer for true and predicted images
    features_true_conv1 = perceptual_model_conv1(y_true_processed)
    features_pred_conv1 = perceptual_model_conv1(y_pred_processed)
   
    # Extract features from specified layer for true and predicted images
    features_true_conv2 = perceptual_model_conv2(y_true_processed)
    features_pred_conv2 = perceptual_model_conv2(y_pred_processed)
  
    # Extract features from specified layer for true and predicted images
    features_true_conv3 = perceptual_model_conv3(y_true_processed)
    features_pred_conv3 = perceptual_model_conv3(y_pred_processed)
    
    # Calculate L2 loss
    mse_conv1 = tf.reduce_mean(tf.keras.losses.mse(features_true_conv1, features_pred_conv1))
    mse_conv2 = tf.reduce_mean(tf.keras.losses.mse(features_true_conv2, features_pred_conv2))
    mse_conv3 = tf.reduce_mean(tf.keras.losses.mse(features_true_conv3, features_pred_conv3))
    
    total_loss = 0.65*mse_conv1 + 0.3*mse_conv2 + 0.05*mse_conv3

    return total_loss

def psnr(y_true, y_pred):
    return tf.reduce_mean(tf.image.psnr(y_true, y_pred, max_val=2.0))  # Adjust max_val for data normalized between -1 and 1
def l1_loss(y_true, y_pred):
    return tf.reduce_mean(tf.abs(y_true - y_pred))
def l2_loss(y_true, y_pred):
    """
    Computes the L2 loss between the ground truth and predicted tensors.

    Parameters:
        y_true (tf.Tensor): Ground truth tensor.
        y_pred (tf.Tensor): Predicted tensor.

    Returns:
        tf.Tensor: Normalized L2 loss.

    This function calculates the mean squared error (MSE) between the ground truth
    and predicted tensors. It then reduces the MSE along the spatial dimensions,
    typically representing the height and width of the tensors, resulting in a
    tensor of shape (batch_size,), where each element represents the mean MSE
    for a single sample in the batch.

    The loss is then normalized using L2 normalization to ensure that it falls
    within the range of 0 to 1. Finally, the mean of the normalized loss across
    the batch is computed and returned.
    """
    mse = tf.keras.losses.mean_squared_error(y_true, y_pred)

    # Reduce on spatial information
    batch_mse = tf.reduce_mean(mse, axis=(1, 2))

    # Normalize the loss function to be between 0 and 1
    normalized_loss = tf.nn.l2_normalize(batch_mse, axis=-1)
    
    # Compute the mean of the normalized loss across the batch
    normalized_reduced_loss = tf.reduce_mean(batch_mse)

    return normalized_reduced_loss
def total_loss(y_true, y_pred):
    perceptual = perceptual_loss(y_true, y_pred)
    ssim = ssim_loss(y_true, y_pred)
    
    scaled_perceptual = (perceptual*0.05807468295097351)
    adjusted_perceptual = (scaled_perceptual+0.009354699403047562)
    
    total = (ssim+adjusted_perceptual)/2
    return total

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint, LearningRateScheduler
import math
import pandas as pd
from tensorflow.keras.models import model_from_json

# Constants
TRAIN = 1  # True False
TEST = 0  # True False
NB_EPOCH = 100
LEARNING_RATE = 0.001  # 0.001 (default)
HEIGHT, WIDTH = 256, 256
PREDICTION_PATH = '/kaggle/working/Prediction'
WEIGHTS_PATH = '/kaggle/working/Weights'

print('Reading Data .... ')
data_path = "/kaggle/input/mmmai-regist-data/MR-ART-Regist"
split_ratio = [0.7, 0.2, 0.1]
#split_ratio = [0.03, 0.92, 0.03]
batch_size = 5

data_loader = DataLoader(data_path, split_ratio, batch_size)
data_loader.split_data()

train_dataset = data_loader.generator('train')
test_dataset = data_loader.generator('test')
validation_dataset = data_loader.generator('validation')

In [None]:
resume = False

In [None]:
import csv
import os
from tensorflow.keras.callbacks import ModelCheckpoint
import h5py
from tensorflow.keras.models import load_model
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# epochs = 10
epoch_results_path = "/kaggle/working/epoch_results_file.csv"
step_results_path = "/kaggle/working/step_results_file.csv"
model_checkpoint_path = "/kaggle/working/model_checkpoint_{epoch:02d}.h5"
resume = False  # Set this according to your needs
LEARNING_RATE = 0.001  # Replace with your learning rate
NB_EPOCH = 10  # Replace with the number of epochs

# Define callbacks to save the model after each epoch
checkpoint_callback = ModelCheckpoint(
    filepath=model_checkpoint_path,
    monitor='val_loss',  # You can change the monitor value based on your requirement
    save_best_only=True,
    save_weights_only=False,
    mode='min',  # You can change the mode value based on your requirement
    verbose=1
)

if resume:
    data_of_model = os.listdir('/kaggle/input/modelstacked-unet-with-style')
    data_of_model.sort()
    curr_epoch = int((len(data_of_model) - 2) / 2)
    full_unet = load_model(f'/kaggle/input/modelstacked-unet-with-style/Wav_Stacked_UNet_epoch_{curr_epoch}.h5')
    full_style = load_model(f'/kaggle/input/modelstacked-unet-with-style/full_style_epoch_{curr_epoch}.h5')
else:
    full_unet = Correction_Multi_input(HEIGHT, WIDTH)  # Replace with your model initialization
    img_free = tf.keras.Input(shape=(256, 256, 1))
    Style_output = Style_Encoder(256, img_free)  # Replace with your model initialization
    Style_model = tf.keras.Model(inputs=img_free, outputs=Style_output)

    Mlp_output = MLP(Style_output, 1536, 2048)  # Replace with your model initialization
    Mlp_model = tf.keras.Model(inputs=Style_output, outputs=Mlp_output)

    full_style = tf.keras.Model(inputs=img_free, outputs=Mlp_output)
    curr_epoch = 0

def scheduler(epoch):
    ep = 10
    if epoch < ep:
        return LEARNING_RATE
    else:
        return LEARNING_RATE * np.exp(0.1 * (ep - epoch))

Unet_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)
Style_optimizer = tf.keras.optimizers.Adam(0.0001)

try:
    with open(epoch_results_path, 'w', newline='') as epoch_file, open(step_results_path, 'w', newline='') as step_file:
        epoch_writer = csv.writer(epoch_file)
        step_writer = csv.writer(step_file)
        epoch_writer.writerow(['Epoch', 'Loss', 'L1_loss', 'SSIM_Score', 'PSNR', 'MSE', 'SSIM_Loss_val', 'L1_loss_val', 'SSIM_Score_val', 'PSNR_val', 'MSE_val'])
        step_writer.writerow(['Epoch', 'Step', 'Avg_Loss', 'Avg_L1_loss', 'Avg_SSIM_Score', 'Avg_PSNR'])

        for epoch in range(curr_epoch, NB_EPOCH):
            print(f'Starting epoch {epoch + 1}')
            ssim_scores, losses, l1_losses, psnr_values, l2_losses = [], [], [], [], []
            
            new_learning_rate = scheduler(epoch)
            Unet_optimizer.learning_rate.assign(new_learning_rate)
            
            for step, ([Slice1, Slice2, Slice3], FreeImage) in enumerate(train_dataset):
                with tf.GradientTape() as tape:
                    norm_list_output = full_style(Slice2, training=True)
                    output1 = full_unet([Slice1, Slice2, Slice3, norm_list_output], training=True)
                    ssim_scores.append(ssim_score(FreeImage, output1))
                    loss_value = total_loss(FreeImage, output1)
                    psnr_values.append(psnr(FreeImage, output1))
                    losses.append(loss_value)
                    l2_losses.append(l2_loss(FreeImage, output1))

                Unet_gradients = tape.gradient(loss_value, full_unet.trainable_variables)
                Unet_optimizer.apply_gradients(zip(Unet_gradients, full_unet.trainable_variables))

                with tf.GradientTape() as tape:
                    recon_norm_list = full_style(output1)
                    free_norm_list = full_style(FreeImage)
                    L1_loss_value = l1_loss(free_norm_list, recon_norm_list)
                    l1_losses.append(L1_loss_value)

                Style_gradients = tape.gradient(L1_loss_value, full_style.trainable_variables)
                Style_optimizer.apply_gradients(zip(Style_gradients, full_style.trainable_variables))

                if step % 100 == 0:
                    avg_loss = float(np.mean(losses))
                    avg_l1_loss = float(np.mean(l1_losses))
                    avg_ssim_score = float(np.mean(ssim_scores))
                    avg_psnr = float(np.mean(psnr_values))
                    step_writer.writerow([epoch + 1, step, "{:.4f}".format(avg_loss), avg_l1_loss, "{:.4f}".format(avg_ssim_score), "{:.4f}".format(avg_psnr)])
                    step_file.flush()  # Ensure data is written to file
                    #plt.subplot(1, 3, 1)
                    #plt.imshow(FreeImage[2], cmap='gray')
                    #plt.subplot(1, 3, 2)
                    #plt.imshow(Slice2[2], cmap='gray')
                    #plt.subplot(1, 3, 3)
                    #plt.imshow(output1[2], cmap='gray')
                    #plt.show()
                    print(f'Epoch {epoch + 1}, Step {step}, Avg_Loss: {"{:.4f}".format(avg_loss)}, Avg_L1_loss: {avg_l1_loss}, Avg_SSIM_Score: {"{:.4f}".format(avg_ssim_score)}, Avg_PSNR: {"{:.4f}".format(avg_psnr)}, Avg_L2_Loss: {"{:.4f}".format(float(np.mean(l2_losses)))}')

            print(f"================================Epoch {epoch + 1} Validation ===========================================")
            ssim_scores_val, losses_val, l1_losses_val, psnr_values_val, l2_losses_val = [], [], [], [], []

            for step, ([Slice1, Slice2, Slice3], FreeImage) in enumerate(validation_dataset):
                norm_list_val = full_style(Slice2, training=False)
                output1_val = full_unet([Slice1, Slice2, Slice3, norm_list_val], training=False)
                ssim_scores_val.append(ssim_score(FreeImage, output1_val))
                loss_value = total_loss(FreeImage, output1_val)
                psnr_values_val.append(psnr(FreeImage, output1_val))
                losses_val.append(loss_value)
                l2_losses_val.append(l2_loss(FreeImage, output1_val))
                recon_norm_list_val = full_style(output1_val)
                L1_loss_value = l1_loss(norm_list_val, recon_norm_list_val)
                l1_losses_val.append(L1_loss_value)

            avg_loss = float(np.mean(losses))
            avg_l1_loss = float(np.mean(l1_losses))
            avg_ssim_score = float(np.mean(ssim_scores))
            avg_psnr = float(np.mean(psnr_values))
            avg_loss_val = float(np.mean(losses_val))
            avg_l1_loss_val = float(np.mean(l1_losses_val))
            avg_ssim_score_val = float(np.mean(ssim_scores_val))
            avg_psnr_val = float(np.mean(psnr_values_val))
            epoch_writer.writerow([epoch + 1, "{:.4f}".format(avg_loss), avg_l1_loss, "{:.4f}".format(avg_ssim_score), "{:.4f}".format(avg_psnr), "{:.4f}".format(float(np.mean(l2_losses))), "{:.4f}".format(avg_loss_val), avg_l1_loss_val, "{:.4f}".format(avg_ssim_score_val), "{:.4f}".format(avg_psnr_val), "{:.4f}".format(float(np.mean(l2_losses_val)))])
            epoch_file.flush()  # Ensure data is written to file
            full_unet.save(f"/kaggle/working/Wav_Stacked_UNet_epoch_{epoch + 1}.h5")
            full_style.save(f"/kaggle/working/full_style_epoch_{epoch + 1}.h5")

except Exception as e:
    print(f"An error occurred: {e}")


In [None]:
# import math
# import pandas as pd
# import tensorflow as tf
# from tensorflow.keras.callbacks import CSVLogger, LearningRateScheduler, ModelCheckpoint
# from tensorflow.keras.optimizers import Adam
# from tensorflow.keras.models import model_from_json, load_model
# from tensorflow.keras.utils import plot_model

# def exponential_lr(epoch, LEARNING_RATE):
#     if epoch < 10:
#         return LEARNING_RATE
#     else:
#         return LEARNING_RATE * math.exp(0.1 * (10 - epoch)) # lr decreases exponentially by a factor of 10
    


# def main():
#         print('---------------------------------')
#         print('Model Training ...')
#         print('---------------------------------')
        
#         model = Correction_Multi_input(HEIGHT, WIDTH)
#         model.summary()
# #         model = load_model("/kaggle/input/stackedunet-regist-final-wavtf-dataset/stacked_model_20_val_loss_0.0661.h5",
# #                            custom_objects={'total_loss':total_loss, 'ssim_score': ssim_score, 'psnr':psnr, 'K':K})
        
#         csv_logger = CSVLogger(f'{WEIGHTS_PATH}_Loss_Acc.csv', append=True, separator=',')
#         reduce_lr = LearningRateScheduler(exponential_lr)
        
#         model.compile(loss=total_loss, optimizer=Adam(learning_rate=LEARNING_RATE),
#                       metrics=[ssim_score, 'mse', psnr])
        
#         checkpoint_path = '/kaggle/working/stacked_model_{epoch:02d}_val_loss_{val_loss:.4f}.h5'
#         model_checkpoint = ModelCheckpoint(checkpoint_path,
#                                    monitor='val_loss',
#                                    save_best_only=False,
#                                    save_weights_only=False,
#                                    mode='min',
#                                    verbose=1)
        
#         hist = model.fit(train_dataset,
#                          epochs=NB_EPOCH,
#                          verbose=1,
#                          validation_data=validation_dataset,
#                          initial_epoch=20,
#                          callbacks=[csv_logger, reduce_lr, model_checkpoint])


# if __name__ == "__main__":
#     main()