## pip installs

In [1]:
!pip install Levenshtein
!pip install einops
!pip install einops_exts
!pip install torch
!pip install transformers
!pip install tqdm
!pip install sentencepiece
!pip install black
!pip install fair-esm
!pip install wandb

Collecting Levenshtein
  Downloading Levenshtein-0.25.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.3 kB)
Collecting rapidfuzz<4.0.0,>=3.8.0 (from Levenshtein)
  Downloading rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading Levenshtein-0.25.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (177 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m177.4/177.4 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m44.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, Levenshtein
Successfully installed Levenshtein-0.25.1 rapidfuzz-3.9.6
Collecting einops_exts
  Downloading einops_exts-0.0.4-py3-none-any.whl.metadata (621 bytes)
Downloading einops_exts-0.0.4-py3-none-any.whl (3.9 kB)
In

## dataset

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import re
import esm
from einops import rearrange, repeat
import math
import numpy as np
from torch import einsum
import wandb
wandb.login()
import os
os.chdir('/content/drive/MyDrive/Programmable Biology Group/Srikar/Code/proteins/flamingo-diffusion/data_dump/old_dat/')

# ESM Model Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model = esm_model.to(device)
esm_model.eval()
for param in esm_model.parameters():
    param.requires_grad = False

# Data Preprocessing
def preprocess_snp_data(file_path):
    snp_df = pd.read_csv(file_path)

    def transform_energy_scores(energy_scores):
        transformed_scores = []
        for score in energy_scores:
            score = re.sub(r'[\s\n]+', ',', score)
            score = re.sub(r'\[\s*,', '[', score)
            score = re.sub(r'^[\s,]+', '', score)
            transformed_scores.append(score)
        return transformed_scores

    snp_df['energy_scores'] = transform_energy_scores(snp_df['energy_scores'])
    snp_df['energy_scores_lengths'] = snp_df['energy_scores'].apply(
        lambda x: x.count(',') + 1 - (1 if x.startswith(',') else 0)
    )

    snp_df['peptide_source_RCSB_lengths'] = snp_df['peptide_source_RCSB'].apply(len)
    snp_df['protein_RCSB_lengths'] = snp_df['protein_RCSB'].apply(len)
    snp_df['protein_derived_seq_length'] = snp_df['protein_derived_sequence'].apply(len)
    snp_df['peptide_derived_seq_length'] = snp_df['peptide_derived_sequence'].apply(len)

    return snp_df

def filter_datasets(dataset):
    return dataset[dataset['protein_RCSB'] != dataset['peptide_source_RCSB']]

# Dataset Class
class ProteinInteractionDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        peptide_seq = row['peptide_derived_sequence']
        protein_seq = row['protein_derived_sequence']
        energy_scores = row['energy_scores']

        energy_scores = re.findall(r'-?\d+\.?\d*(?:e[-+]?\d+)?', energy_scores)
        energy_scores = [float(score) for score in energy_scores]
        energy_scores = self.one_hot_encode_energy_scores(energy_scores)

        # Convert energy scores to tensor
        energy_scores = torch.tensor(energy_scores, dtype=torch.float32)

        return energy_scores, protein_seq, peptide_seq

    @staticmethod
    def one_hot_encode_energy_scores(scores):
        return [1 if score <= -1 else 0 for score in scores]



## model

In [None]:
# ESM Encoder-Decoder (reduce latent to 64 as paper says its easier to diffuse on)
class RefinedESMEncoderDecoder(nn.Module):
    def __init__(self, esm_dim=1280, latent_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(esm_dim, 512),
            nn.ReLU(),
            nn.Linear(512, latent_dim)
        )
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=latent_dim, nhead=8),
            num_layers=3
        )
        self.final_layer = nn.Linear(latent_dim, 21)  # 20 amino acids + unknown

    def forward(self, x):
        latent = torch.tanh(self.encoder(x))
        decoded = self.decoder(latent, latent)
        return self.final_layer(decoded)

# CLIP Model with classifier free guidance schedule (90/10)
class CLIPModel(nn.Module):
    def __init__(self, embed_dim=128):
        super().__init__()
        self.antibody_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=1280, nhead=8),
            num_layers=3
        )
        self.protein_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=1280, nhead=8),
            num_layers=3
        )
        self.project_antibody = nn.Linear(1280, embed_dim)
        self.project_protein = nn.Linear(1280, embed_dim)

    def forward(self, antibody_emb, protein_emb):
        antibody_vec = self.antibody_encoder(antibody_emb)
        protein_vec = self.protein_encoder(protein_emb)
        return F.normalize(self.project_antibody(antibody_vec[:, 0]), dim=-1), \
               F.normalize(self.project_protein(protein_vec[:, 0]), dim=-1)

