## OT - we want to use OT concept to find pairs btwn ctrl and treatment, then using VAE get the transformations. Should be a good baseline.

### Imports

In [1]:
import scanpy as scp
import pandas as pd
import numpy as np
import catboost as cb
from tqdm import tqdm
from scipy import sparse
from sklearn.model_selection import train_test_split

from catboost import CatBoostClassifier, CatBoostRegressor
from lightgbm import LGBMClassifier, LGBMRegressor
from collections import Counter
import lightgbm as lgb
import matplotlib.pyplot as plt
import umap
from sklearn.model_selection import StratifiedKFold

import ot
from sklearn.decomposition import PCA

### Magics

In [48]:
GENE_PER_CELL_BINNING = False
N_BINS = 1000
N_ITER = 50
TOP_N_GENES = 5000

### Step 0

In [49]:
adata = scp.read_h5ad('./data/Norman_2019/norman_umi_go/perturb_processed.h5ad')

In [50]:
## Following the scGPT paper, we bin the genes within cell. 

def bin_nonzero_values(arr, num_bins):
    # Filter out non-zero values
    nonzero_vals = arr[arr != 0]
    
    # Calculate bin edges
    bin_edges = np.linspace(nonzero_vals.min(), nonzero_vals.max(), num_bins)
    
    # Bin the values
    binned_values = np.zeros_like(arr)
    binned_nonzero = np.digitize(nonzero_vals, bin_edges)
    binned_values[arr != 0] = binned_nonzero
    
    return binned_values

# Example usage
arr = np.random.randint(low=0, high=100, size=100)
num_bins = 3
binned_values = bin_nonzero_values(arr, num_bins)
print(set(binned_values))

{1, 2, 3}


In [51]:
scp.pp.normalize_total(adata, exclude_highly_expressed=True)
scp.pp.log1p(adata)
scp.pp.highly_variable_genes(adata, n_top_genes=TOP_N_GENES,subset=True)

In [52]:
if GENE_PER_CELL_BINNING:
    tempy = adata.X.toarray()
    
    for c in tqdm(range(adata.X.shape[0])):
        tempy[c,:] = bin_nonzero_values(tempy[c,:], N_BINS)
    
    adata.X = sparse.csr_matrix(tempy)
    del tempy

In [105]:
y = adata.obs.condition.values.astype(str)
X = adata.X.toarray()

In [106]:
gene_num_map = ['ctrl']
y_processed = []

for rec in tqdm(y):
    y_processed.append([])
    comps = rec.split('+')
    for c in comps:
        if c not in gene_num_map:
            gene_num_map.append(c)
        y_processed[-1].append(gene_num_map.index(c))
    if len(y_processed[-1])<2:
        y_processed[-1].append(gene_num_map.index('ctrl'))

100%|██████████████████████████████████████████████████████████| 91205/91205 [00:00<00:00, 324442.86it/s]


In [111]:
def pair_records_optimal_transport(set1, set2):
    """
    Pairs records from set1 with records from set2 using the optimal transport concept.
    Records from set2 can be used multiple times if needed.

    Args:
    - set1 (np.ndarray): First 2D array of records (shape: n1 x features).
    - set2 (np.ndarray): Second 2D array of records (shape: n2 x features).

    Returns:
    - pairs (list of tuples): List of index pairs (i, j) where i is the index
      from set1 and j is the index from set2 that are paired.
    """

    # Ensure the inputs are numpy arrays
    set1 = np.array(set1)
    set2 = np.array(set2)

    # Compute the cost matrix (Euclidean distance between records)
    cost_matrix = ot.dist(set1, set2, metric='euclidean')

    # Compute the optimal transport plan using linear programming
    n1, n2 = set1.shape[0], set2.shape[0]
    a = np.ones(n1) / n1  # uniform distribution on set1
    b = np.ones(n2)  # flexible distribution on set2

    # Normalize b to sum to the same total mass as a
    b = b / b.sum() * a.sum()

    transport_plan = ot.emd(a, b, cost_matrix)

    # Extract pairs based on the transport plan
    pairs = []
    for i in range(n1):
        j = np.argmax(transport_plan[i])
        pairs.append((i, j))

    return pairs

    
