In [1]:
# %% Import required libraries and modules
import logging
import os
import sys
import warnings

# Add the src directory to the Python path
sys.path.append(os.path.abspath(os.path.join("..", "src")))

# Suppress all FutureWarning messages
warnings.simplefilter(action="ignore", category=FutureWarning)

import decoupler as dc
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from evaluation import evaluate_model
from models import FlexibleFCNN
from preprocess import split_data
from training import train_model
from utils import create_dataloader, load_config, load_sampled_data


# Configure logging
logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s - %(levelname)s - %(message)s",
)

# Load Config
config = load_config("../config.yaml")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")

# %% Load and Preprocess Datasets with TF Activity Inference
logging.info("Loading datasets and running TF activity inference...")

2025-01-28 14:41:34 [INFO] - Using device: cuda
2025-01-28 14:41:34 [INFO] - Loading datasets and running TF activity inference...


In [2]:
def run_tf_activity_inference(X, net, min_n=1):
    """
    Run TF activity inference on the input data.

    Args:
        X (pd.DataFrame): Gene expression matrix, including metadata columns.
        net (pd.DataFrame): Regulatory network for TF activity inference.
        min_n (int): Minimum number of targets for each TF.

    Returns:
        pd.DataFrame: TF activity matrix with metadata reattached.
    """
    import scanpy as sc
    import pandas as pd
    import logging

    # Separate metadata columns from gene expression
    metadata_cols = [
        "cell_mfc_name",
        "viability",
        "pert_dose",
    ]  
    metadata = X[metadata_cols]
    gene_expression = X.drop(columns=metadata_cols)

    # Filter the network for shared genes
    shared_genes = net["target"].unique() 
    shared_genes = [gene for gene in shared_genes if gene in gene_expression.columns]
    logging.debug(f"Number of shared genes: {len(shared_genes)}")
    assert shared_genes, "No shared genes between network and gene expression matrix!"

    # Filter network and gene expression
    net_filtered = net[net["target"].isin(shared_genes)]
    logging.debug(f"Filtered network has {len(net_filtered)} interactions.")
    gene_expression = gene_expression[shared_genes]

    # Create AnnData object
    adata = sc.AnnData(
        X=gene_expression.values,
        obs=pd.DataFrame(index=gene_expression.index),
        var=pd.DataFrame(index=gene_expression.columns),
    )
    logging.info(f"AnnData object created with shape: {adata.shape}")

    # Run ULM for TF activity inference
    dc.run_ulm(
        mat=adata,
        net=net_filtered,
        source="source",
        target="target",
        weight="weight",
        min_n=min_n,
        use_raw=False,
    )

    tf_activity = pd.DataFrame(adata.obsm["ulm_estimate"], index=adata.obs.index)

    # Convert the index of tf_activity to integers to match metadata
    tf_activity.index = tf_activity.index.astype(int)

    # Reattach metadata columns
    tf_activity = tf_activity.join(metadata)

    return tf_activity

In [3]:
gene_df = load_sampled_data(config["data_paths"]["preprocessed_gene_file"], sample_size=1000)
# Load Collectri network
collectri_net = dc.get_collectri(organism="human", split_complexes=False)

2025-01-28 14:41:41 [INFO] - Downloading data from `https://omnipathdb.org/queries/enzsub?format=json`
2025-01-28 14:41:41 [INFO] - Downloading data from `https://omnipathdb.org/queries/interactions?format=json`
2025-01-28 14:41:41 [INFO] - Downloading data from `https://omnipathdb.org/queries/complexes?format=json`
2025-01-28 14:41:41 [INFO] - Downloading data from `https://omnipathdb.org/queries/annotations?format=json`
2025-01-28 14:41:41 [INFO] - Downloading data from `https://omnipathdb.org/queries/intercell?format=json`
2025-01-28 14:41:42 [INFO] - Downloading data from `https://omnipathdb.org/about?format=text`


