

### Mount Directory

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os

In [3]:
os.chdir('/content/drive/MyDrive/tf design')

### Dataloaders

In [4]:
import torch
import random
from collections import defaultdict
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

In [5]:
!pip install pytorch-lightning


Collecting pytorch-lightning
  Downloading pytorch_lightning-2.1.2-py3-none-any.whl (776 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m776.9/776.9 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.2.0-py3-none-any.whl (805 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m47.7 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics, pytorch-lightning
Successfully installed lightning-utilities-0.10.0 pytorch-lightning-2.1.2 torchmetrics-1.2.0


In [6]:
import torch
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl

def pad_tensor(tensor, max_length, dim=0):
    tensor = torch.tensor(tensor) if isinstance(tensor, np.ndarray) else tensor
    pad_size = max_length - tensor.shape[dim]
    padding = (0, 0) * (tensor.dim() - dim - 1) + (0, pad_size)
    return torch.nn.functional.pad(tensor, padding)

class CellProteinDataset(torch.utils.data.Dataset):
    def __init__(self, joint_embeddings, esm_embeddings):
        assert len(joint_embeddings) == len(esm_embeddings), "Dictionaries must have the same length"

        self.joint_keys = list(joint_embeddings.keys())
        self.esm_keys = list(esm_embeddings.keys())

        self.joint_embeddings = joint_embeddings
        self.esm_embeddings = esm_embeddings

        # Compute the maximum embedding size
        self.max_joint_dim = max(tensor.shape[0] for tensor in joint_embeddings.values())
        self.max_esm_dim = max(tensor.shape[0] for tensor in esm_embeddings.values())

    def __len__(self):
        return len(self.joint_embeddings)

    def __getitem__(self, index):
        joint_key = self.joint_keys[index]
        esm_key = self.esm_keys[index]

        joint_embedding = pad_tensor(self.joint_embeddings[joint_key], self.max_joint_dim)
        esm_embedding = pad_tensor(self.esm_embeddings[esm_key], self.max_esm_dim)

        return {
            "cell_input": joint_embedding,
            "protein_input": esm_embedding
        }

class CellProteinCollator:
    def __call__(self, raw_batch):
        batch = {}

        cell_input_list = [v['cell_input'] for v in raw_batch]
        protein_input_list = [v['protein_input'] for v in raw_batch]

        batch['cell_input'] = torch.stack(cell_input_list)
        batch['protein_input'] = torch.stack(protein_input_list)

        return batch

class CellProteinDataModule(pl.LightningDataModule):
    def __init__(self, joint_embeddings, esm_embeddings, batch_size):
        super().__init__()

        dataset = CellProteinDataset(joint_embeddings, esm_embeddings)

        # Splitting data into train, test, and validation
        train_size = int(0.7 * len(dataset))
        val_size = int(0.15 * len(dataset))
        test_size = len(dataset) - train_size - val_size

        self.train_dataset, self.val_dataset, self.test_dataset = random_split(dataset, [train_size, val_size, test_size])

        self.batch_size = batch_size
        self.collator = CellProteinCollator()

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collator, shuffle=True,drop_last=True)

    def val_dataloader(self):
        full_batch = DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collator, shuffle=False,drop_last=True)
        #binary_batch = DataLoader(self.val_dataset, batch_size=2, collate_fn=self.collator, shuffle=False)
        return [full_batch]#, binary_batch]

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.collator, shuffle=False,drop_last=True)



In [7]:
import pytorch_lightning as pl


In [8]:
class CellProteinDataModule(pl.LightningDataModule):
    def __init__(self, joint_embeddings, esm_embeddings, batch_size, train_keys_list=None, val_test_keys_list=None):
        super().__init__()

        # Subset dictionaries based on provided keys
        train_joint_embeddings = {key: joint_embeddings[key] for key in train_keys_list} if train_keys_list else joint_embeddings
        train_esm_embeddings = {key: esm_embeddings[key] for key in train_keys_list} if train_keys_list else esm_embeddings

        val_test_joint_embeddings = {key: joint_embeddings[key] for key in val_test_keys_list} if val_test_keys_list else joint_embeddings
        val_test_esm_embeddings = {key: esm_embeddings[key] for key in val_test_keys_list} if val_test_keys_list else esm_embeddings

        # Create datasets
        self.train_dataset = CellProteinDataset(train_joint_embeddings, train_esm_embeddings)
        val_test_dataset = CellProteinDataset(val_test_joint_embeddings, val_test_esm_embeddings)

        # Splitting val_test_dataset into validation and test sets
        val_size = int(0.15 * len(val_test_dataset))
        test_size = len(val_test_dataset) - val_size

        self.val_dataset, self.test_dataset = random_split(val_test_dataset, [val_size, test_size])

        self.batch_size = batch_size
        self.collator = CellProteinCollator()

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collator, shuffle=True,drop_last=True)

    def val_dataloader(self):
        full_batch = DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collator, shuffle=False,drop_last=True)
        binary_batch = DataLoader(self.val_dataset, batch_size=2, collate_fn=self.collator, shuffle=False)
        return [full_batch,binary_batch]

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.collator, shuffle=False,drop_last=True)




#### load norman19 dataloaders