# Example usage
set1 = np.random.rand(20, 5000)  # 100 records with 5000 features each
set2 = np.random.rand(100, 5000)  # 100 records with 5000 features each

pairs = pair_records_optimal_transport(set1, set2)
print(pairs)

[(0, 32), (1, 41), (2, 17), (3, 13), (4, 6), (5, 15), (6, 11), (7, 8), (8, 54), (9, 9), (10, 18), (11, 0), (12, 5), (13, 7), (14, 20), (15, 1), (16, 60), (17, 22), (18, 43), (19, 58)]


In [108]:
X_train, X_test, y_train, y_test = train_test_split(X, y_processed, 
                                                  test_size=0.3, 
                                                  random_state=42,
                                                  )

X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, 
                                                  test_size=0.2, 
                                                  random_state=42,
                                                  )

In [118]:
def vae_data_prep(X, y):
    # Convert list of lists to a set of tuples to find unique classes
    unique_classes = set(tuple(sorted(yi)) for yi in y)
    
    output_X = []
    output_Y = []
    labels = []
    
    # Convert y to a list of tuples for easier comparison
    y_tuples = [tuple(sorted(yi)) for yi in y]
    
    for c in tqdm(unique_classes):
        
        set1_indices = [i for i, yi in enumerate(y_tuples) if yi == c]
        set2_indices = [i for i, yi in enumerate(y_tuples) if yi == (0,0)]
        
        set1 = X[set1_indices]
        set2 = X[set2_indices]
        
        pairs = pair_records_optimal_transport(set1, set2)
        
        for p in pairs:
            output_X.append(set2[p[1]])
            output_Y.append(set1[p[0]])
            labels.append(c)
    
    output_X = np.array(output_X)
    output_Y = np.array(output_Y)
    
    return output_X, output_Y, labels

In [122]:
X_trainv, Y_trainv, labels_trainv = vae_data_prep(X_train, y_train)
X_valv, Y_valv, labels_valv = vae_data_prep(X_val, y_val)
X_testv, Y_testv, labels_testv = vae_data_prep(X_test, y_test)

100%|██████████████████████████████████████████████████████████████████| 237/237 [02:25<00:00,  1.63it/s]
100%|██████████████████████████████████████████████████████████████████| 237/237 [00:17<00:00, 13.29it/s]
100%|██████████████████████████████████████████████████████████████████| 237/237 [00:50<00:00,  4.69it/s]


In [13]:
pca = PCA(n_components=0.99)
X_train_r = pca.fit_transform(X_trainv)
X_val_r = pca.transform(X_valv)
X_test_r = pca.transform(X_testv)

In [14]:
Y_train_r = pca.transform(Y_trainv)
Y_val_r = pca.transform(Y_valv)
Y_test_r = pca.transform(Y_testv)

In [16]:
# model = LGBMClassifier(verbose=-1, n_jobs=10)
# model.fit(X_train, y_train, eval_set=[(X_val, y_val)], eval_metric='auc_mu',    callbacks=[
#         lgb.early_stopping(stopping_rounds=100),
#         lgb.log_evaluation(1)
#     ])

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, X):
        self.X = torch.tensor(X, dtype=torch.float32)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx]

