In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import expit
from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.losses import mse
import os
import pandas as pd
from glob import glob
import sys
from tqdm import tqdm
import gc
from datetime import datetime
from rsa_funcs import fit_rsa,make_RDM,get_triu
import pickle
from matplotlib import pyplot as plt
now = datetime.now

2024-02-26 18:18:10.962418: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1


In [2]:
def sampling(args):
    """Reparameterization trick by sampling fr an isotropic unit Gaussian.
    # Arguments:
      args (tensor): mean and log of variance of Q(z|X)
    # Returns:
      z (tensor): sampled latent vector
    """

    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

def get_MRI_VAE_3D(input_shape=(64,64,64,1),
                   latent_dim=2,
                   batch_size = 32,
                   disentangle=False,
                   gamma=1,
                   kernel_size = 3,
                   filters = 16,
                   intermediate_dim = 128,
                   opt=None):

    #TODO: add discriminator loss, see if there is improvement. Perhaps try on shapes dataset if it's easier...

    image_size, _, _, channels = input_shape
    
    #epochs = 10
    nlayers = 2
      
      # VAE model = encoder + decoder
      # build encoder model
    inputs = Input(shape=input_shape, name='encoder_input')
    x = inputs
    for i in range(nlayers):
        filters *= 2
        x = Conv3D(filters=filters,
                kernel_size=kernel_size,
                activation='relu',
                strides=2,
                padding='same')(x)

    # shape info needed to build decoder model
    shape = K.int_shape(x)

    # generate latent vector Q(z|X)
    x = Flatten()(x)
    x = Dense(intermediate_dim, activation='relu')(x)
    z_mean = Dense(latent_dim, name='z_mean')(x)
    z_log_var = Dense(latent_dim, name='z_log_var')(x)

    # use reparameterization trick to push the sampling out as input
    # note that "output_shape" isn't necessary with the TensorFlow backend
    z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

    # instantiate encoder model
    encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')

    # build decoder model
    latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
    x = Dense(intermediate_dim, activation='relu')(latent_inputs)
    x = Dense(shape[1] * shape[2] * shape[3] * shape[4], activation='relu')(x)
    x = Reshape((shape[1], shape[2], shape[3],shape[4]))(x)

    for i in range(nlayers):
        x = Conv3DTranspose(filters=filters,
                          kernel_size=kernel_size,
                          activation='relu',
                          strides=2,
                          padding='same')(x)
        filters //= 2

    outputs = Conv3DTranspose(filters=1,
                            kernel_size=kernel_size,
                            activation='sigmoid',
                            padding='same',
                            name='decoder_output')(x)

    # instantiate decoder model
    decoder = Model(latent_inputs, outputs, name='decoder')

    #     decoder.summary()

    # instantiate VAE model
    outputs = decoder(encoder(inputs)[2])
    vae = Model(inputs, outputs, name='vae')

    if disentangle:
        discriminator = Dense(1, activation='sigmoid')

        z1 = Lambda(lambda x: x[:int(batch_size/2),:int(latent_dim/2)])(z)
        z2 = Lambda(lambda x: x[int(batch_size/2):,:int(latent_dim/2)])(z)
        s1 = Lambda(lambda x: x[:int(batch_size/2),int(latent_dim/2):])(z)
        s2 = Lambda(lambda x: x[int(batch_size/2):,int(latent_dim/2):])(z)
        
        q_bar = tf.keras.layers.concatenate(
          [tf.keras.layers.concatenate([s1, z2], axis=1),
          tf.keras.layers.concatenate([s2, z1], axis=1)],
          axis=0)
        q = tf.keras.layers.concatenate(
          [tf.keras.layers.concatenate([s1, z1], axis=1),
          tf.keras.layers.concatenate([s2, z2], axis=1)],
          axis=0)
        
#         q_bar_score = discriminator(q_bar)
#         q_score = discriminator(q)        
#         tc_loss = K.log(q_score / (1 - q_score)) 

        q_bar_score = (discriminator(q_bar)+.1) *.85 # +.1 * .85 so that it's 0<x<1
        q_score = (discriminator(q)+.1) *.85 
        tc_loss = K.log(q_score / (1 - q_score)) 

        discriminator_loss = - K.log(q_score) - K.log(1 - q_bar_score)

    reconstruction_loss = mse(K.flatten(inputs), K.flatten(outputs))
    reconstruction_loss *= image_size * image_size


    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    if disentangle:
        vae_loss = K.mean(reconstruction_loss) + K.mean(kl_loss) + gamma * K.mean(tc_loss) + gamma * K.mean(discriminator_loss)
    else:
        vae_loss = K.mean(reconstruction_loss) + K.mean(kl_loss)

    vae.add_loss(vae_loss)
    
    if type(opt)==type(None):
        opt = tf.keras.optimizers.Adam(learning_rate=0.001,beta_1=0.9,beta_2=0.999,epsilon=1e-07,amsgrad=False,name='Adam')
        
    #vae.compile(optimizer='rmsprop')
    vae.compile(optimizer=opt)
    

    if disentangle:
        vae.metrics_tensors = [reconstruction_loss, kl_loss, tc_loss, discriminator_loss]
        #     vae.summary()
    return encoder, decoder, vae

In [3]:
batch_size = 10

encoder, decoder, vae = get_MRI_VAE_3D(input_shape=(64,64,64,1),
                   latent_dim=4,
                   batch_size = batch_size,
                   disentangle=False,
                   gamma=100,
                   kernel_size = 3,
                   filters = 64,
                   intermediate_dim = 128,
                   opt=None)

2024-02-26 18:18:12.560718: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
2024-02-26 18:18:12.593792: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:18:00.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.77GiB deviceMemoryBandwidth: 836.37GiB/s
2024-02-26 18:18:12.593815: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2024-02-26 18:18:12.595739: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2024-02-26 18:18:12.597562: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
2024-02-26 18:18:12.598060: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so

In [4]:
encoder, decoder, vae

(<tensorflow.python.keras.engine.functional.Functional at 0x1554f369c400>,
 <tensorflow.python.keras.engine.functional.Functional at 0x1554f0089730>,
 <tensorflow.python.keras.engine.functional.Functional at 0x1554f009d130>)

In [5]:
encoder.summary(line_length=150)

Model: "encoder"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
encoder_input (InputLayer)                       [(None, 64, 64, 64, 1)]          0                                                                   
______________________________________________________________________________________________________________________________________________________
conv3d (Conv3D)                                  (None, 32, 32, 32, 128)          3584              encoder_input[0][0]                               
______________________________________________________________________________________________________________________________________________________
conv3d_1 (Conv3D)                                (None, 16, 16, 16, 256)     

In [6]:
decoder.summary(line_length=150)

Model: "decoder"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                                       Output Shape                                                Param #                
z_sampling (InputLayer)                                            [(None, 4)]                                                 0                      
______________________________________________________________________________________________________________________________________________________
dense_1 (Dense)                                                    (None, 128)                                                 640                    
______________________________________________________________________________________________________________________________________________________
dense_2 (Dense)                                                    (None, 104

In [7]:
vae.summary(line_length=150)

Model: "vae"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
encoder_input (InputLayer)                       [(None, 64, 64, 64, 1)]          0                                                                   
______________________________________________________________________________________________________________________________________________________
encoder (Functional)                             [(None, 4), (None, 4), (None, 4) 135107464         encoder_input[0][0]                               
______________________________________________________________________________________________________________________________________________________
decoder (Functional)                             (None, 64, 64, 64, 1)           