In [2]:
import pickle
import torch
import anndata as ad
import pandas as pd
import numpy as np
import random
import itertools
import argparse
from scouter import Scouter, ScouterData

In [4]:
def set_seeds(seed=24):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# Normalize the condition name. Make "A+B" and "B+A" the same
def condition_sort(x):
    return '+'.join(sorted(x.split('+')))

In [8]:
data_path = '/Users/pancake/Downloads/Perturb/Gears/adamson/Gears_data/adamson/perturb_processed.h5ad'
embd_path = '/Users/pancake/Downloads/Perturb/scoracle/GeneEmb/GenePT_emb/GenePT_gene_embedding_ada_text.pickle'

# Load the processed scRNA-seq dataset as Anndata
adata = ad.read_h5ad(data_path)
adata.obs['condition'] = adata.obs['condition'].astype(str).apply(lambda x: condition_sort(x)).astype('category')
adata.uns = {}; adata.obs.drop('condition_name', axis=1, inplace=True)

# Load the gene embedding as the dataframe, and rename its gene alias to match the Anndata
with open(embd_path, 'rb') as f:
    embd = pd.DataFrame(pickle.load(f)).T
ctrl_row = pd.DataFrame([np.zeros(embd.shape[1])], columns=embd.columns, index=['ctrl'])
embd = pd.concat([ctrl_row, embd])

pertdata = ScouterData(adata, embd, 'condition', 'gene_name')
pertdata.setup_ad('embd_index')
pertdata.gene_ranks()
pertdata.get_dropout_non_zero_genes()
pertdata.split_Train_Val_Test(seed=1)

scouter_model = Scouter(pertdata)
scouter_model.model_init()
scouter_model.train(loss_lambda=0.01)
metric_df = scouter_model.evaluate()

All 87 perturbed genes are found in the gene embedding matrix!


Epoch 1/40 - Training Batches: 100%|█████████████████████████████████| 250/250 [00:45<00:00,  5.52batch/s]
Epoch 1/40 - Validation Batches: 100%|█████████████████████████████████| 30/30 [00:00<00:00, 31.52batch/s]


Epoch 1/40, Training Loss: 18.7230, Validation Loss: 2.5552


Epoch 2/40 - Training Batches: 100%|█████████████████████████████████| 250/250 [00:43<00:00,  5.69batch/s]
Epoch 2/40 - Validation Batches: 100%|█████████████████████████████████| 30/30 [00:00<00:00, 32.05batch/s]


Epoch 2/40, Training Loss: 0.2949, Validation Loss: 0.1064


Epoch 3/40 - Training Batches: 100%|█████████████████████████████████| 250/250 [00:43<00:00,  5.69batch/s]
Epoch 3/40 - Validation Batches: 100%|█████████████████████████████████| 30/30 [00:00<00:00, 31.56batch/s]


Epoch 3/40, Training Loss: 0.1181, Validation Loss: 0.1049


Epoch 4/40 - Training Batches: 100%|█████████████████████████████████| 250/250 [00:43<00:00,  5.69batch/s]
Epoch 4/40 - Validation Batches: 100%|█████████████████████████████████| 30/30 [00:00<00:00, 31.95batch/s]


Epoch 4/40, Training Loss: 0.1174, Validation Loss: 0.1049


Epoch 5/40 - Training Batches: 100%|█████████████████████████████████| 250/250 [00:44<00:00,  5.67batch/s]
Epoch 5/40 - Validation Batches: 100%|█████████████████████████████████| 30/30 [00:00<00:00, 31.96batch/s]


Epoch 5/40, Training Loss: 0.1162, Validation Loss: 0.1037


Epoch 6/40 - Training Batches: 100%|█████████████████████████████████| 250/250 [00:44<00:00,  5.67batch/s]
Epoch 6/40 - Validation Batches: 100%|█████████████████████████████████| 30/30 [00:00<00:00, 31.83batch/s]


Epoch 6/40, Training Loss: 0.1160, Validation Loss: 0.1036


Epoch 7/40 - Training Batches: 100%|█████████████████████████████████| 250/250 [00:45<00:00,  5.51batch/s]
Epoch 7/40 - Validation Batches: 100%|█████████████████████████████████| 30/30 [00:00<00:00, 31.64batch/s]


Epoch 7/40, Training Loss: 0.1159, Validation Loss: 0.1038


Epoch 8/40 - Training Batches: 100%|█████████████████████████████████| 250/250 [00:44<00:00,  5.68batch/s]
Epoch 8/40 - Validation Batches: 100%|█████████████████████████████████| 30/30 [00:00<00:00, 31.67batch/s]


Epoch 8/40, Training Loss: 0.1154, Validation Loss: 0.1034


Epoch 9/40 - Training Batches: 100%|█████████████████████████████████| 250/250 [00:44<00:00,  5.68batch/s]
Epoch 9/40 - Validation Batches: 100%|█████████████████████████████████| 30/30 [00:00<00:00, 31.81batch/s]


Epoch 9/40, Training Loss: 0.1156, Validation Loss: 0.1038


Epoch 10/40 - Training Batches:  13%|████▎                            | 33/250 [00:05<00:39,  5.51batch/s]


KeyboardInterrupt: 