In [None]:
# TODO: Adapt this for two-gene perts such that we are only testing the two-gene perts

In [53]:
import os.path

import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm import tqdm
import scanpy as sc
from src.utils.spectra.dataset import SpectraDataset
from src.utils.spectra.perturb import Spectra
from gears import PertData
import pickle as pkl

In [54]:
class PerturbGraphData(SpectraDataset):
    def parse(self, pert_data):
        if isinstance(pert_data, PertData):
            self.adata = pert_data.adata
        else:
            self.adata = pert_data
        self.control_expression = self.adata[self.adata.obs['condition'] == 'ctrl'].X.toarray().mean(axis=0)
        return [p for p in self.adata.obs['condition'].unique() if p != 'ctrl']

    def get_mean_logfold_change(self, perturbation):
        perturbation_expression = self.adata[self.adata.obs['condition'] == perturbation].X.toarray().mean(axis=0)
        logfold_change = np.nan_to_num(np.log2(perturbation_expression + 1) - np.log2(self.control_expression + 1))
        return logfold_change

    def sample_to_index(self, sample):
        if not hasattr(self, 'index_to_sequence'):
            print("Generating index to sequence")
            self.index_to_sequence = {}
            for i in tqdm(range(len(self))):
                x = self.__getitem__(i)
                self.index_to_sequence['-'.join(list(x))] = i

        return self.index_to_sequence[sample]
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        perturbation = self.samples[idx]
        return self.get_mean_logfold_change(perturbation)

In [55]:
class SPECTRAPerturb(Spectra):
    def spectra_properties(self, sample_one, sample_two):
        return -np.linalg.norm(sample_one - sample_two)

    def cross_split_overlap(self, train, test):
        average_similarity = []

        for i in test:
            for j in train:
                average_similarity.append(self.spectra_properties(i, j))

        return np.mean(average_similarity)

In [79]:
adata = sc.read('../data/norman_2019_raw.h5ad')
data_type = 'norman_2'

In [81]:
nonzero_genes = (adata.X.sum(axis=0) > 5).A1
filtered_adata = adata[:, nonzero_genes]
adata.obs['condition'] = adata.obs['guide_ids'].cat.rename_categories({'': 'ctrl'})
# data_type = input(f"Pick the data type from ['norman_1', 'norman_2']: ")
# assert data_type in ['norman_1', 'norman_2']

data_type = "norman_2"

if data_type == "norman_1":
    single_gene_mask = [True if "," not in name else False for name in adata.obs['condition']]
    adata = filtered_adata[single_gene_mask, :]
else:
    adata.obs['condition'] = adata.obs['condition'].str.replace(',', '+')

genes = adata.var['gene_symbols'].to_list()
genes_and_ctrl = genes + ['ctrl']

# we remove the cells with perts that are not in the genes because we need gene expression values
# to generate an in-silico perturbation embedding
if data_type == "norman_1":
    adata = adata[adata.obs['condition'].isin(genes_and_ctrl), :]
else:
    conditions = adata.obs['condition']
    
    # need to account for the two-gene perturbations
    filtered_conditions = conditions.apply(
        lambda cond: cond in genes_and_ctrl or (
                '+' in cond and all(gene in genes_and_ctrl for gene in cond.split('+'))
        )
    )
    adata = adata[filtered_conditions, :]

In [82]:
perturb_graph_data = PerturbGraphData(pert_adata, 'norman')

In [83]:
all_perts_orig = [p for p in pert_adata.obs['condition'].unique() if p != 'ctrl']

all_perts_map = {i: pert for i, pert in enumerate(all_perts_orig)}

In [84]:
perts = PerturbGraphData.parse(perturb_graph_data, pert_adata)
perts

