In [1]:
import scanpy as sc
import numpy as np
import os
import glob
import tqdm
import gc

In [2]:
def add_sc_cell_ids(scenario):
    pred_files = glob.glob(f'predictions/{scenario}/*.h5ad')
    save_folder = f'predictions/{scenario}_renamed/'
    os.makedirs(save_folder, exist_ok=True)
    for p_f in pred_files:
        file_name = p_f.split('/')[-1]
        adata = sc.read_h5ad(p_f)

        mask_ctrl = adata.obs['cond_harm_pred'] == 'ctrl'
        mask_stim = adata.obs['cond_harm_pred'] != 'ctrl'

        assert mask_ctrl.sum() == mask_stim.sum()
        assert adata.obs.loc[mask_ctrl, 'sc_cell_ids'].isna().sum() == 0
        
        ctrl_ids = adata.obs.loc[mask_ctrl, 'sc_cell_ids'].to_numpy()
        
        assert not np.isnan(ctrl_ids).any() if np.issubdtype(ctrl_ids.dtype, np.number) else True

        adata.obs.loc[mask_stim, 'sc_cell_ids'] = ctrl_ids

        assert adata.obs.loc[mask_stim, 'sc_cell_ids'].isna().sum() == 0
        assert np.array_equal(
            adata.obs.loc[mask_stim, 'sc_cell_ids'].to_numpy(),
            adata.obs.loc[mask_ctrl, 'sc_cell_ids'].to_numpy()
        )
        
        adata.write_h5ad(os.path.join(save_folder, file_name))

        del adata
        gc.collect()

In [3]:
add_sc_cell_ids('single_only')

In [4]:
add_sc_cell_ids('combinatorially_seen')

In [None]:
# Print path of every installed model
# Check every prediction etc date
# Verify GEARS outputs like other models