In [None]:
import joblib
import sys
import pathlib
import yaml
import subprocess

import numpy as np
import tensorflow as tf
from tensorflow.keras.losses import CategoricalCrossentropy, MeanSquaredError

from buddi_data import BuDDIData

2025-04-11 14:40:28.694151: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-11 14:40:28.703596: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744404028.712751 2846567 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744404028.715552 2846567 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744404028.723917 2846567 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [2]:
SAMPLE_ID_COL = 'sample_id'
STIM_COL = 'stim'
TECH_COL = 'samp_type'
DATASPLIT_COL = 'isTraining'

GENE_ID_COL = 'gene_ids'

In [3]:
# Get the root directory of the analysis repository
REPO_ROOT = subprocess.run(
    ["git", "rev-parse", "--show-toplevel"], capture_output=True, text=True
).stdout.strip()
REPO_ROOT = pathlib.Path(REPO_ROOT)

CONFIG_FILE = REPO_ROOT / 'config.yml'
assert CONFIG_FILE.exists(), f"Config file not found at {CONFIG_FILE}"

with open(CONFIG_FILE, 'r') as file:
    config_dict = yaml.safe_load(file)

In [4]:
buddi_fork_path = config_dict['software_path']['buddi_HGSC']
buddi_fork_path = pathlib.Path(buddi_fork_path)
assert buddi_fork_path.exists(), f"buddi fork not found at {buddi_fork_path}"

sys.path.insert(0, str(buddi_fork_path))
from buddi.models.buddi4 import build_buddi4
from buddi.models.components.layers import ReparameterizationLayer

I0000 00:00:1744404029.807379 2846567 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 4314 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2070 SUPER, pci bus id: 0000:01:00.0, compute capability: 7.5


### Validation Output plot path

In [5]:
VALIDATION_OUTPUT = pathlib.Path('.').absolute() / 'validation_output'
VALIDATION_OUTPUT.mkdir(exist_ok=True)

### Train model weight and data

In [6]:
TRAINED_MODELS_PATH = REPO_ROOT / 'trained_models'
TRAIN_DATA_FILE = REPO_ROOT / 'processed_data' / 'train_data.pkl'

### Data

In [7]:
train_data = joblib.load(TRAIN_DATA_FILE)

n_x = len(train_data.gene_names)
n_y = len(train_data.cell_type_names)
n_labels = train_data.encode_meta[SAMPLE_ID_COL].shape[1]
n_stims = train_data.encode_meta[STIM_COL].shape[1]
n_samp_types = train_data.encode_meta[TECH_COL].shape[1]

## Validate

In [8]:
from validation.plot_validation import plot_perturb_cell_type

from validation.utils import resample_z
from validation.perturb_cell_type import perturb_cell_type

In [9]:
supervised_buddi, unsupervised_buddi, supervised_decoder, unsupervised_decoder = build_buddi4(
        n_x=n_x,
        n_y=n_y,
        n_labels=n_labels,
        n_stims=n_stims,
        n_samp_types=n_samp_types,
        reconstr_loss_fn = MeanSquaredError, # the loss function does not matter unless training further
        classifier_loss_fn = CategoricalCrossentropy, # the loss function does not matter unless training further
        return_decoder = True
    )

model_type = 'buddi4'
model_z_betas = ['1.0', '10.0', '50.0']

for model_z_beta in model_z_betas:

    model_string = f'{model_type}_beta_{model_z_beta}'
    model_validation_plot_dir = VALIDATION_OUTPUT / model_string
    model_validation_plot_dir.mkdir(exist_ok=True)

    supervised_buddi.load_weights(
        TRAINED_MODELS_PATH / f'{model_type}_beta_{model_z_beta}_supervised_model.keras'
    )
    unsupervised_buddi.load_weights(
        TRAINED_MODELS_PATH / f'{model_type}_beta_{model_z_beta}_unsupervised_model.keras'
    )

    # Forward Pass To Obtain Latent Space of the Pseudobulks
    x_kp = train_data.X_kp
    y_kp = train_data.y_kp
    meta_kp = train_data.meta_kp

    pred = unsupervised_buddi((x_kp))
    x_reconst, z_label, z_stim, z_samp_type, z_slack, _, _, _, y_hat = pred

    # Cell Proportion Perturbation
    all_cell_types = list(meta_kp['cell_type'].unique())
    all_cell_types.remove('random')

    resampled_z = resample_z(
        [z_label, z_stim, z_samp_type, z_slack],
        ReparameterizationLayer())

    sample_idx = np.where(
        meta_kp[SAMPLE_ID_COL] == 'Samp-T89'
    )[0]

    x_perturb, meta_perturb = perturb_cell_type(
        supervised_decoder,
        resampled_z,
        y_kp,
        meta_kp,
        idx=sample_idx,
        all_cell_types=all_cell_types,
        n_subsamples=200
    )

    plot_perturb_cell_type(
        x_basis=x_reconst.numpy()[sample_idx,:],
        meta_basis=meta_kp.iloc[sample_idx],
        x_reconst_perturb=x_perturb,
        meta_perturb=meta_perturb,
        show_plot=False,
        save_path=str(model_validation_plot_dir),
    )

  saveable.load_own_variables(weights_store.get(inner_path))
Expected: ['X']
Received: inputs=Tensor(shape=(10500, 7000))
  saveable.load_own_variables(weights_store.get(inner_path))
Expected: ['X']
Received: inputs=Tensor(shape=(10500, 7000))
  saveable.load_own_variables(weights_store.get(inner_path))
Expected: ['X']
Received: inputs=Tensor(shape=(10500, 7000))
