In [None]:
!pip install Levenshtein
!pip install einops
!pip install einops_exts
!pip install torch
!pip install transformers
!pip install tqdm
!pip install sentencepiece
!pip install black
!pip install fair-esm
!pip install wandb

Collecting Levenshtein
  Downloading Levenshtein-0.25.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (177 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m177.4/177.4 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rapidfuzz<4.0.0,>=3.8.0 (from Levenshtein)
  Downloading rapidfuzz-3.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m50.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, Levenshtein
Successfully installed Levenshtein-0.25.1 rapidfuzz-3.9.4
Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0
Collecting einops_exts
  Downloading einops_exts-0.0.4-py3-none-any.whl (3.9 kB)
Installi

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import re
import esm
from einops import rearrange, repeat
import math
import numpy as np
from torch import einsum
import wandb
wandb.login()
import os
os.chdir('/content/drive/MyDrive/Programmable Biology Group/Srikar/Code/proteins/flamingo-diffusion/data_dump/old_dat/')

# ESM Model Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model = esm_model.to(device)
esm_model.eval()
for param in esm_model.parameters():
    param.requires_grad = False

# Data Preprocessing
def preprocess_snp_data(file_path):
    snp_df = pd.read_csv(file_path)

    def transform_energy_scores(energy_scores):
        transformed_scores = []
        for score in energy_scores:
            score = re.sub(r'[\s\n]+', ',', score)
            score = re.sub(r'\[\s*,', '[', score)
            score = re.sub(r'^[\s,]+', '', score)
            transformed_scores.append(score)
        return transformed_scores

    snp_df['energy_scores'] = transform_energy_scores(snp_df['energy_scores'])
    snp_df['energy_scores_lengths'] = snp_df['energy_scores'].apply(
        lambda x: x.count(',') + 1 - (1 if x.startswith(',') else 0)
    )

    snp_df['peptide_source_RCSB_lengths'] = snp_df['peptide_source_RCSB'].apply(len)
    snp_df['protein_RCSB_lengths'] = snp_df['protein_RCSB'].apply(len)
    snp_df['protein_derived_seq_length'] = snp_df['protein_derived_sequence'].apply(len)
    snp_df['peptide_derived_seq_length'] = snp_df['peptide_derived_sequence'].apply(len)

    return snp_df

def filter_datasets(dataset):
    return dataset[dataset['protein_RCSB'] != dataset['peptide_source_RCSB']]

# Dataset Class
class ProteinInteractionDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        peptide_seq = row['peptide_derived_sequence']
        protein_seq = row['protein_derived_sequence']
        energy_scores = row['energy_scores']

        energy_scores = re.findall(r'-?\d+\.?\d*(?:e[-+]?\d+)?', energy_scores)
        energy_scores = [float(score) for score in energy_scores]
        energy_scores = self.one_hot_encode_energy_scores(energy_scores)

        # Convert energy scores to tensor
        energy_scores = torch.tensor(energy_scores, dtype=torch.float32)

        return energy_scores, protein_seq, peptide_seq

    @staticmethod
    def one_hot_encode_energy_scores(scores):
        return [1 if score <= -1 else 0 for score in scores]

# Model Components
class CNNBlock(nn.Module):
    def __init__(self, dim, kernel_size=3, padding=1):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
        self.norm = nn.LayerNorm(dim)
        self.act = nn.GELU()

    def forward(self, x):
        # x shape: (b, n, d)
        x = x.transpose(1, 2)  # (b, d, n)
        x = self.conv(x)
        x = x.transpose(1, 2)  # (b, n, d)
        x = self.norm(x)
        x = self.act(x)
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0)]

class FourierFeatureEmbedding(nn.Module):
    def __init__(self, num_frequencies, d_model):
        super().__init__()
        self.num_frequencies = num_frequencies
        self.d_model = d_model
        self.frequencies = nn.Parameter(torch.randn(num_frequencies) * 2 * math.pi)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0).float()
        features = positions.unsqueeze(-1) * self.frequencies.unsqueeze(0).unsqueeze(0)
        features = torch.cat([torch.sin(features), torch.cos(features)], dim=-1)
        return features.view(1, x.size(1), -1)[:, :, :self.d_model]


