In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import os
import cv2

def load_medical_images(data_path, image_size=(256, 256)):
    image_list = []

    for filename in os.listdir(data_path):
        if filename.endswith(".png") or filename.endswith(".jpg") or filename.endswith(".jpeg"):
            image_path = os.path.join(data_path, filename)
            img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            img = cv2.resize(img, image_size)
            img = img.astype(np.float32) / 255.0
            image_list.append(img)
    # Convert the list to a NumPy array
    X_train = np.array(image_list)
    return X_train

# Example usage
data_path = "/content/drive/MyDrive/med_img"
X_train = load_medical_images(data_path)
print("Shape of X_train:", X_train.shape)


# Function to build the VAE model
def build_vae(input_shape, latent_dim):
    # Encoder
    encoder_inputs = keras.Input(shape=input_shape)
    x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
    x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
    x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
    x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
    x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
    x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation="relu")(x)

    # Latent space
    z_mean = layers.Dense(latent_dim, name="z_mean")(x)
    z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)

    # Sampling layer
    def sampling(args):
        z_mean, z_log_var = args
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

    z = layers.Lambda(sampling, output_shape=(latent_dim,), name="z")([z_mean, z_log_var])

    # Build the encoder and decoder models
    encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")

    # Build the decoder model
    decoder = build_decoder(latent_dim)

    # VAE model
    outputs = decoder(encoder(encoder_inputs)[2])
    vae = keras.Model(encoder_inputs, outputs, name="vae")

    # Resize original images to match output dimensions
    resized_encoder_inputs = layers.Lambda(lambda x: tf.image.resize(x, (64, 64)))(encoder_inputs)

    # Define the VAE loss
    """reconstruction_loss = tf.keras.losses.binary_crossentropy(resized_encoder_inputs, outputs)
    reconstruction_loss *= input_shape[0] * input_shape[1]  # Adjust for image size
    kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
    kl_loss = tf.reduce_mean(kl_loss)
    vae_loss = tf.reduce_mean(reconstruction_loss + kl_loss)"""

   # Define the VAE loss
    reconstruction_loss = tf.keras.losses.binary_crossentropy(tf.keras.backend.flatten(resized_encoder_inputs),
                                                          tf.keras.backend.flatten(outputs))
    reconstruction_loss *= input_shape[0] * input_shape[1]  # Adjust for image size
    kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
    kl_loss = tf.reduce_mean(kl_loss)
    vae_loss = tf.reduce_mean(reconstruction_loss + kl_loss)


    vae.add_loss(vae_loss)
    vae.compile(optimizer="adam")

    return vae, encoder, decoder


def build_decoder(latent_dim):
    decoder_inputs = keras.Input(shape=(latent_dim,))
    x = layers.Dense(16 * 16 * 64, activation="relu")(decoder_inputs)
    x = layers.Reshape((16, 16, 64))(x)
    x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
    x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
    decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
    return keras.Model(decoder_inputs, decoder_outputs)

# Function to display only the masked image
def display_masked_image(masked):
    plt.figure(figsize=(5, 5))
    # Masked Image
    plt.title("Masked Image")
    plt.imshow(np.squeeze(masked), cmap="gray")
    plt.axis("off")

    plt.show()

# Usage
display_masked_image(masked_image)

# Function to display original and masked images
def display_images(original, masked):
    plt.figure(figsize=(10, 5))

    # Original Image
    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(np.squeeze(original), cmap="gray")
    plt.axis("off")

    # Masked Image
    plt.subplot(1, 2, 2)
    plt.title("Masked Image")
    plt.imshow(np.squeeze(masked), cmap="gray")
    plt.axis("off")

    plt.show()

# Usage
#display_images(sample_image, masked_image)
#print("Sample image shape:", sample_image.shape)
#print("Masked image shape:", masked_image.shape)


# Main script
if __name__ == "__main__":
    # Load medical images
    data_path = "/content/drive/MyDrive/med_img"
    X_train = load_medical_images(data_path)
    print(X_train)
    print(len(X_train))
    print(X_train.shape)

    # Build the VAE model
    input_shape = (256, 256, 1)  # Adjust based on your image size and channels
    latent_dim = 32  # Adjust based on your desired latent space dimension

    # Training the VAE
    vae, _, _ = build_vae(input_shape, latent_dim)
    vae.fit(X_train, epochs=10, batch_size=32, shuffle=True)

    # Generate a masked image
    sample_image = X_train[np.random.choice(len(X_train))]
    sample_image = np.expand_dims(sample_image, axis=0)
    masked_image = vae.predict(sample_image)

    #display_masked_image(masked_image)
    display_images(sample_image, masked_image)
