In [None]:
import itertools
import os

import scanpy as sc

from gears import GEARS, PertData

Load "norman" data.

In [None]:
norman = PertData(data_path="./data")
norman.load(data_name="norman")

Load "ReplogleNorman2020_E7" data.
(This data needs to be generated by executing "preprocessing/ReplogleNorman2020_E7.ipynb".)

In [None]:
ReplogleNorman2020_E7 = PertData(data_path="./data")
ReplogleNorman2020_E7.new_data_process(
    dataset_name="ReplogleNorman2020_E7",
    adata=sc.read_h5ad(
        filename=os.path.join("preprocessing", "ReplogleNorman2020_E7", "adata.h5ad")
    ),
)
ReplogleNorman2020_E7.load(data_path="./data/ReplogleNorman2020_E7")

Split data and get dataloaders.
This is the same [procedure](https://github.com/yhr91/GEARS_misc/blob/main/paper/Fig4_UMAP_train.py) as used for Figure 4 in the GEARS paper.

In [None]:
norman.prepare_split(split="no_test", seed=42)  # Used in Fig. 4.
norman.get_dataloader(batch_size=32, test_batch_size=128)

Set up and train GEARS model.
Use defaults.

In [None]:
gears_model = GEARS(pert_data=norman, device="cpu")
gears_model.model_initialize()
gears_model.train()

Save or load the model.

In [None]:
gears_model.save_model(path="gears_norman_no_test")
gears_model.load_pretrained(path="gears_norman_no_test")

Predict.

In [None]:
# Get all single perturbations.
genes_of_interest = set(
    [
        c.strip("+ctrl")
        for c in norman.obs["condition"]
        if ("ctrl+" in c) or ("+ctrl" in c)
    ]
)
genes_of_interest = [g for g in genes_of_interest if g in list(norman.pert_names)]

# Generate all possible double perturbations (combos).
all_possible_combos = []
for g1 in genes_of_interest:
    for g2 in genes_of_interest:
        if g1 == g2:
            continue
        all_possible_combos.append(sorted([g1, g2]))
all_possible_combos.sort()
all_possible_combos = list(k for k, _ in itertools.groupby(all_possible_combos))

# Predict all single perturbations.
for c in genes_of_interest:
    print(f"Single prediction: {c}")
    predictions, _ = gears_model.predict(pert_list=[[c]])

# Predict all combos.
for it, c in enumerate(all_possible_combos):
    print(f"Combo prediction: {it}")
    predictions, _ = gears_model.predict(pert_list=[c])