# Perturbation Data Preparation

In [1]:
"""Perturbation Data Preparation

Structure:
    1. Imports, Variables, Functions
    2. Load Data
    3. Data Preparation
    4. Convert to Perturbation Data
"""

# 1. Imports, Variables, Functions
# importsÇ
import sys

local_gears_path = "/aloy/home/ddalton/projects/GEARS/"
sys.path.insert(0, local_gears_path)

# Clear the gears module from sys.modules to ensure it gets reloaded
if "gears" in sys.modules:
    del sys.modules["gears"]

from torch_geometric.loader import DataLoader
from gears import PertData, GEARS
from gears.inference import compute_metrics, deeper_analysis, non_dropout_analysis
from gears.utils import create_cell_graph_dataset_for_prediction
import numpy as np, os, sys, pandas as pd, scanpy as sc
import anndata as ad
import logging
from tqdm import tqdm
from scipy.sparse import csr_matrix


logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
from matplotlib import pyplot as plt

In [2]:
import os, sys
import gears
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
gears_file_path = os.path.abspath(gears.__file__)
logging.info(f"GEARS file path: {gears_file_path}")

2024-07-20 14:40:24,985 - GEARS file path: /aloy/home/ddalton/projects/GEARS/gears/__init__.py


In [3]:
# variables
diseases_of_interest_set = {"Influenza", "Colorectal Carcinoma", "Breast Cancer"}

example_data_path = (
    "/aloy/home/ddalton/projects/disease_signatures/data/DiSignAtlas/tmp/DSA00123.csv"
)

df_info_path = os.path.join(
    "/aloy",
    "home",
    "ddalton",
    "projects",
    "disease_signatures",
    "data",
    "DiSignAtlas",
    "Disease_information_Datasets_extended.csv",
)


large_df_path = "/aloy/home/ddalton/projects/disease_signatures/data/DiSignAtlas/DiSignAtlas.exp_prof_merged.csv"


# functions
def get_exp_prof(dsaids_interest):
    """Get Expression Profiles"""

    # variables
    file_dir = "/aloy/home/ddalton/projects/disease_signatures/data/DiSignAtlas/tmp/"
    first = True
    for dsaid in tqdm(dsaids_interest):
        __df = pd.read_csv(os.path.join(file_dir, f"{dsaid}.csv"))
        if first:
            df_global = __df
            first = False
        else:
            df_global = pd.concat([df_global, __df], axis=0)
    return df_global


# 2. Load Data
# load DiSignAtlas data information
df_info = pd.read_csv(df_info_path)


# Query data to retrieve dsaids of interest
library_strategies_of_interest_set = {"RNA-seq", "Microarray"}
QUERY = "disease in @diseases_of_interest_set & library_strategy in @library_strategies_of_interest_set & organism == 'Homo sapiens'"
dsaids_interest = df_info.query(QUERY)["dsaid"].to_list()
logging.info(f"Nº of DSAIDs of interest: {len(dsaids_interest)}")

df = get_exp_prof(dsaids_interest)


# filter-out unknown expression values
df["condition"] = [x.split(";")[2] for x in df.iloc[:, 0].values]

logging.info(f"Loaded dataframe {df.shape}")

df = df[df["condition"].isin(["Case", "Control"])]
df.drop(columns=["condition"], inplace=True)
logging.info(f"Filtered dataframe dataframe {df.shape}")


# 3. Data Preparation
# load dataframe
# df = pd.read_csv(large_df_path, skiprows=skip_rows_idxs, index_col=0)
logging.info(f"Loaded dataframe with shape: {df.shape}")

# drop non significant rows
# Calculate the number of NaNs in each row
nan_counts = df.isna().sum(axis=1)

# Filter the DataFrame to keep only rows with NaNs less than or equal to 18,000
df = df[nan_counts <= 18000]
logging.info(f"Filtered dataframe with shape: {df.shape}")

# 3. Convert to `adata` object
# Extract cell identifiers and gene expression data
ids = df.iloc[:, 0]
gene_expression_data = df.iloc[:, 1:].values
gene_names = df.columns[1:]


# Create an AnnData object
adata = ad.AnnData(X=gene_expression_data)


# Add cell and gene metadata
adata.obs["ids"] = ids.values
adata.var["gene_symbols"] = gene_names
adata.var["index"] = gene_names

2024-07-20 14:40:25,518 - Nº of DSAIDs of interest: 212
100%|██████████| 212/212 [01:21<00:00,  2.61it/s]
2024-07-20 14:41:46,611 - Loaded dataframe (5197, 19692)
2024-07-20 14:41:46,790 - Filtered dataframe dataframe (5108, 19691)
2024-07-20 14:41:46,791 - Loaded dataframe with shape: (5108, 19691)
2024-07-20 14:41:46,970 - Filtered dataframe with shape: (5090, 19691)


In [4]:
# 4. Convert to Perturbation Data
# convert adata.X to sparse matrix
adata.X = csr_matrix(adata.X)

# specify adata object `gene_name` on `adata.var`
adata.var["gene_name"] = list(gene_names)

# specgy adata object `condition` & `cell_type` on `adata.obs`
from collections import Counter

dsaid_2_disease = dict(zip(df_info["dsaid"], df_info["disease"]))

condition = list()
for tags in ids.values:
    if tags.split(";")[2] == "Case":
        condition.append(f"{dsaid_2_disease[tags.split(';')[0]]}+ctrl")
    elif tags.split(";")[2] == "Control":
        condition.append("ctrl")
    else:
        print(tags)


