# WGAN version v3

## only schizophrenia, but pick only 50 random for training

In [1]:
import nibabel as nib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten, Conv3D, Conv3DTranspose, LeakyReLU, Input, Embedding, multiply, Concatenate
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam, RMSprop
import numpy as np
import glob
import nibabel as nib
import os
import matplotlib.pyplot as plt
import scipy.ndimage
import random

In [2]:
# List all physical devices of type 'GPU'
gpus = tf.config.experimental.list_physical_devices('GPU')

if gpus:
    print(f'Number of GPUs available: {len(gpus)}')
    for i, gpu in enumerate(gpus):
        print(f'GPU {i}: {gpu}')
else:
    print('No GPU detected.')

Number of GPUs available: 1
GPU 0: PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


# Matching fMRI files with IDs from demographic csv

In [3]:
# Specify the directory and file pattern
directory_path = '4D'
file_pattern = 'A*_????_func_FL_FD_RPI_DSP_MCF_SS_SM_Nui_CS_InStandard.nii.gz'

# Construct the full path pattern
path_pattern = f'{directory_path}/{file_pattern}'

# Use glob to find all matching files
matching_files = glob.glob(path_pattern)

# full list of all schizophrenia IDs from csv, 86 intoal but not all satisfy have t>90

full_schizophrenia_ids = [
    'A00009280', 'A00028806', 'A00023132', 'A00014804', 'A00016859', 'A00021598', 'A00001181', 'A00023158',
    'A00024568', 'A00028405', 'A00001251', 'A00000456', 'A00015648', 'A00002405', 'A00027391', 'A00016720',
    'A00018434', 'A00016197', 'A00027119', 'A00006754', 'A00009656', 'A00038441', 'A00012767', 'A00034273',
    'A00028404', 'A00035485', 'A00024684', 'A00018979', 'A00027537', 'A00004507', 'A00001452', 'A00023246',
    'A00027410', 'A00014719', 'A00024510', 'A00000368', 'A00019293', 'A00014830', 'A00015201', 'A00018403',
    'A00037854', 'A00024198', 'A00001243', 'A00014590', 'A00002337', 'A00024953', 'A00037224', 'A00027616',
    'A00001856', 'A00037619', 'A00024228', 'A00038624', 'A00037034', 'A00037649', 'A00022500', 'A00013216',
    'A00020787', 'A00028410', 'A00002480', 'A00028303', 'A00020602', 'A00024959', 'A00018598', 'A00014636',
    'A00019349', 'A00017147', 'A00023590', 'A00023750', 'A00031597', 'A00015518', 'A00018317', 'A00016723',
    'A00021591', 'A00023243', 'A00017943', 'A00023366', 'A00014607', 'A00020414', 'A00035003', 'A00028805',
    'A00029486', 'A00000541', 'A00028408', 'A00000909', 'A00031186', 'A00000838' ]

# schizohrenia_id that satisfy t>90
met_requirement_schizophrenia_ids = [
    'A00000368', 'A00000456', 'A00000541', 'A00000838', 'A00001251', 'A00001452', 'A00004507',
    'A00006754', 'A00009280', 'A00012767', 'A00013216', 'A00014607', 'A00014719', 'A00014804',
    'A00014830', 'A00015201', 'A00015648', 'A00016197', 'A00016720', 'A00016723', 'A00017147',
    'A00018317', 'A00018403', 'A00018434', 'A00018979', 'A00019293', 'A00020414', 'A00020602', 
    'A00020787', 'A00021591', 'A00021598', 'A00023158', 'A00023246', 'A00023590', 'A00023750', 
    'A00024198', 'A00024228', 'A00024568', 'A00024684', 'A00024953', 'A00024959', 'A00027410', 
    'A00027537', 'A00028303', 'A00028404', 'A00028408', 'A00028805', 'A00028806', 'A00031186', 
    'A00031597', 'A00034273', 'A00035003', 'A00035485', 'A00037034', 'A00037224', 'A00037619', 
    'A00037649', 'A00038441', 'A00038624']

# Randomly select 50 schizophrenia IDs for training
schizophrenia_ids = random.sample(met_requirement_schizophrenia_ids, 50)
#schizophrenia_ids = full_schizophrenia_ids

In [4]:
# Initialize lists to store the processed image data, corresponding labels, and filenames
image_data = []
labels = []  # 1 for schizophrenia, 0 for non-schizophrenia
schizophrenia_files = []
#non_schizophrenia_files = []

# Lists for files with insufficient time dimensions
insufficient_time_files = []
insufficient_time_ids = []

# Counters for each category
schizophrenia_count = 0
#non_schizophrenia_count = 0
processed_files_count = 0

