## pip installs

In [1]:
!pip install Levenshtein
!pip install einops
!pip install einops_exts
!pip install torch
!pip install transformers
!pip install tqdm
!pip install sentencepiece
!pip install black
!pip install fair-esm
!pip install wandb

Collecting Levenshtein
  Downloading Levenshtein-0.25.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.3 kB)
Collecting rapidfuzz<4.0.0,>=3.8.0 (from Levenshtein)
  Downloading rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading Levenshtein-0.25.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (177 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m177.4/177.4 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m29.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, Levenshtein
Successfully installed Levenshtein-0.25.1 rapidfuzz-3.9.6
Collecting einops_exts
  Downloading einops_exts-0.0.4-py3-none-any.whl.metadata (621 bytes)
Downloading einops_exts-0.0.4-py3-none-any.whl (3.9 kB)
In

## dataset

In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import re
import esm
from einops import rearrange, repeat
import math
import numpy as np
from torch import einsum
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import wandb
wandb.login()
import os
os.chdir('/content/drive/MyDrive/Programmable Biology Group/Srikar/Code/proteins/flamingo-diffusion/data_dump/old_dat/')

# ESM Model Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model = esm_model.to(device)
esm_model.eval()
for param in esm_model.parameters():
    param.requires_grad = False

# Data Preprocessing
def preprocess_snp_data(file_path):
    snp_df = pd.read_csv(file_path)

    def transform_energy_scores(energy_scores):
        transformed_scores = []
        for score in energy_scores:
            score = re.sub(r'[\s\n]+', ',', score)
            score = re.sub(r'\[\s*,', '[', score)
            score = re.sub(r'^[\s,]+', '', score)
            transformed_scores.append(score)
        return transformed_scores

    snp_df['energy_scores'] = transform_energy_scores(snp_df['energy_scores'])
    snp_df['energy_scores_lengths'] = snp_df['energy_scores'].apply(
        lambda x: x.count(',') + 1 - (1 if x.startswith(',') else 0)
    )

    snp_df['peptide_source_RCSB_lengths'] = snp_df['peptide_source_RCSB'].apply(len)
    snp_df['protein_RCSB_lengths'] = snp_df['protein_RCSB'].apply(len)
    snp_df['protein_derived_seq_length'] = snp_df['protein_derived_sequence'].apply(len)
    snp_df['peptide_derived_seq_length'] = snp_df['peptide_derived_sequence'].apply(len)

    return snp_df

def filter_datasets(dataset):
    return dataset[dataset['protein_RCSB'] != dataset['peptide_source_RCSB']]

# Dataset Class
class ProteinInteractionDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.mismatched_lengths = 0
        self.total_samples = len(dataframe)
        self.check_lengths()

    def check_lengths(self):
        for idx in range(self.total_samples):
            row = self.dataframe.iloc[idx]
            peptide_seq = row['peptide_derived_sequence']
            energy_scores = row['energy_scores']

            energy_scores = re.findall(r'-?\d+\.?\d*(?:e[-+]?\d+)?', energy_scores)
            energy_scores = [float(score) for score in energy_scores]

            if len(energy_scores) != len(peptide_seq):
                self.mismatched_lengths += 1

        print(f"Total samples: {self.total_samples}")
        print(f"Mismatched lengths: {self.mismatched_lengths}")

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        peptide_seq = row['peptide_derived_sequence']
        protein_seq = row['protein_derived_sequence']
        energy_scores = row['energy_scores']

        energy_scores = re.findall(r'-?\d+\.?\d*(?:e[-+]?\d+)?', energy_scores)
        energy_scores = [float(score) for score in energy_scores]
        energy_scores = self.one_hot_encode_energy_scores(energy_scores)

        # Convert energy scores to tensor
        energy_scores = torch.tensor(energy_scores, dtype=torch.float32)

        return energy_scores, peptide_seq, protein_seq # energy scores are aligned with the peptide (we will keep peptide as protien)

    @staticmethod
    def one_hot_encode_energy_scores(scores):
        return [1 if score <= -1 else 0 for score in scores]





Using device: cuda


In [44]:
# Load and preprocess data
train_snp = preprocess_snp_data('training_dataset.csv')
val_snp = preprocess_snp_data('validation_dataset.csv')
test_snp = preprocess_snp_data('testing_dataset.csv')

train_snp = filter_datasets(train_snp)
val_snp = filter_datasets(val_snp)
test_snp = filter_datasets(test_snp)

train_snp, val_snp, test_snp = train_snp[:16], val_snp[16:24], test_snp[24:32] # subset code

# Calculate max_length
all_seqs = pd.concat([
    train_snp['peptide_derived_sequence'], train_snp['protein_derived_sequence'],
    val_snp['peptide_derived_sequence'], val_snp['protein_derived_sequence'],
    test_snp['peptide_derived_sequence'], test_snp['protein_derived_sequence']
])
max_length = max(len(seq) for seq in all_seqs)

# Create datasets
train_dataset = ProteinInteractionDataset(train_snp)
val_dataset = ProteinInteractionDataset(val_snp)
test_dataset = ProteinInteractionDataset(test_snp)


Total samples: 16
Mismatched lengths: 0
Total samples: 8
Mismatched lengths: 0
Total samples: 8
Mismatched lengths: 0


In [46]:
for i in range(5):  # Adjust range to view more samples
    energy_scores, protein_seq, peptide_seq = train_dataset[i]
    print(f"Sample {i}:")

    # Print energy scores and their length
    print(f"Energy Scores: {energy_scores}")
    print(f"Length of Energy Scores: {energy_scores.shape[0]}")

    # Print protein sequence and its length
    print(f"Protein Sequence: {protein_seq}")
    print(f"Length of Protein Sequence: {len(protein_seq)}")

    # Print peptide sequence and its length
    print(f"Peptide Sequence: {peptide_seq}")
    print(f"Length of Peptide Sequence: {len(peptide_seq)}")

    print("\n")


Sample 0:
Energy Scores: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.,

## model

In [None]:
class RefinedESMEncoderDecoder(nn.Module):
    def __init__(self, esm_dim=1280, latent_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(esm_dim, 512),
            nn.ReLU(),
            nn.Linear(512, latent_dim)
        )
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=latent_dim, nhead=8),
            num_layers=3
        )
        self.final_layer = nn.Linear(latent_dim, len(esm_model.alphabet))

    def encode(self, x):
        return torch.tanh(self.encoder(x))

    def decode(self, latent):
        decoded = self.decoder(latent, latent)
        return self.final_layer(decoded)

    def forward(self, x):
        latent = self.encode(x)
        return self.decode(latent)