In [9]:
import pickle
from torch.utils.data import DataLoader


In [10]:
!ls

 1010_cvae_gen_v1.csv
 1010_cvae_gen_v1.gsheet
 1017_cvae_gen_v1.csv
 1017_cvae_gen_v1.gsheet
 200_perts_subset_df_sequences.csv
 bmcite_demo.rds
 bone_marrow_beta_0.0_generated_dict.pkl
 bone_marrow_beta_0.1_generated_dict.pkl
 bone_marrow_beta_0.2_generated_dict.pkl
 bone_marrow_beta_0.3_generated_dict.pkl
 bone_marrow_beta_0.4_generated_dict.pkl
 bone_marrow_beta_0.5_generated_dict.pkl
 bone_marrow_beta_0.6_generated_dict.pkl
 bone_marrow_beta_0.7_generated_dict.pkl
 bone_marrow_beta_0.8_generated_dict.pkl
 bone_marrow_beta_0.9_generated_dict.pkl
 bone_marrow_beta_1.0_generated_dict.pkl
 bone_marrow_combon_generation
 bone_marrow_iteration_090823-attention-beta1_generated_dict.pkl
 bone_marrow_iteration_092623-attention-beta1_generated_dict.pkl
 bone_marrow_iteration_092723-ml-beta1_generated_dict.pkl
 bone_marrow_iteration1_generated_dict.pkl
 bone_marrow_iteration2_generated_dict.pkl
 bone_marrow_iteration3_generated_dict_attention_standard.pkl
 bone_marrow_iteration3_generated_di

In [11]:
with open('train_dataset_norman_ro.pkl', 'rb') as f:
    train_dataset_norman = pickle.load(f)

with open('val_dataset_norman_ro.pkl', 'rb') as f:
    val_dataset_norman = pickle.load(f)

with open('test_dataset_norman_ro.pkl', 'rb') as f:
    test_dataset_norman = pickle.load(f)


In [12]:
from torch.utils.data import Subset
train_indices = list(range(30))
test_indices = list(range(15))
val_indices = list(range(15))

# Create subsets
train_subset = Subset(train_dataset_norman, train_indices)
test_subset = Subset(test_dataset_norman, test_indices)
val_subset = Subset(val_dataset_norman, val_indices)

In [13]:
len(train_subset)

30

In [14]:
type(train_dataset_norman)

__main__.CellProteinDataset

In [15]:
with open('dataloader_params.pkl', 'rb') as f:
    dataloader_params = pickle.load(f)


In [16]:
dataloader_params

{'train': {'batch_size': 4, 'shuffle': 'True'},
 'val': {'batch_size': 4, 'shuffle': 'False'},
 'test': {'batch_size': 4, 'shuffle': 'False'}}

In [17]:
train_loader_norman = DataLoader(train_subset, **dataloader_params['train'])
val_loader_norman = DataLoader(val_subset, **dataloader_params['val'])
test_loader_norman = DataLoader(test_subset, **dataloader_params['test'])


In [18]:
from torch.utils.data import Subset

In [19]:
# Converting string 'True'/'False' to boolean
dataloader_params['train']['shuffle'] = dataloader_params['train']['shuffle'] == 'True'
dataloader_params['val']['shuffle'] = dataloader_params['val']['shuffle'] == 'True'
dataloader_params['test']['shuffle'] = dataloader_params['test']['shuffle'] == 'True'


In [20]:
import numpy as np

In [21]:
esm_dataset = train_loader_norman.dataset

In [22]:
esm_dataset[1]['protein_input'].shape

torch.Size([2078, 1280])

In [23]:
esm_dataset[1]

{'cell_input': tensor([[-1.5663,  1.1364,  0.6010,  ...,  0.1193, -0.1001,  0.0477],
         [-1.3502,  0.4448,  1.6567,  ...,  0.1193, -0.1001,  0.0477],
         [ 0.0071,  0.2452, -0.1685,  ...,  0.1193, -0.1001,  0.0477],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]),
 'protein_input': tensor([[-0.0160, -0.1942,  0.1267,  ...,  0.0727, -0.1567,  0.1401],
         [ 0.0547,  0.1077,  0.1109,  ...,  0.0577,  0.0728,  0.1995],
         [ 0.0195,  0.0153,  0.2484,  ..., -0.0064,  0.1929, -0.0309],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])}

### PCA of train,val,test dataset -- view clusters

In [None]:
type(train_dataset_norman)

In [None]:
import numpy as np
import torch
from sklearn.decomposition import IncrementalPCA
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# Function to estimate memory usage of a numpy array
def estimate_memory_usage(array, unit='bytes'):
    memory_bytes = array.nbytes
    unit_factor = {'bytes': 1, 'kilobytes': 1024, 'megabytes': 1024**2, 'gigabytes': 1024**3}
    return memory_bytes / unit_factor[unit]

# Function to retrieve a sample batch from a DataLoader
def get_sample_batch(dataloader, key, sample_size=4):
    accumulated_data = []
    accumulated_size = 0
    for batch in dataloader:
        # Flatten the data except for the batch dimension
        data = batch[key].reshape(batch[key].shape[0], -1).numpy()
        accumulated_data.append(data)
        accumulated_size += data.shape[0]
        if accumulated_size >= sample_size:
            break
    # Concatenate all collected data and return only the required sample size
    accumulated_data = np.concatenate(accumulated_data, axis=0)
    return accumulated_data[:sample_size]

