

## Mount Directory

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os

In [3]:
os.chdir('/content/drive/MyDrive/tf design')

## Dataloaders

In [4]:
import torch
import random
from collections import defaultdict
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

In [5]:
!pip install pytorch-lightning




In [6]:
import torch
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl

def pad_tensor(tensor, max_length, dim=0):
    tensor = torch.tensor(tensor) if isinstance(tensor, np.ndarray) else tensor
    pad_size = max_length - tensor.shape[dim]
    padding = (0, 0) * (tensor.dim() - dim - 1) + (0, pad_size)
    return torch.nn.functional.pad(tensor, padding)

class CellProteinDataset(torch.utils.data.Dataset):
    def __init__(self, joint_embeddings, esm_embeddings):
        assert len(joint_embeddings) == len(esm_embeddings), "Dictionaries must have the same length"

        self.joint_keys = list(joint_embeddings.keys())
        self.esm_keys = list(esm_embeddings.keys())

        self.joint_embeddings = joint_embeddings
        self.esm_embeddings = esm_embeddings

        # Compute the maximum embedding size
        self.max_joint_dim = max(tensor.shape[0] for tensor in joint_embeddings.values())
        self.max_esm_dim = max(tensor.shape[0] for tensor in esm_embeddings.values())

    def __len__(self):
        return len(self.joint_embeddings)

    def __getitem__(self, index):
        joint_key = self.joint_keys[index]
        esm_key = self.esm_keys[index]

        joint_embedding = pad_tensor(self.joint_embeddings[joint_key], self.max_joint_dim)
        esm_embedding = pad_tensor(self.esm_embeddings[esm_key], self.max_esm_dim)

        return {
            "cell_input": joint_embedding,
            "protein_input": esm_embedding
        }

class CellProteinCollator:
    def __call__(self, raw_batch):
        batch = {}

        cell_input_list = [v['cell_input'] for v in raw_batch]
        protein_input_list = [v['protein_input'] for v in raw_batch]

        batch['cell_input'] = torch.stack(cell_input_list)
        batch['protein_input'] = torch.stack(protein_input_list)

        return batch

class CellProteinDataModule(pl.LightningDataModule):
    def __init__(self, joint_embeddings, esm_embeddings, batch_size):
        super().__init__()

        dataset = CellProteinDataset(joint_embeddings, esm_embeddings)

        # Splitting data into train, test, and validation
        train_size = int(0.7 * len(dataset))
        val_size = int(0.15 * len(dataset))
        test_size = len(dataset) - train_size - val_size

        self.train_dataset, self.val_dataset, self.test_dataset = random_split(dataset, [train_size, val_size, test_size])

        self.batch_size = batch_size
        self.collator = CellProteinCollator()

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collator, shuffle=True,drop_last=True)

    def val_dataloader(self):
        full_batch = DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collator, shuffle=False,drop_last=True)
        #binary_batch = DataLoader(self.val_dataset, batch_size=2, collate_fn=self.collator, shuffle=False)
        return [full_batch]#, binary_batch]

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.collator, shuffle=False,drop_last=True)



In [7]:
import pytorch_lightning as pl


In [8]:
class CellProteinDataModule(pl.LightningDataModule):
    def __init__(self, joint_embeddings, esm_embeddings, batch_size, train_keys_list=None, val_test_keys_list=None):
        super().__init__()

        # Subset dictionaries based on provided keys
        train_joint_embeddings = {key: joint_embeddings[key] for key in train_keys_list} if train_keys_list else joint_embeddings
        train_esm_embeddings = {key: esm_embeddings[key] for key in train_keys_list} if train_keys_list else esm_embeddings

        val_test_joint_embeddings = {key: joint_embeddings[key] for key in val_test_keys_list} if val_test_keys_list else joint_embeddings
        val_test_esm_embeddings = {key: esm_embeddings[key] for key in val_test_keys_list} if val_test_keys_list else esm_embeddings

        # Create datasets
        self.train_dataset = CellProteinDataset(train_joint_embeddings, train_esm_embeddings)
        val_test_dataset = CellProteinDataset(val_test_joint_embeddings, val_test_esm_embeddings)

        # Splitting val_test_dataset into validation and test sets
        val_size = int(0.15 * len(val_test_dataset))
        test_size = len(val_test_dataset) - val_size

        self.val_dataset, self.test_dataset = random_split(val_test_dataset, [val_size, test_size])

        self.batch_size = batch_size
        self.collator = CellProteinCollator()

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collator, shuffle=True,drop_last=True)

    def val_dataloader(self):
        full_batch = DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collator, shuffle=False,drop_last=True)
        binary_batch = DataLoader(self.val_dataset, batch_size=2, collate_fn=self.collator, shuffle=False)
        return [full_batch,binary_batch]

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.collator, shuffle=False,drop_last=True)