class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_seq_len=5000):
        super().__init__()
        self.d_model = d_model
        inv_freq = 1. / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, x):
        seq_len = x.shape[1]
        t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb[None, :, :]

def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(x, sincos):
    sin, cos = map(lambda t: t.repeat_interleave(2, dim=-1), sincos.chunk(2, dim=-1))
    return (x * cos) + (rotate_half(x) * sin)

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)
class PerceiverAttention(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm_media = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x, latents):
        x = self.norm_media(x)
        latents = self.norm_latents(latents)

        b, n, h = *x.shape[:2], self.heads

        q = self.to_q(latents)

        kv_input = torch.cat((x, latents), dim=-2)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))

        q = q * self.scale

        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        attn = sim.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class PerceiverResampler(nn.Module):
    def __init__(self, dim, depth, dim_head=64, heads=8, num_latents=64, ff_mult=4):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, dim))

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                FeedForward(dim=dim, mult=ff_mult)
            ]))

        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        b, n, _ = x.shape
        latents = repeat(self.latents, 'n d -> b n d', b=b)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        return self.norm(latents)

class GatedCrossAttentionBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        self.attn = PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads)
        self.attn_gate = nn.Parameter(torch.tensor([0.]))
        self.ff = FeedForward(dim=dim, mult=ff_mult)
        self.ff_gate = nn.Parameter(torch.tensor([0.]))

    def forward(self, x, media):
        x = self.attn(media, x) * self.attn_gate.tanh() + x
        x = self.ff(x) * self.ff_gate.tanh() + x
        return x