class RefinedDenoiser(nn.Module):
    def __init__(self, latent_dim, protein_dim, alphabet_size):
        super().__init__()
        self.protein_binder_transformer = nn.TransformerEncoderLayer(d_model=latent_dim, nhead=8)
        self.target_protein_transformer = nn.TransformerEncoderLayer(d_model=latent_dim, nhead=8)
        self.cross_attention = nn.MultiheadAttention(embed_dim=latent_dim, num_heads=8)
        self.onehot_projection = nn.Linear(alphabet_size, latent_dim)
        self.final_layer = nn.Linear(latent_dim, latent_dim)

    def forward(self, x, protein_emb, motif_emb, onehot_seq, t, use_classifier_free=False):
        if use_classifier_free:
            protein_emb = torch.zeros_like(protein_emb)
            motif_emb = torch.zeros_like(motif_emb)
            onehot_seq = torch.zeros_like(onehot_seq)

        seq_len = x.size(1)
        protein_emb = F.interpolate(protein_emb.transpose(1, 2), size=seq_len, mode='linear', align_corners=False).transpose(1, 2)

        x = self.protein_binder_transformer(x)
        protein_emb = self.target_protein_transformer(protein_emb)

        x = x.permute(1, 0, 2)
        protein_emb = protein_emb.permute(1, 0, 2)

        x, _ = self.cross_attention(x, protein_emb, protein_emb)
        x = x.permute(1, 0, 2)

        onehot_proj = self.onehot_projection(onehot_seq)
        x = x + motif_emb + onehot_proj
        return self.final_layer(x)

