In [1]:
import gc
import joblib
import sys
import pathlib
import yaml
import subprocess
import re
from collections import defaultdict

import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import umap.umap_ as umap
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler
from matplotlib_venn import venn2, venn3
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.losses import CategoricalCrossentropy, MeanSquaredError


from buddi_data import BuDDIData

2025-04-14 13:53:46.194112: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-14 13:53:46.203009: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744660426.212162  136857 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744660426.214940  136857 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744660426.222971  136857 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [2]:
SAMPLE_ID_COL = 'sample_id'
STIM_COL = 'stim'
TECH_COL = 'samp_type'
DATASPLIT_COL = 'isTraining'

GENE_ID_COL = 'gene_ids'

In [3]:
# Get the root directory of the analysis repository
REPO_ROOT = subprocess.run(
    ["git", "rev-parse", "--show-toplevel"], capture_output=True, text=True
).stdout.strip()
REPO_ROOT = pathlib.Path(REPO_ROOT)

CONFIG_FILE = REPO_ROOT / 'config.yml'
assert CONFIG_FILE.exists(), f"Config file not found at {CONFIG_FILE}"

with open(CONFIG_FILE, 'r') as file:
    config_dict = yaml.safe_load(file)

In [4]:
buddi_fork_path = config_dict['software_path']['buddi_HGSC']
buddi_fork_path = pathlib.Path(buddi_fork_path)
assert buddi_fork_path.exists(), f"buddi fork not found at {buddi_fork_path}"

sys.path.insert(0, str(buddi_fork_path))
from buddi.models.buddi4 import build_buddi4
from buddi.models.components.layers import ReparameterizationLayer

I0000 00:00:1744660427.293403  136857 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5758 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2070 SUPER, pci bus id: 0000:01:00.0, compute capability: 7.5


### Validation Output plot path

In [5]:
VALIDATION_OUTPUT = pathlib.Path('.').absolute() / 'validation_output'
VALIDATION_OUTPUT.mkdir(exist_ok=True)

### Train model weight and data

In [6]:
TRAINED_MODELS_PATH = REPO_ROOT / 'trained_models'
TRAIN_DATA_FILE = REPO_ROOT / 'processed_data' / 'train_data.pkl'

### Data

In [7]:
train_data = joblib.load(TRAIN_DATA_FILE)

n_x = len(train_data.gene_names)
n_y = len(train_data.cell_type_names)
n_labels = train_data.encode_meta[SAMPLE_ID_COL].shape[1]
n_stims = train_data.encode_meta[STIM_COL].shape[1]
n_samp_types = train_data.encode_meta[TECH_COL].shape[1]

## Validate

In [8]:
from validation.resampling import ResamplingDecoder, compute_reconstruction_corr
from validation.plot_validation import plot_resampled_latent_space, plot_correlation_boxplot

In [9]:
supervised_buddi, unsupervised_buddi, supervised_decoder, unsupervised_decoder = build_buddi4(
        n_x=n_x,
        n_y=n_y,
        n_labels=n_labels,
        n_stims=n_stims,
        n_samp_types=n_samp_types,
        reconstr_loss_fn = MeanSquaredError, # the loss function does not matter unless training further
        classifier_loss_fn = CategoricalCrossentropy, # the loss function does not matter unless training further
        return_decoder = True
    )

n_resamples = 10
resampling_decoder = ResamplingDecoder(
    supervised_decoder,
    n_y=n_y,
    z_shape={
        'label': 64,
        'stim': 64,
        'samp_type': 64,
        'slack': 64,
    },
    _reparam_layer=ReparameterizationLayer(),
)

In [None]:
model_type = 'buddi4'
model_z_betas = ['1.0', '10.0', '50.0', '100.0']

for model_z_beta in model_z_betas:

    model_string = f'{model_type}_beta_{model_z_beta}'
    model_validation_plot_dir = VALIDATION_OUTPUT / model_string
    model_validation_plot_dir.mkdir(exist_ok=True)

    supervised_buddi.load_weights(
        TRAINED_MODELS_PATH / f'{model_type}_beta_{model_z_beta}_supervised_model.keras'
    )
    unsupervised_buddi.load_weights(
        TRAINED_MODELS_PATH / f'{model_type}_beta_{model_z_beta}_unsupervised_model.keras'
    )

    # Forward Pass To Obtain Latent Space of the Pseudobulks
    x_kp = train_data.X_kp
    y_kp = train_data.y_kp
    meta_kp = train_data.meta_kp

    pred = unsupervised_buddi((x_kp))
    x_reconst, z_label, z_stim, z_samp_type, z_slack, _, _, _, y_hat = pred

    n_resamples = 10

    xs_resampled = []
    for _ in range(n_resamples):
        with tf.device('/CPU:0'):
            xs_resampled.append(
                resampling_decoder(
                    (y_hat,
                    z_label,
                    z_stim,
                    z_samp_type,
                    z_slack)
                ).numpy()
            )

    x_resampled = np.concatenate(xs_resampled, axis=0)
    meta_resampled = pd.concat([meta_kp] * n_resamples, axis=0)
    meta_resampled = meta_resampled.reset_index(drop=True)
    meta_resampled['expression'] = 'resampled'

    x_concat = np.concatenate(
        [x_kp, x_resampled],
        axis=0
    )
    _meta = meta_kp.copy()
    _meta['expression'] = 'truth'
    meta_concat = pd.concat(
        [_meta, meta_resampled],
        axis=0
    )
    meta_concat = meta_concat.reset_index(drop=True)

    plot_resampled_latent_space(
        x_concat,
        meta_concat,
        color_by=[
            SAMPLE_ID_COL,
            'cell_prop_type',
            'expression'
        ],
        use_umap=True,
        panel_width=5,
        show_plot=False,
        save_path=model_validation_plot_dir / f'{model_string}_resample_expression.png'
    )

    pearson_r = compute_reconstruction_corr(
        x_resampled,
        x_kp,
        n_resamples=n_resamples,
        method='pearson',
        n_jobs=-1
    )
    plot_correlation_boxplot(
        pearson_r,
        coorrelation_method='pearson',
        show_plot=False,
        save_path=model_validation_plot_dir / f'{model_string}_resample_pearson_correlation.png'
    )

  saveable.load_own_variables(weights_store.get(inner_path))
Expected: ['X']
Received: inputs=Tensor(shape=(10500, 7000))
  warn(
  saveable.load_own_variables(weights_store.get(inner_path))
Expected: ['X']
Received: inputs=Tensor(shape=(10500, 7000))
  warn(
  saveable.load_own_variables(weights_store.get(inner_path))
Expected: ['X']
Received: inputs=Tensor(shape=(10500, 7000))
  warn(
  saveable.load_own_variables(weights_store.get(inner_path))
Expected: ['X']
Received: inputs=Tensor(shape=(10500, 7000))
  warn(


: 