# Loop through the matching files
for file_path in matching_files:
    # Extract the filename
    filename = os.path.basename(file_path)
    
    # Extract the ID from the filename
    file_id = filename.split('_')[0]
    
    # Load the file
    t1_img = nib.load(file_path)
    t1_data = t1_img.get_fdata()

    # Check if the time dimension is at least 90
    if t1_data.shape[3] < 90:
        insufficient_time_files.append(filename)
        insufficient_time_ids.append(file_id)
        continue  # Skip this file

    # Determine the label based on the ID and increment counters
    if file_id in schizophrenia_ids:
        label = 1  # Schizophrenia
        schizophrenia_count += 1
        schizophrenia_files.append(filename)
    else:
        continue  # Skip files with IDs not in the provided lists
    
    # Collapse one of the axes by summing
    t1_data_collapsed = np.sum(t1_data, axis=1)
    
    # Append the collapsed image data and label to the respective lists
    image_data.append(t1_data_collapsed)
    labels.append(label)

    # Increment the counter
    processed_files_count += 1

# Print the total number of files processed for each category and their filenames
print(f"Total number of files successfully processed: {processed_files_count}")
print(f"Total number of schizophrenia files: {schizophrenia_count}")
print("Schizophrenia files:", schizophrenia_files)
#print(f"Total number of non-schizophrenia files: {non_schizophrenia_count}")
#print("Non-Schizophrenia files:", non_schizophrenia_files)

# Print files with insufficient time dimension
#print(f"Total number of files with insufficient time dimension: {len(insufficient_time_files)}")
#print("Files with insufficient time dimension:", insufficient_time_files)
#print("IDs of files with insufficient time dimension:", insufficient_time_ids)