class SimpleLatentDiffusion(nn.Module):
    def __init__(self, esm_model, num_steps, latent_dim, protein_dim, device):
        super().__init__()
        self.esm_model = esm_model
        self.num_steps = num_steps
        self.latent_dim = latent_dim
        self.protein_dim = protein_dim
        self.device = device
        self.alphabet_size = len(esm_model.alphabet)

        self.esm_encoder_decoder = RefinedESMEncoderDecoder(esm_dim=1280, latent_dim=latent_dim)
        self.denoiser = RefinedDenoiser(latent_dim=latent_dim, protein_dim=protein_dim, alphabet_size=self.alphabet_size)

        self.beta = torch.linspace(1e-4, 0.02, num_steps).to(device)
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.sqrt_alpha_bar = torch.sqrt(self.alpha_bar)
        self.sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar)

    def q_sample(self, x0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x0).to(self.device)
        return (
            self.sqrt_alpha_bar[t, None, None] * x0 +
            self.sqrt_one_minus_alpha_bar[t, None, None] * noise
        )

    def p_losses(self, peptide_latent, protein_emb, motif_emb, onehot_seq, t, target_seq, noise=None):
        if noise is None:
            noise = torch.randn_like(peptide_latent).to(self.device)

        x_noisy = self.q_sample(peptide_latent, t, noise=noise)

        batch_size, seq_len, _ = x_noisy.shape
        protein_emb = F.interpolate(protein_emb.transpose(1, 2), size=seq_len, mode='linear', align_corners=False).transpose(1, 2)
        motif_emb = motif_emb.expand(-1, seq_len, -1)
        onehot_seq = F.interpolate(onehot_seq.float().transpose(1, 2), size=seq_len, mode='linear', align_corners=False).transpose(1, 2)

        use_classifier_free = random.random() < 0.1
        predicted_noise = self.denoiser(x_noisy, protein_emb, motif_emb, onehot_seq, t, use_classifier_free)

        loss = F.mse_loss(predicted_noise, noise)

        decoded_seq = self.esm_encoder_decoder.decode(peptide_latent)
        ce_loss = F.cross_entropy(decoded_seq.view(-1, decoded_seq.size(-1)), target_seq.view(-1))

        total_loss = loss + ce_loss
        return total_loss

def pad_or_truncate(tensor, target_length, pad_value=0):
    current_length = tensor.size(1)
    if current_length < target_length:
        padding = torch.full((tensor.size(0), target_length - current_length, *tensor.size()[2:]), pad_value, device=tensor.device)
        return torch.cat([tensor, padding], dim=1)
    else:
        return tensor[:, :target_length]

def train(model, train_loader, val_loader, optimizer, num_epochs, device):
    wandb.init(project="simple_latent_diffusion", entity="vskavi2003")
    wandb.config.update({
        "learning_rate": optimizer.param_groups[0]['lr'],
        "epochs": num_epochs,
        "batch_size": train_loader.batch_size
    })

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            energy_scores, protein_seq, peptide_seq = batch
            energy_scores = energy_scores.to(device)
            padded_energy_scores = F.pad(energy_scores, (1, 1), value=0)

            batch_converter = model.esm_model.alphabet.get_batch_converter()
            _, _, protein_tokens = batch_converter([(0, protein_seq[0])])
            _, _, peptide_tokens = batch_converter([(0, peptide_seq[0])])

            max_seq_len = max(protein_tokens.size(1), peptide_tokens.size(1), padded_energy_scores.size(1))
            protein_tokens = pad_or_truncate(protein_tokens, max_seq_len, pad_value=model.esm_model.alphabet.padding_idx)
            peptide_tokens = pad_or_truncate(peptide_tokens, max_seq_len, pad_value=model.esm_model.alphabet.padding_idx)
            padded_energy_scores = pad_or_truncate(padded_energy_scores, max_seq_len)

            protein_tokens = protein_tokens.to(device)
            peptide_tokens = peptide_tokens.to(device)
            protein_onehot = F.one_hot(protein_tokens, num_classes=len(model.esm_model.alphabet)).float()

            with torch.no_grad():
                protein_embedding = model.esm_model(protein_tokens, repr_layers=[33], return_contacts=False)["representations"][33]
                peptide_embedding = model.esm_model(peptide_tokens, repr_layers=[33], return_contacts=False)["representations"][33]

            protein_latent = model.esm_encoder_decoder.encoder(protein_embedding)
            peptide_latent = model.esm_encoder_decoder.encoder(peptide_embedding)
            motif_emb = (padded_energy_scores <= -1).float().unsqueeze(-1)

            t = torch.randint(0, model.num_steps, (protein_embedding.shape[0],), device=device).long()
            loss = model.p_losses(peptide_latent, protein_latent, motif_emb, protein_onehot, t, peptide_tokens)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)
        val_loss = validate(model, val_loader, device)

        wandb.log({
            "epoch": epoch+1,
            "train_loss": avg_train_loss,
            "val_loss": val_loss
        })

        print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}")