# One hot encoding of antigen sequence + Binary motif (epitope) representation
class RefinedRepresentation(nn.Module):
    def __init__(self, seq_len):
        super().__init__()
        self.seq_len = seq_len

    def forward(self, sequence, energy_scores):
        # One-hot encoding
        one_hot = F.one_hot(sequence, num_classes=21)

        # Binary motif channel
        motif_channel = (energy_scores <= -1).float().unsqueeze(-1)

        # Combine representations
        combined = torch.cat([one_hot, motif_channel], dim=-1)
        return combined

# Updated Denoiser Model
class RefinedDenoiser(nn.Module):
    def __init__(self, latent_dim, protein_dim, clip_dim):
        super().__init__()
        self.protein_binder_transformer = nn.TransformerEncoderLayer(d_model=latent_dim, nhead=8)
        self.target_protein_transformer = nn.TransformerEncoderLayer(d_model=protein_dim, nhead=8)
        self.cross_attention = nn.MultiheadAttention(embed_dim=latent_dim, num_heads=8)
        self.final_layer = nn.Linear(latent_dim, latent_dim)

    def forward(self, x, protein_emb, clip_emb, motif_emb, t, use_classifier_free=False):
        if use_classifier_free:
            # For unconditional generation, set conditioning inputs to zero
            protein_emb = torch.zeros_like(protein_emb)
            clip_emb = torch.zeros_like(clip_emb)
            motif_emb = torch.zeros_like(motif_emb)

        x = self.protein_binder_transformer(x)
        protein_emb = self.target_protein_transformer(protein_emb)
        x, _ = self.cross_attention(x, protein_emb, protein_emb)
        x = x + clip_emb + motif_emb # esm reduced + clip + onehot + binary motif (antigen represention)
        return self.final_layer(x)

# Updated LatentDiffusion Model
class RefinedLatentDiffusion(nn.Module):
    def __init__(self, esm_model, num_steps, latent_dim, protein_dim, clip_dim, device):
        super().__init__()
        self.esm_model = esm_model
        self.num_steps = num_steps
        self.latent_dim = latent_dim
        self.protein_dim = protein_dim
        self.clip_dim = clip_dim
        self.device = device

        self.esm_encoder_decoder = RefinedESMEncoderDecoder(esm_dim=1280, latent_dim=latent_dim)
        self.refined_representation = RefinedRepresentation(seq_len=1000)  # Adjust seq_len as needed
        self.clip_model = CLIPModel(embed_dim=clip_dim)
        self.denoiser = RefinedDenoiser(latent_dim=latent_dim, protein_dim=protein_dim, clip_dim=clip_dim)

        # Define beta schedule
        self.beta = torch.linspace(1e-4, 0.02, num_steps).to(device)
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.sqrt_alpha_bar = torch.sqrt(self.alpha_bar)
        self.sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar)

    def q_sample(self, x0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x0).to(self.device)
        return (
            self.sqrt_alpha_bar[t, None, None] * x0 +
            self.sqrt_one_minus_alpha_bar[t, None, None] * noise
        )

    def p_losses(self, x0, protein_emb, clip_emb, motif_emb, t, target_seq, noise=None):
        if noise is None:
            noise = torch.randn_like(x0).to(self.device)

        x_noisy = self.q_sample(x0, t, noise=noise)

        # Classifier-free guidance: 10% of the time, use unconditional generation
        use_classifier_free = random.random() < 0.1
        predicted_noise = self.denoiser(x_noisy, protein_emb, clip_emb, motif_emb, t, use_classifier_free)

        loss = F.mse_loss(predicted_noise, noise)

        # CLIP loss (only when not using classifier-free guidance)
        if not use_classifier_free:
            binder_clip, target_clip = self.clip_model(x0, protein_emb)
            clip_loss = -torch.sum(binder_clip * target_clip, dim=-1).mean()
        else:
            clip_loss = torch.tensor(0.0).to(self.device)

        # Decoder loss
        decoded_seq = self.esm_encoder_decoder(x0)
        ce_loss = F.cross_entropy(decoded_seq.view(-1, decoded_seq.size(-1)), target_seq.view(-1))

        total_loss = loss + 0.1 * clip_loss + ce_loss
        return total_loss

    @torch.no_grad()
    def p_sample(self, x, protein_emb, clip_emb, motif_emb, t, guidance_scale=3.0):
        betas_t = self.beta[t][:, None, None]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alpha_bar[t][:, None, None]
        sqrt_recip_alphas_t = torch.sqrt(1.0 / self.alpha[t])[:, None, None]

        # Generate both conditional and unconditional predictions
        noise_pred_cond = self.denoiser(x, protein_emb, clip_emb, motif_emb, t, use_classifier_free=False)
        noise_pred_uncond = self.denoiser(x, protein_emb, clip_emb, motif_emb, t, use_classifier_free=True)

        # Apply classifier-free guidance
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * noise_pred / sqrt_one_minus_alphas_cumprod_t
        )

        if t[0] > 0:
            noise = torch.randn_like(x).to(self.device)
            return model_mean + torch.sqrt(betas_t) * noise
        else:
            return model_mean

    @torch.no_grad()
    def sample(self, num_samples, sequence_length, protein_emb, clip_emb, motif_emb, guidance_scale=3.0):
        device = next(self.parameters()).device
        shape = (num_samples, sequence_length, self.latent_dim)
        x = torch.randn(shape, device=device)
        protein_emb = protein_emb.to(device)
        clip_emb = clip_emb.to(device)
        motif_emb = motif_emb.to(device)

        for t in reversed(range(0, self.num_steps)):
            t_batch = torch.full((num_samples,), t, device=device, dtype=torch.long)
            x = self.p_sample(x, protein_emb, clip_emb, motif_emb, t_batch, guidance_scale)

        return x

    @torch.no_grad()
    def sample_without_guidance(self, num_samples, sequence_length):
        device = next(self.parameters()).device
        shape = (num_samples, sequence_length, self.latent_dim)
        x = torch.randn(shape, device=device)

        for t in reversed(range(0, self.num_steps)):
            t_batch = torch.full((num_samples,), t, device=device, dtype=torch.long)
            betas_t = self.beta[t][:, None, None]
            sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alpha_bar[t][:, None, None]
            sqrt_recip_alphas_t = torch.sqrt(1.0 / self.alpha[t])[:, None, None]

            # Unconditional prediction
            noise_pred = self.denoiser(x, None, None, None, t_batch, use_classifier_free=True)

            model_mean = sqrt_recip_alphas_t * (
                x - betas_t * noise_pred / sqrt_one_minus_alphas_cumprod_t
            )

            if t > 0:
                noise = torch.randn_like(x).to(device)
                x = model_mean + torch.sqrt(betas_t) * noise
            else:
                x = model_mean

        return x

