# Using a pre-trained model and extracting information about context specific feature impacts using Sampling Perturbation method for Transformers

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import os
import sys
import time
import copy
from pathlib import Path
from typing import Iterable, List, Tuple, Dict, Union, Optional
import warnings

import torch
import numpy as np
import matplotlib
from torch import nn
from torch.nn import functional as F
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)
from torch_geometric.loader import DataLoader
from gears import PertData, GEARS
from gears.inference import compute_metrics, deeper_analysis, non_dropout_analysis
from gears.utils import create_cell_graph_dataset_for_prediction

sys.path.insert(0, "../")

import scgpt as scg
from scgpt.model import TransformerGenerator
from scgpt.loss import (
    masked_mse_loss,
    criterion_neg_log_bernoulli,
    masked_relative_error,
)
from scgpt.tokenizer import tokenize_batch, pad_batch, tokenize_and_pad_batch
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.utils import set_seed, map_raw_id_to_vocab_id, compute_perturbation_metrics

matplotlib.rcParams["savefig.transparent"] = False
warnings.filterwarnings("ignore")

set_seed(42)


  from .autonotebook import tqdm as notebook_tqdm


 ## Training Settings

In [3]:
# settings for data prcocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
pad_value = 0  # for padding values
pert_pad_id = 0
include_zero_gene = "all"
max_seq_len = 1536

# settings for training
MLM = True  # whether to use masked language modeling, currently it is always on.
CLS = False  # celltype classification objective
CCE = False  # Contrastive cell embedding objective
MVC = False  # Masked value prediction for cell embedding
ECS = False  # Elastic cell similarity objective
amp = True
load_model = "../save/scGPT_human"
load_param_prefixs = [
    "encoder",
    "value_encoder",
    "transformer_encoder",
]

# settings for optimizer
lr = 1e-4  # or 1e-4
batch_size = 64
eval_batch_size = 64
epochs = 3
schedule_interval = 1
early_stop = 10

# settings for the model
embsize = 512  # embedding dimension
d_hid = 512  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 12  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 8  # number of heads in nn.MultiheadAttention
n_layers_cls = 3
dropout = 0  # dropout probability
use_fast_transformer = True  # whether to use fast transformer

# logging
log_interval = 100

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [4]:
pert_data = PertData("../data/")
pert_data.load(data_path='../data/fibroblast_p20-nr/')
pert_data.prepare_split(split='no_split', seed=1)
pert_data.get_dataloader(batch_size=batch_size, test_batch_size=eval_batch_size)

Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Done!
Creating dataloaders....
Dataloaders created...


{'test_loader': <torch_geometric.deprecation.DataLoader at 0x7f7034168c70>}