def validate(model, dataloader, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in dataloader:
            energy_scores, protein_seq, peptide_seq = batch
            energy_scores = energy_scores.to(device)
            padded_energy_scores = F.pad(energy_scores, (1, 1), value=0)

            batch_converter = model.esm_model.alphabet.get_batch_converter()
            _, _, protein_tokens = batch_converter([(0, protein_seq[0])])
            _, _, peptide_tokens = batch_converter([(0, peptide_seq[0])])

            max_seq_len = max(protein_tokens.size(1), peptide_tokens.size(1), padded_energy_scores.size(1))
            protein_tokens = pad_or_truncate(protein_tokens, max_seq_len, pad_value=model.esm_model.alphabet.padding_idx)
            peptide_tokens = pad_or_truncate(peptide_tokens, max_seq_len, pad_value=model.esm_model.alphabet.padding_idx)
            padded_energy_scores = pad_or_truncate(padded_energy_scores, max_seq_len)

            protein_tokens = protein_tokens.to(device)
            peptide_tokens = peptide_tokens.to(device)
            protein_onehot = F.one_hot(protein_tokens, num_classes=len(model.esm_model.alphabet)).float()

            protein_embedding = model.esm_model(protein_tokens, repr_layers=[33], return_contacts=False)["representations"][33]
            peptide_embedding = model.esm_model(peptide_tokens, repr_layers=[33], return_contacts=False)["representations"][33]

            protein_latent = model.esm_encoder_decoder.encoder(protein_embedding)
            peptide_latent = model.esm_encoder_decoder.encoder(peptide_embedding)
            motif_emb = (padded_energy_scores <= -1).float().unsqueeze(-1)

            t = torch.randint(0, model.num_steps, (protein_embedding.shape[0],), device=device).long()
            loss = model.p_losses(peptide_latent, protein_latent, motif_emb, protein_onehot, t, peptide_tokens)

            total_loss += loss.item()

    return total_loss / len(dataloader)

def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Losses')
    plt.legend()
    plt.savefig('loss_plot.png')
    wandb.log({"loss_plot": wandb.Image('loss_plot.png')})

def generate_protein_binders(model, protein_seq, motif, num_samples=1, guidance_scale=3.0):
    device = next(model.parameters()).device

    # Generate protein embedding
    with torch.no_grad():
        protein_tokens = model.esm_model.encode(protein_seq)
        protein_embedding = model.esm_model(protein_tokens.to(device), repr_layers=[33], return_contacts=False)["representations"][33]

    # Process motif based on input type
    if isinstance(motif, str):
        # Motif is a sequence
        motif_tokens = model.esm_model.encode(motif)
        with torch.no_grad():
            motif_embedding = model.esm_model(motif_tokens.to(device), repr_layers=[33], return_contacts=False)["representations"][33]
        motif_repr = model.refined_representation(motif_tokens, torch.ones_like(motif_embedding[:, :, 0]))
    elif isinstance(motif, torch.Tensor):
        # Motif is a binary tensor
        motif_embedding = motif.float().to(device)
        motif_repr = model.refined_representation(protein_tokens, motif_embedding)
    else:
        raise ValueError("Motif must be either a string (sequence) or a torch.Tensor (binary representation)")

    # Create refined representations
    protein_repr = model.refined_representation(protein_tokens, torch.zeros_like(protein_embedding[:, :, 0]))

    # Process through ESM encoder-decoder
    protein_latent = model.esm_encoder_decoder.encoder(protein_embedding)
    motif_latent = model.esm_encoder_decoder.encoder(motif_embedding)

    # Sample from the model
    latent_samples = model.sample(num_samples, protein_latent.shape[1], protein_latent, motif_latent, guidance_scale)

    # Decode the latent samples to amino acid sequences
    generated_sequences = model.esm_encoder_decoder.decoder(latent_samples)

    return generated_sequences

def generate_protein_binders_without_guidance(model, sequence_length, num_samples=1):
    # Sample from the model without guidance
    latent_samples = model.sample_without_guidance(num_samples, sequence_length)

    # Decode the latent samples to amino acid sequences
    generated_sequences = model.esm_encoder_decoder.decoder(latent_samples)

    return generated_sequences


def main():
    # Load and preprocess data
    train_snp = preprocess_snp_data('training_dataset.csv')
    val_snp = preprocess_snp_data('validation_dataset.csv')
    test_snp = preprocess_snp_data('testing_dataset.csv')

    train_snp = filter_datasets(train_snp)
    val_snp = filter_datasets(val_snp)
    test_snp = filter_datasets(test_snp)

    # train_snp, val_snp, test_snp = train_snp[:16], val_snp[16:24], test_snp[24:32] # subset code

    # Calculate max_length
    all_seqs = pd.concat([
        train_snp['peptide_derived_sequence'], train_snp['protein_derived_sequence'],
        val_snp['peptide_derived_sequence'], val_snp['protein_derived_sequence'],
        test_snp['peptide_derived_sequence'], test_snp['protein_derived_sequence']
    ])
    max_length = max(len(seq) for seq in all_seqs)

    # Create datasets
    train_dataset = ProteinInteractionDataset(train_snp)
    val_dataset = ProteinInteractionDataset(val_snp)
    test_dataset = ProteinInteractionDataset(test_snp)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

    # Initialize SimpleLatentDiffusion model
    latent_dim = 128
    protein_dim = esm_model.embed_dim
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_steps = 1000

    latent_diffusion_model = SimpleLatentDiffusion(esm_model, num_steps, latent_dim, protein_dim, device).to(device)

    # Train SimpleLatentDiffusion model
    optimizer = torch.optim.AdamW(latent_diffusion_model.parameters(), lr=1e-4)
    train(latent_diffusion_model, train_loader, val_loader, optimizer, num_epochs=10, device=device)

    # Save the trained model
    torch.save(latent_diffusion_model.state_dict(), 'latent_diffusion_model.pth')

    # Generation (dummy examples)
    latent_diffusion_model.eval()

    # Example usage of generation with guidance (sequence-based motif)
    protein_seq = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
    motif_seq = "LRSLGY"
    generated_binders_seq = generate_protein_binders(latent_diffusion_model, protein_seq, motif_seq, num_samples=3, guidance_scale=3.0)

    print("Generated protein binders with guidance (sequence-based motif):")
    for i, seq in enumerate(generated_binders_seq):
        print(f"Binder {i+1}: {seq}")

    # Example usage of generation with guidance (binary motif)
    binary_motif = torch.zeros(len(protein_seq))
    binary_motif[30:36] = 1  # Assuming the motif is in this region
    generated_binders_binary = generate_protein_binders(latent_diffusion_model, protein_seq, binary_motif, num_samples=3, guidance_scale=3.0)

    print("\nGenerated protein binders with guidance (binary motif):")
    for i, seq in enumerate(generated_binders_binary):
        print(f"Binder {i+1}: {seq}")

    # Example usage of generation without guidance
    generated_binders_no_guidance = generate_protein_binders_without_guidance(latent_diffusion_model, sequence_length=100, num_samples=3)

    print("\nGenerated protein binders without guidance:")
    for i, seq in enumerate(generated_binders_no_guidance):
        print(f"Binder {i+1}: {seq}")

if __name__ == "__main__":
    main()

Total samples: 12353
Mismatched lengths: 0
Total samples: 2390
Mismatched lengths: 0
Total samples: 2782
Mismatched lengths: 0




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_loss,█▅▃▂▂▁▁▁▁▁
val_loss,█▅▃▂▂▂▁▁▁▁

0,1
epoch,10.0
train_loss,1.02998
val_loss,1.12116


Epoch 1/10:   0%|          | 60/12353 [00:20<1:18:52,  2.60it/s]