logging.info(f"{Counter(condition)}")


adata.obs["condition"] = condition
adata.obs["cell_type"] = "A"

pert_data = PertData("./data")
pert_data.new_data_process(dataset_name="test_1", adata=adata)

2024-07-20 14:41:48,849 - Counter({'ctrl': 1764, 'Breast Cancer+ctrl': 1740, 'Colorectal Carcinoma+ctrl': 1115, 'Influenza+ctrl': 471})
Found local copy...
Found local copy...
Creating pyg object for each cell in the data...
Creating dataset file...
 25%|██▌       | 1/4 [00:00<00:01,  2.48it/s]

De genes:  [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
1764
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(

 50%|█████     | 2/4 [00:01<00:02,  1.00s/it]

De genes:  [ 1073  1830  1928  3468  4337  5789  6273  6389  6401  6694  7087  7173
  7264  7279  8214  8294  9911 10552 12354 12997]
2230
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1,

 75%|███████▌  | 3/4 [00:04<00:01,  1.59s/it]

Running for perturbation:  Influenza+ctrl
Influenza+ctrl


100%|██████████| 4/4 [00:04<00:00,  1.15s/it]
Done!
Saving new dataset pyg object at ./data/test_1/data_pyg/cell_graphs.pkl


De genes:  [  683   865  1073  1830  1928  3468  4337  5789  6273  6401  6694  7173
  7264  7279  8214  8294  9911 10552 12354 12997]
942
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 19690)
(1, 

Done!


In [5]:
pert_data.prepare_split(split="simulation", seed=1)
pert_data.get_dataloader(batch_size=batch_size, test_batch_size=eval_batch_size)

Creating new splits....
Saving new splits at ./data/test_1/splits/test_1_simulation_1_0.75.pkl
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:1
Done!


NameError: name 'batch_size' is not defined

In [None]:
batch_size = 64
eval_batch_size = 64

pert_data.prepare_split(split="simulation", seed=1)
pert_data.get_dataloader(batch_size=batch_size, test_batch_size=eval_batch_size)

train_loader = pert_data.dataloader["train_loader"]

for batch, batch_data in enumerate(train_loader):
    batch_size = len(batch_data.y)
    print(batch_data.x.shape)
    break

Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:1
Done!
Creating dataloaders....
Done!


here1
torch.Size([1260160, 1])


In [None]:
train_loader = pert_data.dataloader["train_loader"]
for batch, batch_data in enumerate(train_loader):
    batch_size = len(batch_data.y)
    print(batch_size)
    break

print(len(train_loader) * batch_size)

64


3456

In [None]:
test_loader = pert_data.dataloader["test_loader"]
for batch, batch_data in enumerate(test_loader):
    batch_size = len(batch_data.y)
    print(batch_size)
    break

print(len(test_loader) * batch_size)

64
1152


In [None]:
val_loader = pert_data.dataloader["val_loader"]
for batch, batch_data in enumerate(val_loader):
    batch_size = len(batch_data.y)
    print(batch_size)
    break

print(len(val_loader) * batch_size)

64
512


In [None]:
# to load the processed data
pert_data.load(data_path="./data/test_1")

Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['Colorectal Carcinoma+ctrl' 'Breast Cancer+ctrl' 'Influenza+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!


In [None]:
pert_data = PertData("./data")
pert_data.load(data_name="test_1")

Found local copy...


TypeError: stat: path should be string, bytes, os.PathLike or integer, not NoneType

In [None]:
pert_data = PertData("./data")
pert_data.load(data_name="adamson")
# pert_data.load(data_path="./data/test_1")

pert_data.prepare_split(split="simulation", seed=1)
pert_data.get_dataloader(batch_size=batch_size, test_batch_size=eval_batch_size)

Found local copy...
Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['SRPR+ctrl' 'SLMO2+ctrl' 'TIMM23+ctrl' 'AMIGO3+ctrl' 'KCTD16+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:22
Done!


here1


NameError: name 'batch_size' is not defined

In [None]:
train_loader = pert_data.dataloader["train_loader"]

for batch, batch_data in enumerate(train_loader):
    batch_size = len(batch_data.y)
    print(batch_data.x.shape)
    break

NameError: name 'pert_data' is not defined

In [None]:
dir(pert_data)

In [None]:
adata_2 = pert_data.adata

In [None]:
pert_data = PertData("./data")
pert_data.new_data_process(dataset_name="adam_2", adata=adata_2)

Found local copy...
Found local copy...
Creating pyg object for each cell in the data...
Creating dataset file...
100%|██████████| 82/82 [00:37<00:00,  2.21it/s]
Done!
Saving new dataset pyg object at ./data/adam_2/data_pyg/cell_graphs.pkl
Done!


In [None]:
pert_data.load(data_path="./data/adam_2")

pert_data.prepare_split(split="simulation", seed=1)
pert_data.get_dataloader(batch_size=batch_size, test_batch_size=eval_batch_size)

train_loader = pert_data.dataloader["train_loader"]

for batch, batch_data in enumerate(train_loader):
    batch_size = len(batch_data.y)
    print(batch_data.x.shape)
    break

Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
[]
Local copy of pyg dataset is detected. Loading...
Done!
Creating new splits....
Saving new splits at ./data/adam_2/splits/adam_2_simulation_1_0.75.pkl
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:21
Done!
Creating dataloaders....
Done!


torch.Size([323840, 1])
