In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
print(os.getcwd())
os.chdir('../../complete_project/../')
print(os.getcwd())
# Then set up the paths
# import sys

# import os
# os.environ['PYTHONPATH'] = os.getcwd()  # Now points to thesis_code directory
# sys.path.append(os.environ['PYTHONPATH'])
sys.path.append("/home/caspar/thesis_code/CellOracle")
sys.path.append("/home/caspar/thesis_code/complete_project/py files")
sys.path.append("/home/caspar/thesis_code/complete_project/py files/AIFiles")
sys.path.append("/home/caspar/thesis_code/complete_project/py files/baseGRNConstructionFiles")
sys.path.append("/home/caspar/thesis_code/complete_project/py files/oracleInferenceFiles")
sys.path.append("/home/caspar/thesis_code/complete_project/py files/oracleSetup")


/home/caspar/thesis_code/complete_project/notebooks
/home/caspar/thesis_code


In [2]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import anndata
import sys
import logging
from datetime import datetime
import wandb
import pickle
import optuna
#import modified_celloracle as mco
from CellOracle import celloracle as co
import igraph as ig
import CellOracleSetup as setup_module
import GRNClusterAnalysis as analysis_module
import GRNInference as inference_module
import GRNInferenceTest as inference_test_module
from scipy.sparse import issparse
from typing import Dict, List, Tuple, Optional

log_dir = 'logs'
os.makedirs(log_dir, exist_ok=True)
log_filename = os.path.join(log_dir, f"app_{datetime.now().strftime('%Y_%m_%d')}.log")

# Configure the basic logging
logging.basicConfig(
    filename=log_filename,
    filemode='a',
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
    
)

In [None]:
def _calculate_gene_activation_values(oracle, sd_factor: float, genes_that_can_be_perturbed) -> dict:
        """
        Calculates a target activation value for each reg gene based on mean + sd_factor*sd

        The value is calculated as: mean(expression) + sd_factor * SD(expression).
        Mean and SD are calculated after filtering outliers using the IQR method.
        Expression data is taken from self.oracle.adata.X.

        Args:
            sd_factor (float): The number of standard deviations to add to the mean.

        Returns:
            Dict[str, float]: A dictionary mapping gene names to their calculated activation values.
        """
        activation_values_dict = {}
        perturb_indices_in_adata = [oracle.adata.var.index.get_loc(g) for g in genes_that_can_be_perturbed
                                        if g in oracle.adata.var.index]
        expression_data = oracle.adata[:, perturb_indices_in_adata].X

        if issparse(expression_data):
            expression_data = expression_data.toarray()
        elif not isinstance(expression_data, np.ndarray):
            expression_data = np.asarray(expression_data)
            
        for i, gene_name in enumerate(genes_that_can_be_perturbed):

            if gene_name not in oracle.adata.var.index:
                continue

            gene_expr = expression_data[:, i]

            try:
                q1, q3 = np.percentile(gene_expr, [25, 75])
                iqr = q3 - q1
                lower_bound = q1 - 3 * iqr
                upper_bound = q3 + 3 * iqr
                mask_non_outlier = (gene_expr >= lower_bound) & (gene_expr <= upper_bound)
                filtered_expr = gene_expr[mask_non_outlier]
                mean_expr = np.mean(filtered_expr)
                sd_expr = np.std(filtered_expr)
                activation_val = mean_expr + sd_factor * sd_expr

                activation_values_dict[gene_name] = max(0.0, activation_val)

            except Exception as e:
                print("THIS SHOULD NOT HAPPEN: ", e)
                activation_values_dict[gene_name] = 0.0

        print(f"Finished calculating activation values for {len(activation_values_dict)} genes.")
        return activation_values_dict

data_path_trans = os.path.join('../celloracle_data', "transition_matrix")
data_path_new_data = os.path.join('../celloracle_data', "celloracle_object/new_promoter_without_mescs_trimmed_test_own_umap")
oracle_path =  os.path.join(data_path_new_data, "ready_oracle.pkl")
transition_matrix_path = os.path.join(data_path_trans, "transition_matrix.pkl")

if not os.path.exists(data_path_trans):
    os.makedirs(data_path_trans)

progress_indicator_file = os.path.join(data_path_trans, "progress.txt")


with open(oracle_path, 'rb') as f:
    oracle = pickle.load(f)
genes_that_can_be_perturbed = oracle.return_active_reg_genes()
n_neighbors = 200;
batch_size = len(genes_that_can_be_perturbed)*2
oracle.init(embedding_type="X_umap", n_neighbors=n_neighbors, torch_approach=False,cupy_approach=True, batch_size=batch_size)

number_of_reg_genes = len(genes_that_can_be_perturbed)
high_ranges_dict = _calculate_gene_activation_values(oracle, 1.5, genes_that_can_be_perturbed)
#this is the action list
action_array = np.zeros((number_of_reg_genes*2,))
for i, gene_name in enumerate(genes_that_can_be_perturbed):
    action_array[i] =0.0
    action_array[i+number_of_reg_genes] = high_ranges_dict[gene_name]
    
transition_matrix = np.zeros((len(oracle.adata), len(action_array)))
#create perturb list for onracle inference, for all the same, loop over each indivudal cell, apply for that cell mass perturbation of all possible actoins, save resulting indices
perturb_conditions = []
for i in range((len(genes_that_can_be_perturbed)*2)):
    gene_name = ""
    if i < number_of_reg_genes:
        gene_name = genes_that_can_be_perturbed[i]
    else:
        gene_name = genes_that_can_be_perturbed[i-number_of_reg_genes]
    perturb_conditions.append((gene_name, action_array[i]))

