## Simple perturbation prediction

Below I introduce an experiment design allowing to perform perturbation modelling using of-the-shelf models and only scRNAseq data.

0. Normalise as proposed in scGPT and split data into two batches with no class overlap.
1. Predict perturbation type on per-cell basis. This gives some kind of assymmetric similarity (or divergence). Loop 0-1 to obtain multiple measures.
2. Build divergence matrices providing context to every cell.
3. As a baseline, predict gene expression perturbation AB from the profiles of perturbations A and B. Can be done either using multiregression model (simpler but slower) or a DL architecture (e.g. VAE).
4. Assess the success of the synthetic profile using classifiers from (1).
5. Introduce the divergence matrix context and assess whether it helps.
6. Compare the designed matrix to external data sources.

### Imports

In [37]:
del cb

In [148]:
import scanpy as scp
import pandas as pd
import numpy as np
import catboost as cb
from tqdm import tqdm
from scipy import sparse
from sklearn.model_selection import train_test_split

from catboost import CatBoostClassifier, CatBoostRegressor
from lightgbm import LGBMClassifier, LGBMRegressor
from collections import Counter

### Magics

In [268]:
N_BINS = 30
N_ITER = 100
TOP_N_GENES = 100
TOP_N_DIVER = 10

### Step 0

In [269]:
adata = scp.read_h5ad('./data/Norman_2019/norman_umi_go/perturb_processed.h5ad')

In [270]:
## Following the scGPT paper, we bin the genes within cell. 

def bin_nonzero_values(arr, num_bins):
    # Filter out non-zero values
    nonzero_vals = arr[arr != 0]
    
    # Calculate bin edges
    bin_edges = np.linspace(nonzero_vals.min(), nonzero_vals.max(), num_bins)
    
    # Bin the values
    binned_values = np.zeros_like(arr)
    binned_nonzero = np.digitize(nonzero_vals, bin_edges)
    binned_values[arr != 0] = binned_nonzero
    
    return binned_values

# Example usage
arr = np.random.randint(low=0, high=100, size=100)
num_bins = 3
binned_values = bin_nonzero_values(arr, num_bins)
print(set(binned_values))

{0, 1, 2, 3}


In [271]:
scp.pp.normalize_total(adata, exclude_highly_expressed=True)
scp.pp.log1p(adata)
scp.pp.highly_variable_genes(adata, n_top_genes=TOP_N_GENES,subset=True)

In [272]:
tempy = adata.X.toarray()

for c in tqdm(range(adata.X.shape[0])):
    tempy[c,:] = bin_nonzero_values(tempy[c,:], N_BINS)

adata.X = sparse.csr_matrix(tempy)

100%|███████████████████████████████████████████████████████████| 91205/91205 [00:03<00:00, 27102.59it/s]


In [273]:
del tempy

In [274]:
y = adata.obs.condition.values.astype(str)
X = adata.X.toarray()

In [275]:
X.shape

(91205, 100)

In [None]:
unique_classes = np.unique(y)

similarity_res = np.zeros(shape=(X.shape[0], N_ITER),dtype=object)

for i in tqdm(range(N_ITER)):
    # Split the unique classes into two sets
    classes_train, _ = train_test_split(unique_classes, test_size=0.5, random_state=47+i)
    
    # Filter the data based on the selected classes for training and testing
    curr_idx_mask = np.isin(y, classes_train)
    X_train, y_train = X[curr_idx_mask], y[curr_idx_mask]
    X_test, y_test = X[~curr_idx_mask], y[~curr_idx_mask]
    idx = np.arange(X.shape[0])
    curr_idx = idx[curr_idx_mask]
    curr_idx_test = idx[~curr_idx_mask]
    
    # Split the data into train and test sets
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, 
                                                      test_size=0.2, 
                                                      random_state=42,
                                                      )
    
    model = LGBMClassifier(verbose=-1, n_jobs=10)
    
    # Train the model with validation data
    model.fit(X_train, y_train, eval_set=[(X_val, y_val)], eval_metric='auc_mu')
    
                                                                        # # Create the CatBoost classifier
                                                                        # model = CatBoostClassifier(verbose=True,thread_count=10,)
                                                                        
                                                                        # # Train the model
                                                                        # model.fit(X_train, y_train, 
                                                                        #           eval_set=(X_val, y_val), 
                                                                        #           use_best_model=True)
    
    y_pred = model.predict(X_test)
    similarity_res[curr_idx_test, i] = y_pred

  2%|█▎                                                               | 2/100 [04:21<3:33:43, 130.86s/it]

### Compute the divergences