['TSC22D1',
 'MAML2',
 'CEBPE',
 'DUSP9',
 'ELMSAN1',
 'UBASH3B',
 'FOXA1',
 'BCORL1',
 'MEIS1',
 'GLB1L2',
 'KLF1',
 'BAK1',
 'FEV',
 'ATL1',
 'CBL',
 'ETS2',
 'SET',
 'TBX3',
 'LHX1',
 'SLC4A1',
 'RREB1',
 'ZNF318',
 'COL2A1',
 'ZBTB25',
 'MAP4K5',
 'SLC6A9',
 'MIDN',
 'DLX2',
 'CBFA2T3',
 'HES7',
 'AHR',
 'FOXO4',
 'RHOXF2',
 'SPI1',
 'RUNX1T1',
 'S1PR2',
 'POU3F2',
 'CNN1',
 'CELF2',
 'MAP2K3',
 'MAP4K3',
 'SAMD1',
 'CDKN1A',
 'PTPN1',
 'TBX2',
 'IER5L',
 'CEBPA',
 'PTPN12',
 'TP73',
 'MAP7D1',
 'FOSB',
 'MAPK1',
 'IRF1',
 'IKZF3',
 'HOXB9',
 'HOXC13',
 'CKS1B',
 'CLDN6',
 'FOXA3',
 'COL1A1',
 'FOXF1',
 'ZBTB1',
 'PRTG',
 'PLK4',
 'BPGM',
 'ARRDC3',
 'BCL2L11',
 'LYL1',
 'MAP2K6',
 'SGK1',
 'CDKN1B',
 'FOXL2',
 'NIT1',
 'IGDCC3',
 'OSR2',
 'HNF4A',
 'KMT2A',
 'ISL2',
 'TMSB4X',
 'KIF2C',
 'CSRNP1',
 'ARID1A',
 'CNNM4',
 'UBASH3A',
 'NCL',
 'ZC3HAV1',
 'PTPN9',
 'STIL',
 'CEBPB',
 'TGFBR2',
 'JUN',
 'ZBTB10',
 'PTPN13',
 'SLC38A2',
 'HOXA13',
 'SNAI1',
 'CITED1',
 'PRDM1',
 'HK2',
 

In [9]:
sp_ids = ["0.00", "0.10", "0.20", "0.30", "0.40", "0.50", "0.60", "0.70"]
replicates = 3 
all_sps = [f"{x}_{y}" for x in sp_ids for y in range(replicates)]
all_sps

['0.00_0',
 '0.00_1',
 '0.00_2',
 '0.10_0',
 '0.10_1',
 '0.10_2',
 '0.20_0',
 '0.20_1',
 '0.20_2',
 '0.30_0',
 '0.30_1',
 '0.30_2',
 '0.40_0',
 '0.40_1',
 '0.40_2',
 '0.50_0',
 '0.50_1',
 '0.50_2',
 '0.60_0',
 '0.60_1',
 '0.60_2',
 '0.70_0',
 '0.70_1',
 '0.70_2']

In [10]:
split_map = {}
for split in all_sps:
    split_id = f'SP_{split}'
    with open(f'../../data/splits/perturb/{data_type}/norman_SPECTRA_splits/{split_id}/test.pkl', 'rb') as f:
        spectra_splits = pkl.load(f)
    test_perts = [all_perts_map[split] for split in spectra_splits]
    split_map[split_id] = test_perts

In [11]:
split_map

{'SP_0.00_0': ['SNAI1',
  'PTPN12',
  'COL2A1',
  'CITED1',
  'NIT1',
  'KLF1',
  'FOXF1',
  'CSRNP1',
  'ZBTB10',
  'ARID1A',
  'CEBPB',
  'SLC4A1',
  'KIF18B',
  'FOSB',
  'PTPN1',
  'CBFA2T3',
  'SET',
  'SAMD1',
  'ZBTB1',
  'BAK1',
  'HK2'],
 'SP_0.00_1': ['ZC3HAV1',
  'LYL1',
  'SPI1',
  'ETS2',
  'S1PR2',
  'TSC22D1',
  'TMSB4X',
  'CKS1B',
  'TBX3',
  'CELF2',
  'AHR',
  'FEV',
  'COL2A1',
  'MEIS1',
  'FOSB',
  'HOXB9',
  'ZBTB1',
  'ZNF318',
  'PTPN13',
  'TP73',
  'FOXA1'],
 'SP_0.00_2': ['SET',
  'CEBPB',
  'MAP4K3',
  'SLC6A9',
  'LHX1',
  'MAP2K3',
  'BAK1',
  'OSR2',
  'CITED1',
  'TP73',
  'ZNF318',
  'ELMSAN1',
  'JUN',
  'FOXA1',
  'BCORL1',
  'SGK1',
  'HNF4A',
  'ETS2',
  'KIF2C',
  'MAP7D1',
  'COL1A1'],
 'SP_0.10_0': ['PTPN9',
  'ARID1A',
  'JUN',
  'CNNM4',
  'DUSP9',
  'IER5L',
  'IRF1',
  'CDKN1A',
  'MAP2K3',
  'TGFBR2',
  'ZBTB1',
  'HNF4A',
  'HOXC13',
  'CSRNP1',
  'FOXL2',
  'CKS1B'],
 'SP_0.10_1': ['FOXA3',
  'RREB1',
  'CNNM4',
  'SGK1',
  'CITED1',
  'F

In [12]:
# genes_to_keep = ['SLC4A1', 'IKZF3', 'GLB1L2', 'CEBPE', 'CEBPA', 'AHR']

In [13]:
# split_map_reduced = {k: v for k, v in split_map.items() if '0.00' in k or '0.70' in k}
# split_map_reduced

In [14]:
def filter_dict_by_list(input_dict, filter_list):
    filtered_dict = {}
    for key, values in input_dict.items():
        filtered_values = [value for value in values if value in filter_list]
        if filtered_values:
            filtered_dict[key] = filtered_values
    return filtered_dict

In [15]:
split_map_reduced = filter_dict_by_list(split_map_reduced, genes_to_keep)
split_map_reduced

NameError: name 'split_map_reduced' is not defined

In [25]:
# save the pert_split_map dict to a pkl file to data 
with open('../../data/splits/perturb/norman_1/pert_split_map.pkl', 'wb') as f:
    pkl.dump(split_map, f)

In [None]:
pert_names = []
for split in spectra_splits:
    pert_name = all_perts_map[split]
    pert_names.append(pert_name)
pert_names

In [None]:
def update_yaml_config(file_path, new_eval_pert):
    yaml = YAML()
    yaml.preserve_quotes = True
    yaml.indent(mapping=2, sequence=4, offset=2)

    with open(file_path, 'r') as f:
        config = yaml.load(f)

    # Update the eval_pert value
    config['data']['eval_pert'] = new_eval_pert

    # Save the updated config
    with open(file_path, 'w') as f:
        yaml.dump(config, f)

In [51]:
# TODO: open the experiment config file and generate as many as we need 
from ruamel.yaml import YAML 

base_config = '../../configs/experiment/mlp_norman_inference.yaml'
yaml = YAML()
yaml.preserve_quotes = True
yaml.indent(mapping=2, sequence=4, offset=2)

with open(base_config, 'r') as f:
    config = yaml.load(f)
config

{'model_type': 'mlp', 'defaults': [{'override /model': 'mlp'}, {'override /logger': 'wandb'}], 'total_genes': 2060, 'emb_dim': 3072, 'hidden_dim': 1536, 'mean_adjusted': False, 'save_dir': '${paths.data_dir}/${data.data_name}/pert_effects/${data.eval_pert}/pert_effect_pred_${data.fm}.pkl', 'train_date': '2024-09-13', 'timestamp': '12-08-25', 'data': {'data_name': 'norman_1', 'data_type': 'scfoundation', 'deg_eval': True, 'eval_pert': 'SET', 'split': 0.0, 'replicate': 0, 'fm': 'scfoundation'}, 'trainer': {'num_sanity_val_steps': 0, 'inference_mode': True, 'accelerator': 'cpu'}, 'ckpt_path': '${paths.log_dir}train/runs/${train_date}/${timestamp}/checkpoints/${callbacks.model_checkpoint.filename}.ckpt', 'logger': {'wandb': {'tags': ['eval', 'norman', '${data.eval_pert}', '${data.fm}', 'split_${data.split}', 'replicate_${data.replicate}'], 'group': 'test_mean_norman_${data.split}', 'project': 'perturbench-deg'}}}

In [52]:
import os 
if not os.path.exists(f'../../configs/experiment/deg_evals/'):
    os.makedirs(f'../../configs/experiment/deg_evals/')

id = 1
for i, (key, perts) in enumerate(split_map.items()):
    split = float(key.split('_')[1])
    replicate = int(key.split('_')[2])
    config['data']['split'] = split
    config['data']['replicate'] = replicate
    for j, pert in enumerate(perts):
        config['data']['eval_pert'] = pert
        with open(f'../../configs/experiment/deg_evals/mlp_norman_inference_{id}.yaml', 'w') as f:
            yaml.dump(config, f)
        id += 1 