Total number of files successfully processed: 50
Total number of schizophrenia files: 50
Schizophrenia files: ['A00000368_0011_func_FL_FD_RPI_DSP_MCF_SS_SM_Nui_CS_InStandard.nii.gz', 'A00000456_0013_func_FL_FD_RPI_DSP_MCF_SS_SM_Nui_CS_InStandard.nii.gz', 'A00000541_0014_func_FL_FD_RPI_DSP_MCF_SS_SM_Nui_CS_InStandard.nii.gz', 'A00000838_0013_func_FL_FD_RPI_DSP_MCF_SS_SM_Nui_CS_InStandard.nii.gz', 'A00001251_0015_func_FL_FD_RPI_DSP_MCF_SS_SM_Nui_CS_InStandard.nii.gz', 'A00001452_0014_func_FL_FD_RPI_DSP_MCF_SS_SM_Nui_CS_InStandard.nii.gz', 'A00004507_0014_func_FL_FD_RPI_DSP_MCF_SS_SM_Nui_CS_InStandard.nii.gz', 'A00006754_0011_func_FL_FD_RPI_DSP_MCF_SS_SM_Nui_CS_InStandard.nii.gz', 'A00009280_0013_func_FL_FD_RPI_DSP_MCF_SS_SM_Nui_CS_InStandard.nii.gz', 'A00012767_0015_func_FL_FD_RPI_DSP_MCF_SS_SM_Nui_CS_InStandard.nii.gz', 'A00013216_0013_func_FL_FD_RPI_DSP_MCF_SS_SM_Nui_CS_InStandard.nii.gz', 'A00014607_0012_func_FL_FD_RPI_DSP_MCF_SS_SM_Nui_CS_InStandard.nii.gz', 'A00014719_0013_func_FL_F

In [5]:
# Determine the maximum z-dimension size
max_z_size = max(img.shape[2] for img in image_data)
max_z_size

146

In [6]:
image_data_normalized = [(img - np.min(img)) / (np.max(img) - np.min(img)) * 2 - 1 for img in image_data]

In [7]:
image_data_normalized[0]

array([[[-0.02207416, -0.02207416, -0.02207416, ..., -0.02207416,
         -0.02207416, -0.02207416],
        [-0.02207416, -0.02207416, -0.02207416, ..., -0.02207416,
         -0.02207416, -0.02207416],
        [-0.02207416, -0.02207416, -0.02207416, ..., -0.02207416,
         -0.02207416, -0.02207416],
        ...,
        [-0.02207416, -0.02207416, -0.02207416, ..., -0.02207416,
         -0.02207416, -0.02207416],
        [-0.02207416, -0.02207416, -0.02207416, ..., -0.02207416,
         -0.02207416, -0.02207416],
        [-0.02207416, -0.02207416, -0.02207416, ..., -0.02207416,
         -0.02207416, -0.02207416]],

       [[-0.02207416, -0.02207416, -0.02207416, ..., -0.02207416,
         -0.02207416, -0.02207416],
        [-0.02207416, -0.02207416, -0.02207416, ..., -0.02207416,
         -0.02207416, -0.02207416],
        [-0.02207416, -0.02207416, -0.02207416, ..., -0.02207416,
         -0.02207416, -0.02207416],
        ...,
        [-0.02207416, -0.02207416, -0.02207416, ..., -

In [8]:
# Pad each image to have a consistent z-dimension size
padded_data = [np.pad(img, ((0, 0), (0, 0), (0, max_z_size - img.shape[2])), mode='constant') for img in image_data_normalized]

# Now convert the padded data list to a numpy array
padded_data_array = np.array(padded_data)

In [9]:
print(len(labels))

50


In [10]:
train_images = padded_data_array
# Define batch size
batch_size = 10
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, labels)).shuffle(len(train_images)).batch(batch_size)

In [11]:
def build_wgan_generator(z_dim):
    # Noise input
    z_input = Input(shape=(z_dim,))

    # Generator network
    model = Sequential()
    
    # Start with a Dense layer to an initial shape that's smaller than the final target
    model.add(Dense(128 * 7 * 7 * 9, input_dim=z_dim))  # Adjust to match an initial volume
    model.add(LeakyReLU(alpha=0.01))
    model.add(Reshape((7, 7, 9, 128)))  # This is the initial volume
    
    # Begin upsampling to the desired size
    model.add(Conv3DTranspose(64, kernel_size=3, strides=2, padding='same'))
    model.add(LeakyReLU(alpha=0.01))
    
    # Continue upsampling
    model.add(Conv3DTranspose(32, kernel_size=3, strides=(3, 3, 2), padding='same'))
    model.add(LeakyReLU(alpha=0.01))
    
    # Final upsampling step to reach just under the target size
    model.add(Conv3DTranspose(1, kernel_size=3, strides=(2, 2, 2), padding='same', activation='tanh'))

    # Output tensor
    output = model(z_input)

    return Model(z_input, output)


In [12]:
def build_wgan_critic(img_shape):
    # Image input
    img_input = Input(shape=img_shape)

    # Critic network
    x = Conv3D(64, kernel_size=3, strides=2, padding='same')(img_input)
    x = LeakyReLU(alpha=0.01)(x)

    x = Conv3D(128, kernel_size=3, strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.01)(x)
    
    x = Flatten()(x)
    x = Dense(128)(x)
    x = LeakyReLU(alpha=0.01)(x)
    
    # Output a score for realness (no sigmoid activation)
    output = Dense(1)(x)

    return Model(img_input, output)


### GP is calculated by
### Interpolating between real and fake images.
### Computing the gradient of the critic's scores with respect to this interpolation.
### Penalizing the deviation of these gradients from the norm value of 1.

In [13]:
def gradient_penalty(real_images, fake_images, critic):
    batch_size = tf.shape(real_images)[0]

    # Alpha for interpolation - shape: (batch_size, 1, 1, 1, 1)
    alpha = tf.random.uniform([batch_size, 1, 1, 1, 1], 0., 1.)

    # Interpolated images - shape: (batch_size, 84, 84, 72, 1)
    interpolated_images = (real_images * alpha) + (fake_images * (1 - alpha))

    with tf.GradientTape() as tape:
        tape.watch(interpolated_images)
        # Critic now only takes the images as input
        predictions = critic(interpolated_images, training=True)

    # Calculate the gradients with respect to the interpolated images
    gradients = tape.gradient(predictions, [interpolated_images])[0]

    # Compute the norm of the gradients - reduce over all dimensions except the batch dimension
    gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3, 4]))

    # Penalize the gradient norm deviation from 1
    gp = tf.reduce_mean((gradients_norm - 1.) ** 2)
    return gp



In [14]:
# Image shape and other parameters
img_shape = (84, 84, 72, 1)
z_dim = 100

# Create the generator and critic
generator = build_wgan_generator(z_dim)
critic = build_wgan_critic(img_shape)

# RMSprop optimizers, from the paper, but we might try Adam optimizer?

#critic_optimizer = RMSprop(learning_rate=0.00005)
#generator_optimizer = RMSprop(learning_rate=0.00005)

# from paper as well
critic_optimizer = Adam(learning_rate=0.0001, beta_1=0.5, beta_2=0.999)
generator_optimizer = Adam(learning_rate=0.0001, beta_1=0.5, beta_2=0.999)

# Note: No need to compile the models with loss functions here
# as the loss will be calculated manually during training


In [15]:
def resize_image(image, new_shape):
    factors = (
        new_shape[0]/image.shape[0],
        new_shape[1]/image.shape[1],
        new_shape[2]/image.shape[2]
    )
    return scipy.ndimage.zoom(image, factors, order=1)  # order=1 is bilinear interpolation