for i in range(len(oracle.adata)):
    #create a list for each cell len(oracle.adata)) times for the same idx
    cell_index = [i] * len(perturb_conditions)
    _, new_idx_list, _, _, _ = oracle.training_phase_inference_batch_cp(
                    batch_size=batch_size, idxs=cell_index,
                    perturb_condition=perturb_conditions, n_neighbors=n_neighbors,
                    n_propagation=3, threads=4,
                    knockout_prev_used=False)
    #Store transitions
    print(new_idx_list)
    transition_matrix[i,:] = new_idx_list
    if (i+1) % 100 == 0:
        with open(progress_indicator_file, 'w') as f: # 'w' mode truncates (replaces) the file each time
            f.write(f"Processing cell {i+1} out of {len(oracle.adata)}\n")
            f.write("Still running...\n")
        
with open(transition_matrix_path, 'wb') as f:
    pickle.dump(transition_matrix, f)





AnnData object with n_obs × n_vars = 30000 × 3000
    obs: 'bc_idx', 'colnames', 'obs_names', 'celltype', 'celltype_general'
    var: 'rownames', 'mean', 'std', 'symbol', 'isin_top1000_var_mean_genes', 'isin_TFdict_targets', 'isin_TFdict_regulators', 'isin_actve_regulators'
    uns: 'celltype_colors', 'neighbors', 'pca', 'umap', 'umap_neighbors_sparse'
    obsm: 'X_pca', 'X_umap', 'colnames_factor', 'umap_neighbors'
    varm: 'PCs'
    layers: 'chic', 'raw_chich_counts', 'raw_count', 'unspliced_spliced', 'normalized_count', 'imputed_count', 'simulation_input'
    obsp: 'connectivities', 'distances', 'umap_neighbors_sparse'
Finished calculating activation values for 108 genes.
[28067, 24182, 15671, 19463, 11048, 28975, 4580, 1859, 4683, 3452, 22860, 8676, 19934, 26118, 7152, 9925, 19113, 9734, 23819, 14386, 19583, 16827, 9884, 26766, 26925, 17324, 12611, 138, 23016, 2095, 24179, 19846, 2242, 7789, 27125, 12129, 27068, 7773, 25333, 13547, 5817, 11862, 10532, 17546, 7078, 25229, 11862, 95

In [None]:
transition_path = os.path.join('../celloracle_data', "transition_matrix")
data_path = os.path.join('../celloracle_data', "celloracle_object/new_promoter_without_mescs_trimmed_test_own_umap")
ORACLE_PATH = os.path.join(data_path, "ready_oracle.pkl")
TRANSITION_MATRIX_PATH = os.path.join(transition_path, "transition_matrix.pkl") 

with open(TRANSITION_MATRIX_PATH, 'rb') as f:
    transition_matrix = pickle.load(f)

with open(ORACLE_PATH, 'rb') as f:
    oracle = pickle.load(f)
    
genes_that_can_be_perturbed = oracle.return_active_reg_genes()
number_of_reg_genes = len(genes_that_can_be_perturbed)
celltypes_unique = oracle.adata.obs['celltype'].unique()
all_cell_types = oracle.adata.obs['celltype']

celltype_to_idx_dict = {}
for i, celltype in enumerate(celltypes_unique):
    celltype_to_idx_dict[celltype] = list(np.where(all_cell_types == celltype)[0])
    
#start with igraph
graph = ig.Graph(n=len(oracle.adata), directed=True)
edge_list  = []
action_idx = []
for i in range(len(oracle.adata)):
    for j in range(len(action_array)):
        next_node_idx = int(transition_matrix[i,j])
        edge_list.append((i, next_node_idx))
        action_idx.append(j)
        
graph.add_edges(edge_list)
graph.es['action_idx'] = action_idx



In [None]:
def find_shortest_path(graph:ig.Graph, start_node: int, target_celltype:str, celltype_to_idx_dict:dict, genes_that_can_be_perturbed: Optional[List[str]]) -> Tuple[List[int], List[int]]:
    try:
        target_nodes = celltype_to_idx_dict[target_celltype]
        #convert to ints
        target_nodes = [int(i) for i in target_nodes]
        path_to_target  = graph.get_shortest_paths(start_node, to=target_nodes, output='epath', weight=None, mode = "OUT")
        
        if path_to_target is None:
            return [], []
        
        min_length = float('inf')
        shortest_path = None
        for path in path_to_target:
            if path and len(path) > 0:
                if len(path) < min_length:
                    min_length = len(path)
                    shortest_path = path
                    
        if genes_that_can_be_perturbed is None:
            return shortest_path, []
        
        actions  = []
        path_nodes = graph.get_vpath(shortest_path)
        for i in range(len(shortest_path)):
            edge = graph.es[shortest_path[i]]
            action_idx = edge['action_idx']
            action_name = ""
            if action_idx >= len(genes_that_can_be_perturbed):
                action_idx = action_idx - len(genes_that_can_be_perturbed)
                action_name = genes_that_can_be_perturbed[action_idx] + "_ACTIVATION"
            else:
                action_name = genes_that_can_be_perturbed[action_idx] + "_KO"
            actions.append(action_name)
            
        return path_nodes, actions
    except Exception as e:
        print("Error in find_shortest_path: ", e)
        return [], []
        