# Function to perform incremental PCA on a single feature type
def incremental_pca(dataloader, feature_key, scaler, ipca, max_batch_size):
    for i, batch in enumerate(dataloader):
        data = batch[feature_key].reshape(batch[feature_key].shape[0], -1).numpy()
        print(f"Batch {i} size: {data.shape[0]}")  # Print the size of each batch
        if data.shape[0] < 2:
            print(f"Skipping batch {i} because it's smaller than n_components")
            continue  # Skip this batch
        for start_idx in range(0, data.shape[0], max_batch_size):
            end_idx = start_idx + max_batch_size
            batch_data = data[start_idx:end_idx]
            batch_data = scaler.transform(batch_data)  # Standardize the features
            ipca.partial_fit(batch_data)  # Partial fit on the standardized features


# Function to transform features using the fitted scaler and PCA
def transform_features(dataloader, feature_key, scaler, ipca, max_batch_size):
    transformed_data = []
    for batch in dataloader:
        data = batch[feature_key].reshape(batch[feature_key].shape[0], -1).numpy()
        for start_idx in range(0, data.shape[0], max_batch_size):
            end_idx = start_idx + max_batch_size
            batch_data = data[start_idx:end_idx]
            batch_data = scaler.transform(batch_data)  # Standardize the features
            transformed_batch = ipca.transform(batch_data)  # Apply PCA transformation
            transformed_data.append(transformed_batch)
    return np.concatenate(transformed_data, axis=0)

# Get a single batch to estimate memory usage
single_batch = next(iter(train_loader_norman))
cell_input_sample = single_batch['cell_input'].reshape(single_batch['cell_input'].shape[0], -1).numpy()
protein_input_sample = single_batch['protein_input'].reshape(single_batch['protein_input'].shape[0], -1).numpy()

# Estimate memory usage for each type of input
cell_memory_usage = estimate_memory_usage(cell_input_sample, unit='bytes')
protein_memory_usage = estimate_memory_usage(protein_input_sample, unit='bytes')

# Define an acceptable memory usage per batch, e.g., 1 GB in a system with 8 GB of RAM
acceptable_memory_per_batch = 1 * 1024**3  # in bytes

# Calculate the maximum allowable batch size for each input type
max_batch_size_cell = acceptable_memory_per_batch // cell_memory_usage
max_batch_size_protein = acceptable_memory_per_batch // protein_memory_usage

# Choose the smaller batch size to ensure both fit into memory
max_batch_size = min(max_batch_size_cell, max_batch_size_protein)

# Ensure the max_batch_size is an integer and at least 1
max_batch_size = int(max(4, max_batch_size))

# Initialize StandardScaler and IncrementalPCA for each input type
scaler_cell = StandardScaler()
scaler_protein = StandardScaler()
ipca_cell = IncrementalPCA(n_components=2)
ipca_protein = IncrementalPCA(n_components=2)