In [16]:
def train_wgan_gp(generator, critic, dataset, epochs, z_dim, lambda_gp, critic_optimizer, generator_optimizer):
    # Lists to keep track of losses
    critic_losses = []
    generator_losses = []

    # Directory for saving checkpoints
    checkpoint_dir = "wgan_gp_checkpoints"
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    for epoch in range(epochs):
        epoch_critic_loss = 0.0
        epoch_generator_loss = 0.0
        num_batches = 0
        
        for batch in dataset:
            # Assuming dataset yields only images
            real_imgs = batch[0]

            num_batches += 1
            batch_size = real_imgs.shape[0]

            # Resize real images to match the expected dimensions of the critic
            real_imgs_resized = np.array([resize_image(img, (84, 84, 72)) for img in real_imgs])
            real_imgs_resized = np.expand_dims(real_imgs_resized, axis=-1)  # Add channel dimension

            # Train the critic
            for _ in range(5):  # Critic is often trained more frequently
                with tf.GradientTape() as tape:
                    # Generate fake images
                    z = tf.random.normal([batch_size, z_dim])
                    fake_imgs = generator(z, training=True)

                    # Get critic scores for real and fake images
                    real_output = critic(real_imgs_resized, training=True)
                    fake_output = critic(fake_imgs, training=True)

                    # Calculate critic loss
                    critic_cost = tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)
                    gp = gradient_penalty(real_imgs_resized, fake_imgs, critic)
                    critic_loss = critic_cost + lambda_gp * gp

                # Update critic weights
                critic_grads = tape.gradient(critic_loss, critic.trainable_variables)
                critic_optimizer.apply_gradients(zip(critic_grads, critic.trainable_variables))
                epoch_critic_loss += critic_loss

            # Train the generator
            z = tf.random.normal([batch_size, z_dim])
            misleading_labels = tf.ones((batch_size, 1))

            with tf.GradientTape() as tape:
                fake_imgs = generator(z, training=True)
                fake_output = critic(fake_imgs, training=True)
                generator_loss = -tf.reduce_mean(fake_output)

            # Update generator weights
            generator_grads = tape.gradient(generator_loss, generator.trainable_variables)
            generator_optimizer.apply_gradients(zip(generator_grads, generator.trainable_variables))
            epoch_generator_loss += generator_loss

            # print the losses for each batch
            print(f'Epoch: {epoch}, Batch: {num_batches}, Critic Loss: {critic_loss}, Generator Loss: {generator_loss}')

        # Checkpointing every n epochs
        if (epoch + 1) % 2 == 0:
            generator.save_weights(os.path.join(checkpoint_dir, f"generator_epoch_{epoch+1}.h5"))
            critic.save_weights(os.path.join(checkpoint_dir, f"critic_epoch_{epoch+1}.h5"))
            print(f"Checkpoint: Saved model weights at epoch {epoch+1}")
            
        # Record the average losses for this epoch
        critic_losses.append(epoch_critic_loss / num_batches)
        generator_losses.append(epoch_generator_loss / num_batches)

    return critic_losses, generator_losses



In [17]:
def plot_losses(d_losses, g_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(d_losses, label='Discriminator Loss')
    plt.plot(g_losses, label='Generator Loss')
    plt.title('Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()


In [18]:
# Set parameters
z_dim = 100
epochs = 10
lambda_gp = 10  # Gradient penalty coefficient

# Train the WGAN-GP
critic_losses, generator_losses = train_wgan_gp(
    generator, 
    critic, 
    train_dataset, 
    epochs, 
    z_dim,
    lambda_gp, 
    critic_optimizer, 
    generator_optimizer
)


ResourceExhaustedError: {{function_node __wrapped__MatMul_device_/job:localhost/replica:0/task:0/device:GPU:0}} OOM when allocating tensor with shape[1016064,128] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:MatMul]

In [None]:
# Generate a single fake image
z = np.random.normal(0, 1, (1, z_dim))
generated_image = generator.predict(z)[0]  # [0] to get the single image from the batch

# Get a single real image from the dataset
real_images = next(iter(train_dataset))[0]  # Assuming the dataset yields only images

# Take the first real image from the batch for comparison
real_image = real_images[0]

# Plot the real and fake images side by side
plt.figure(figsize=(10, 5))

# Plot real image
plt.subplot(1, 2, 1)
plt.imshow(real_image[:, :, 10, 0])  # Adjust indexing and color map as needed
plt.title('Real Image')
plt.axis('off')

# Plot fake image
plt.subplot(1, 2, 2)
plt.imshow(generated_image[:, :, 10, 0])  # Adjust indexing and color map as needed
plt.title('Generated Image')
plt.axis('off')

plt.tight_layout()
plt.show()



In [None]:
generator.save('wgan_generator_model_v3.h5')
critic.save('wgan_critic_model_v3.h5')