In [None]:
import gc
import joblib
import sys
import pathlib
import yaml
import subprocess
import re
from collections import defaultdict

import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import umap.umap_ as umap
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler
from matplotlib_venn import venn2, venn3
import seaborn as sns
import tensorflow as tf
from tensorflow.keras.utils import plot_model

from buddi_data import BuDDIData

In [None]:
CELL_TYPE_COL = 'encode_celltype'
SAMPLE_ID_COL = 'sample_id'
STIM_COL = 'stim'
TECH_COL = 'samp_type'
DATASPLIT_COL = 'isTraining'

GENE_ID_COL = 'gene_ids'

In [None]:
# Get the root directory of the analysis repository
REPO_ROOT = subprocess.run(
    ["git", "rev-parse", "--show-toplevel"], capture_output=True, text=True
).stdout.strip()
REPO_ROOT = pathlib.Path(REPO_ROOT)

CONFIG_FILE = REPO_ROOT / 'config.yml'
assert CONFIG_FILE.exists(), f"Config file not found at {CONFIG_FILE}"

with open(CONFIG_FILE, 'r') as file:
    config_dict = yaml.safe_load(file)

In [None]:
buddi_fork_path = config_dict['software_path']['buddi_HGSC']
buddi_fork_path = pathlib.Path(buddi_fork_path)
assert buddi_fork_path.exists(), f"buddi fork not found at {buddi_fork_path}"

sys.path.insert(0, str(buddi_fork_path))
# this is quite ugly, once activate modifications are done this will be changed
# to a proper installation + import
from buddi.models.buddi4 import build_buddi4, fit_buddi4_v2
# from prototype_buddi_dataset import *
from buddi.dataset.dataset import get_dataset
from buddi.dataset.utils import train_validation_split
from buddi.plotting.plot_latent_space import plot_latent_spaces_buddi4

### Train model weight output

In [None]:
TRAINED_MODELS_PATH = REPO_ROOT / 'trained_models'
TRAINED_MODELS_PATH.mkdir(exist_ok=True)
TRAIN_PLOT_PATH = TRAINED_MODELS_PATH / 'train_plots'
TRAIN_PLOT_PATH.mkdir(exist_ok=True)

In [None]:
train_data = joblib.load("train_data.pkl")

## Training

In [None]:
n_x = len(train_data.gene_names)
n_y = len(train_data.cell_type_names)
n_labels = train_data.encode_meta[SAMPLE_ID_COL].shape[1]
n_stims = train_data.encode_meta[STIM_COL].shape[1]
n_samp_types = train_data.encode_meta[TECH_COL].shape[1]

In [None]:
ds_sup = get_dataset(
    input_tuple_order = ['X', 'Y_prop'],
    output_tuple_order = ['X', 'z_label', 'z_stim', 'z_samp_type', 'z_slack', 'label', 'stim', 'samp_type', 'Y_prop'],
    X = train_data.X_kp,
    Y_prop = train_data.y_kp,
    label = train_data.label_kp,
    stim = train_data.drug_kp,
    samp_type = train_data.bulk_kp,
)
print(ds_sup.cardinality().numpy())

ds_unsup = get_dataset(
    input_tuple_order = ['X'],
    output_tuple_order = ['X', 'z_label', 'z_stim', 'z_samp_type', 'z_slack', 'label', 'stim', 'samp_type', 'Y_dummy'],
    X = train_data.X_unkp,
    label = train_data.label_unkp,
    stim = train_data.drug_unkp,
    samp_type = train_data.bulk_unkp,
)
print(ds_unsup.cardinality().numpy())

In [None]:
from tensorflow.keras.losses import CategoricalCrossentropy, MeanSquaredError
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from buddi_visualize_train import plot_train_loss, plot_buddi4_latent_space


# betas = [1.0, 10.0, 20.0, 50.0] #100.0]
betas = [20.0, 50.0] #100.0]

for beta in betas:

    model_str = f"buddi4_beta_{beta}"
    
    optimizer = Adam(learning_rate=0.0005)
    supervised_buddi, unsupervised_buddi = build_buddi4(
            n_x=n_x,
            n_y=n_y,
            n_labels=n_labels,
            n_stims=n_stims,
            n_samp_types=n_samp_types,
            reconstr_loss_fn = MeanSquaredError, 
            classifier_loss_fn = CategoricalCrossentropy, 
            # Default loss weights
            alpha_x = 1.0,
            alpha_label = 100.0,
            alpha_stim = 100.0,
            alpha_samp_type = 100.0,
            alpha_prop = 100.0,
            beta_kl_slack = 0.1,
            beta_kl_label = beta,
            beta_kl_stim = beta,
            beta_kl_samp_type = beta,
            optimizer=optimizer,
        )

    all_loss_df = fit_buddi4_v2(
                supervised_buddi, unsupervised_buddi, 
                ds_sup, ds_unsup, 
                epochs=5, batch_size=16, prefetch=True
            )
    
    plot_train_loss(
        all_loss_df,
        show_plot=False,
        save_path=TRAIN_PLOT_PATH / f"{model_str}_loss.png")
    
    _ = plot_buddi4_latent_space(
        unsupervised_buddi,
        train_data,
        type='PCA',
        palette='tab20',
        show_plot=False,
        save_path=TRAIN_PLOT_PATH / f"{model_str}_latent_space_PCA.png"
    )

    _ = plot_buddi4_latent_space(
        unsupervised_buddi,
        train_data,
        type='UMAP',
        palette='tab20',
        show_plot=False,
        save_path=TRAIN_PLOT_PATH / f"{model_str}_latent_space_UMAP.png"
    )

    supervised_buddi.save(
        TRAINED_MODELS_PATH / f"{model_str}_supervised_model.keras"
    )
    unsupervised_buddi.save(
        TRAINED_MODELS_PATH / f"{model_str}_unsupervised_model.keras"
    )

    del supervised_buddi
    del unsupervised_buddi
    del optimizer
    K.clear_session()
    gc.collect()