# Fit the scalers on a sample batch
memory_for_scaler_fitting = estimate_memory_usage(single_batch['cell_input'], unit='bytes')
sample_size_to_fit_scaler = int(memory_for_scaler_fitting // cell_memory_usage)
sample_size_to_fit_scaler = int(max(4, min(sample_size_to_fit_scaler, len(train_loader_norman.dataset))))

scaler_cell.fit(get_sample_batch(train_loader_norman, 'cell_input', sample_size_to_fit_scaler))
scaler_protein.fit(get_sample_batch(train_loader_norman, 'protein_input', sample_size_to_fit_scaler))

# Perform the incremental PCA on the dataloaders for each feature type
for dataloader in [train_loader_norman, val_loader_norman, test_loader_norman]:
    incremental_pca(dataloader, 'cell_input', scaler_cell, ipca_cell, max_batch_size)
    incremental_pca(dataloader, 'protein_input', scaler_protein, ipca_protein, max_batch_size)

# Transform features and get the PCA results for plotting
transformed_data_cell = transform_features(train_loader_norman, 'cell_input', scaler_cell, ipca_cell, max_batch_size)
transformed_data_protein = transform_features(train_loader_norman, 'protein_input', scaler_protein, ipca_protein, max_batch_size)

# Plot the PCA results
plt.figure(figsize=(8, 6))
plt.scatter(transformed_data_cell[:, 0], transformed_data_cell[:, 1], alpha=0.5, label='Cell Input')
plt.scatter(transformed_data_protein[:, 0], transformed_data_protein[:, 1], alpha=0.5, label='Protein Input')
plt.title('PCA of Cell State and ESM-2 Embeddings')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.legend()
plt.show()


In [None]:
# Plot the PCA results
plt.figure(figsize=(8, 6))
plt.scatter(transformed_data_cell[:, 0], transformed_data_cell[:, 1], alpha=0.5, label='Cell Input')
plt.title('PCA of Cell State and ESM-2 Embeddings')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.legend()
plt.show()

In [None]:
total_number_of_samples

### VAE -- cell state condition, outputs protein sequence

In [24]:
import numpy as np

In [25]:
def pad_tensor(tensor, max_length, dim=0):
    if isinstance(tensor, np.ndarray):
        tensor = torch.tensor(tensor)

    pad_size = max_length - tensor.shape[dim]
    padding = (0, 0) * (tensor.dim() - dim - 1) + (0, pad_size)
    return torch.nn.functional.pad(tensor, padding)


In [26]:
esm_embedding_dim = 1280
joint_embedding_dim = 20

In [27]:
print(esm_embedding_dim,joint_embedding_dim)

1280 20


In [28]:
### load esm-2 embedding library before generation!!

In [29]:
import pandas as pd
from tqdm import tqdm

In [30]:
!pip install fair-esm
##setting up ESM
import torch
import esm
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()  # disables dropout for deterministic results
esm_model.cuda() #push model to gpu

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt


ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [31]:
# list = list("ARNDCEQGHILKMFPSTWYV")

In [32]:
# len(list)

In [33]:
import torch.nn as nn
import torch.nn.functional as F

class ReshapeEmbedding(nn.Module):
    def __init__(self, from_seq_len, to_seq_len, embed_dim):
        super(ReshapeEmbedding, self).__init__()
        self.projection = nn.Linear(from_seq_len, to_seq_len)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.projection(x)
        x = x.permute(0, 2, 1)
        return x


### no attn

In [None]:

class ConditionalVAE(nn.Module):
    def __init__(self, esm_model, esm_embedding_dim=1280, hidden_dim=512, latent_dim=256, joint_embedding_dim=20, dropout_rate=0.5, beta=1):
        super().__init__()

        self.esm_model = esm_model
        self.reshaper = ReshapeEmbedding(8350, 2078, joint_embedding_dim)
        seq_len = 2078

        # Encoder
        self.fc_encode1 = nn.Linear(seq_len * (esm_embedding_dim + joint_embedding_dim), hidden_dim)
        self.fc_encode1_bn = nn.BatchNorm1d(hidden_dim)
        self.fc_encode_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_encode_logvar = nn.Linear(hidden_dim, latent_dim)

        # Decoder
        self.fc_decode1 = nn.Linear(latent_dim + seq_len * joint_embedding_dim, hidden_dim)
        self.fc_decode1_bn = nn.BatchNorm1d(hidden_dim)
        self.fc_decode2 = nn.Linear(hidden_dim, seq_len * esm_embedding_dim)

        self.dropout = nn.Dropout(dropout_rate)
        self.leaky_relu = nn.LeakyReLU(0.1)
        self.beta = beta

    def encode(self, concatenated_cell_protein_embedding):
        h1 = self.leaky_relu(self.fc_encode1(concatenated_cell_protein_embedding))
        h1 = self.fc_encode1_bn(h1)
        h1 = self.dropout(h1)
        return self.fc_encode_mu(h1), self.fc_encode_logvar(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, joint_embedding):
        z_joint = torch.cat([z, joint_embedding.reshape(joint_embedding.size(0), -1)], dim=1)
        h2 = self.leaky_relu(self.fc_decode1(z_joint))
        h2 = self.fc_decode1_bn(h2)
        h2 = self.dropout(h2)
        embeddings = self.fc_decode2(h2)
        return embeddings.view(embeddings.size(0), -1, esm_embedding_dim)

    def embeddings_to_sequence(self, embeddings):
        aa_toks = alphabet.all_toks
        aa_idxs = [alphabet.get_idx(aa) for aa in aa_toks]
        aa_logits = self.esm_model.lm_head(embeddings)[:, aa_idxs]
        predictions = torch.argmax(aa_logits, dim=1).tolist()
        generated_peptides = [aa_toks[pred] for pred in predictions]
        return generated_peptides

    def forward(self, esm_embedding, joint_embedding):
        joint_embedding = self.reshaper(joint_embedding)

        # Padding esm_embedding to match the sequence length of joint_embedding
        pad_size = joint_embedding.size(1) - esm_embedding.size(1)
        if pad_size > 0:
            padding = torch.zeros((esm_embedding.size(0), pad_size, esm_embedding.size(2)), device=esm_embedding.device)
            print(esm_embedding.shape)
            esm_embedding = torch.cat([esm_embedding, padding], dim=1)
            print(esm_embedding.shape)


        concatenated_input = torch.cat([esm_embedding, joint_embedding], dim=2).view(esm_embedding.size(0), -1)
        mu, logvar = self.encode(concatenated_input)
        z = self.reparameterize(mu, logvar)
        reconstructed_protein = self.decode(z, joint_embedding)
        return reconstructed_protein, mu, logvar

    def loss_function(self, recon_x, x, mu, logvar):
        pad_size = recon_x.size(1) - x.size(1)
        if pad_size > 0:
            padding = torch.zeros((x.size(0), pad_size, x.size(2)), device=x.device)
            x = torch.cat([x, padding], dim=1)
        MSE = F.mse_loss(recon_x, x)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return MSE + self.beta * KLD, MSE, KLD



### with attn


> num_attention_heads: The number of attention heads.

> attention_head_size: The size of each attention head.

> attention_dropout_rate: The dropout rate applied to the attention scores.

> attention_weight_init_gain: The gain factor used for initializing attention weights.

In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class LinearBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearBlock, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.bn = nn.BatchNorm1d(output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        print('forward linear')
        return self.relu(self.bn(self.linear(x)))

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.lb1 = LinearBlock(input_dim, hidden_dim)
        self.lb2 = LinearBlock(hidden_dim, hidden_dim)
        self.lb3 = LinearBlock(hidden_dim, output_dim)

    def forward(self, x):
        x = self.lb1(x)
        x = self.lb2(x)
        x = self.lb3(x)
        print('forward MLP')
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, hidden_dim, dropout_rate=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.attention_head_size = int(hidden_dim / num_heads)
        self.all_head_size = self.num_heads * self.attention_head_size

        self.query = nn.Linear(hidden_dim, self.all_head_size)
        self.key = nn.Linear(hidden_dim, self.all_head_size)
        self.value = nn.Linear(hidden_dim, self.all_head_size)

        self.dropout = nn.Dropout(dropout_rate)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        print(x.shape)
        print('transpose done')
        return x.permute(0, 2, 1) # changed from 0,2,1,3 to 0,2,1

    def forward(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
        print('forward MHA transpose')

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        print('attn scores MHA forward')

        attention_probs = self.dropout(attention_probs)
        print('done dropout MHA')

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1).contiguous() # same permute change here
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        print('forward MHA done')
        return context_layer

class ConditionalVAE(nn.Module):
    def __init__(self, esm_model, esm_embedding_dim=1280, hidden_dim=512, latent_dim=256, joint_embedding_dim=20, dropout_rate=0.5, beta=1, num_attention_heads=4):
        super().__init__()

        self.esm_model = esm_model
        self.reshaper = ReshapeEmbedding(8350, 2078, joint_embedding_dim)  # Assuming this class is already defined elsewhere
        print('reshaper done...')
        seq_len = 2078

        # Encoder
        self.encoder_mlp = MLP(seq_len * (esm_embedding_dim + joint_embedding_dim), hidden_dim, hidden_dim)
        print('encoder mlp done..')
        self.attention = MultiHeadAttention(num_attention_heads, hidden_dim, dropout_rate)
        print('attention done...')
        self.fc_encode_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_encode_logvar = nn.Linear(hidden_dim, latent_dim)
        print('logvar and mu done (encoder)...')

        # Decoder
        self.decoder_mlp = MLP(latent_dim + seq_len * joint_embedding_dim, hidden_dim, hidden_dim)
        print('decoder mlp done...')
        self.fc_decode2 = nn.Linear(hidden_dim, seq_len * esm_embedding_dim)
        print('decode linear done...')
        self.sigmoid = nn.Sigmoid()
        print('decoder sigmoid done...')
        self.dropout = nn.Dropout(dropout_rate)
        print('dropout applied...')
        self.beta = beta
        print('beta appllied...')

    def encode(self, concatenated_cell_protein_embedding):
        h1 = self.encoder_mlp(concatenated_cell_protein_embedding)
        h1 = self.attention(h1)
        h1 = self.dropout(h1)
        return self.fc_encode_mu(h1), self.fc_encode_logvar(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, joint_embedding):
        z_joint = torch.cat([z, joint_embedding.reshape(joint_embedding.size(0), -1)], dim=1)
        h2 = self.decoder_mlp(z_joint)
        h2 = self.dropout(h2)
        recon_seq = self.fc_decode2(h2)
        recon_seq = self.sigmoid(recon_seq)
        return recon_seq.view(recon_seq.size(0), -1, esm_embedding_dim)

    def forward(self, esm_embedding, joint_embedding):
        # Reshape joint_embedding to match esm_embedding in sequence length
        joint_embedding = self.reshaper(joint_embedding)
        print('FORWARD RESHAPER DONE...')

        # Ensure esm_embedding has the same sequence length as joint_embedding
        pad_size = joint_embedding.size(1) - esm_embedding.size(1)
        if pad_size > 0:
            padding = torch.zeros((esm_embedding.size(0), pad_size, esm_embedding.size(2)), device=esm_embedding.device)
            esm_embedding = torch.cat([esm_embedding, padding], dim=1)

        # Concatenate embeddings and pass through the encoder to get latent variables
        concatenated_input = torch.cat([esm_embedding, joint_embedding], dim=2).flatten(start_dim=1)
        print('FORWARD CONCAT DONE...')
        mu, logvar = self.encode(concatenated_input)
        print('FORWARD ENCODE DONE...')
        z = self.reparameterize(mu, logvar)
        print('FORWARD REPARAMETERIZE DONE...')

        # Decode the latent variable to reconstruct the protein sequence
        reconstructed_protein = self.decode(z, joint_embedding)
        print('FORWARD DECODE DONE...')
        return reconstructed_protein, mu, logvar

    def loss_function(self, recon_x, x, mu, logvar):
        # Compute reconstruction loss as Mean Squared Error
        # Ensure recon_x and x have the same dimensions
        if recon_x.size(1) != x.size(1):
            min_seq_len = min(recon_x.size(1), x.size(1))
            recon_x = recon_x[:, :min_seq_len]
            x = x[:, :min_seq_len]
        MSE = F.mse_loss(recon_x, x, reduction='sum')
        print('mse calculated...')

        # Compute Kullback-Leibler Divergence (KLD)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        print('kld calculated...')

        # The total loss is the sum of MSE and KLD, with a weighting factor for KLD
        print('loss can be outputted...')
        return MSE + self.beta * KLD, MSE, KLD


# Initialize the model
model = ConditionalVAE(
    esm_model,
    esm_embedding_dim=1280,
    hidden_dim=512,
    latent_dim=256,
    joint_embedding_dim=20,
    dropout_rate=0.5,
    beta=1,
    num_attention_heads=4
)


reshaper done...
encoder mlp done..
attention done...
logvar and mu done (encoder)...
decoder mlp done...
decode linear done...
decoder sigmoid done...
dropout applied...
beta appllied...


### back to training code

In [35]:

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0.01)

learning_rate = 1e-3
dropout_rate = 0.5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model = ConditionalVAE(esm_model, esm_embedding_dim, 512, 256, joint_embedding_dim, dropout_rate).to(device)
model.apply(init_weights)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5, verbose=True)
model = model.to(device)


