In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, LeakyReLU, BatchNormalization, Flatten, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import TruncatedNormal
from tensorflow.keras.models import Model
import numpy as np
import h5py
import matplotlib.pyplot as plt
import os

In [2]:
def preprocess_data(data, size, channels):
    # Resize the data
    data = np.array([np.resize(image, (size, size, channels)) for image in data])

    # Normalize the data
    data = (data - np.min(data)) / (np.max(data) - np.min(data))

    return data

def data_loader(data_dir_mri, data_dir_ct, batch_size, image_size, label_size, c_dim):
    """
    Data loader function to load MRI, CT images, and their corresponding labels.

    Args:
        data_dir_mri (str): Path to the directory containing MRI data.
        data_dir_ct (str): Path to the directory containing CT data.
        batch_size (int): Batch size for training.
        image_size (int): Size of the input images.
        label_size (int): Size of the label images.
        c_dim (int): Number of channels in the input images.

    Yields:
        Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: A batch of MRI images, MRI labels, CT images, and CT labels.
    """
    # Open HDF5 files
    mri_file = h5py.File(data_dir_mri, 'r')
    ct_file = h5py.File(data_dir_ct, 'r')

    # Get the dataset sizes
    mri_data_size = mri_file['data'].shape[0]
    ct_data_size = ct_file['data'].shape[0]
    data_size = min(mri_data_size, ct_data_size)

    # Create shuffled indices
    indices = np.random.permutation(data_size)

    # Iterate over the data in batches
    for batch_start in range(0, data_size, batch_size):
        batch_end = min(batch_start + batch_size, data_size)
        batch_indices = indices[batch_start:batch_end]

        # Load MRI data
        mri_images = mri_file['data'][batch_indices, :, :, :]
        mri_labels = mri_file['labels'][batch_indices, :, :, :]

        # Load CT data
        ct_images = ct_file['data'][batch_indices, :, :, :]
        ct_labels = ct_file['labels'][batch_indices, :, :, :]

        # Preprocess data (e.g., normalize, reshape)
        mri_images = preprocess_data(mri_images, image_size, c_dim)
        mri_labels = preprocess_data(mri_labels, label_size, c_dim)
        ct_images = preprocess_data(ct_images, image_size, c_dim)
        ct_labels = preprocess_data(ct_labels, label_size, c_dim)

        yield mri_images, mri_labels, ct_images, ct_labels

    # Close HDF5 files
    mri_file.close()
    ct_file.close()

In [3]:
def fusion_model(inputs):
    x = Conv2D(256, 5, padding='same', kernel_initializer=TruncatedNormal(stddev=1e-3))(inputs)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = Conv2D(128, 5, padding='same', kernel_initializer=TruncatedNormal(stddev=1e-3))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = Conv2D(64, 3, padding='same', kernel_initializer=TruncatedNormal(stddev=1e-3))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = Conv2D(32, 3, padding='same', kernel_initializer=TruncatedNormal(stddev=1e-3))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = Conv2D(1, 1, padding='same', kernel_initializer=TruncatedNormal(stddev=1e-3))(x)
    outputs = tf.nn.tanh(x)
    return outputs

def discriminator(inputs):
    x = Conv2D(32, 3, strides=2, padding='valid', kernel_initializer=TruncatedNormal(stddev=1e-3))(inputs)
    x = LeakyReLU()(x)
    x = Conv2D(64, 3, strides=2, padding='valid', kernel_initializer=TruncatedNormal(stddev=1e-3))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = Conv2D(128, 3, strides=2, padding='valid', kernel_initializer=TruncatedNormal(stddev=1e-3))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = Conv2D(256, 3, strides=2, padding='valid', kernel_initializer=TruncatedNormal(stddev=1e-3))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = Flatten()(x)
    x = Dense(1, kernel_initializer=TruncatedNormal(stddev=1e-3))(x)
    return x

In [7]:
# Define input shapes
image_size = 256
label_size = 256
c_dim = 1

# Define placeholders
images_mri = Input(shape=(image_size, image_size, c_dim), name='images_mri')
labels_mri = Input(shape=(label_size, label_size, c_dim), name='labels_mri')
images_ct = Input(shape=(image_size, image_size, c_dim), name='images_ct')
labels_ct = Input(shape=(label_size, label_size, c_dim), name='labels_ct')

# Concatenate inputs
input_image = tf.concat([images_mri, images_ct], axis=-1)

# Pass through fusion model
fusion_image = fusion_model(input_image)

# Pass through discriminator
pos = discriminator(labels_ct)
neg = discriminator(fusion_image)


# Define loss functions
def pos_loss(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_pred - tf.random.uniform(shape=[tf.shape(y_pred)[0], 1], minval=0.7, maxval=1.2)))

def neg_loss(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_pred - tf.random.uniform(shape=[tf.shape(y_pred)[0], 1], minval=0, maxval=0.3)))

def d_loss(y_true, y_pred):
    return pos_loss(y_true, y_pred) + neg_loss(y_true, y_pred)

def g_loss_1(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_pred - tf.random.uniform(shape=[tf.shape(y_pred)[0], 1], minval=0.7, maxval=1.2)))

def g_loss_2(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_pred - y_true)) + 5 * tf.reduce_mean(tf.square(tf.image.image_gradients(y_pred) - tf.image.image_gradients(y_true)))

def g_loss_total(y_true, y_pred):
    return g_loss_1(y_true, y_pred) + 100*g_loss_2(y_true, y_pred)

# Optimizers
d_optimizer = Adam(learning_rate=0.0002)
g_optimizer = Adam(learning_rate=0.0002)

In [9]:
# Build and compile models

discriminator_model = Model(inputs=[labels_ct, fusion_image], outputs= [pos, neg])
discriminator_model.compile(optimizer=d_optimizer, loss=None)

fusion_model_input = Input(shape=(image_size, image_size, 2 * c_dim))
fusion_output = fusion_model(fusion_model_input)
fusion_model = Model(inputs=fusion_model_input, outputs=fusion_output)
fusion_model.compile(optimizer=g_optimizer, loss=None)

# Define train_step function
@tf.function
def train_step(images_mri, labels_mri, images_ct, labels_ct):

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        fusion_output = fusion_model(tf.concat([images_mri, images_ct], axis=-1), training=True)
        real_output = discriminator_model([labels_ct, tf.ones_like(labels_ct)], training=True)
        fake_output = discriminator_model([fusion_output, tf.zeros_like(fusion_output)], training=True)

        gen_loss = g_loss_total([labels_mri], fusion_output)
        disc_loss = d_loss([labels_ct, tf.zeros_like(fusion_output)], fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, fusion_model.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator_model.trainable_variables)

    g_optimizer.apply_gradients(zip(gradients_of_generator, fusion_model.trainable_variables))
    d_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator_model.trainable_variables))

    return disc_loss, gen_loss

In [None]:
# Training loop
epochs = 100
batch_size = 32

data_dir_mri = 'path/to/mri/data'
data_dir_ct = 'path/to/ct/data'
data_loader_gen = data_loader(data_dir_mri, data_dir_ct, batch_size, image_size, label_size, c_dim)

for epoch in range(epochs):
    for images_mri, labels_mri, images_ct, labels_ct in data_loader_gen:
        disc_loss, gen_loss = train_step(images_mri, labels_mri, images_ct, labels_ct)
        print(f"Epoch {epoch+1}, Disc Loss: {disc_loss}, Gen Loss: {gen_loss}")

# Save models
fusion_model.save('fusion_model.h5')
discriminator_model.save('discriminator_model.h5')