In [4]:
gene_df_tf = run_tf_activity_inference(gene_df, collectri_net, min_n=100)
gene_df_tf.head()

2025-01-28 14:41:50 [INFO] - AnnData object created with shape: (1000, 5576)


Unnamed: 0,AP1,AR,CEBPA,CEBPB,CEBPG,CREB1,CTNNB1,E2F1,EGR1,ESR1,...,STAT1,STAT3,STAT5A,TFAP2A,TP53,USF1,YY1,cell_mfc_name,viability,pert_dose
0,-1.124206,1.901735,1.432077,0.140267,-0.039159,0.86342,-1.04574,-0.980761,-1.394349,-0.488773,...,-0.121722,2.353502,-1.145322,1.476755,-0.763296,0.18004,-1.825588,PC3,0.372083,10.0
1,-0.125112,1.307696,-1.056423,-1.015212,-0.13205,0.398142,0.579014,0.262649,-0.408944,0.216572,...,-0.690598,-2.178105,0.085542,-0.260803,1.232395,0.299708,0.628411,BT474,0.713679,1.0
2,0.421042,-0.110407,1.03805,0.297216,-2.694845,0.922409,3.08867,-0.111851,0.295576,-1.088556,...,-0.483612,1.827108,-1.062499,-1.465414,0.602122,-0.223885,-0.130144,H1563,0.916843,0.1
3,0.543086,-2.097737,-1.795091,-0.646886,-1.983225,0.165981,-1.605723,-1.550775,-0.563031,0.504422,...,0.909154,-1.525902,1.841691,0.202296,-0.896314,-0.064952,-0.121592,A549,0.73885,10.0
4,1.560993,1.027034,-0.789244,1.138605,-0.278628,0.345317,-0.324742,1.639847,0.127035,2.561164,...,1.032759,0.310378,0.083342,-0.79113,1.304706,-1.691423,-0.686354,A549,0.783271,10.0


In [6]:
# Iterate through datasets, apply TF activity inference
datasets = {
    name: (
        run_tf_activity_inference(
            load_sampled_data(file_path, sample_size=1000),
            collectri_net,
            min_n=config.get("min_n", 1),  # Use min_n from config or default to 1
        ),
        "viability",
    )
    for name, file_path in config["data_paths"].items()
}

2025-01-28 10:37:13 [INFO] - AnnData object created with shape: (1000, 549)
2025-01-28 10:37:17 [INFO] - AnnData object created with shape: (1000, 5576)
2025-01-28 10:37:21 [INFO] - AnnData object created with shape: (1000, 4613)


In [7]:
datasets["preprocessed_landmark_file"]

(    cell_mfc_name  viability  pert_dose  ABL1  AEBP1  AHR  AIP  AIRE  AKNA  \
 0             PC3   0.372083   10.00000   NaN    NaN  NaN  NaN   NaN   NaN   
 1           BT474   0.713679    1.00000   NaN    NaN  NaN  NaN   NaN   NaN   
 2           H1563   0.916843    0.10000   NaN    NaN  NaN  NaN   NaN   NaN   
 3            A549   0.738850   10.00000   NaN    NaN  NaN  NaN   NaN   NaN   
 4            A549   0.783271   10.00000   NaN    NaN  NaN  NaN   NaN   NaN   
 ..            ...        ...        ...   ...    ...  ...  ...   ...   ...   
 995          A375   1.000000    0.12000   NaN    NaN  NaN  NaN   NaN   NaN   
 996          U2OS   0.978274    0.37037   NaN    NaN  NaN  NaN   NaN   NaN   
 997          MCF7   0.978809    1.11000   NaN    NaN  NaN  NaN   NaN   NaN   
 998          MCF7   0.814630    0.04120   NaN    NaN  NaN  NaN   NaN   NaN   
 999          HT29   1.000000   10.00000   NaN    NaN  NaN  NaN   NaN   NaN   
 
      AP1  ...  ZNF667  ZNF671  ZNF699  ZNF76  ZNF