In [208]:
def top_n_probs(arr, n):
    """
    Computes the top N frequent non-zero elements and their associated probabilities
    in every row of a 2D NumPy array of type object.

    Args:
        arr (numpy.ndarray): A 2D NumPy array of type object.
        n (int): The number of top frequent elements to consider.

    Returns:
        tuple: A tuple containing two arrays:
            - top_n_arr (numpy.ndarray): An array with shape (n_rows, n),
              containing the top N frequent non-zero elements per row.
            - probs_arr (numpy.ndarray): An array with shape (n_rows, n),
              containing the probabilities of the associated top N non-zero elements.
    """
    n_rows = arr.shape[0]
    top_n_arr = np.full((n_rows, n), None, dtype=object)
    probs_arr = np.full((n_rows, n), np.nan, dtype=np.float64)

    for i, row in enumerate(arr):
        non_zero_row = [x for x in row if x != 0]
        counter = Counter(non_zero_row)
        top_n_elements = [item for item, count in counter.most_common(n)]
        probs = [count / sum(counter.values()) for item, count in counter.most_common(n)]

        for j, element in enumerate(top_n_elements):
            top_n_arr[i, j] = element
            probs_arr[i, j] = probs[j]

    return top_n_arr, probs_arr

In [218]:
gene_names = adata.var.gene_name.values

In [209]:
diverg = top_n_probs(similarity_res, TOP_N_DIVER)

In [240]:
for i in tqdm(range(diverg[0].shape[0])):
    for j in range(TOP_N_DIVER):
        if diverg[0][i][j] is not None:
            diverg[0][i][j] = diverg[0][i][j].replace('+ctrl','').replace('ctrl+','')

100%|██████████████████████████████████████████████████████████| 91205/91205 [00:00<00:00, 302502.87it/s]


In [241]:
enrich_mtx = np.zeros(shape=(adata.obs.shape[0], TOP_N_GENES))
for i in range(enrich_mtx.shape[0]):
    for j in range(2):
        if diverg[0][i][j] in gene_names:
            enrich_mtx[i, diverg[0][i][j] == gene_names] = diverg[1][i][j]

In [244]:
enrich_mtx.shape

(91205, 100)

In [245]:
(enrich_mtx!=0).sum()

1944

## Preparing data from the second part - first, naive without the divergences, then adding it

In [246]:
y = np.array([i.translate(str.maketrans('', '', '+ctrl')) if ('+' in i and 'ctrl' in i) else i for i in y])

In [248]:
unique_classes = np.unique(y)
classes_train, classes_test = train_test_split(unique_classes, test_size=0.5, random_state=47+i)

# Filter the data based on the selected classes for training and testing
curr_idx_mask = np.isin(y, classes_train)
X_train, y_train = X[curr_idx_mask], y[curr_idx_mask]
X_test, y_test = X[~curr_idx_mask], y[~curr_idx_mask]

diverg_tr = enrich_mtx[curr_idx_mask]
diverg_ts = enrich_mtx[~curr_idx_mask]

double_pert_ids_tr = []
for s in range(y_train.shape[0]):
    if '+' in y_train[s] and 'ctrl' not in y_train[s]:
        double_pert_ids_tr.append(s)

double_pert_ids_ts = []
for s in range(y_test.shape[0]):
    if '+' in y_test[s] and 'ctrl' not in y_test[s]:
        double_pert_ids_ts.append(s)


# First for training
dic = {}

p2_train = []
for dp_id in tqdm(double_pert_ids_tr):
    c1, c2 = y_train[dp_id].split('+')
    
    if c1 in y_train and c2 in y_train:
        
        if c1 not in dic.keys():
            dic[c1] = np.where(np.array(y_train)==c1)[0]
        pos1 = np.random.choice(dic[c1])
        
        if c2 not in dic.keys():
            dic[c2] = np.where(np.array(y_train)==c1)[0]
        pos2 = np.random.choice(dic[c2])
        
        p2_train.append((pos1, pos2, dp_id))

# Then for testing
p2_test = []
for dp_id in tqdm(double_pert_ids_ts):
    c1, c2 = y_test[dp_id].split('+')
    
    if c1 in y_test and c2 in y_test:
        
        if c1 not in dic.keys():
            dic[c1] = np.where(np.array(y_test)==c1)[0]
        pos1 = np.random.choice(dic[c1])
        
        if c2 not in dic.keys():
            dic[c2] = np.where(np.array(y_test)==c1)[0]
        pos2 = np.random.choice(dic[c2])
        
        p2_test.append((pos1, pos2, dp_id))

100%|████████████████████████████████████████████████████████████| 16517/16517 [00:04<00:00, 3753.72it/s]
100%|████████████████████████████████████████████████████████████| 18928/18928 [00:05<00:00, 3253.59it/s]


