### Import the libraries

In [3]:
!pip install "numpy<2.0" --force-reinstall

Collecting numpy<2.0
  Using cached numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Using cached numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
Successfully installed numpy-1.26.4


In [3]:
# reinstall scanpy
!pip install scanpy

Collecting scanpy
  Downloading scanpy-1.11.2-py3-none-any.whl.metadata (9.1 kB)
Collecting legacy-api-wrap>=1.4.1 (from scanpy)
  Downloading legacy_api_wrap-1.4.1-py3-none-any.whl.metadata (2.1 kB)
Collecting numba>=0.57.1 (from scanpy)
  Downloading numba-0.61.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.8 kB)
Collecting patsy!=1.0.0 (from scanpy)
  Downloading patsy-1.0.1-py2.py3-none-any.whl.metadata (3.3 kB)
Collecting pynndescent>=0.5.13 (from scanpy)
  Downloading pynndescent-0.5.13-py3-none-any.whl.metadata (6.8 kB)
Collecting seaborn>=0.13.2 (from scanpy)
  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting session-info2 (from scanpy)
  Downloading session_info2-0.1.2-py3-none-any.whl.metadata (2.5 kB)
Collecting statsmodels>=0.14.4 (from scanpy)
  Downloading statsmodels-0.14.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.2 kB)
Collecting umap-learn>=0.5.6 (from scanpy)
  Downloading umap_learn-0.5.7

In [18]:

# standard libraries
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import scanpy as sc
from scipy.sparse import csr_matrix

# Visualize latent space
import matplotlib.pyplot as plt
import seaborn as sns


In [9]:
from torch.utils.data import DataLoader
from tqdm import tqdm

### Import the data

In [15]:

adata_100m=sc.read_h5ad("/home/ubuntu/anatoly-tahoe-100-texas/data/tahoe-100m_5M.h5ad")




### Normalize the data

In [17]:

# Normalize each cell to 10,000 counts (CPM-like)
sc.pp.normalize_total(adata_100m, target_sum=1e4)

# Log-transform the data
sc.pp.log1p(adata_100m)


### HVG selection (start with small subset)

In [35]:

# Step 2: Identify highly variable genes
sc.pp.highly_variable_genes(adata_100m, n_top_genes=2000, subset=True, flavor="seurat")


### Train and test data split

In [38]:

# get train and test idx
train_idx, test_idx = train_test_split(adata_100m.obs.index, test_size=0.1, random_state=42)

# data split with copy
adata_train = adata_100m[train_idx].copy()
adata_test = adata_100m[test_idx].copy()

### Dataset class

In [39]:

class AdataCVAEWrapper(Dataset):
    def __init__(self, adata, cat_features, cont_features):
        self.X = adata.X  # keep expression data as sparse matrix
        self.cat_data = pd.get_dummies(adata.obs[cat_features], drop_first=False).values.astype(np.float32)
        self.cat_data = torch.from_numpy(self.cat_data)

        cont = adata.obs[cont_features].values.astype(np.float32)
        cont = (cont - cont.mean(axis=0)) / cont.std(axis=0)
        self.cont_data = torch.from_numpy(cont)

        self.cond = torch.cat([self.cat_data, self.cont_data], dim=1)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        x_row = torch.tensor(self.X[idx].toarray().squeeze(), dtype=torch.float32)
        c = self.cond[idx]
        return x_row, c
    

### Neural network class

In [40]:
class CVAE(nn.Module):
    def __init__(self, input_dim, cond_dim, latent_dim=32):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim + cond_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + cond_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim),
            nn.ReLU()  # or identity / Sigmoid depending on your output
        )

    def encode(self, x, c):
        h = self.encoder(torch.cat([x, c], dim=1))
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):
        return self.decoder(torch.cat([z, c], dim=1))

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z, c)
        return recon_x, mu, logvar

### Loss function

In [41]:
def loss_function(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div

### Training on example

### Prepare the data

In [42]:

# setup features
cat_features=["drug", "cell_line_id"]
cont_features=["drug_conc"]

# create train and test datasets
train_dataset = AdataCVAEWrapper(adata_train, cat_features, cont_features)
test_dataset = AdataCVAEWrapper(adata_test, cat_features, cont_features)

# create train and test loaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)


In [46]:
adata_train.shape

(900000, 2000)

In [43]:
# Setup
#dataset = AdataCVAEWrapper(adata_100m, cat_features=["drug", "cell_line_id"], cont_features=["drug_conc"])
#loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

# setup the model
model = CVAE(input_dim=adata_train.n_vars, cond_dim=train_dataset.cond.shape[1])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


CVAE(
  (encoder): Sequential(
    (0): Linear(in_features=2146, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=128, bias=True)
    (3): ReLU()
  )
  (fc_mu): Linear(in_features=128, out_features=32, bias=True)
  (fc_logvar): Linear(in_features=128, out_features=32, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=178, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=2000, bias=True)
    (5): ReLU()
  )
)