def generate_protein_binders(model, protein_seq, motif, num_samples=1, guidance_scale=3.0):
    device = next(model.parameters()).device

    # Generate protein embedding
    with torch.no_grad():
        protein_tokens = model.esm_model.encode(protein_seq)
        protein_embedding = model.esm_model(protein_tokens.to(device), repr_layers=[33], return_contacts=False)["representations"][33]

    # Process motif based on input type
    if isinstance(motif, str):
        # Motif is a sequence
        motif_tokens = model.esm_model.encode(motif)
        with torch.no_grad():
            motif_embedding = model.esm_model(motif_tokens.to(device), repr_layers=[33], return_contacts=False)["representations"][33]
        motif_repr = model.refined_representation(motif_tokens, torch.ones_like(motif_embedding[:, :, 0]))
    elif isinstance(motif, torch.Tensor):
        # Motif is a binary tensor
        motif_embedding = motif.float().to(device)
        motif_repr = model.refined_representation(protein_tokens, motif_embedding)
    else:
        raise ValueError("Motif must be either a string (sequence) or a torch.Tensor (binary representation)")

    # Create refined representations
    protein_repr = model.refined_representation(protein_tokens, torch.zeros_like(protein_embedding[:, :, 0]))

    # Process through ESM encoder-decoder
    protein_latent = model.esm_encoder_decoder.encoder(protein_embedding)
    motif_latent = model.esm_encoder_decoder.encoder(motif_embedding)

    # CLIP processing
    protein_clip, motif_clip = model.clip_model(protein_repr, motif_repr)

    # Sample from the model
    latent_samples = model.sample(num_samples, protein_latent.shape[1], protein_latent, protein_clip, motif_latent, guidance_scale)

    # Decode the latent samples to amino acid sequences
    generated_sequences = model.esm_encoder_decoder.decoder(latent_samples)

    return generated_sequences

def generate_protein_binders_without_guidance(model, sequence_length, num_samples=1):
    # Sample from the model without guidance
    latent_samples = model.sample_without_guidance(num_samples, sequence_length)

    # Decode the latent samples to amino acid sequences
    generated_sequences = model.esm_encoder_decoder.decoder(latent_samples)

    return generated_sequences