#### load norman19 dataloaders

In [9]:
import pickle
from torch.utils.data import DataLoader


In [10]:
with open('train_subset.pkl', 'rb') as f:
    train_dataset_norman = pickle.load(f)

with open('test_subset.pkl', 'rb') as f:
    val_dataset_norman = pickle.load(f)

with open('val_subset.pkl', 'rb') as f:
    test_dataset_norman = pickle.load(f)


In [11]:
from torch.utils.data import Subset
train_indices = list(range(2500))
test_indices = list(range(1000))
val_indices = list(range(500))

# Create subsets
train_subset = Subset(train_dataset_norman, train_indices)
test_subset = Subset(test_dataset_norman, test_indices)
val_subset = Subset(val_dataset_norman, val_indices)

In [12]:
len(train_subset)

2500

In [13]:
type(train_dataset_norman)

torch.utils.data.dataset.Subset

In [14]:
with open('dataloader_params.pkl', 'rb') as f:
    dataloader_params = pickle.load(f)


In [15]:
dataloader_params

{'train': {'batch_size': 4, 'shuffle': 'True'},
 'val': {'batch_size': 4, 'shuffle': 'False'},
 'test': {'batch_size': 4, 'shuffle': 'False'}}

In [16]:
train_loader_norman = DataLoader(train_subset, **dataloader_params['train'])
val_loader_norman = DataLoader(val_subset, **dataloader_params['val'])
test_loader_norman = DataLoader(test_subset, **dataloader_params['test'])


In [17]:
from torch.utils.data import Subset

In [18]:
# Converting string 'True'/'False' to boolean
dataloader_params['train']['shuffle'] = dataloader_params['train']['shuffle'] == 'True'
dataloader_params['val']['shuffle'] = dataloader_params['val']['shuffle'] == 'True'
dataloader_params['test']['shuffle'] = dataloader_params['test']['shuffle'] == 'True'


In [19]:
import numpy as np

In [20]:
esm_dataset = train_loader_norman.dataset

In [21]:
esm_dataset[1]['protein_input'].shape

torch.Size([2078, 1280])

In [22]:
esm_dataset[1]['cell_input'].shape

torch.Size([8350, 20])

In [23]:
esm_dataset[1]

{'cell_input': tensor([[-1.5663,  1.1364,  0.6010,  ...,  0.1193, -0.1001,  0.0477],
         [-1.3502,  0.4448,  1.6567,  ...,  0.1193, -0.1001,  0.0477],
         [ 0.0071,  0.2452, -0.1685,  ...,  0.1193, -0.1001,  0.0477],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]),
 'protein_input': tensor([[-0.0160, -0.1942,  0.1267,  ...,  0.0727, -0.1567,  0.1401],
         [ 0.0547,  0.1077,  0.1109,  ...,  0.0577,  0.0728,  0.1995],
         [ 0.0195,  0.0153,  0.2484,  ..., -0.0064,  0.1929, -0.0309],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])}

## Models

### cVAE-no-attn