In [36]:
# Training Loop
num_epochs = 1
accumulation_steps = 10  # Define how many steps to wait before updating the weights

train_losses, val_losses, test_losses = [], [], []
train_mses, val_mses, test_mses = [], [], []
train_klds, val_klds, test_klds = [], [], []

In [37]:
train_loader_norman.dataset[1]['protein_input'].shape

torch.Size([2078, 1280])

### one batch

In [38]:
import torch

# Define the function for saving checkpoints
def save_checkpoint(state, filename="checkpoint.pth.tar"):
    torch.save(state, filename)

# Initialize best_val_loss to a very high value
best_val_loss = float('inf')

# Process one batch from the training set
model.train()
batch = next(iter(train_loader_norman))
print(len(batch))
joint_embedding = batch['cell_input'].to(device)
esm_embedding = batch['protein_input'].to(device)

optimizer.zero_grad()
recon_emb, mu, logvar = model(esm_embedding, joint_embedding)
loss, MSE, KLD = model.loss_function(recon_emb, esm_embedding, mu, logvar)
loss.backward()
optimizer.step()

train_loss = loss.item()
train_mse = MSE.item()
train_kld = KLD.item()

# Process one batch from the validation set
model.eval()
batch = next(iter(val_loader_norman))

joint_embedding = batch['cell_input'].to(device)
esm_embedding = batch['protein_input'].to(device)

