In [None]:
import sys
sys.path.append("../")
from GOCVAE.models._GOCVAE import GOCVAE 
from GOCVAE.metrics import *
from GOCVAE.utils import *
import pickle
import scanpy as sc
from scipy.sparse import issparse

## loading and preparing data 

In [None]:
adata = sc.read("../data/GEARS/norman/perturb_processed.h5ad")

split_path = "../data/GEARS/norman/norman_simulation_1_0.75.pkl"
with open(split_path, 'rb') as f:
    set2conditions = pickle.load(f)

subgroup_path = "../data/GEARS/norman/norman_simulation_1_0.75_subgroup.pkl"
with open(subgroup_path, 'rb') as f:
    subgroup = pickle.load(f)

In [None]:
test_conditions_eval = []

for subgroup_name in ['combo_seen2', 'combo_seen1', 'combo_seen0', 'unseen_single']:
    if subgroup_name in subgroup['test_subgroup']:
        test_conditions_eval.extend(subgroup['test_subgroup'][subgroup_name])

test_conditions_eval = list(set(test_conditions_eval))

print("Subgroup Distribution:")
for sg_name in ['combo_seen2', 'combo_seen1', 'combo_seen0', 'unseen_single']:
    if sg_name in subgroup['test_subgroup']:
        count = len([p for p in test_conditions_eval if p in subgroup['test_subgroup'][sg_name]])
        print(f"  {sg_name}: {count}")

In [None]:
all_test_conditions = []
for subgroup_name, pert_list in subgroup['test_subgroup'].items():
    all_test_conditions.extend(pert_list)

adata_train = adata[~adata.obs['condition'].isin(all_test_conditions)].copy()

In [None]:
condition_key = "condition"
n_conditions = adata.obs[condition_key].unique().shape[0]
conditions = adata.obs[condition_key].unique().tolist()

## creating model object 

In [None]:
adata.var['ensembl_id'] = adata.var.index.copy()
adata_train.var['ensembl_id'] = adata_train.var.index.copy()

adata.var_names = adata.var['gene_name'].values
adata_train.var_names = adata_train.var['gene_name'].values

if 'gene_name' in adata.var.columns:
    gene_names_for_model = adata.var['gene_name'].tolist()
else:
    gene_names_for_model = adata.var_names.tolist()

In [None]:
network = GOCVAE(
    gene_size=adata.shape[1],
    architecture=[256, 64],
    n_topic=50,
    gene_names=gene_names_for_model,
    conditions=conditions,
    model_path='../results/models/GOCVAE-norman/',
    alpha=0.0001,
    eta=100,
    topk=5,
    loss_fn='sse', 
    output_activation='relu',
    pawine_lambda_d=500, 
    data_path='../data/GEARS/', 
    go_similarity_threshold=0.15,
    use_gears_embedding=True,          
    gears_embedding_dim=64,           
    num_go_gnn_layers=1,               
    go_graph_path=None,                
    ctrl_key='control',        
    pert_key='perturbation',   
    nperts_key='nperts',    
    gears_gamma=2,     
    gears_direction_lambda=1e-2,  
    gears_lambda=20000.0,      
    learning_rate=0.0005,
    dropout_rate=0.2,
)

### Training CLDRCVAE

In [None]:
network.train(adata_train,
              condition_key,
              train_size=0.8,
              n_epochs=300,
              batch_size=256,  
              early_stop_limit=50,
              lr_reducer=20,
              verbose=5,
              save=False, 
              retrain=True, 
              )

## Making prediction

In [None]:
all_predictions = {}
all_ground_truth = {}

ctrl_adata = adata[adata.obs['control'] == 1].copy()

for i, condition in enumerate(test_conditions_eval):
    print(f"\nPredicting {i+1}/{len(test_conditions_eval)}: {condition}")
    
    true_adata = adata[adata.obs[condition_key] == condition].copy()
    n_cells = true_adata.n_obs
    
    ctrl_sample = ctrl_adata[np.random.choice(
        ctrl_adata.n_obs, 
        size=n_cells, 
        replace=True
    )].copy()
    
    adata_pred = network.predict(
        adata=ctrl_sample,
        condition_key=condition_key,
        target_condition=condition,
        ctrl_key='control' 
    )
    adata_pred.obs[condition_key] = condition
        
    all_predictions[condition] = adata_pred
    all_ground_truth[condition] = true_adata 