class AdaptiveLayerNorm(nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.ln = nn.LayerNorm(num_features)
        self.alpha = nn.Parameter(torch.ones(1, 1, num_features))
        self.beta = nn.Parameter(torch.zeros(1, 1, num_features))

    def forward(self, x):
        x = self.ln(x)
        return self.alpha * x + self.beta

class SqueezeExcitation(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1)
        return x * y.expand_as(x)

class FeaturePyramidNetwork(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, in_channels // 2, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels // 2, in_channels // 4, kernel_size=1)
        self.conv3 = nn.Conv1d(in_channels // 4, in_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.conv1(x.transpose(1, 2))
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        return x3.transpose(1, 2) + x

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=5):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

class StochasticDepth(nn.Module):
    def __init__(self, drop_prob=0.1, mode="row"):
        super().__init__()
        self.drop_prob = drop_prob
        self.mode = mode

    def forward(self, x):
        if not self.training or self.drop_prob == 0.:
            return x

        if self.mode == "row":
            shape = (x.shape[0], 1, 1)
        else:
            shape = (1, x.shape[1], 1)

        keep_prob = 1 - self.drop_prob
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()  # binarize
        output = x.div(keep_prob) * random_tensor
        return output

class ConditionalBatchNorm1d(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        self.bn = nn.BatchNorm1d(num_features, affine=False)
        self.embed = nn.Embedding(num_classes, num_features * 2)
        self.embed.weight.data[:, :num_features].normal_(1, 0.02)  # Initialize scale at 1
        self.embed.weight.data[:, num_features:].zero_()  # Initialize bias at 0

    def forward(self, x, y):
        out = self.bn(x)
        gamma, beta = self.embed(y).chunk(2, 1)
        out = gamma.view(-1, self.num_features, 1) * out + beta.view(-1, self.num_features, 1)
        return out

# Main Diffusion Model
class DiffusionModel(nn.Module):
    def __init__(self, esm_model, num_steps, latent_dim, motif_dim):
        super().__init__()
        self.esm_model = esm_model
        self.num_steps = num_steps
        self.latent_dim = latent_dim
        self.motif_dim = motif_dim
        self.hidden_dim = 256
        self.project_combined = nn.Linear(self.hidden_dim * 2, self.hidden_dim)
        self.project_to_esm = nn.Linear(self.hidden_dim, 1280)
        self.project_from_esm = nn.Linear(1280, self.hidden_dim)

        # Projection layers
        self.project_to_hidden = nn.Linear(latent_dim, self.hidden_dim)
        self.project_from_hidden = nn.Linear(self.hidden_dim, latent_dim)
        self.project_motif = nn.Linear(motif_dim, self.hidden_dim)

        # Add ESM-2 attention layers
        self.esm_attention_layers = nn.ModuleList([
            esm_model.layers[-i] for i in range(1, 9)  # Use the last 3 layers
        ])
        # Get the expected dimension for rotary embedding
        self.rotary_dim = esm_model.layers[0].self_attn.rotary_emb.dim if hasattr(esm_model.layers[0].self_attn, 'rotary_emb') else 64  # default to 64 if not found
        # Projection layers for rotary embedding
        self.embed_dim = esm_model.layers[0].self_attn.embed_dim
        self.to_rotary_dim = nn.Linear(self.embed_dim, self.rotary_dim)
        self.from_rotary_dim = nn.Linear(self.rotary_dim, self.embed_dim)

        # MLP for processing
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim * 4),
            nn.GELU(),
            nn.Linear(self.hidden_dim * 4, self.hidden_dim)
        )

        self.mlp_final = nn.Sequential(
            nn.Linear(latent_dim, latent_dim * 4),
            nn.GELU(),
            nn.Linear(latent_dim* 4, latent_dim)
        )

        # Add & Norm layer
        self.add_norm = nn.LayerNorm(self.hidden_dim)

        # Scale & Shift layer
        self.scale_shift = nn.Linear(self.hidden_dim, self.hidden_dim * 2)

        self.motif_resampler = PerceiverResampler(dim=self.hidden_dim, depth=6)
        self.combined_resampler = PerceiverResampler(dim=self.hidden_dim, depth=6)
        self.time_embed = nn.Sequential(
            nn.Linear(1, self.hidden_dim * 4),
            nn.SiLU(),
            nn.Linear(self.hidden_dim * 4, self.hidden_dim)
        )
        self.positional_encoding = PositionalEncoding(self.hidden_dim)
        self.rope = RotaryPositionalEmbedding(self.hidden_dim)
        self.fourier_emb = FourierFeatureEmbedding(num_frequencies=self.hidden_dim//2, d_model=self.hidden_dim)

        self.denoiser = nn.ModuleList([
            GatedCrossAttentionBlock(self.hidden_dim, self.hidden_dim) for _ in range(6)
        ])

        self.residual_layers = nn.ModuleList([
            nn.Linear(self.hidden_dim, self.hidden_dim) for _ in range(6)
        ])

        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(self.hidden_dim) for _ in range(6)
        ])

        self.adaptive_ln = nn.ModuleList([
            AdaptiveLayerNorm(self.hidden_dim) for _ in range(6)
        ])

        self.se_blocks = nn.ModuleList([
            SqueezeExcitation(self.hidden_dim) for _ in range(6)
        ])

        self.final_layer = nn.Linear(latent_dim, latent_dim)
        self.binding_predictor = nn.Linear(latent_dim, 1)

        self.transformer_denoiser = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=self.hidden_dim, nhead=8),
            num_layers=3
        )

        self.glu = nn.GLU()
        self.fpn = FeaturePyramidNetwork(latent_dim)
        self.focal_loss = FocalLoss(alpha=0.25, gamma=5)
        self.stochastic_depth = StochasticDepth(drop_prob=0.1, mode="row")
        self.cond_bn = ConditionalBatchNorm1d(latent_dim, num_steps)

        self.time_attention = nn.MultiheadAttention(self.hidden_dim, num_heads=8)

        self.use_checkpoint = True
        self.upsample = lambda x: F.interpolate(x.transpose(1, 2), scale_factor=2, mode='linear', align_corners=False).transpose(1, 2)
        self.downsample = lambda x: F.avg_pool1d(x.transpose(1, 2), kernel_size=2, stride=2).transpose(1, 2)

    def forward(self, x, protein_emb, motif_emb, t):

        # Project inputs to hidden dimension
        x = self.project_to_hidden(x)
        protein_emb = self.project_to_hidden(protein_emb)
        motif_emb = self.project_motif(motif_emb)

        t_embedding = self.time_embed(t.float().unsqueeze(-1))

        ## positional embeddings
        rope_emb = self.rope(x)
        x = apply_rotary_pos_emb(x, rope_emb)
        fourier_emb = self.fourier_emb(x)
        x = self.positional_encoding(x) + fourier_emb

        # Apply time attention
        x_t = x + t_embedding.unsqueeze(1)
        x_t, _ = self.time_attention(x_t, x_t, x_t)

        resampled_motif = self.motif_resampler(motif_emb)

        # Process combined protein and motif information
        combined = torch.cat([protein_emb, motif_emb], dim=-1)
        combined = self.project_combined(combined)
        resampled_combined = self.combined_resampler(combined)

        scale_shift = self.scale_shift(t_embedding).unsqueeze(1)
        scale, shift = scale_shift.chunk(2, dim=-1)
        x = x * (scale + 1) + shift

        # x = self.transformer_denoiser(x) # not needed if we are doing ESM

        # Project to ESM dimension before attention layers
        x = self.project_to_esm(x)

        ## esm layers
        # # Convert latent representations to token probabilities
        # token_probs = self.to_tokens(x).softmax(dim=-1)
        # # Use argmax to get token IDs (you might want to use sampling for more diversity)
        # token_ids = token_probs.argmax(dim=-1)
        # Apply ESM attention layers
        for esm_layer in self.esm_attention_layers:
            # Create attention mask (
            attention_mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.bool, device=x.device)

            # Extract only the self-attention part of the ESM layer
            self_attn = esm_layer.self_attn

            # Transpose x to match ESM's expected format: (seq_length, batch_size, embed_dim)
            x = x.transpose(0, 1)

            # Apply self-attention
            residual = x
            x = esm_layer.self_attn_layer_norm(x)
            x, _ = esm_layer.self_attn(
                query=x,
                key=x,
                value=x,
                key_padding_mask=~attention_mask,
                need_weights=False
            )
            x = residual + x

            # Apply feed-forward network
            residual = x
            x = esm_layer.final_layer_norm(x)
            x = esm_layer.fc1(x)
            x = F.gelu(x)
            x = esm_layer.fc2(x)
            x = residual + x

            # Transpose x back to (batch_size, seq_length, embed_dim)
            x = x.transpose(0, 1)

        # Project back to hidden dimension after attention layers
        x = self.project_from_esm(x)

        x = self.add_norm(x + self.mlp(x))

        for i, (layer, res_layer, norm, adaptive_ln, se_block) in enumerate(zip(
            self.denoiser, self.residual_layers, self.layer_norms, self.adaptive_ln, self.se_blocks
        )):
            if self.use_checkpoint and self.training:
                x = torch.utils.checkpoint.checkpoint(
                    self._forward_layer, x, resampled_motif, resampled_combined, layer, res_layer, norm, adaptive_ln, se_block, t
                )
            else:
                x = self._forward_layer(x, resampled_motif, resampled_combined, layer, res_layer, norm, adaptive_ln, se_block, t)

            if i % 2 == 0:
                # Upsample sequence length without changing hidden dimension
                x = F.interpolate(x.transpose(1, 2), scale_factor=2, mode='linear', align_corners=False).transpose(1, 2)
            else:
                # Downsample sequence length without changing hidden dimension
                x = F.avg_pool1d(x.transpose(1, 2), kernel_size=2, stride=2).transpose(1, 2)

        x = self.project_from_hidden(x)
        # x = self.glu(x)
        # print('x shape after glu:', x.shape)
        x = self.fpn(x)
        x = self.mlp_final(x)
        x = self.cond_bn(x.transpose(1, 2), t).transpose(1, 2)
        x = self.stochastic_depth(x)

        x = self.final_layer(x)
        binding_pred = self.binding_predictor(x).squeeze(-1)

        return x, binding_pred

    def _forward_layer(self, x, resampled_motif, resampled_combined, layer, res_layer, norm, adaptive_ln, se_block, t):
      residual = res_layer(x)

      x = layer(x, resampled_motif)
      x = layer(x, resampled_combined)
      x = norm(x + residual)
      x = adaptive_ln(x)
      x = se_block(x.transpose(1, 2)).transpose(1, 2)
      return x

    def loss_function(self, pred, target):
        return self.focal_loss(pred, target)