def train(model, dataloader, optimizer, num_epochs, device):
    for epoch in range(num_epochs):
        for batch in dataloader:
            energy_scores, protein_seq, peptide_seq = batch
            energy_scores = energy_scores.to(device)

            # Generate ESM embeddings
            with torch.no_grad():
                protein_embedding = model.esm_model(protein_seq, repr_layers=[33], return_contacts=False)["representations"][33]
                peptide_embedding = model.esm_model(peptide_seq, repr_layers=[33], return_contacts=False)["representations"][33]

            # Create refined representations
            protein_repr = model.refined_representation(protein_seq, energy_scores)
            peptide_repr = model.refined_representation(peptide_seq, torch.zeros_like(energy_scores))

            # Process through ESM encoder-decoder
            protein_latent = model.esm_encoder_decoder.encoder(protein_embedding)
            peptide_latent = model.esm_encoder_decoder.encoder(peptide_embedding)

            # Motif embedding
            motif_emb = (energy_scores <= -1).float().unsqueeze(-1)

            # Create positive and negative pairs for CLIP
            positive_pairs = (protein_repr, peptide_repr)
            negative_pairs = create_negative_pairs(protein_repr, peptide_repr)

            # CLIP processing
            clip_loss = model.clip_model(positive_pairs, negative_pairs)

            # Diffusion process
            t = torch.randint(0, model.num_steps, (protein_embedding.shape[0],), device=device).long()
            diff_loss = model.p_losses(peptide_latent, protein_latent, protein_clip, motif_emb, t, peptide_seq)

            # Total loss
            loss = diff_loss + clip_loss

            # Backpropagation and optimization steps
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Logging
            wandb.log({"batch_loss": loss.item(), "diff_loss": diff_loss.item(), "clip_loss": clip_loss.item()})

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

def create_negative_pairs(protein_repr, peptide_repr):
    # Shuffle the peptide representations to create negative pairs
    neg_peptide_repr = peptide_repr[torch.randperm(peptide_repr.size(0))]
    return (protein_repr, neg_peptide_repr)

def main():
    # Load and preprocess data
    train_snp = preprocess_snp_data('training_dataset.csv')
    val_snp = preprocess_snp_data('validation_dataset.csv')
    test_snp = preprocess_snp_data('testing_dataset.csv')

    train_snp = filter_datasets(train_snp)
    val_snp = filter_datasets(val_snp)
    test_snp = filter_datasets(test_snp)

    # Calculate max_length
    all_seqs = pd.concat([
        train_snp['peptide_derived_sequence'], train_snp['protein_derived_sequence'],
        val_snp['peptide_derived_sequence'], val_snp['protein_derived_sequence'],
        test_snp['peptide_derived_sequence'], test_snp['protein_derived_sequence']
    ])
    max_length = max(len(seq) for seq in all_seqs)

    # Create datasets
    train_dataset = ProteinInteractionDataset(train_snp)
    val_dataset = ProteinInteractionDataset(val_snp)
    test_dataset = ProteinInteractionDataset(test_snp)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

    # Initialize LatentDiffusion model
    latent_dim = 64
    protein_dim = esm_model.embed_dim
    num_steps = 1000
    model = RefinedLatentDiffusion(esm_model, num_steps, latent_dim, protein_dim, clip_dim, device)
    model.to(device)

    # Training
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    num_epochs = 10

    train(model, train_loader, optimizer, num_epochs, device)

    # Generation (dummy examples)
    model.eval()

    # Example usage of generation with guidance (sequence-based motif)
    protein_seq = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
    motif_seq = "LRSLGY"
    generated_binders_seq = generate_protein_binders(model, protein_seq, motif_seq, num_samples=3, guidance_scale=3.0)

    print("Generated protein binders with guidance (sequence-based motif):")
    for i, seq in enumerate(generated_binders_seq):
        print(f"Binder {i+1}: {seq}")

    # Example usage of generation with guidance (binary motif)
    binary_motif = torch.zeros(len(protein_seq))
    binary_motif[30:36] = 1  # Assuming the motif is in this region
    generated_binders_binary = generate_protein_binders(model, protein_seq, binary_motif, num_samples=3, guidance_scale=3.0)

    print("\nGenerated protein binders with guidance (binary motif):")
    for i, seq in enumerate(generated_binders_binary):
        print(f"Binder {i+1}: {seq}")

    # Example usage of generation without guidance
    generated_binders_no_guidance = generate_protein_binders_without_guidance(model, sequence_length=100, num_samples=3)

    print("\nGenerated protein binders without guidance:")
    for i, seq in enumerate(generated_binders_no_guidance):
        print(f"Binder {i+1}: {seq}")

if __name__ == "__main__":
    main()