In [259]:
X2_train = np.zeros(shape=(len(p2_train),TOP_N_GENES*2+diverg_tr.shape[1]))
Y2_train = np.zeros(shape=(len(p2_train),TOP_N_GENES))

for i in tqdm(range(Y2_train.shape[0])):
    X2_train[i,:TOP_N_GENES] = X[p2_train[i][0]]
    X2_train[i,TOP_N_GENES:TOP_N_GENES*2] = X[p2_train[i][1]]
    X2_train[i,TOP_N_GENES*2:diverg_tr.shape[1]+TOP_N_GENES*2] = diverg_tr[p2_train[i][0]][0]
    X2_train[i,TOP_N_GENES*2+diverg_tr.shape[1]:] = diverg_tr[p2_train[i][1]][0]
    Y2_train[i,:] = X[p2_train[i][2]]
    

X2_test = np.zeros(shape=(len(p2_test),TOP_N_GENES*2+diverg_ts.shape[1]))
Y2_test = np.zeros(shape=(len(p2_test),TOP_N_GENES))

for i in tqdm(range(Y2_test.shape[0])):
    X2_test[i,:TOP_N_GENES] = X[p2_dataset[i][0]]
    X2_test[i,TOP_N_GENES:TOP_N_GENES*2] = X[p2_dataset[i][1]]
    X2_test[i,TOP_N_GENES*2:diverg_ts.shape[1]+TOP_N_GENES*2] = diverg_ts[p2_test[i][0]][0]
    X2_test[i,TOP_N_GENES*2+diverg_ts.shape[1]:] = diverg_ts[p2_test[i][1]][0]
    
    Y2_test[i,:] = X[p2_dataset[i][2]]

100%|████████████████████████████████████████████████████████████| 2920/2920 [00:00<00:00, 190312.45it/s]
100%|████████████████████████████████████████████████████████████| 6795/6795 [00:00<00:00, 168841.61it/s]


In [260]:
X2_test.shape, X2_train.shape

((6795, 300), (2920, 300))

In [261]:
X2_train, X2_val, Y2_train, Y2_val = train_test_split(X2_train, Y2_train, 
                                                  test_size=0.3, 
                                                  random_state=42,
                                                  )

In [263]:
params = {'learning_rate': 0.3, 
          'depth': 12, 
          'l2_leaf_reg': 3, 
          'loss_function': 'MultiRMSE', 
          'eval_metric': 'MultiRMSE', 
          'task_type': 'CPU', 
          'iterations': 150,
          'od_type': 'Iter', 
          'boosting_type': 'Plain', 
          'bootstrap_type': 'Bernoulli', 
          'allow_const_label': True, 
         }

model = CatBoostRegressor(**params)
#model = LGBMRegressor(objective='regression_l2', metric='l2', num_iterations=50, verbose=-1)

# Train the model
model.fit(X2_train, Y2_train, eval_set=[(X2_val, Y2_val)], early_stopping_rounds = 50, 
                              use_best_model = True, verbose = 1)

# Make predictions on the test set
y_pred = model.predict(X2_test)

0:	learn: 39.5022582	test: 41.7953592	best: 41.7953592 (0)	total: 10.2s	remaining: 25m 12s
1:	learn: 38.2300950	test: 41.8848557	best: 41.7953592 (0)	total: 19.9s	remaining: 24m 31s
2:	learn: 37.5985850	test: 41.9403278	best: 41.7953592 (0)	total: 29.2s	remaining: 23m 49s


KeyboardInterrupt: 

In [164]:
pd.DataFrame([y_pred[0],Y2_test[0]]).T.head(30)

Unnamed: 0,0,1
0,0.013978,0.0
1,0.0,0.0
2,0.125816,0.0
3,0.844329,0.0
4,0.118969,0.0
5,-0.001614,0.0
6,0.000161,0.0
7,1.199218,1.0
8,0.114377,0.0
9,1.65855,0.0


array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,
        0.,  1.,  0.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  0., 16., 16.,
        0.,  0.,  0.,  0.,  0.,  0.,  9.,  0.,  0.,  0., 22.,  0.,  0.,
        1.,  1., 11.,  0.,  0.,  0.,  0.,  0.,  0., 15.,  0.,  0., 13.,
        0., 13.,  0.,  0.,  0.,  0.,  0., 11.,  0.,  0., 20.,  6.,  0.,
        1.,  0.,  0.,  0.,  1.,  1.,  0.,  0.,  1.,  0., 12., 30.,  1.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  0.,
        0.,  0.,  9.,  0.,  1.,  9.,  1.,  6.,  0.])