In [1]:
from IPython import display
import numpy as np 

import glob
# import imageio
import matplotlib.pyplot as plt
import numpy as np
# import PIL
import tensorflow as tf
# import tensorflow_probability as tfp
# import time

import pandas as pd 
import sys 

sys.path.insert(0, '/data/jlu/OR_learning/utils')
import voxel_functions as vf
import color_function as cf 

import random 
import plotly.graph_objects as go 

2025-01-22 16:59:14.943130: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-01-22 16:59:14.960145: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-01-22 16:59:14.965019: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-22 16:59:14.981818: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


#### Loading voxels 

In [2]:
# Open pickle file of binding cavity dictionary
bc_cav_coords = pd.read_pickle('/data/jlu/OR_learning/files/dict_bc_cav_tmaligned.pkl')

In [3]:
# Create cavity voxels from coordinates 
# DROP Or defined in the exclusion list below
EXCLUDE_OR_LIST = ['Or4Q3', 'Or2W25', 'Or2l1', 'Or4A67', 'Or2I1']
bc_cav_coords = {key: value for key, value in bc_cav_coords.items() if key not in EXCLUDE_OR_LIST}
# DROP non DL_OR names
bc_cav_coords = {key: value for key, value in bc_cav_coords.items() if key.startswith('Or')}


# Voxelize binding cavity coordinates 
voxelized_cavities, voxel_shape = vf.voxelize_coordinates(list(bc_cav_coords.values()), resolution=1)

# Output: List of 1D arrays representing voxelized space
print(np.array(voxelized_cavities).shape)

(1080, 72, 81, 69)


In [None]:
# Visualize voxel 

Or_name = ['Or1Ad1', 'Or2T43', 'Or2T48']

temp = {key: value for key, value in bc_cav_coords.items() if key in Or_name}
voxelized_cavities, voxel_shape = vf.voxelize_coordinates(list(temp.values()), resolution=0.6)

voxel_data = voxelized_cavities
voxel_size = 1

# Create a plotly scatter plot
fig = go.Figure()

color_map = cf.distinct_colors(list(range(3)))
for i, voxel_grid in enumerate(voxel_data):
    # Get the coordinates of the occupied voxels (where voxel value is 1)
    occupied_voxels = np.array(np.where(voxel_grid == 1)).T
    
    # Convert the voxel indices back to 3D space coordinates
    x = occupied_voxels[:, 0] * voxel_size
    y = occupied_voxels[:, 1] * voxel_size
    z = occupied_voxels[:, 2] * voxel_size
    
    # Add the points to the plot
    fig.add_trace(go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        name=Or_name[i],
        marker=dict(size=3, 
                    color=color_map[i], 
                    opacity=0.3)
    ))

# Update layout for 3D visualization
fig.update_layout(margin=dict(r=10, l=10, b=10, t=10))
fig.show()


#### VAE

In [18]:
class CVAE(tf.keras.Model):
    def __init__(self, input_shape, latent_dim, encoder_filters, encoder_kernels, encoder_strides,
                 decoder_dense_shape, decoder_filters, decoder_kernels, decoder_strides):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = self.build_encoder(input_shape, encoder_filters, encoder_kernels, encoder_strides, latent_dim)
        self.decoder = self.build_decoder(latent_dim, decoder_dense_shape, decoder_filters, decoder_kernels, decoder_strides)

    def build_encoder(self, input_shape, filters, kernels, strides, latent_dim):
        layers = [tf.keras.layers.InputLayer(input_shape=input_shape)]
        
        for f, k, s in zip(filters, kernels, strides):
            layers.append(tf.keras.layers.Conv3D(filters=f, kernel_size=k, strides=s, activation='relu'))

        layers.append(tf.keras.layers.Flatten())
        layers.append(tf.keras.layers.Dense(latent_dim + latent_dim))  # No activation
        
        return tf.keras.Sequential(layers)

    def build_decoder(self, latent_dim, dense_shape, filters, kernels, strides):
        layers = [tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
                  tf.keras.layers.Dense(units=dense_shape, activation=tf.nn.relu),
                  tf.keras.layers.Reshape(target_shape=dense_shape)]

        for f, k, s in zip(filters, kernels, strides):
            layers.append(tf.keras.layers.Conv3DTranspose(filters=f, kernel_size=k, strides=s, padding='same', activation='relu'))
        
        layers.append(tf.keras.layers.Conv3DTranspose(filters=4, kernel_size=3, strides=1, padding='same'))  # No activation
        
        return tf.keras.Sequential(layers)
    
    
    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean
    
    def decode(self, z, apply_sigmoid=False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits
    
    @tf.function
    def sample(self, z=None):
        if z is None:
            z = tf.random.normal(shape=(100, self.latent_dim))
        return self.decode(z, apply_sigmoid=True)
    

def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2. * np.pi)
    return tf.reduce_sum(
        -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
        axis=raxis)
    
def compute_loss(model, x):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)
    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
    logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
    logpz = log_normal_pdf(z, 0., 0.)
    logqz_x = log_normal_pdf(z, mean, logvar)
    return -tf.reduce_mean(logpx_z + logpz - logqz_x)


@tf.function
def train_step(model, x, optimizer):
    """Executes one training step and returns the loss.

    This function computes the loss and gradients, and uses the latter to
    update the model's parameters.
    """
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    

In [None]:
# Example usage
input_shape = (312, 282, 404, 4)
latent_dim = 10
encoder_filters = [32, 64]
encoder_kernels = [3, 3]
encoder_strides = [(2, 2, 2), (2, 2, 2)]
decoder_dense_shape = (39, 36, 51, 32)  # Adjusted according to the encoder output
decoder_filters = [64, 32]
decoder_kernels = [3, 3]
decoder_strides = [(2, 2, 2), (2, 2, 2)]

model = CVAE(input_shape, latent_dim, encoder_filters, encoder_kernels, encoder_strides,
            decoder_dense_shape, decoder_filters, decoder_kernels, decoder_strides)

In [23]:
if tf.test.gpu_device_name() != '/device:GPU:0':
  print('WARNING: GPU device not found.')
else:
  print('SUCCESS: Found GPU: {}'.format(tf.test.gpu_device_name()))

SUCCESS: Found GPU: /device:GPU:0


2023-11-28 15:03:04.658535: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /device:GPU:0 with 22296 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:03:00.0, compute capability: 8.6
2023-11-28 15:03:04.659468: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /device:GPU:1 with 8021 MB memory:  -> device: 1, name: NVIDIA GeForce RTX 3080, pci bus id: 0000:82:00.0, compute capability: 8.6
2023-11-28 15:03:04.666760: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /device:GPU:0 with 22296 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:03:00.0, compute capability: 8.6
2023-11-28 15:03:04.667635: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /device:GPU:1 with 8021 MB memory:  -> device: 1, name: NVIDIA GeForce RTX 3080, pci bus id: 0000:82:00.0, compute capability: 8.6


In [20]:
7*7*32

1568