class LatentDiffusion(nn.Module):
    def __init__(self, esm_model, num_steps, latent_dim, motif_dim, device):
        super().__init__()
        self.esm_model = esm_model
        self.num_steps = num_steps
        self.latent_dim = latent_dim
        self.motif_dim = motif_dim
        self.diffusion_model = DiffusionModel(esm_model, num_steps, latent_dim=1280, motif_dim=1)
        self.device = device
        self.hidden_dim=256

        # Define beta schedule
        self.beta = torch.linspace(1e-4, 0.02, num_steps).to(device)
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.sqrt_alpha_bar = torch.sqrt(self.alpha_bar)
        self.sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar)

    def q_sample(self, x0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x0).to(self.device)
        return (
            self.sqrt_alpha_bar[t, None, None] * x0 +
            self.sqrt_one_minus_alpha_bar[t, None, None] * noise
        )

    def p_losses(self, x0, protein_emb, motif_emb, t, target_seq, energy_scores, noise=None):
      if noise is None:
          noise = torch.randn_like(x0).to(self.device)

      x_noisy = self.q_sample(x0, t, noise=noise)
      predicted_noise, binding_pred = self.diffusion_model(x_noisy, protein_emb, motif_emb, t)

      l2_loss = F.mse_loss(predicted_noise, noise) # l2 loss objective
      binding_loss = self.diffusion_model.loss_function(binding_pred, energy_scores)
      print(binding_pred)


      # esm logits loss (ce)
      esm_logits = self.esm_model.lm_head(predicted_noise)
      # Cross-entropy loss for ESM logits
      ce_loss = F.cross_entropy(esm_logits.view(-1, esm_logits.size(-1)), target_seq.view(-1))
      print(esm_logits)
      print(target_seq)

      total_loss = (1.5*l2_loss) + ce_loss + (3*binding_loss)
      print("total loss:", total_loss)
      print("l2 loss:", (1.5*l2_loss))
      print("ce loss:", ce_loss)
      print("binding loss:", (3*binding_loss))

      return total_loss

    @torch.no_grad()
    def p_sample(self, x, motif, t):
        betas_t = self.beta[t][:, None, None]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alpha_bar[t][:, None, None]
        sqrt_recip_alphas_t = torch.sqrt(1.0 / self.alpha[t])[:, None, None]

        model_mean, _ = self.diffusion_model(x, motif, t)
        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * model_mean / sqrt_one_minus_alphas_cumprod_t
        )

        if t[0] > 0:
            noise = torch.randn_like(x).to(self.device)
            return model_mean + torch.sqrt(betas_t) * noise
        else:
            return model_mean

    @torch.no_grad()
    def sample(self, num_samples, sequence_length, motif):
        device = next(self.parameters()).device
        shape = (num_samples, sequence_length, self.latent_dim)
        x = torch.randn(shape, device=device)
        motif = motif.to(device)

        for t in reversed(range(0, self.num_steps)):
            t_batch = torch.full((num_samples,), t, device=device, dtype=torch.long)
            x = self.p_sample(x, motif, t_batch)

        return x