class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dims):
        super(VAE, self).__init__()
        
        self.encoder = self.build_encoder(input_dim, latent_dim, hidden_dims)
        self.decoder = self.build_decoder(latent_dim, input_dim, hidden_dims)
        
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_logvar = nn.Linear(hidden_dims[-1], latent_dim)
        
    def build_encoder(self, input_dim, latent_dim, hidden_dims):
        layers = []
        for h_dim in hidden_dims:
            layers.append(nn.Linear(input_dim, h_dim))
            layers.append(nn.ReLU())
            input_dim = h_dim
        return nn.Sequential(*layers)
    
    def build_decoder(self, latent_dim, output_dim, hidden_dims):
        layers = []
        hidden_dims.reverse()
        for h_dim in hidden_dims:
            layers.append(nn.Linear(latent_dim, h_dim))
            layers.append(nn.ReLU())
            latent_dim = h_dim
        layers.append(nn.Linear(hidden_dims[-1], output_dim))
        layers.append(nn.Sigmoid())
        return nn.Sequential(*layers)
    
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x, perturbs = []):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)

        #perturbation introduced at the latent space level
        
        for pert_n in perturbs:
            c = 0
            for p in pert_n:
                if p>0:
                    z[c, p] += 5
                c+=1
                
        recon_x = self.decode(z)
        return recon_x, mu, logvar

def vae_loss(recon_x, x, mu, logvar):
    recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss
    
def train_vae(model, train_loader, val_loader, epochs, learning_rate=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    
    for epoch in range(epochs):
        total_train_loss = 0
        for batch_ctrl, batch_pert, pert_labels in train_loader:
            optimizer.zero_grad()
            
            recon_x, mu, logvar = model(batch_ctrl, pert_labels)
            
            z = model.reparameterize(mu, logvar)
            loss = vae_loss(recon_x, batch_pert, mu, logvar)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        
        avg_train_loss = total_train_loss / len(train_loader.dataset)
        
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch_ctrl, batch_pert, pert_labels in val_loader:
                recon_x, mu, logvar = model(batch_ctrl, pert_labels)
                z = model.reparameterize(mu, logvar)
                loss = vae_loss(recon_x, batch_pert, mu, logvar)
                total_val_loss += loss.item()
        
        avg_val_loss = total_val_loss / len(val_loader.dataset)
        
        print(f'Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
        
        model.train()

# Example usage:

# # Assuming X_train, X_val, Y_train, Y_val are defined numpy arrays
# X_trainv = np.random.rand(1000, 50)  # Replace with your actual data
# Y_trainv = np.random.rand(1000, 50)  # Replace with your actual class labels
# X_valv = np.random.rand(200, 50)     # Replace with your actual data
# Y_valv = np.random.rand(200, 50) # Replace with your actual class labels

# l_val1 = np.random.randint(0,100,200)
# l_val2 = np.random.randint(0,100,200)
# labels_valv = [[l_val1[i], l_val2[i]] for i in range(len(l_val1))]

# l_train1 = np.random.randint(0,100,1000)
# l_train2 = np.random.randint(0,100,1000)
# labels_train = [[l_train1[i], l_train2[i]] for i in range(len(l_train1))]

class CustomDataset(Dataset):
    def __init__(self, X, Y, labels):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.Y = torch.tensor(Y, dtype=torch.float32)
        self.labels = labels
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx], self.labels[idx]


train_dataset = CustomDataset(X_trainv[:,:50], Y_trainv[:,:50], labels_trainv)
val_dataset = CustomDataset(X_valv[:,:50], Y_valv[:,:50], labels_valv)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

input_dim = 50
latent_dim = 125
hidden_dims = [128]
epochs = 20
learning_rate = 1e-3

vae = VAE(input_dim, latent_dim, hidden_dims)
train_vae(vae, train_loader, val_loader, epochs, learning_rate)

# For testing, you can run the model in evaluation mode
vae.eval()
with torch.no_grad():
    for batch_ctrl, batch_pert, pert_labels in val_loader:
        recon_x, mu, logvar = vae(batch_ctrl, pert_labels)
        z = vae.reparameterize(mu, logvar)
        break

Epoch 1, Train Loss: 3.1179, Val Loss: 2.9125
Epoch 2, Train Loss: 2.9354, Val Loss: 2.9031