In [5]:
data_name = 'testing_sampling_pert'
save_dir = Path(f"./save/dev_perturb_{data_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
print(f"saving to {save_dir}")

logger = scg.logger
scg.utils.add_file_handler(logger, save_dir / "run.log")
# log running date and current git commit
logger.info(f"Running on {time.strftime('%Y-%m-%d %H:%M:%S')}")

saving to save/dev_perturb_testing_sampling_pert-Oct15-14-31
scGPT - INFO - Running on 2024-10-15 14:31:44


In [6]:
if load_model is not None:
    model_dir = Path(load_model)
    model_config_file = model_dir / "args.json"
    model_file = model_dir / "best_model.pt"
    vocab_file = model_dir / "vocab.json"

    vocab = GeneVocab.from_file(vocab_file)
    for s in special_tokens:
        if s not in vocab:
            vocab.append_token(s)

    pert_data.adata.var["id_in_vocab"] = [
        1 if gene in vocab else -1 for gene in pert_data.adata.var["gene_name"]
    ]
    gene_ids_in_vocab = np.array(pert_data.adata.var["id_in_vocab"])
    logger.info(
        f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
        f"in vocabulary of size {len(vocab)}."
    )
    genes = pert_data.adata.var["gene_name"].tolist()

    # model
    with open(model_config_file, "r") as f:
        model_configs = json.load(f)
    logger.info(
        f"Resume model from {model_file}, the model args will override the "
        f"config {model_config_file}."
    )
    embsize = model_configs["embsize"]
    nhead = model_configs["nheads"]
    d_hid = model_configs["d_hid"]
    nlayers = model_configs["nlayers"]
    n_layers_cls = model_configs["n_layers_cls"]
else:
    genes = pert_data.adata.var["gene_name"].tolist()
    vocab = Vocab(
        VocabPybind(genes + special_tokens, None)
    )  # bidirectional lookup [gene <-> int]
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(
    [vocab[gene] if gene in vocab else vocab["<pad>"] for gene in genes], dtype=int
)
n_genes = len(genes)



scGPT - INFO - match 17315/17315 genes in vocabulary of size 60697.
scGPT - INFO - Resume model from ../save/scGPT_human/best_model.pt, the model args will override the config ../save/scGPT_human/args.json.


 # Create and train scGpt

In [7]:
from scgpt.model import TransformerModel

In [8]:
ntokens = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    use_fast_transformer=use_fast_transformer,
)
if load_param_prefixs is not None and load_model is not None:
    # only load params that start with the prefix
    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_file,  map_location=torch.device(device))
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if any([k.startswith(prefix) for prefix in load_param_prefixs])
    }
    for k, v in pretrained_dict.items():
        logger.info(f"Loading params {k} with shape {v.shape}")
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
elif load_model is not None:
    try:
        model.load_state_dict(torch.load(model_file))
        logger.info(f"Loading all model params from {model_file}")
    except:
        # only load params that are in the model and match the size
        model_dict = model.state_dict()
        pretrained_dict = torch.load(model_file, map_location=torch.device(device))
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        for k, v in pretrained_dict.items():
            logger.info(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
model.to(device)

scGPT - INFO - Loading params encoder.embedding.weight with shape torch.Size([60697, 512])
scGPT - INFO - Loading params encoder.enc_norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params encoder.enc_norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
scGPT - INFO - Loading params value_encoder.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params value_encoder.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.Wqkv.bias with shape torch.Size([153

TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): FlashTransformerEncoderLayer(
        (self_attn): FlashMHA(
          (Wqkv): Linear(in_features=512, out_features=1536, bias=True)
          (inner_attn): FlashAttention()
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0, inplace=False)
        (linear2): Linear(in_features=512, out_fe

In [9]:
import torch_geometric
from torch_geometric.data import Data
from typing import List, Union

In [10]:
def generate_binary_array(n_masks, n_genes, fraction):
    """
    Generate a binary 2D NumPy array with a set fraction of 1s randomly distributed in every row.

    Parameters:
    rows (int): Number of rows in the array.
    cols (int): Number of columns in the array.
    fraction (float): Fraction of 1s in each row (0 <= fraction <= 1).

    Returns:
    np.ndarray: Binary 2D NumPy array.
    """
    if not (0 <= fraction <= 1):
        raise ValueError("Fraction must be between 0 and 1.")

    array = np.zeros((n_masks, n_genes), dtype=int)
    num_ones = int(fraction * n_genes)

    for row in array:
        ones_indices = np.random.choice(n_genes, num_ones, replace=False)
        row[ones_indices] = 1

    return array

def apply_masks(
    values: Union[torch.Tensor, np.ndarray],
    masks: np.ndarray,
    mask_value: int = -1,
    pad_value: int = 0,
    indices_to_keep: List[int] = None
) -> torch.Tensor:
    """
    Apply given masks to the values.

    Args:
        values (array-like): The data to mask, shape (n_features,)
        masks (array-like): An array of masks to apply, shape (n_masks, n_features)
        mask_value (int): The value to mask with, default to -1.
        pad_value (int): The value of padding in the values, will be kept unchanged.
        indices_to_keep (list of int): List of indices that should not be masked.

    Returns:
        torch.Tensor: A tensor of masked data, shape (n_masks, n_features)
    """
    if isinstance(values, torch.Tensor):
        values = values.clone().detach().cpu().numpy()
    else:
        values = values.copy()

    if indices_to_keep is None:
        indices_to_keep = []

    masked_values = []
    for i in range(masks.shape[0]):
        mask = masks[i]
        masked_value = values.copy()
        # Only mask the positions where mask == 1 and values != pad_value
        mask_positions = (mask == 1) & (values != pad_value)
        # Set the indices that should not be masked to False
        mask_positions[indices_to_keep] = False
        masked_value[mask_positions] = mask_value
        masked_values.append(masked_value)

    return torch.from_numpy(np.array(masked_values)).float()

def sample_pert(
    model: nn.Module,
    cell_data: torch_geometric.data.Data,
    masks: np.ndarray,
) -> torch.Tensor:
    """
    Applies multiple masks to a single cell record, reconstructs the missing values using the model,
    and returns the reconstructed cell expressions.

    Args:
        model (nn.Module): The transformer model.
        cell_data (torch_geometric.data.Data): The data of a single cell.
        masks (np.ndarray): An array of masks to apply, shape (n_masks, n_genes).
        device (torch.device): The device to run the computations on.
        map_raw_id_to_vocab_id (Callable): Function to map raw gene IDs to vocab IDs.
        gene_ids (torch.Tensor): Tensor of gene IDs.
        criterion (Callable): Loss function.
        amp (bool): Automatic Mixed Precision flag.
        CLS, CCE, MVC, ECS: Model-specific flags.
        vocab (Dict): Vocabulary mapping.
        pad_token (str): Padding token.
        max_seq_len (int, optional): Maximum sequence length.
        include_zero_gene (str): Inclusion criteria for zero-expression genes.

    Returns:
        torch.Tensor: Reconstructed cell expressions, shape (n_masks, n_genes).
    """
    model.eval()
    cell_data.to(device)
    x: torch.Tensor = cell_data.x  # (n_genes, 2)
    ori_gene_values = x[0, :]  # (n_genes,)
    pert_flags = x[:, 1].long()  # (n_genes,)
    target_gene_values = cell_data.x  # (n_genes,)

    n_masks = masks.shape[0]
    n_genes = target_gene_values.shape[1]

    # Repeat ori_gene_values and pert_flags for each mask
    
    #ori_gene_values = ori_gene_values.unsqueeze(0).repeat(n_masks, 1)  # (n_masks, n_genes)
        
    # Prepare input_gene_ids
    if include_zero_gene in ["all", "batch-wise"]:
        if include_zero_gene == "all":
            input_gene_ids = torch.arange(n_genes, device=device, dtype=torch.long)
        else:
            input_gene_ids = (
                ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0]
            )
        # Sample input_gene_id
        if max_seq_len and len(input_gene_ids) > max_seq_len:
            # input_gene_ids = torch.randperm(len(input_gene_ids), device=device)[
            #     :max_seq_len
            # ]
            input_gene_ids = torch.tensor(np.arange(max_seq_len)).to(device)

        
        input_values = ori_gene_values[input_gene_ids]
    
        input_values = apply_masks(
                                        input_values,
                                        masks[:, input_gene_ids.cpu().numpy()],
                                        mask_value=-1,
                                        pad_value=0,
                                    ) 
    
        
        mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
        mapped_input_gene_ids = mapped_input_gene_ids.unsqueeze(0).repeat(n_masks, 1)

        # src_key_padding_mask
        src_key_padding_mask = mapped_input_gene_ids.eq(vocab[pad_token])
    
    
    with torch.cuda.amp.autocast(enabled=amp):
        output_dict = model(
            mapped_input_gene_ids.to(device),
            input_values.to(device),
            src_key_padding_mask=src_key_padding_mask.to(device),
            CLS=CLS,
            CCE=CCE,
            MVC=MVC,
            ECS=ECS,
        )
        output_values = output_dict["mlm_output"]  # (n_masks, n_genes)

    return output_values


### Create the reconstructed reprogramming data

In [11]:
import pandas as pd

In [12]:
OKSM = ['SOX2', 'KLF4', 'POU5F1', 'MYC', 'NANOG']

TFs = pd.read_csv('../../perturb_train/little_data/TF_db.csv', index_col = 0)
TFs = TFs.loc[:,'HGNC symbol'].tolist()

ipsc_genes = pert_data.adata.var.gene_name.tolist()
to_perturb = list(set(ipsc_genes).intersection(set(TFs)))

tf_exp = pert_data.adata[:,pert_data.adata.var.gene_name.isin(to_perturb)].X

a = (tf_exp.shape[0] - (tf_exp==0).sum(axis=0))/tf_exp.shape[0]
threshold = 0.10  # Set your desired threshold here

# Get the indexes where values are above the threshold
to_perturb = np.array(to_perturb)[np.array(a).squeeze() > threshold]
to_perturb = np.array(list(set(list(to_perturb) + OKSM)))

In [64]:
adata_list = []
for sample in ['d20-nr', 'p20-nr', 'd8-fm', 'd4-fm', 'p3-nr']:
    pert_data.load(data_path = f'../data/fibroblast_{sample}')
    adata_list.append(pert_data.adata)

pert_data.prepare_split(split='no_split', seed=1)
pert_data.get_dataloader(batch_size=batch_size, test_batch_size=eval_batch_size,)
import anndata as ad    
adata_combined = ad.concat(adata_list, join='outer', axis=0)

Local copy of pyg dataset is detected. Loading...
Done!
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Done!
Creating dataloaders....
Dataloaders created...


In [72]:
adata_combined = adata_combined[:,adata_combined.var.index.isin(to_perturb)].copy()

In [74]:
#train_loader = pert_data.dataloader["train_loader"]
n_masks = 512
batch = 4
fraction = 0.15
libs = {}
for library in ['D20-nr', 'P20-nr', 'D8-fm', 'D4-fm', 'P3-nr']:
    expression = adata_combined[adata_combined.obs.library==library].X[0].toarray()
    single_cell_ref = torch.tensor(expression)
    n_genes = single_cell_ref.shape[1]
    masks = generate_binary_array(n_masks, n_genes, fraction)
    
    ref_cells = []
    for i in range(0, masks.shape[0]-batch, batch):
        res = sample_pert(model, Data(x=single_cell_ref), masks[i:i+batch]).detach().cpu().numpy()
        ref_cells.append(res)
        
    ref_cells = np.array(ref_cells).reshape(-1, len(to_perturb))
    ref_cells = pd.DataFrame(ref_cells, columns = adata.var.index.tolist())
    libs[library] = ref_cells

### Optimise the number of masks needed

In [27]:
from tqdm import tqdm
import matplotlib.pyplot as plt

In [28]:
def analyze_correlation_vs_mask_size(single_cell_ref, single_cell_pert, model, adata, batch=4, fraction=0.3, step_size=64):
    """
    Analyze how correlation changes with increasing mask sizes.

    Parameters:
    - single_cell_ref: numpy array, reference cell data
    - single_cell_pert: numpy array, perturbed cell data
    - model: PyTorch model used to sample perturbations
    - adata: AnnData object containing experimental data
    - batch: int, batch size for sampling perturbations
    - fraction: float, fraction of genes to perturb
    - step_size: int, increment in number of masks for each iteration

    Returns:
    - None
    """
    n_genes = single_cell_ref.shape[1]
    max_masks = 5000
    correlations = []
    mask_sizes = list(range(step_size, max_masks + 1, step_size))
    
    #ref = adata[adata.obs.condition == 'ctrl'].X.toarray()
    pert = adata[adata.obs.condition != 'ctrl'].X.toarray()
        
    for n_masks in tqdm(mask_sizes):
        # Generate binary masks
        masks = generate_binary_array(n_masks, n_genes, fraction)
        
        # Sample reference and perturbed cells using masks
        ref_cells = []
        pert_cells = []
        for i in range(0, masks.shape[0] - batch, batch):
            # res_ref = sample_pert(model, Data(x=single_cell_ref), masks[i:i + batch]).detach().cpu().numpy()
            # ref_cells.append(res_ref)
            
            res_pert = sample_pert(model, Data(x=single_cell_pert), masks[i:i + batch]).detach().cpu().numpy()
            pert_cells.append(res_pert)
        
        # # Convert to DataFrame
        # ref_cells = np.array(ref_cells).reshape(-1, single_cell_ref.shape[1])
        # ref_cells = pd.DataFrame(ref_cells, columns=adata.var.index.tolist())
        
        pert_cells = np.array(pert_cells).reshape(-1, single_cell_pert.shape[1])
        pert_cells = pd.DataFrame(pert_cells, columns=adata.var.index.tolist())
        
        corr = np.corrcoef(pert_cells.mean().values, pert.mean(axis=0))[0, 1]
        correlations.append(corr)

    # Plot correlation against mask size
    plt.figure(figsize=(10, 6))
    plt.plot(mask_sizes, correlations, marker='o')
    plt.xlabel('Number of Masks')
    plt.ylabel('Correlation')
    plt.title('Correlation vs. Number of Masks')
    plt.grid()
    plt.show()

analyze_correlation_vs_mask_size(single_cell_ref, single_cell_pert, model, adata, batch=128, fraction=0.15, step_size=128)

 10%|█         | 4/39 [00:04<00:35,  1.02s/it]

KeyboardInterrupt



### After the mask size is optimised, it is time to run the search for the optimal Fibroblast - iPSCs reprogramming factors!
We start by iterating over 3-gene subsets and measuring correlation with all the available reprogramming pathways.

### Generate perturbed cells

In [29]:
pert_data = PertData("../data/")
pert_data.load(data_path='../data/fibroblast_p20-nr/')
pert_data.prepare_split(split='no_split', seed=1)
pert_data.get_dataloader(batch_size=batch_size, test_batch_size=eval_batch_size)

Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Done!
Creating dataloaders....
Dataloaders created...


{'test_loader': <torch_geometric.deprecation.DataLoader at 0x7f6ef7e4dd20>}

In [30]:
OKSM = ['SOX2', 'KLF4', 'POU5F1', 'MYC', 'NANOG']

TFs = pd.read_csv('../../perturb_train/little_data/TF_db.csv', index_col = 0)
TFs = TFs.loc[:,'HGNC symbol'].tolist()

ipsc_genes = pert_data.adata.var.gene_name.tolist()
to_perturb = list(set(ipsc_genes).intersection(set(TFs)))

tf_exp = pert_data.adata[:,pert_data.adata.var.gene_name.isin(to_perturb)].X

a = (tf_exp.shape[0] - (tf_exp==0).sum(axis=0))/tf_exp.shape[0]
threshold = 0.10  # Set your desired threshold here

# Get the indexes where values are above the threshold
to_perturb = np.array(to_perturb)[np.array(a).squeeze() > threshold]
to_perturb = np.array(list(set(list(to_perturb) + OKSM)))

In [31]:
len(to_perturb)

1252

In [32]:
'SOX2' in to_perturb, 'KLF4' in to_perturb, 'MYC' in to_perturb, 'POU5F1'in to_perturb, 'NANOG' in to_perturb

(True, True, True, True, True)

In [33]:
import random
import math
from tqdm import tqdm

# Fallback implementation of comb for Python versions earlier than 3.8
def comb(n, k):
    return math.factorial(n) // (math.factorial(k) * math.factorial(n - k))

def generate_random_subset(elements, k):
    return tuple(sorted(random.sample(elements, k)))

def generate_random_subsets(elements, k, count):
    total_possible = comb(len(elements), k)
    if count > total_possible:
        raise ValueError(f"Requested count ({count}) exceeds total possible combinations ({total_possible})")
    
    seen = set()
    subsets = []
    
    while len(subsets) < count:
        subset = generate_random_subset(range(len(elements)), k)
        if subset not in seen:
            seen.add(subset)
            subsets.append(tuple(elements[i] for i in subset))
    
    return subsets

def get_subsets(to_perturb, num_subsets=50_000, seed=42):
    random.seed(seed)  # Set seed for reproducibility
    subset_size = 3
    
    print(f"Generating {num_subsets} random {subset_size}-element subsets from {len(to_perturb)} elements")
    return generate_random_subsets(to_perturb, subset_size, num_subsets)

subsets = get_subsets(to_perturb)

print(f"\nTotal subsets generated: {len(subsets)}")
print("\nFirst 5 subsets:")
for i, subset in enumerate(subsets[:5]):
    print(f"Subset {i+1}: {subset}")

print(f"\nLast 5 subsets:")
for i, subset in enumerate(subsets[-5:]):
    print(f"Subset {len(subsets)-4+i}: {subset}")

Generating 50000 random 3-element subsets from 1252 elements

Total subsets generated: 50000

First 5 subsets:
Subset 1: ('ZNF471', 'ICAM1', 'NR4A3')
Subset 2: ('ZNF212', 'NR1H3', 'ZNF282')
Subset 3: ('TAF9B', 'ZBTB9', 'MNAT1')
Subset 4: ('PRAM1', 'BPNT1', 'LGALS9')
Subset 5: ('PSMA6', 'PPP3CA', 'NRIP1')

Last 5 subsets:
Subset 49996: ('MAPK10', 'INSM1', 'ZNF346')
Subset 49997: ('TFAP2E', 'KLF7', 'TCF24')
Subset 49998: ('ZNF451', 'TAF9B', 'POLH')
Subset 49999: ('ERC1', 'PEX14', 'POLE4')
Subset 50000: ('ZC3H14', 'TWIST2', 'ZNF706')


In [34]:
adata_list = []
for sample in ['d20-nr', 'p20-nr', 'd8-fm', 'd4-fm', 'p3-nr']:
    pert_data.load(data_path = f'../data/fibroblast_{sample}')
    adata_list.append(pert_data.adata)

Local copy of pyg dataset is detected. Loading...
Done!
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of pyg dataset is detected. Loading...
Done!


In [35]:
pert_data.prepare_split(split='no_split', seed=1)
pert_data.get_dataloader(batch_size=batch_size, test_batch_size=eval_batch_size,)

Local copy of split is detected. Loading...
Done!
Creating dataloaders....
Dataloaders created...


{'test_loader': <torch_geometric.deprecation.DataLoader at 0x7f6ef8e22470>}

In [36]:
import anndata as ad    
adata_combined = ad.concat(adata_list, join='outer', axis=0)

In [37]:
adata_combined.obs.library.unique()

['D0-fm', 'D20-nr', 'P20-nr', 'D8-fm', 'D4-fm', 'P3-nr']
Categories (6, object): ['D0-fm', 'D4-fm', 'D8-fm', 'D20-nr', 'P3-nr', 'P20-nr']

##### OKSM 3-element subsets data

In [40]:
from itertools import combinations

# Generate all 3-element subsets
subsets_oksm = list(combinations(OKSM, 4))

best_model = model
res = pd.DataFrame()

In [41]:
len(subsets)

50000

##### General Search

In [42]:
from IPython.display import clear_output

In [43]:
def perturb_gene_subset(expression, adata_var_index, model, gene_names_to_perturb, batch=4, n_masks=512, fraction=0.3):
    """
    Perturb a desired subset of genes provided as a list of gene names.

    Parameters:
    - adata: AnnData object containing gene expression data
    - model: PyTorch model used to sample perturbations
    - gene_names_to_perturb: list of strings, names of genes to be perturbed
    - batch: int, batch size for sampling perturbations
    - n_masks: int, number of binary masks to generate for sampling
    - fraction: float, fraction of genes to perturb in each mask

    Returns:
    - ref_cells: DataFrame, reference cell reconstructions
    - pert_cells: DataFrame, perturbed cell reconstructions
    """
    # Extract expression data
    
    expression_pert = expression.copy()
    
    # Find indices of genes to perturb
    gene_indices_to_perturb = [adata_var_index.get_loc(gene) for gene in gene_names_to_perturb]
    
    # Apply overexpression to the selected genes
    expression_pert[:, gene_indices_to_perturb] = expression_pert.max() * 10
    
    # Convert to tensors
    #single_cell_ref = torch.tensor(expression)
    single_cell_pert = torch.tensor(expression_pert)
    
    # Generate binary masks
    n_genes = expression_pert.shape[1]
    masks = generate_binary_array(n_masks, n_genes, fraction)
    
    # Sample reference cells
    # ref_cells = []
    # for i in range(0, masks.shape[0] - batch, batch):
    #     res = sample_pert(model, Data(x=single_cell_ref), masks[i:i + batch]).detach().cpu().numpy()
    #     ref_cells.append(res)
    
    # Sample perturbed cells
    pert_cells = []
    for i in range(0, masks.shape[0] - batch, batch):
        res = sample_pert(model, Data(x=single_cell_pert), masks[i:i + batch]).detach().cpu().numpy()
        pert_cells.append(res)
    
    # Convert results to DataFrames
    # ref_cells = np.array(ref_cells).reshape(-1, n_genes)
    # ref_cells = pd.DataFrame(ref_cells, columns=adata.var.index.tolist())

    pert_cells = np.array(pert_cells).reshape(-1, n_genes)
    pert_cells = pd.DataFrame(pert_cells, columns=adata.var.index.tolist())
    
    return pert_cells

In [None]:
#%%time
res = pd.DataFrame()
counter = 0
adata_var_index = adata.var.index
expression = adata[adata.obs.condition=='ctrl'].X[0,:].toarray()
batch_size = 128
n_masks = 512

# OKSM perturbations
for s in subsets_oksm+[OKSM]:
    genes = [s]
    name = '_'.join(genes[0])
    single_cell = perturb_gene_subset(expression, adata_var_index, model, s, batch=batch_size, n_masks=n_masks, fraction=0.15)
    for lib in libs:
        res.loc[name, f'euclid_distance_{lib}'] = np.linalg.norm(libs[lib].mean(axis=0) - single_cell.mean(axis=0))
        res.loc[name, f'pearson_corr_{lib}'] = np.corrcoef(libs[lib].mean(axis=0), single_cell.mean(axis=0))[0, 1]
    res.to_csv('./save/ipsc_perturbation_sampling_search_oksm.csv')

res = pd.DataFrame()
# Perturbation Search
for s in tqdm(subsets):
    genes = [s]
    name = '_'.join(genes[0])
    single_cell = perturb_gene_subset(expression, adata_var_index, model, s, batch=batch_size, n_masks=n_masks, fraction=0.15)
    #res_express.loc[name,pert_data.adata.var.index[gene_mask]] = single_cell
    
    for lib in libs:
        res.loc[name, f'euclid_distance_{lib}'] = np.linalg.norm(libs[lib].mean(axis=0) - single_cell.mean(axis=0))
        res.loc[name, f'pearson_corr_{lib}'] = np.corrcoef(libs[lib].mean(axis=0), single_cell.mean(axis=0))[0, 1]
    counter+=1
    clear_output()
    if counter>=100:
        res.to_csv('./save/ipsc_perturbation_sampling_search.csv')
        counter = 0
        

  0%|          | 18/50000 [00:31<24:09:13,  1.74s/it]