def embeddings_to_sequence(embeddings, esm_model, alphabet):
    aa_toks = alphabet.all_toks
    aa_idxs = [alphabet.get_idx(aa) for aa in aa_toks]
    aa_logits = esm_model.lm_head(embeddings)[:, :, aa_idxs]
    predictions = torch.argmax(aa_logits, dim=-1).tolist()
    generated_peptides = [''.join([aa_toks[pred] for pred in seq]) for seq in predictions]
    return generated_peptides

def calculate_sequence_recovery(generated_sequences, ground_truth_sequences):
    correct = sum(gen == gt for gen, gt in zip(generated_sequences, ground_truth_sequences))
    total = len(ground_truth_sequences)
    return correct / total

def train(model, dataloader, optimizer, num_epochs, device):

    wandb.init(project="latent-diffusion-protein", entity="vskavi2003")
    wandb.config.update({
        "learning_rate": optimizer.param_groups[0]['lr'],
        "epochs": num_epochs,
        "batch_size": dataloader.batch_size
    })

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in dataloader:
            optimizer.zero_grad()
            energy_scores, protein_seq, peptide_seq = batch
            energy_scores = energy_scores.to(device)

            # Tokenize protein and peptide sequences
            batch_converter = model.esm_model.alphabet.get_batch_converter()

            _, _, protein_tokens = batch_converter([(0, protein_seq[0])])
            protein_tokens = protein_tokens.to(device)

            _, _, peptide_tokens = batch_converter([(0, peptide_seq[0])])
            peptide_tokens = peptide_tokens.to(device)
            target_seq = peptide_tokens

            # Encode protein and peptide sequences
            with torch.no_grad():
                protein_embedding = model.esm_model(protein_tokens, repr_layers=[33], return_contacts=False)
                peptide_embedding = model.esm_model(peptide_tokens, repr_layers=[33], return_contacts=False)

            protein_embedding = protein_embedding["representations"][33]
            peptide_embedding = peptide_embedding["representations"][33]

            # Process motif embeddings
            motif_embeddings = energy_scores.unsqueeze(-1).float()
            # Pad sequences to the same length
            max_len = max(protein_embedding.shape[1], motif_embeddings.shape[1], peptide_embedding.shape[1])

            protein_embedding = F.pad(protein_embedding, (0, 0, 0, max_len - protein_embedding.shape[1]))
            motif_embeddings = F.pad(motif_embeddings, (0, 0, 0, max_len - motif_embeddings.shape[1]))
            peptide_embedding = F.pad(peptide_embedding, (0, 0, 0, max_len - peptide_embedding.shape[1]))
            energy_scores = F.pad(energy_scores, (0, max_len - energy_scores.shape[1], 0, 0))

            t = torch.randint(0, model.num_steps, (protein_embedding.shape[0],), device=device).long()
            print('initial energy score shape',energy_scores.shape)
            print('intiial prot emb shape',protein_embedding.shape)
            loss = model.p_losses(peptide_embedding, protein_embedding, motif_embeddings, t, target_seq, energy_scores)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            # Log batch loss
            wandb.log({"batch_loss training": loss.item()})

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}, Loss: {avg_loss}")

        # Log epoch loss
        wandb.log({"epoch": epoch+1, "avg_loss training": avg_loss})