In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConditionalVAE(nn.Module):
    def __init__(self, hidden_dim=512, latent_dim=256, joint_embedding_dim=20, esm_embedding_dim=1280, dropout_rate=0.5, beta=1):
        super(ConditionalVAE, self).__init__()

        # Encoder
        self.fc_encode1 = nn.Linear(joint_embedding_dim*8350, hidden_dim)
        self.fc_encode1_bn = nn.BatchNorm1d(hidden_dim)
        self.fc_encode_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_encode_logvar = nn.Linear(hidden_dim, latent_dim)

        # Decoder
        self.fc_decode1 = nn.Linear(latent_dim, hidden_dim)
        self.fc_decode1_bn = nn.BatchNorm1d(hidden_dim)
        self.fc_decode2 = nn.Linear(hidden_dim, esm_embedding_dim)

        # Final reshaping layer
        self.final_reshape = nn.Linear(esm_embedding_dim, 2078 * esm_embedding_dim)

        self.dropout = nn.Dropout(dropout_rate)
        self.leaky_relu = nn.LeakyReLU(0.1)
        self.beta = beta

    def encode(self, joint_embedding):
        h1 = self.leaky_relu(self.fc_encode1(joint_embedding))
        h1 = self.fc_encode1_bn(h1)
        h1 = self.dropout(h1)
        return self.fc_encode_mu(h1), self.fc_encode_logvar(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h2 = self.leaky_relu(self.fc_decode1(z))
        h2 = self.fc_decode1_bn(h2)
        h2 = self.dropout(h2)
        return self.fc_decode2(h2)

    def forward(self, joint_embedding):
        # Flatten joint_embedding for processing
        joint_embedding = joint_embedding.view(joint_embedding.size(0), -1)
        mu, logvar = self.encode(joint_embedding)
        z = self.reparameterize(mu, logvar)
        generated_esm_embedding = self.decode(z)
        print('after decode generated_esm_embedding shape', generated_esm_embedding.shape)
        generated_esm_embedding = self.final_reshape(generated_esm_embedding)
        generated_esm_embedding = generated_esm_embedding.view(-1, 2078, 1280)
        print('generated_esm_embedding shape', generated_esm_embedding.shape)
        return generated_esm_embedding, mu, logvar

    def loss_function(self, generated_esm_embedding, real_esm_embedding, mu, logvar):
        MSE = F.mse_loss(generated_esm_embedding, real_esm_embedding, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return MSE + self.beta * KLD, MSE, KLD


### cVAE-rnn-no-attn

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConditionalVAERNN(nn.Module):
    def __init__(self, hidden_dim=512, latent_dim=256, joint_embedding_dim=20, esm_embedding_dim=1280, rnn_hidden_dim=256, num_layers=2, dropout_rate=0.5, beta=1):
        super(ConditionalVAERNN, self).__init__()

        # RNN Encoder
        self.rnn_encoder = nn.RNN(joint_embedding_dim, rnn_hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout_rate)
        self.fc_encode_mu = nn.Linear(rnn_hidden_dim, latent_dim)
        self.fc_encode_logvar = nn.Linear(rnn_hidden_dim, latent_dim)

        # Decoder
        self.fc_decode1 = nn.Linear(latent_dim, rnn_hidden_dim)
        self.rnn_decoder = nn.RNN(rnn_hidden_dim, esm_embedding_dim, num_layers=num_layers, batch_first=True, dropout=dropout_rate)

        # Final reshaping layer
        self.final_reshape = nn.Linear(esm_embedding_dim, 2078 * esm_embedding_dim)

        self.beta = beta

    def encode(self, joint_embedding):
        _, h_n = self.rnn_encoder(joint_embedding)  # Use the last hidden state
        h_n = h_n[-1]  # Take the last layer's hidden state, shape: [batch_size, rnn_hidden_dim]
        return self.fc_encode_mu(h_n), self.fc_encode_logvar(h_n)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        z = z.unsqueeze(1).repeat(1, 2078, 1)  # Repeat z for each time step
        output, _ = self.rnn_decoder(z)
        return output

    def forward(self, joint_embedding):
        joint_embedding = joint_embedding.view(joint_embedding.size(0), -1, 20)
        mu, logvar = self.encode(joint_embedding)
        z = self.reparameterize(mu, logvar)
        print('z (reparameterize) shape:',z.shape)
        generated_esm_embedding = self.decode(z)
        print('after decode generated_esm_embedding shape', generated_esm_embedding.shape)
        return generated_esm_embedding, mu, logvar

    def loss_function(self, generated_esm_embedding, real_esm_embedding, mu, logvar):
        MSE = F.mse_loss(generated_esm_embedding, real_esm_embedding, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return MSE + self.beta * KLD, MSE, KLD


### cVAE-lstm-no-attn

In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConditionalVAELSTM(nn.Module):
    def __init__(self, hidden_dim=512, latent_dim=256, joint_embedding_dim=20, esm_embedding_dim=1280, lstm_hidden_dim=256, num_layers=2, dropout_rate=0.5, beta=1):
        super(ConditionalVAELSTM, self).__init__()

        # LSTM Encoder
        self.lstm_encoder = nn.LSTM(joint_embedding_dim, lstm_hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout_rate)
        self.fc_encode_mu = nn.Linear(lstm_hidden_dim, latent_dim)
        self.fc_encode_logvar = nn.Linear(lstm_hidden_dim, latent_dim)

        # Decoder
        self.fc_decode1 = nn.Linear(latent_dim, lstm_hidden_dim)
        self.lstm_decoder = nn.LSTM(lstm_hidden_dim, esm_embedding_dim, num_layers=num_layers, batch_first=True, dropout=dropout_rate)

        # Final reshaping layer
        self.final_reshape = nn.Linear(esm_embedding_dim, 2078 * esm_embedding_dim)

        self.beta = beta

    def encode(self, joint_embedding):
        _, (h_n, _) = self.lstm_encoder(joint_embedding)  # Use the last hidden state
        h_n = h_n[-1]  # Take the hidden state from the last LSTM layer
        return self.fc_encode_mu(h_n), self.fc_encode_logvar(h_n)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        z = z.unsqueeze(1).repeat(1, 2078, 1)  # Repeat z for each time step
        output, _ = self.lstm_decoder(z)
        return output

    def forward(self, joint_embedding):
        joint_embedding = joint_embedding.view(joint_embedding.size(0), -1, 20)
        mu, logvar = self.encode(joint_embedding)
        z = self.reparameterize(mu, logvar)
        generated_esm_embedding = self.decode(z)
        print('generated_esm_embedding shape', generated_esm_embedding.shape)
        return generated_esm_embedding, mu, logvar

    def loss_function(self, generated_esm_embedding, real_esm_embedding, mu, logvar):
        MSE = F.mse_loss(generated_esm_embedding, real_esm_embedding, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return MSE + self.beta * KLD, MSE, KLD


### cVAE-rnn-attn

In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RNNAttentionEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout_rate, attention_heads):
        super(RNNAttentionEncoder, self).__init__()
        self.rnn = nn.RNN(input_dim, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout_rate)
        self.attention = nn.MultiheadAttention(hidden_dim, attention_heads, batch_first=True)
        self.hidden_dim = hidden_dim

    def forward(self, x):
        output, _ = self.rnn(x)
        attn_output, _ = self.attention(output, output, output)
        return attn_output.mean(dim=1)

class ConditionalVAERNNAttention(nn.Module):
    def __init__(self, hidden_dim=512, latent_dim=256, joint_embedding_dim=20, esm_embedding_dim=1280, rnn_hidden_dim=256, num_layers=2, dropout_rate=0.5, beta=1, attention_heads=4):
        super(ConditionalVAERNNAttention, self).__init__()

        # RNN Attention Encoder
        self.rnn_attention_encoder = RNNAttentionEncoder(joint_embedding_dim, rnn_hidden_dim, num_layers, dropout_rate, attention_heads)
        self.fc_encode_mu = nn.Linear(rnn_hidden_dim, latent_dim)
        self.fc_encode_logvar = nn.Linear(rnn_hidden_dim, latent_dim)

        # Decoder
        self.fc_decode1 = nn.Linear(latent_dim, hidden_dim)
        self.fc_decode2 = nn.Linear(hidden_dim, esm_embedding_dim)

        # Final reshaping layer
        self.final_reshape = nn.Linear(esm_embedding_dim, 2078 * esm_embedding_dim)

        self.beta = beta

    def encode(self, joint_embedding):
        attn_output = self.rnn_attention_encoder(joint_embedding)
        return self.fc_encode_mu(attn_output), self.fc_encode_logvar(attn_output)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h2 = F.relu(self.fc_decode1(z))
        h2 = self.fc_decode2(h2)
        return h2

    def forward(self, joint_embedding):
        joint_embedding = joint_embedding.view(joint_embedding.size(0), -1, 20)
        mu, logvar = self.encode(joint_embedding)
        z = self.reparameterize(mu, logvar)
        generated_esm_embedding = self.decode(z)
        print('after decode generated_esm_embedding shape', generated_esm_embedding.shape)
        generated_esm_embedding = self.final_reshape(generated_esm_embedding)
        generated_esm_embedding = generated_esm_embedding.view(-1, 2078, 1280)
        print('generated_esm_embedding shape', generated_esm_embedding.shape)
        return generated_esm_embedding, mu, logvar

    def loss_function(self, generated_esm_embedding, real_esm_embedding, mu, logvar):
        MSE = F.mse_loss(generated_esm_embedding, real_esm_embedding, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return MSE + self.beta * KLD, MSE, KLD


### cVAE-lstm-attn

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMAttentionEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout_rate, attention_heads):
        super(LSTMAttentionEncoder, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout_rate)
        self.attention = nn.MultiheadAttention(hidden_dim, attention_heads, batch_first=True)

    def forward(self, x):
        output, _ = self.lstm(x)
        attn_output, _ = self.attention(output, output, output)
        return attn_output.mean(dim=1)

class ConditionalVAELSTMAttention(nn.Module):
    def __init__(self, hidden_dim=512, latent_dim=256, joint_embedding_dim=20, esm_embedding_dim=1280, lstm_hidden_dim=256, num_layers=2, dropout_rate=0.5, beta=1, attention_heads=4):
        super(ConditionalVAELSTMAttention, self).__init__()

        # LSTM Attention Encoder
        self.lstm_attention_encoder = LSTMAttentionEncoder(joint_embedding_dim, lstm_hidden_dim, num_layers, dropout_rate, attention_heads)
        self.fc_encode_mu = nn.Linear(lstm_hidden_dim, latent_dim)
        self.fc_encode_logvar = nn.Linear(lstm_hidden_dim, latent_dim)

        # Decoder
        self.fc_decode1 = nn.Linear(latent_dim, hidden_dim)
        self.fc_decode2 = nn.Linear(hidden_dim, esm_embedding_dim)

        # Final reshaping layer
        self.final_reshape = nn.Linear(esm_embedding_dim, 2078 * esm_embedding_dim)

        self.beta = beta

    def encode(self, joint_embedding):
        attn_output = self.lstm_attention_encoder(joint_embedding)
        return self.fc_encode_mu(attn_output), self.fc_encode_logvar(attn_output)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h2 = F.relu(self.fc_decode1(z))
        h2 = self.fc_decode2(h2)
        return h2

    def forward(self, joint_embedding):
        joint_embedding = joint_embedding.view(joint_embedding.size(0), -1, 20)
        mu, logvar = self.encode(joint_embedding)
        z = self.reparameterize(mu, logvar)
        generated_esm_embedding = self.decode(z)
        print('after decode generated_esm_embedding shape', generated_esm_embedding.shape)
        generated_esm_embedding = self.final_reshape(generated_esm_embedding)
        generated_esm_embedding = generated_esm_embedding.view(-1, 2078, 1280)
        print('generated_esm_embedding shape', generated_esm_embedding.shape)
        return generated_esm_embedding, mu, logvar

    def loss_function(self, generated_esm_embedding, real_esm_embedding, mu, logvar):
        MSE = F.mse_loss(generated_esm_embedding, real_esm_embedding, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return MSE + self.beta * KLD, MSE, KLD


## Initialize Models

In [33]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gc
from torch.utils.data import DataLoader

# Define all model classes (ConditionalVAE, ConditionalVAERNNAttention, etc.)

# Define the init_weights function
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0.01)

# Model configurations
model_configs = [
    {"class": ConditionalVAE, "params": {"hidden_dim": 512, "latent_dim": 256, "joint_embedding_dim": 20, "esm_embedding_dim": 1280, "dropout_rate": 0.5, "beta": 1}},
    {"class": ConditionalVAERNN, "params": {"hidden_dim": 512, "latent_dim": 256, "joint_embedding_dim": 20, "esm_embedding_dim": 1280, "rnn_hidden_dim": 256, "num_layers": 2, "dropout_rate": 0.5, "beta": 1}},
    {"class": ConditionalVAELSTM, "params": {"hidden_dim": 512, "latent_dim": 256, "joint_embedding_dim": 20, "esm_embedding_dim": 1280, "lstm_hidden_dim": 256, "num_layers": 2, "dropout_rate": 0.5, "beta": 1}},
    {"class": ConditionalVAERNNAttention, "params": {"hidden_dim": 512, "latent_dim": 256, "joint_embedding_dim": 20, "esm_embedding_dim": 1280, "rnn_hidden_dim": 256, "num_layers": 2, "dropout_rate": 0.5, "beta": 1, "attention_heads": 4}},
    {"class": ConditionalVAELSTMAttention, "params": {"hidden_dim": 512, "latent_dim": 256, "joint_embedding_dim": 20, "esm_embedding_dim": 1280, "lstm_hidden_dim": 256, "num_layers": 2, "dropout_rate": 0.5, "beta": 1, "attention_heads": 4}}]


# Define the function for saving checkpoints
def save_checkpoint(state, filename="checkpoint.pth.tar"):
    torch.save(state, filename)

# Define your data loaders
# train_loader_norman, val_loader_norman, test_loader_norman should be defined

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
learning_rate = 1e-3

# Initialize and train models
losses = []

for config in model_configs:
    # Initialize model, optimizer, and scheduler
    print(config['class'])
    model = config["class"](**config["params"]).to(device)
    model.apply(init_weights)
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5, verbose=True)

    # Initialize best validation loss for checkpointing
    best_val_loss = float('inf')

    # Train on a single batch
    model.train()
    train_batch = next(iter(train_loader_norman))
    joint_embedding, esm_embedding = train_batch['cell_input'].to(device), train_batch['protein_input'].to(device)
    optimizer.zero_grad()
    recon_emb, mu, logvar = model(joint_embedding)
    loss, MSE, KLD = model.loss_function(recon_emb, esm_embedding, mu, logvar)
    print('loss:',loss)
    print('MSE:',MSE)
    print('KLD:',KLD)
    loss.backward()
    optimizer.step()

    # Validation on a single batch
    model.eval()
    val_batch = next(iter(val_loader_norman))
    joint_embedding, esm_embedding = val_batch['cell_input'].to(device), val_batch['protein_input'].to(device)
    with torch.no_grad():
        recon_emb, mu, logvar = model(joint_embedding)
        val_loss, _, _ = model.loss_function(recon_emb, esm_embedding, mu, logvar)
        print('val loss:',val_loss)
    scheduler.step(val_loss)

    # Checkpointing
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_checkpoint({
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'best_val_loss': best_val_loss,
        }, filename=f"checkpoint_{config['class'].__name__}.pth.tar")

    # Record losses
    losses.append({
        'model': config['class'].__name__,
        'train_loss': loss.item(),
        'val_loss': val_loss
    })

    # Clear memory
    del model, optimizer, scheduler
    torch.cuda.empty_cache()
    gc.collect()

# Print summary
for loss_record in losses:
    print(f"{loss_record['model']}: Train Loss: {loss_record['train_loss']:.4f}, Val Loss: {loss_record['val_loss']:.4f}")


<class '__main__.ConditionalVAERNN'>


OutOfMemoryError: ignored

In [32]:
# Clear memory
del model, optimizer, scheduler
torch.cuda.empty_cache()
gc.collect()


235

In [None]:
config

In [None]:
model