## Drug Info from PubChem

We also provide the pubchem IDs for the compounds in Tahoe, this can be used to querry additional information as needed.

### Train model (10 epochs)

In [45]:

# simple example of model train
for epoch in range(10):
    model.train()
    epoch_loss = 0
    for x_batch, c_batch in tqdm(loader):
        x_batch = x_batch.to(device)
        c_batch = c_batch.to(device)

        recon_x, mu, logvar = model(x_batch, c_batch)
        loss = loss_function(recon_x, x_batch, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.2f}")


  0%|                                                                                                                                                                             | 0/7813 [00:02<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x62856 and 2146x512)

## Drug Info from PubChem

We also provide the pubchem IDs for the compounds in Tahoe, this can be used to querry additional information as needed.

## Drug Info from PubChem

We also provide the pubchem IDs for the compounds in Tahoe, this can be used to querry additional information as needed.

### Evaluate the model

In [None]:

def evaluate_model(model, test_loader, device):
    model.eval()
    total_loss = 0
    all_recon = []
    all_original = []
    all_latent = []
    all_conditions = []
    
    with torch.no_grad():
        for x_batch, c_batch in tqdm(test_loader):
            x_batch = x_batch.to(device)
            c_batch = c_batch.to(device)
            
            # Forward pass
            recon_x, mu, logvar = model(x_batch, c_batch)
            loss = loss_function(recon_x, x_batch, mu, logvar)
            total_loss += loss.item()
            
            # Store results
            all_recon.append(recon_x.cpu().numpy())
            all_original.append(x_batch.cpu().numpy())
            all_latent.append(mu.cpu().numpy())
            all_conditions.append(c_batch.cpu().numpy())
    
    # Concatenate all batches
    all_recon = np.concatenate(all_recon, axis=0)
    all_original = np.concatenate(all_original, axis=0)
    all_latent = np.concatenate(all_latent, axis=0)
    all_conditions = np.concatenate(all_conditions, axis=0)
    
    return {
        'total_loss': total_loss,
        'recon': all_recon,
        'original': all_original,
        'latent': all_latent,
        'conditions': all_conditions
    }

### Run evalutations

In [None]:

# Run evaluation
eval_results = evaluate_model(model, test_loader, device)
print(f"Test Loss: {eval_results['total_loss']:.2f}")

# Calculate reconstruction metrics
from sklearn.metrics import mean_squared_error, r2_score

# Calculate MSE and R2 for each gene
mse_per_gene = mean_squared_error(eval_results['original'], eval_results['recon'], multioutput='raw_values')
r2_per_gene = r2_score(eval_results['original'], eval_results['recon'], multioutput='raw_values')

print(f"Average MSE per gene: {np.mean(mse_per_gene):.4f}")
print(f"Average R2 per gene: {np.mean(r2_per_gene):.4f}")

### Vizualise the results

In [None]:


# Get the drug indices from the conditions
drug_indices = np.argmax(eval_results['conditions'][:, :len(adata_test.obs['drug'].unique())], axis=1)
drugs = adata_test.obs['drug'].unique()[drug_indices]

# Plot first two dimensions of latent space colored by drug
plt.figure(figsize=(10, 8))
scatter = plt.scatter(eval_results['latent'][:, 0], 
                     eval_results['latent'][:, 1], 
                     c=drug_indices, 
                     cmap='tab20',
                     alpha=0.6)
plt.title('Latent Space Visualization (First 2 Dimensions)')
plt.xlabel('Latent Dimension 1')
plt.ylabel('Latent Dimension 2')
plt.colorbar(scatter, label='Drug')
plt.show()

# Plot reconstruction vs original for a few random genes
n_genes_to_plot = 4
random_genes = np.random.choice(eval_results['original'].shape[1], n_genes_to_plot, replace=False)

fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.ravel()

for idx, gene_idx in enumerate(random_genes):
    axes[idx].scatter(eval_results['original'][:, gene_idx], 
                     eval_results['recon'][:, gene_idx], 
                     alpha=0.5)
    axes[idx].plot([0, 1], [0, 1], 'r--')  # Diagonal line
    axes[idx].set_xlabel('Original Expression')
    axes[idx].set_ylabel('Reconstructed Expression')
    axes[idx].set_title(f'Gene {gene_idx}')

plt.tight_layout()
plt.show()