In [9]:
from IPython import display

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
# import tensorflow_probability as tfp
import time

import numpy as np 
import pandas as pd 
import sys 

sys.path.insert(0, '/mnt/af3ff5c3-2943-4972-8c3a-6b98174779b7/Justice/OR_learning/utils')
import voxel_functions as vx

import random 
import plotly.graph_objects as go 

#### Loading voxels 

In [2]:
# Depends on the resolution may take awhile
tmaligned_df = pd.read_pickle('../../AF_files/dict_tmaligned.pkl')
voxel_list, voxel_shape, voxel_order = vx.create_voxel(tmaligned_df, resolution = 1)

100%|██████████| 5470/5470 [00:16<00:00, 330.91it/s]


In [None]:
# Visualize random voxels 
random.seed(10)
indice = random.sample(range(1,100), 5)

fig = go.Figure()
for i in indice:
    pos_space = np.argwhere(np.any(voxel_list[i] != 0, axis=3))
    fig.add_trace(go.Scatter3d(x = pos_space[:,0], 
                               y = pos_space[:,1], 
                               z = pos_space[:,2], 
                               mode = 'markers', 
                               name = voxel_order[i]))
fig.update_traces( marker=dict(size=3, opacity = 0.1))
fig.update_layout(scene = dict(xaxis = dict(visible= False,showbackground=False),
                                yaxis = dict(visible= False,showbackground=False),
                                zaxis = dict(visible= False,showbackground=False)),
                  margin=dict(r=10, l=10, b=10, t=10))
fig.show()

#### VAE

In [18]:
class CVAE(tf.keras.Model):
    def __init__(self, input_shape, latent_dim, encoder_filters, encoder_kernels, encoder_strides,
                 decoder_dense_shape, decoder_filters, decoder_kernels, decoder_strides):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = self.build_encoder(input_shape, encoder_filters, encoder_kernels, encoder_strides, latent_dim)
        self.decoder = self.build_decoder(latent_dim, decoder_dense_shape, decoder_filters, decoder_kernels, decoder_strides)

    def build_encoder(self, input_shape, filters, kernels, strides, latent_dim):
        layers = [tf.keras.layers.InputLayer(input_shape=input_shape)]
        
        for f, k, s in zip(filters, kernels, strides):
            layers.append(tf.keras.layers.Conv3D(filters=f, kernel_size=k, strides=s, activation='relu'))

        layers.append(tf.keras.layers.Flatten())
        layers.append(tf.keras.layers.Dense(latent_dim + latent_dim))  # No activation
        
        return tf.keras.Sequential(layers)

    def build_decoder(self, latent_dim, dense_shape, filters, kernels, strides):
        layers = [tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
                  tf.keras.layers.Dense(units=dense_shape, activation=tf.nn.relu),
                  tf.keras.layers.Reshape(target_shape=dense_shape)]

        for f, k, s in zip(filters, kernels, strides):
            layers.append(tf.keras.layers.Conv3DTranspose(filters=f, kernel_size=k, strides=s, padding='same', activation='relu'))
        
        layers.append(tf.keras.layers.Conv3DTranspose(filters=4, kernel_size=3, strides=1, padding='same'))  # No activation
        
        return tf.keras.Sequential(layers)
    
    
    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean
    
    def decode(self, z, apply_sigmoid=False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits
    
    @tf.function
    def sample(self, z=None):
        if z is None:
            z = tf.random.normal(shape=(100, self.latent_dim))
        return self.decode(z, apply_sigmoid=True)
    

def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2. * np.pi)
    return tf.reduce_sum(
        -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
        axis=raxis)
    
def compute_loss(model, x):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)
    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
    logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
    logpz = log_normal_pdf(z, 0., 0.)
    logqz_x = log_normal_pdf(z, mean, logvar)
    return -tf.reduce_mean(logpx_z + logpz - logqz_x)


@tf.function
def train_step(model, x, optimizer):
    """Executes one training step and returns the loss.

    This function computes the loss and gradients, and uses the latter to
    update the model's parameters.
    """
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    

In [None]:
# Example usage
input_shape = (312, 282, 404, 4)
latent_dim = 10
encoder_filters = [32, 64]
encoder_kernels = [3, 3]
encoder_strides = [(2, 2, 2), (2, 2, 2)]
decoder_dense_shape = (39, 36, 51, 32)  # Adjusted according to the encoder output
decoder_filters = [64, 32]
decoder_kernels = [3, 3]
decoder_strides = [(2, 2, 2), (2, 2, 2)]

model = CVAE(input_shape, latent_dim, encoder_filters, encoder_kernels, encoder_strides,
            decoder_dense_shape, decoder_filters, decoder_kernels, decoder_strides)

In [23]:
if tf.test.gpu_device_name() != '/device:GPU:0':
  print('WARNING: GPU device not found.')
else:
  print('SUCCESS: Found GPU: {}'.format(tf.test.gpu_device_name()))

SUCCESS: Found GPU: /device:GPU:0


2023-11-28 15:03:04.658535: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /device:GPU:0 with 22296 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:03:00.0, compute capability: 8.6
2023-11-28 15:03:04.659468: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /device:GPU:1 with 8021 MB memory:  -> device: 1, name: NVIDIA GeForce RTX 3080, pci bus id: 0000:82:00.0, compute capability: 8.6
2023-11-28 15:03:04.666760: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /device:GPU:0 with 22296 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:03:00.0, compute capability: 8.6
2023-11-28 15:03:04.667635: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /device:GPU:1 with 8021 MB memory:  -> device: 1, name: NVIDIA GeForce RTX 3080, pci bus id: 0000:82:00.0, compute capability: 8.6


In [20]:
7*7*32

1568