with torch.no_grad():
    recon_emb, mu, logvar = model(esm_embedding, joint_embedding)
    loss, MSE, KLD = model.loss_function(recon_emb, esm_embedding, mu, logvar)

val_loss = loss.item()
val_mse = MSE.item()
val_kld = KLD.item()

# Scheduler step (if using learning rate scheduler)
scheduler.step(val_loss)

# Save the model if the validation loss improved
if val_loss < best_val_loss:
    best_val_loss = val_loss
    save_checkpoint({
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'best_val_loss': best_val_loss,
    }, filename="checkpoint_single_batch.pth.tar")

# Process one batch from the test set
batch = next(iter(test_loader_norman))

joint_embedding = batch['cell_input'].to(device)
esm_embedding = batch['protein_input'].to(device)

with torch.no_grad():
    recon_emb, mu, logvar = model(esm_embedding, joint_embedding)
    loss, MSE, KLD = model.loss_function(recon_emb, esm_embedding, mu, logvar)

test_loss = loss.item()
test_mse = MSE.item()
test_kld = KLD.item()

# Print Summary
print(f'Train Loss: {train_loss:.4f} | Validation Loss: {val_loss:.4f} | Test Loss: {test_loss:.4f}')


2
FORWARD RESHAPER DONE...
FORWARD CONCAT DONE...
forward linear
forward linear
forward linear
forward MLP
torch.Size([4, 4, 128])
transpose done
torch.Size([4, 4, 128])
transpose done
torch.Size([4, 4, 128])
transpose done
forward MHA transpose
attn scores MHA forward
done dropout MHA
forward MHA done
FORWARD ENCODE DONE...
FORWARD REPARAMETERIZE DONE...
forward linear
forward linear
forward linear
forward MLP
FORWARD DECODE DONE...
mse calculated...
kld calculated...
loss can be outputted...


OutOfMemoryError: ignored

### iterate batches

In [None]:
import torch

# Define the function for saving checkpoints
def save_checkpoint(state, filename="checkpoint.pth.tar"):
    torch.save(state, filename)

# Initialize best_val_loss to a very high value
best_val_loss = float('inf')