def validate(model, dataloader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            energy_scores, protein_seq, peptide_seq = batch
            energy_scores = energy_scores.to(device)

            # Tokenize protein and peptide sequences
            batch_converter = model.esm_model.alphabet.get_batch_converter()

            _, _, protein_tokens = batch_converter([(0, protein_seq[0])])
            protein_tokens = protein_tokens.to(device)

            _, _, peptide_tokens = batch_converter([(0, peptide_seq[0])])
            peptide_tokens = peptide_tokens.to(device)
            target_seq = peptide_tokens

            # Encode protein and peptide sequences
            with torch.no_grad():
                protein_embedding = model.esm_model(protein_tokens, repr_layers=[33], return_contacts=False)
                peptide_embedding = model.esm_model(peptide_tokens, repr_layers=[33], return_contacts=False)

            protein_embedding = protein_embedding["representations"][33]
            peptide_embedding = peptide_embedding["representations"][33]

            # Process motif embeddings
            motif_embeddings = energy_scores.unsqueeze(-1).float()
            # Pad sequences to the same length
            max_len = max(protein_embedding.shape[1], motif_embeddings.shape[1], peptide_embedding.shape[1])

            protein_embedding = F.pad(protein_embedding, (0, 0, 0, max_len - protein_embedding.shape[1]))
            motif_embeddings = F.pad(motif_embeddings, (0, 0, 0, max_len - motif_embeddings.shape[1]))
            peptide_embedding = F.pad(peptide_embedding, (0, 0, 0, max_len - peptide_embedding.shape[1]))
            energy_scores = F.pad(energy_scores, (0, max_len - energy_scores.shape[1], 0, 0))

            t = torch.randint(0, model.num_steps, (protein_embedding.shape[0],), device=device).long()

            loss = model.p_losses(peptide_embedding, protein_embedding, motif_embeddings, t, target_seq)
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    return avg_loss


def main():
    # Load and preprocess data
    train_snp = preprocess_snp_data('training_dataset.csv')
    val_snp = preprocess_snp_data('validation_dataset.csv')
    test_snp = preprocess_snp_data('testing_dataset.csv')

    train_snp = filter_datasets(train_snp)
    val_snp = filter_datasets(val_snp)
    test_snp = filter_datasets(test_snp)

    # Calculate max_length
    all_seqs = pd.concat([
        train_snp['peptide_derived_sequence'], train_snp['protein_derived_sequence'],
        val_snp['peptide_derived_sequence'], val_snp['protein_derived_sequence'],
        test_snp['peptide_derived_sequence'], test_snp['protein_derived_sequence']
    ])
    max_length = max(len(seq) for seq in all_seqs)

    # Create datasets
    train_dataset = ProteinInteractionDataset(train_snp)
    val_dataset = ProteinInteractionDataset(val_snp)
    test_dataset = ProteinInteractionDataset(test_snp)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

    # Initialize LatentDiffusion model
    latent_dim = esm_model.embed_dim
    motif_dim = max_length  # Assuming motif is represented by one-hot encoded energy scores
    num_steps = 1000
    model = LatentDiffusion(esm_model, num_steps, latent_dim, motif_dim, device)
    model.to(device)

    # Training
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    num_epochs = 8  # Increase the number of epochs for better tracking

    for epoch in range(num_epochs):
        train_loss = train(model, train_loader, optimizer, num_epochs, device)  # Train for 1 epoch
        val_loss = validate(model, val_loader, device)

        wandb.log({
            "epoch": epoch+1,
            "train_loss": train_loss,
            "val_loss": val_loss
        })

        print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")

    # After training the model
    model.eval()  # Set the model to evaluation mode

#     # Generate samples for the validation set
#     val_motifs = []
#     val_protein_seqs = []
#     val_peptide_seqs = []

#     with torch.no_grad():
#         for batch in val_loader:
#             energy_scores, protein_seq, peptide_seq = batch
#             energy_scores = energy_scores.to(device)
#             results = model(batch_tokens, repr_layers=[33], return_contacts=False)

#             # Encode protein sequences
#             protein_embedding = model.esm_model(protein_seq)['last_hidden_state']

#             # Generate peptide embeddings
#             generated_embeddings = model.sample(energy_scores.size(0), max_length, energy_scores)

#             val_motifs.append(energy_scores)
#             val_protein_seqs.extend(protein_seq)
#             val_peptide_seqs.extend(peptide_seq)

#     # Concatenate all batches
#     val_motifs = torch.cat(val_motifs, dim=0)

#     # Convert generated embeddings to sequences
#     generated_sequences = embeddings_to_sequence(generated_embeddings, esm_model, alphabet)

#     print("Sample of generated sequences:")
#     for i in range(min(10, len(generated_sequences))):
#         print(f"Generated: {generated_sequences[i]}")
#         print(f"Ground truth: {val_peptide_seqs[i]}")
#         print()

#     # Calculate sequence recovery
#     recovery_rate = calculate_sequence_recovery(generated_sequences, val_peptide_seqs)
#     print(f"Sequence recovery rate: {recovery_rate:.4f}")

#     # Calculate per-position accuracy
#     per_position_accuracy = calculate_per_position_accuracy(generated_sequences, val_peptide_seqs)
#     print(f"Per-position accuracy: {per_position_accuracy:.4f}")

# def calculate_per_position_accuracy(generated_sequences, ground_truth_sequences):
#     total_correct = 0
#     total_positions = 0

#     for gen, gt in zip(generated_sequences, ground_truth_sequences):
#         for g, t in zip(gen, gt):
#             if g == t:
#                 total_correct += 1
#             total_positions += 1

#     return total_correct / total_positions if total_positions > 0 else 0

if __name__ == "__main__":
    main()




Using device: cuda




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
batch_loss training,█▇▆▅▄▃▃▃▂▂▁▁▁▂▃▂▂▁▂▁▁▁▁▃▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁

0,1
batch_loss training,4.73197


initial energy score shape torch.Size([1, 106])
intiial prot emb shape torch.Size([1, 106, 1280])
tensor([[ 1.2002e-01,  2.3128e-01,  3.5025e-02,  1.4320e-02,  3.1932e-01,
         -1.3154e-01,  2.2804e-03, -6.3823e-02, -2.9063e-02, -1.8486e-01,
          3.0451e-02, -9.3639e-02,  4.2829e-04,  3.8379e-01,  5.3006e-01,
          3.6579e-01,  5.5800e-01,  4.8924e-01,  1.6760e-01,  2.8062e-01,
          2.4693e-01,  4.8707e-01,  7.5809e-01,  7.3354e-01,  1.4617e-01,
          4.2820e-01,  3.0235e-01,  3.3622e-02, -3.4350e-01, -4.0311e-01,
         -4.8136e-01, -3.0632e-01, -3.4722e-01, -4.3481e-01, -8.3017e-02,
          6.1324e-01,  1.7082e-01, -3.5690e-01, -1.1254e-01, -2.3791e-02,
          3.6737e-01,  7.7707e-01, -1.2616e-01, -3.2163e-01, -2.4831e-02,
         -2.5484e-02,  5.0863e-01,  8.0236e-01,  1.8826e-01, -6.3469e-01,
         -3.6920e-01,  7.7913e-02, -3.7937e-01, -1.3128e-01, -3.0516e-01,
         -1.5366e-01, -3.8302e-01, -3.9545e-01, -6.3183e-01, -3.5832e-01,
         -3.00



initial energy score shape torch.Size([1, 832])
intiial prot emb shape torch.Size([1, 832, 1280])
tensor([[-0.1935, -0.2765,  0.1024, -0.3244,  0.3450,  0.2760,  0.0469,  0.2213,
          0.1074, -0.1403, -0.0733, -0.1097,  0.0970,  0.4372,  0.1955, -0.3777,
         -0.6190, -0.2874, -0.0095, -0.0575, -0.2048, -0.0598, -0.0693, -0.0195,
         -0.0125,  0.0016, -0.0750, -0.0030,  0.0187, -0.0070, -0.2340, -0.4295,
         -0.5722, -0.3098, -0.3754, -0.1196, -0.1609, -0.2162,  0.4624, -0.0108,
         -0.2095, -0.2012,  0.1101,  0.4210,  0.2809,  0.5657,  0.2896,  0.1954,
          0.0812, -0.0912, -0.1507,  0.4404, -0.1040,  0.1874,  0.6212,  0.3431,
          0.2971,  0.3953,  0.2242,  0.5737,  0.3037,  0.2091,  0.3751,  0.2270,
          0.0649,  0.7942, -0.0145,  0.0502, -0.0555, -0.4537, -0.0287, -0.0169,
          0.4150,  0.5071,  0.3787,  0.3341,  0.1871, -0.1711, -0.1807,  0.0339,
          0.2227,  0.2923,  0.4040,  0.5692,  0.2746,  0.0082,  0.0932,  0.3323,
          0

KeyboardInterrupt: 