# Training Loop
for epoch in range(num_epochs):
    # Set model to training mode
    model.train()
    # Initialize loss accumulators
    train_loss, train_mse, train_kld = 0, 0, 0

    # Training step
    for batch_idx, batch in enumerate(train_loader_norman):
        if batch_idx == len(train_loader_norman) - 1:  # Skip the last batch
            continue

        print(batch_idx)
        # Move data to the appropriate device (GPU or CPU)
        joint_embedding = batch['cell_input'].to(device)
        esm_embedding = batch['protein_input'].to(device)

        # Zero the gradients
        optimizer.zero_grad()
        # Forward pass
        recon_emb, mu, logvar = model(esm_embedding, joint_embedding)
        # Calculate loss
        loss, MSE, KLD = model.loss_function(recon_emb, esm_embedding, mu, logvar)
        # Backward pass
        loss.backward()

        # Gradient accumulation step
        if (batch_idx + 1) % accumulation_steps == 0 or batch_idx == len(train_loader_norman) - 2:
            optimizer.step()
            optimizer.zero_grad()

        # Accumulate losses
        train_loss += loss.item()
        train_mse += MSE.item()
        train_kld += KLD.item()

    # Calculate average losses for this epoch
    train_losses.append(train_loss / len(train_loader_norman.dataset))
    train_mses.append(train_mse / len(train_loader_norman.dataset))
    train_klds.append(train_kld / len(train_loader_norman.dataset))

    # Validation step
    model.eval()
    val_loss, val_mse, val_kld = 0, 0, 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader_norman):
            if batch_idx == len(val_loader_norman) - 1:  # Skip the last batch
                continue

            joint_embedding = batch['cell_input'].to(device)
            esm_embedding = batch['protein_input'].to(device)

            recon_emb, mu, logvar = model(esm_embedding, joint_embedding)
            loss, MSE, KLD = model.loss_function(recon_emb, esm_embedding, mu, logvar)

            val_loss += loss.item()
            val_mse += MSE.item()
            val_kld += KLD.item()

    # Scheduler step (if using learning rate scheduler)
    scheduler.step(val_loss)

    # Save the model if the validation loss improved
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'best_val_loss': best_val_loss,
        }, filename=f"checkpoint_epoch_{epoch+1}.pth.tar")

    # Calculate average validation losses for this epoch
    val_losses.append(val_loss / len(val_loader_norman.dataset))
    val_mses.append(val_mse / len(val_loader_norman.dataset))
    val_klds.append(val_kld / len(val_loader_norman.dataset))

    # Print Epoch Summary
    print(f'Epoch {epoch+1}/{num_epochs} | '
          f'Train Loss: {train_losses[-1]:.4f} | '
          f'Validation Loss: {val_losses[-1]:.4f}')

# After training, evaluate on test set
model.eval()
test_loss, test_mse, test_kld = 0, 0, 0
with torch.no_grad():
    for batch_idx, batch in enumerate(test_loader_norman):
        if batch_idx == len(test_loader_norman) - 1:  # Skip the last batch
            continue

        joint_embedding = batch['cell_input'].to(device)
        esm_embedding = batch['protein_input'].to(device)

        recon_emb, mu, logvar = model(esm_embedding, joint_embedding)
        loss, MSE, KLD = model.loss_function(recon_emb, esm_embedding, mu, logvar)

        test_loss += loss.item()
        test_mse += MSE.item()
        test_kld += KLD.item()

# Calculate average test losses
test_losses.append(test_loss / len(test_loader_norman.dataset))
test_mses.append(test_mse / len(test_loader_norman.dataset))
test_klds.append(test_kld / len(test_loader_norman.dataset))

# Print Test Summary
print(f'Test Loss: {test_losses[-1]:.4f} | '
      f'Test MSE: {test_mses[-1]:.4f} | '
      f'Test KLD: {test_klds[-1]:.4f}')


### plot loss

In [None]:
import matplotlib.pyplot as plt

# Plotting
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train')
plt.plot(val_losses, label='Validation')
plt.title('Total Losses')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(train_mses, label='Train')
plt.plot(val_mses, label='Validation')
plt.title('MSE Losses')
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(train_klds, label='Train')
plt.plot(val_klds, label='Validation')
plt.title('KLD Losses')
plt.legend()

plt.tight_layout()
plt.show()



#### latent space viz

In [None]:
mu_values = []
with torch.no_grad():
    for batch in train_loader_norman:
        esm_embedding = batch['protein_input'].to(device)
        joint_embedding = batch['cell_input'].to(device)

        # Reshaping joint_embedding
        joint_embedding = model.reshaper(joint_embedding)

        # Padding esm_embedding
        pad_size = joint_embedding.size(1) - esm_embedding.size(1)
        if pad_size > 0:
            padding = torch.zeros((esm_embedding.size(0), pad_size, esm_embedding.size(2)), device=esm_embedding.device)
            esm_embedding = torch.cat([esm_embedding, padding], dim=1)

        # Concatenate the embeddings
        concatenated_input = torch.cat([esm_embedding, joint_embedding], dim=2).view(esm_embedding.size(0), -1)

        mu, _ = model.encode(concatenated_input)
        mu_values.extend(mu.cpu().numpy())

mu_values = np.array(mu_values)


In [None]:
from sklearn.decomposition import PCA

# Reduce dimensionality to 2D using PCA
pca = PCA(n_components=2)
mu_2d = pca.fit_transform(mu_values)


In [None]:
import matplotlib.pyplot as plt

plt.scatter(mu_2d[:, 0], mu_2d[:, 1], alpha=0.5)
plt.xlabel('PCA 1')
plt.ylabel('PCA 2')
plt.title('Latent Space Visualization')
plt.show()


In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, random_state=0)
mu_2d_tsne = tsne.fit_transform(mu_values)

plt.scatter(mu_2d_tsne[:, 0], mu_2d_tsne[:, 1], alpha=0.5)
plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.title('Latent Space Visualization with t-SNE')
plt.show()


### see how well reconstruction is happening

In [None]:
# Assuming your model instance is named 'model'
model.eval()  # Set the model to evaluation mode

embeddings_list = []

with torch.no_grad():  # Disable gradient computation
    for batch in test_loader_norman:
        esm_embedding = batch['protein_input'].to(device)
        joint_embedding = batch['cell_input'].to(device)

        reconstructed_embedding, _, _ = model(esm_embedding, joint_embedding)

        embeddings_list.append(reconstructed_embedding.cpu().numpy())

# Convert the list of batched embeddings to a single list or array
all_embeddings = np.concatenate(embeddings_list, axis=0)


In [None]:
def embeddings_to_sequence(embeddings, esm_model, alphabet):
    aa_toks = alphabet.all_toks
    aa_idxs = [alphabet.get_idx(aa) for aa in aa_toks]
    embeddings = embeddings.to(device)
    aa_logits = esm_model.lm_head(embeddings)[:, aa_idxs]
    predictions = torch.argmax(aa_logits, dim=-1).tolist()
    generated_peptides = ''.join([aa_toks[pred] for pred in predictions])  # concatenate the amino acids
    return generated_peptides

In [None]:
all_decoded_sequences = []

# Ensure your model is on the correct device
esm_model = esm_model.to(device)

# Convert the numpy array to a torch tensor
all_embeddings_tensor = torch.tensor(all_embeddings).to(device)

# Loop through each sequence in the batch and decode:
for idx in range(all_embeddings_tensor.size(0)):
    single_embedding = all_embeddings_tensor[idx]  # The embedding is already on the device
    decoded_seq = embeddings_to_sequence(single_embedding, esm_model, alphabet)
    all_decoded_sequences.append(decoded_seq)


In [None]:
len(all_decoded_sequences)

In [None]:
all_decoded_sequences[1]

In [None]:
import csv

# File path
output_path = "1017_cvae_gen_v1.csv"

# Save to CSV
with open(output_path, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["AA_Sequence"])  # Header
    for seq in all_decoded_sequences:
        writer.writerow([seq])

print(f"Sequences saved to {output_path}")


### **conditional generation with mlp**

In [None]:
import pickle

with open('latent_dict_bone_marrow.pkl', 'rb') as file:
    latent_dict = pickle.load(file)


In [None]:
latent_dict[('lymphocyte',
  'dendritic cell')].shape

In [None]:
### straight generation

In [None]:
def pad_tensor(tensor, max_length, dim=0):
    if isinstance(tensor, np.ndarray):
        tensor = torch.tensor(tensor)

    pad_size = max_length - tensor.shape[dim]
    padding = (0, 0) * (tensor.dim() - dim - 1) + (0, pad_size)
    return torch.nn.functional.pad(tensor, padding)


In [None]:
def generate_sequences(model, joint_embedding_tensor, esm_dataset, latent_dim=256, random_sample_size=100):
    all_sequences = []

    # Ensure the tensor has shape [batch_size, 20]
    joint_embedding_tensor = pad_tensor(joint_embedding_tensor, 20, dim=1)
    print(joint_embedding_tensor.shape)

    for single_joint_embedding in joint_embedding_tensor:
        # Randomly sample ESM embeddings
        random_indices = torch.randint(0, len(esm_dataset), (random_sample_size,))
        random_esm_embeddings = [esm_dataset[i]['protein_input'] for i in random_indices]
        random_esm_embeddings = torch.stack(random_esm_embeddings).to(device)

        # Repeat the single joint embedding to match the random_sample_size
        repeated_joint_embedding = single_joint_embedding.repeat(random_sample_size, 1)

        # Transform the repeated joint_embedding
        joint_embedding_transformed = model.mlp_transform(repeated_joint_embedding)

        # Concatenate the joint embedding with each of the ESM embeddings
        concatenated_tensor = torch.cat([random_esm_embeddings, joint_embedding_transformed], dim=1)

        # Get the latent space mu
        mu, _ = model.encode(concatenated_tensor)

        # Decode using the learned latent space
        with torch.no_grad():
            _, sequences = model.decode(mu, joint_embedding_transformed)

        # Collect the sequences for this joint embedding
        all_sequences.append(sequences)

    return all_sequences


In [None]:
esm_dataset = train_loader_norman.dataset

In [None]:
generated_dict = {}
for key, joint_embedding_tensor in latent_dict.items():
    sequences_list = generate_sequences(model, joint_embedding_tensor, esm_dataset, latent_dim=256)
    generated_dict[key] = sequences_list


In [None]:
## write pipeline to where priors are given from GRNboost (sampling -- top 10 TFs are given) -- encode those with the joint embedding OR write seq2seq model

In [None]:
import pickle

In [None]:
# Serialize and save the dictionary to a .pkl file
with open('generateBM1006.pkl', 'wb') as file:
    pickle.dump(generated_dict, file)


In [None]:
# generated_dict checks

In [None]:
print(generated_dict.keys(),len(generated_dict[('lymphocyte', 'common myeloid progenitor')]),type(generated_dict[('lymphocyte', 'common myeloid progenitor')][1]))

In [None]:
len(generated_dict[('lymphocyte', 'megakaryocyte')])