### morgan fingerprints

In [1]:
!pip install rdkit



In [2]:
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from sklearn.cluster import KMeans
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances

# Read the data
df_combined = pd.read_csv('sigma_data.csv')

# Drop existing cluster columns if they exist
if 'cluster' in df_combined.columns:
    df_combined = df_combined.drop('cluster', axis=1)
if 'Cluster' in df_combined.columns:
    df_combined = df_combined.drop('Cluster', axis=1)

# Generate Morgan fingerprints from SMILES
def smiles_to_morgan_fp(smiles, radius=2, nBits=2048):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
    return np.array(fp)

# Convert SMILES to fingerprints
fingerprints = []
valid_indices = []

for i, smiles in enumerate(df_combined['smiles']):
    try:
        fp = smiles_to_morgan_fp(smiles)
        if fp is not None:
            fingerprints.append(fp)
            valid_indices.append(i)
    except Exception as e:
        print(f"Error processing SMILES {smiles}: {e}")

# Create a numpy array from the fingerprints
X = np.array(fingerprints)

# Increase the number of clusters to create more granular groups
n_clusters = min(40, len(X))  # Increased from 20 to 40 for more fine-grained clusters

# Perform KMeans clustering
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
cluster_labels = kmeans.fit_predict(X)

# Add cluster labels to the original dataframe
df_valid = df_combined.iloc[valid_indices].copy()
df_valid['cluster'] = cluster_labels

# Calculate cluster centers
cluster_centers = kmeans.cluster_centers_

# Split based on cluster dissimilarity for a harder validation set
def split_by_cluster_dissimilarity(df, cluster_centers, val_size=0.3):
    unique_clusters = df['cluster'].unique()
    n_clusters = len(unique_clusters)
    n_val = int(n_clusters * val_size)

    # Calculate pairwise distances between cluster centers
    distances = euclidean_distances(cluster_centers)

    # Start with a random cluster
    np.random.seed(42)
    current_cluster = np.random.choice(unique_clusters)
    val_clusters = [current_cluster]

    # Choose the most distant clusters for validation set
    remaining_clusters = set(unique_clusters) - {current_cluster}

    while len(val_clusters) < n_val:
        # Calculate the average distance of each remaining cluster to all currently selected val clusters
        avg_distances = []
        for cluster in remaining_clusters:
            avg_dist = np.mean([distances[cluster, val_cluster] for val_cluster in val_clusters])
            avg_distances.append((cluster, avg_dist))

        # Choose the most distant cluster
        next_cluster = max(avg_distances, key=lambda x: x[1])[0]
        val_clusters.append(next_cluster)
        remaining_clusters.remove(next_cluster)

    train_clusters = list(set(unique_clusters) - set(val_clusters))

    # Create dataframes
    train_df = df[df['cluster'].isin(train_clusters)]
    val_df = df[df['cluster'].isin(val_clusters)]

    return train_df, val_df

# Perform the split
train_df, val_df = split_by_cluster_dissimilarity(df_valid, cluster_centers)

print(f"Training set: {len(train_df)} samples, {train_df['cluster'].nunique()} clusters")
print(f"Validation set: {len(val_df)} samples, {val_df['cluster'].nunique()} clusters")

# Verify there's no cluster overlap between splits
train_clusters = set(train_df['cluster'].unique())
val_clusters = set(val_df['cluster'].unique())

print(f"Cluster overlap between train and val: {len(train_clusters.intersection(val_clusters))}")

# Save the splits
train_df.to_csv('train_data.csv', index=False)
val_df.to_csv('val_data.csv', index=False)

# Optional: Analyze the chemical diversity of the train and val sets
train_fps = np.array([fingerprints[valid_indices.index(i)] for i in train_df.index if i in valid_indices])
val_fps = np.array([fingerprints[valid_indices.index(i)] for i in val_df.index if i in valid_indices])

# Calculate average Tanimoto similarity within each set (lower means more diverse)
def calculate_avg_tanimoto(fps):
    sum_sim = 0
    count = 0
    for i in range(len(fps)):
        for j in range(i+1, len(fps)):
            # Tanimoto similarity using the dot product of bit vectors
            intersection = np.sum(fps[i] & fps[j])
            union = np.sum(fps[i] | fps[j])
            if union > 0:
                sim = intersection / union
                sum_sim += sim
                count += 1
    return sum_sim / count if count > 0 else 0

# Print diversity metrics
if len(train_fps) > 1 and len(val_fps) > 1:
    train_diversity = calculate_avg_tanimoto(train_fps)
    val_diversity = calculate_avg_tanimoto(val_fps)
    print(f"Average train set similarity: {train_diversity:.4f}")
    print(f"Average validation set similarity: {val_diversity:.4f}")
    print(f"Validation set is {('more diverse' if val_diversity < train_diversity else 'less diverse')} than training set")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  return fit_method(estimator, *args, **kwargs)


Training set: 7956 samples, 17 clusters
Validation set: 1071 samples, 6 clusters
Cluster overlap between train and val: 0
Average train set similarity: 0.4660
Average validation set similarity: 0.3580
Validation set is more diverse than training set


In [3]:
val_df

Unnamed: 0,Plastic Type,Enzyme Name,protein_sequence,smiles,protein_length,synthetic,cluster
102,PVA,PVA_dehydrogenase,MQQNIERNQVSMTTSRFVWGAVMALVALGSASAAELNLPDGAALYR...,CCO,639,,5
103,PVA,PVA_dehydrogenase,MQQNIERNVVSMTTSRFVAGAVMALVALGSASAAELNLPDGAALYQ...,CCO,639,True,5
104,PVA,PVA_dehydrogenase,MQQNKERNQVSRTTSRFVFGAVVALVALGSASAAELNLPDGEALYR...,CCO,639,True,5
105,PVA,PVA_dehydrogenase,MQQNTERNLVSRTTSRFVAGAVLALVALGSASAAEPPLPDGAALYR...,CCO,639,True,5
106,PVA,PVA_dehydrogenase,MSINIRRSDVSMTWSRIVAGAVIALVAAGSASAAELDLPDGAALYR...,CCO,639,True,5
...,...,...,...,...,...,...,...
9022,PE,Manganese_Peroxidase_Iz-MnP2,MRLIGSSLLSASLRLARQAPAAELAACPDGTRVSNSACCAFIPIAQ...,[*]CC[*],385,True,10
9023,PE,Manganese_Peroxidase_Iz-MnP2,MALHLSSLLSASLRLLVAAPAAETAVCPDGTRTSNSACCAFLPLAQ...,[*]CC[*],385,True,10
9024,PE,Manganese_Peroxidase_Iz-MnP2,MALLLSSLLSASPILSRAAPAARSAVCPDGQRVANPACCAFFPIAQ...,[*]CC[*],385,True,10
9025,PE,Manganese_Peroxidase_Iz-MnP2,MALHLSLLLSALARLVRTLSAANTAVCPDGTRVSNSACCAFFPVAQ...,[*]CC[*],385,True,10


In [4]:
train_df

Unnamed: 0,Plastic Type,Enzyme Name,protein_sequence,smiles,protein_length,synthetic,cluster
0,PCL,Cutinase,MKFFALTTLLAATASALPTSHPVQELEARQLGGGTTRNDLTNGNSA...,[*]OCCCCC(=O)[*],231,,7
1,PCL,Cutinase,MKFFALTTLLAATASALPTSAPVAELEARQLGAGTTRNDLTNGNSA...,[*]OCCCCC(=O)[*],231,True,7
2,PCL,Cutinase,MKFFALTTLLAATASALPTSIPVQELEARQLGGGTTRNDLTNGNSA...,[*]OCCCCC(=O)[*],231,True,7
3,PCL,Cutinase,MKFFAITTLLAATASALPTSHPVQELEARQLGGGTTRNDLTNGNSA...,[*]OCCCCC(=O)[*],231,True,7
4,PCL,Cutinase,MKFFALTTLLAATAAALPTSAPVVELEARQLGGGTTRNDLTNGNSA...,[*]OCCCCC(=O)[*],231,True,7
...,...,...,...,...,...,...,...
8920,PET,AAC_BTA_hydrolase,MNPYERGPNPTDAELSASSGPFSVSTENVSALSASGFGGGTIYYPA...,[*]OCCOC(=O)c1ccc(cc1)C(=O)O[*],261,True,1
8921,PET,AAC_BTA_hydrolase,MNPYERGPNPTQAADSASSGPFSVSSENVSTLSASGFGGGTIYYPR...,[*]OCCOC(=O)c1ccc(cc1)C(=O)O[*],261,True,1
8922,PET,AAC_BTA_hydrolase,MNPYERGPNPTDSALSASSGPFSVSRESVFGLSASGFGGGTIYYPI...,[*]OCCOC(=O)c1ccc(cc1)C(=O)O[*],261,True,1
8923,PET,AAC_BTA_hydrolase,MNPYERGPNPTDSSNDASVGPFSVSTENVSGLSASGFGGGTIYYPT...,[*]OCCOC(=O)c1ccc(cc1)C(=O)O[*],261,True,1


### setup

In [5]:
!pip install Levenshtein
!pip install einops
!pip install einops_exts
!pip install torch
!pip install transformers
!pip install tqdm
!pip install sentencepiece
# !pip install fair-esm



In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import pandas as pd
import re
import math
import json
from tqdm import tqdm
from einops import rearrange, repeat
# import esm

# Set up GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# # Load ESM-2 model
# esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
# batch_converter = alphabet.get_batch_converter()
# esm_model = esm_model.to(device)  # Move to GPU if available
# esm_model.eval()  # Set to evaluation mode


Using device: cuda


### data

In [7]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [8]:
def preprocess_snp_data(file_path):
    snp_df = pd.read_csv(file_path)

    # Basic preprocessing and length calculations
    snp_df['smiles_length'] = snp_df['smiles'].apply(len)
    snp_df['protein_length'] = snp_df['protein_sequence'].apply(len)

    return snp_df

def filter_datasets(dataset):
    return dataset[
        (dataset['smiles'].notna()) &
        (dataset['protein_sequence'].notna()) &
        (dataset['smiles_length'] > 0) &
        (dataset['protein_length'] > 0)
    ]

class ProteinGenerationDataset(Dataset):
    def __init__(self, dataframe, max_length):
        self.dataframe = dataframe
        self.max_length = max_length

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        return row['smiles'], row['protein_sequence']

def collate_fn(batch):
    """
    Custom collate function to handle padding within batches.
    Args:
        batch: List of tuples (smiles, protein)
    Returns:
        Padded and batched tensors
    """
    smiles, proteins = zip(*batch)

    # SMILES strings don't need padding as PolyBERT handles that internally
    smiles = list(smiles)

    # Get max length in this batch for proteins (not exceeding dataset max_length)
    max_protein_len = min(max(len(p) for p in proteins), max_length)

    # Pad proteins to max length in batch
    padded_proteins = []
    protein_masks = []

    for protein in proteins:
        if len(protein) > max_protein_len:
            padded = protein[:max_protein_len]
            mask = [1] * max_protein_len
        else:
            padded = protein + ' ' * (max_protein_len - len(protein))
            mask = [1] * len(protein) + [0] * (max_protein_len - len(protein))

        padded_proteins.append(padded)
        protein_masks.append(mask)

    return {
        'smiles': smiles,
        'proteins': padded_proteins,
        'protein_masks': torch.tensor(protein_masks, dtype=torch.bool)
    }

### utils

In [9]:
# Model Components
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0)]

class DoublePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        # Use the full embedding dimension divided into two halves
        self.d_model = d_model
        half_dim = d_model // 2

        # Create position encodings for both input and output positions
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, half_dim, 2) * (-math.log(10000.0) / half_dim))

        # Input position encodings
        pe_input = torch.zeros(max_len, half_dim)
        pe_input[:, 0::2] = torch.sin(position * div_term)
        pe_input[:, 1::2] = torch.cos(position * div_term)

        # Output position encodings
        pe_output = torch.zeros(max_len, half_dim)
        pe_output[:, 0::2] = torch.sin(position * div_term)
        pe_output[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe_input', pe_input)
        self.register_buffer('pe_output', pe_output)

    def forward(self, x, input_positions, output_positions):
        batch_size, seq_length, _ = x.shape

        # Create a tensor of zeros with the same shape as the input
        pos_encoding = torch.zeros_like(x)

        # For each item in the batch
        for b in range(batch_size):
            for t in range(seq_length):
                # Get the input and output positions for this token
                input_pos = input_positions[b, t] if input_positions is not None else t
                output_pos = output_positions[b, t] if output_positions is not None else t

                if input_pos < self.pe_input.size(0) and output_pos < self.pe_output.size(0):
                    # Fill the first half with input position encoding
                    pos_encoding[b, t, :self.d_model//2] = self.pe_input[input_pos]
                    # Fill the second half with output position encoding
                    pos_encoding[b, t, self.d_model//2:] = self.pe_output[output_pos]

        return x + pos_encoding

class PerceiverAttention(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm_media = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x, latents):
        """
        x: [batch_size, seq_len_x, dim]
        latents: [batch_size, seq_len_l, dim]
        """
        batch_size = x.shape[0]

        x = self.norm_media(x)
        latents = self.norm_latents(latents)

        # Ensure latents has correct batch size
        if latents.size(0) != batch_size:
            latents = latents.expand(batch_size, -1, -1)

        q = self.to_q(latents)
        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
        q = q * self.scale

        # Ensure proper concatenation
        kv_input = torch.cat((x, latents), dim=1)  # concatenate along sequence dimension
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)

        sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
        attn = sim.softmax(dim=-1)
        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

class GatedCrossAttentionBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        self.attn = PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads)
        self.attn_gate = nn.Parameter(torch.tensor([0.]))
        self.ff = FeedForward(dim, mult=ff_mult)
        self.ff_gate = nn.Parameter(torch.tensor([0.]))

    def forward(self, x, media):
        """
        x: [batch_size, seq_len_x, dim]
        media: [batch_size, seq_len_m, dim]
        """
        batch_size = x.shape[0]
        target_batch_size = media.size(0)

        # Handle batch size mismatch
        if batch_size > target_batch_size:
            media = media.expand(batch_size, -1, -1)
        elif batch_size < target_batch_size:
            x = x.expand(target_batch_size, -1, -1)

        gate = self.attn_gate.tanh()
        x = self.attn(media, x) * gate + x
        x = self.ff(x) * self.ff_gate.tanh() + x
        return x

class PerceiverResampler(nn.Module):
    def __init__(self, dim, depth, dim_head=64, heads=8, num_latents=64):
        super().__init__()
        # Initialize latents without batch dimension
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                FeedForward(dim=dim)
            ]))

    def forward(self, x):
        batch_size = x.shape[0]
        # Expand latents to batch size
        latents = repeat(self.latents, 'n d -> b n d', b=batch_size)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        return latents

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * mult, bias=False),
            nn.GELU(),
            nn.Linear(dim * mult, dim, bias=False)
        )

    def forward(self, x):
        return self.net(x)

# class PerceiverResampler(nn.Module):
#     def __init__(self, dim, depth, dim_head=64, heads=8, num_latents=64):
#         super().__init__()
#         self.latents = nn.Parameter(torch.randn(num_latents, dim))
#         self.layers = nn.ModuleList([])

#         for _ in range(depth):
#             self.layers.append(nn.ModuleList([
#                 PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
#                 FeedForward(dim=dim)
#             ]))

#     def forward(self, x):
#         latents = repeat(self.latents, 'n d -> b n d', b=x.shape[0])

#         for attn, ff in self.layers:
#             latents = attn(x, latents) + latents
#             latents = ff(latents) + latents

#         return latents

# class GatedCrossAttentionBlock(nn.Module):
#     def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
#         super().__init__()
#         self.attn = PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads)
#         self.attn_gate = nn.Parameter(torch.tensor([0.]))
#         self.ff = FeedForward(dim, mult=ff_mult)
#         self.ff_gate = nn.Parameter(torch.tensor([0.]))

#     def forward(self, x, media):
#         gate = self.attn_gate.tanh()
#         x = self.attn(media, x) * gate + x
#         x = self.ff(x) * self.ff_gate.tanh() + x
#         return x

### PolyBert Encoder

In [10]:
from transformers import AutoTokenizer, AutoModel
import torch


In [11]:
# class PolyBERTEncoder(nn.Module):
#     def __init__(self, output_dim):
#         super().__init__()
#         self.polybert = AutoModel.from_pretrained('kuelumbus/polyBERT')
#         self.tokenizer = AutoTokenizer.from_pretrained('kuelumbus/polyBERT')
#         self.output_dim = output_dim
#         # Add a projection layer to match the required dimension
#         self.projection = nn.Linear(self.polybert.config.hidden_size, output_dim)

#     def mean_pooling(self, model_output, attention_mask):
#         token_embeddings = model_output[0]
#         input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
#         return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

#     def forward(self, smiles_strings):
#         # Tokenize the SMILES strings
#         encoded_input = self.tokenizer(smiles_strings,
#                                      padding=True,
#                                      truncation=True,
#                                      return_tensors='pt').to(next(self.polybert.parameters()).device)

#         # Get PolyBERT embeddings
#         with torch.no_grad():
#             model_output = self.polybert(**encoded_input)

#         # Debug prints
#         print("Model Output Keys:", model_output.keys())  # Check available keys
#         # print("Last Hidden State:", model_output.last_hidden_state)
#         print("Last Hidden State Shape:", model_output.last_hidden_state.shape)

#         # Pool the embeddings
#         pooled_output = self.mean_pooling(model_output, encoded_input['attention_mask'])

#         # print("Pooled Output:", pooled_output)
#         print("Pooled Output Shape:", pooled_output.shape)

#         # Project to required dimension
#         projected_output = self.projection(pooled_output)

#         return projected_output

In [12]:
class PolyBERTEncoder(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.polybert = AutoModel.from_pretrained('kuelumbus/polyBERT')
        self.tokenizer = AutoTokenizer.from_pretrained('kuelumbus/polyBERT')
        self.output_dim = output_dim
        # Project each token embedding to required dimension
        self.projection = nn.Linear(self.polybert.config.hidden_size, output_dim)

    def forward(self, smiles_strings):
        # Tokenize the SMILES strings
        encoded_input = self.tokenizer(smiles_strings,
                                     padding=True,
                                     truncation=True,
                                     return_tensors='pt').to(next(self.polybert.parameters()).device)

        # Get PolyBERT embeddings
        with torch.no_grad():
            model_output = self.polybert(**encoded_input)

        # Debug prints
        # print("Model Output Keys:", model_output.keys())
        # print("Last Hidden State Shape:", model_output.last_hidden_state.shape)  # [batch_size, seq_len, hidden_size]

        # Get sequence embeddings
        sequence_embeddings = model_output.last_hidden_state

        # Project each token embedding to required dimension
        projected_output = self.projection(sequence_embeddings)  # [batch_size, seq_len, output_dim]
        # print("Projected Output Shape:", projected_output.shape)

        return projected_output

### ProtFlamingo

In [13]:
import torch.nn.functional as F


In [14]:
class SigmaProtFlamingo(nn.Module):
    def __init__(self, model_path, max_len, cross_attn_every=3, dim_head=64, heads=8, perceiver_depth=2, perceiver_num_latents=64):
        super().__init__()

        self.protGPT2_model = GPT2LMHeadModel.from_pretrained(model_path)
        self.protGPT2_tokenizer = GPT2Tokenizer.from_pretrained(model_path)
        self.max_len = max_len

        if self.protGPT2_tokenizer.pad_token is None:
            self.protGPT2_tokenizer.pad_token = self.protGPT2_tokenizer.eos_token
            self.protGPT2_model.config.pad_token_id = self.protGPT2_model.config.eos_token_id

        self.cross_attn_every = cross_attn_every

        # PolyBERT encoder for SMILES strings
        self.polybert_encoder = PolyBERTEncoder(self.protGPT2_model.config.n_embd)

        # Replace single positional encoding with double positional encoding
        self.positional_encoding = DoublePositionalEncoding(self.protGPT2_model.config.n_embd, max_len=max_len)

        # Single perceiver resampler for SMILES embeddings
        self.smiles_perceiver = PerceiverResampler(
            dim=self.protGPT2_model.config.n_embd,
            depth=perceiver_depth,
            dim_head=dim_head,
            heads=heads,
            num_latents=perceiver_num_latents
        )

        # Cross attention layers
        num_gpt_layers = len(self.protGPT2_model.transformer.h)
        self.cross_attn = nn.ModuleList([
            GatedCrossAttentionBlock(dim=self.protGPT2_model.config.n_embd, dim_head=dim_head, heads=heads)
            for _ in range(num_gpt_layers)
        ])

        # Combine GPT layers with cross attention
        self.layers = nn.ModuleList()
        for i, block in enumerate(self.protGPT2_model.transformer.h):
            self.layers.append(block)
            if i % cross_attn_every == 0 and i != 0:
                self.layers.append(GatedCrossAttentionBlock(dim=self.protGPT2_model.config.n_embd, dim_head=dim_head, heads=heads))

    def forward(self, smiles_strings, order=None, targets=None, optimize=False, kv_cache=None, burst=False):
        device = next(self.parameters()).device

        # Get SMILES embeddings through PolyBERT
        smiles_embeddings = self.polybert_encoder(smiles_strings)
        processed_smiles = self.smiles_perceiver(smiles_embeddings)

        # Initialize with start token
        gpt_input = self.protGPT2_tokenizer.encode_plus(
            "<|endoftext|>",
            return_tensors="pt",
            padding='max_length',
            max_length=self.max_len,
            truncation=True
        ).to(device)

        input_ids = gpt_input.input_ids.long()
        seq_length = input_ids.size(1)
        batch_size = 1 if isinstance(smiles_strings, str) else len(smiles_strings)

        hidden_states = self.protGPT2_model.transformer.wte(input_ids)

        # If no order is provided, use left-to-right
        if order is None:
            order = torch.arange(seq_length, device=device).unsqueeze(0).repeat(batch_size, 1)

        # Make sure order is the right length
        if order.size(1) > seq_length:
            order = order[:, :seq_length]
        elif order.size(1) < seq_length:
            # Pad order if needed
            padding = torch.arange(order.size(1), seq_length, device=device).unsqueeze(0).repeat(batch_size, 1)
            order = torch.cat([order, padding], dim=1)

        # Map the input tokens according to the order
        # When using random order, we need to reshuffle the input tokens
        if not optimize and not burst:  # Only shuffle during training
            reordered_input_ids = torch.zeros_like(input_ids)
            for b in range(batch_size):
                # Reorder the input tokens according to the order
                reordered_input_ids[b] = input_ids[b, order[b]]

            # Re-embed with reordered tokens
            hidden_states = self.protGPT2_model.transformer.wte(reordered_input_ids)

        # Get input and output positions from the order
        # Input positions: the current position in the order
        # Output positions: the next position in the order
        input_positions = order
        # Shift the order by 1 to get output positions (target positions)
        output_positions = torch.roll(order, -1, dims=1)
        # The last position wraps to the first position
        output_positions[:, -1] = order[:, 0]

        # Apply double positional encoding
        hidden_states = self.positional_encoding(hidden_states, input_positions, output_positions)

        # Create attention mask based on the order
        attention_mask = gpt_input.attention_mask
        num_heads = self.protGPT2_model.config.n_head

        # Create 4D attention mask [batch_size, num_heads, seq_length, seq_length]
        attention_mask = attention_mask.view(batch_size, 1, 1, seq_length)
        attention_mask = attention_mask.expand(batch_size, num_heads, seq_length, seq_length)
        attention_mask = attention_mask.to(dtype=hidden_states.dtype)

        # Create causal mask based on the order
        # A token at position i can attend to tokens at positions j where order[j] <= order[i]
        # Vectorized causal mask creation
        seq_indices = torch.arange(seq_length, device=device)
        expanded_seq_indices_i = seq_indices.unsqueeze(1).expand(seq_length, seq_length)
        expanded_seq_indices_j = seq_indices.unsqueeze(0).expand(seq_length, seq_length)

        causal_mask = torch.zeros((batch_size, seq_length, seq_length), device=device)
        for b in range(batch_size):
            # Get order for this batch
            order_b = order[b]
            # Get order values at positions i and j
            order_i = order_b[expanded_seq_indices_i]
            order_j = order_b[expanded_seq_indices_j]
            # Create mask where order_j <= order_i
            causal_mask[b] = (order_j <= order_i).float()

        # Reshape causal_mask to match attention_mask and combine them
        causal_mask = causal_mask.unsqueeze(1)  # [batch_size, 1, seq_length, seq_length]
        combined_mask = attention_mask * causal_mask

        for i, layer in enumerate(self.layers):
            if isinstance(layer, GatedCrossAttentionBlock):
                hidden_states = layer(hidden_states, processed_smiles)
            else:
                hidden_states = layer(hidden_states, attention_mask=combined_mask)[0]

        # Get logits
        logits = self.protGPT2_model.lm_head(hidden_states)

        if targets is None:
            if optimize:
                # inference-time mini-optimization: only forward the lm_head on the very last position
                return logits[:, [-1], :], None
            return logits, None

        # Compute loss against the targets
        # If targets are provided in original order, we need to shuffle them to match our order
        if targets is not None:
            shuffled_targets = torch.zeros_like(targets)
            for b in range(batch_size):
                # Reorder the targets according to the order
                shuffled_targets[b] = targets[b, order[b]]

            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                shuffled_targets.view(-1),
                ignore_index=-1
            )
        else:
            loss = None

        return logits, loss

    def custom_generate(self, smiles_string, max_length=200):
        device = next(self.parameters()).device

        # Get SMILES embeddings
        smiles_embeddings = self.polybert_encoder(smiles_string)
        processed_smiles = self.smiles_perceiver(smiles_embeddings)

        # Initialize with start token
        input_ids = torch.tensor([[self.protGPT2_tokenizer.bos_token_id]]).to(device)

        # Autoregressive generation
        for _ in range(max_length):
            inputs_embeds = self.protGPT2_model.transformer.wte(input_ids)
            inputs_embeds = self.positional_encoding(inputs_embeds)

            hidden_states = inputs_embeds
            cross_attn_idx = 0

            for i, layer in enumerate(self.layers):
                if isinstance(layer, GatedCrossAttentionBlock):
                    hidden_states = layer(hidden_states, processed_smiles)
                    cross_attn_idx += 1
                else:
                    hidden_states = layer(hidden_states, attention_mask=None)[0]

            next_token_logits = self.protGPT2_model.lm_head(hidden_states[:, -1, :])
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)

            input_ids = torch.cat([input_ids, next_token], dim=-1)

            if next_token.item() == self.protGPT2_tokenizer.eos_token_id:
                break

        return self.protGPT2_tokenizer.decode(input_ids[0], skip_special_tokens=True)

    def generate(self, smiles_string, max_length=50):
        return self.custom_generate(smiles_string, max_length)

    def state_dict(self):
        state_dict = super().state_dict()
        state_dict['smiles_perceiver'] = self.smiles_perceiver.state_dict()
        state_dict['cross_attn'] = self.cross_attn.state_dict()
        state_dict['polybert_encoder'] = self.polybert_encoder.state_dict()
        return state_dict

    def load_state_dict(self, state_dict):
        smiles_perceiver_state = state_dict.pop('smiles_perceiver')
        cross_attn_state = state_dict.pop('cross_attn')
        polybert_encoder_state = state_dict.pop('polybert_encoder')

        super().load_state_dict(state_dict)

        self.smiles_perceiver.load_state_dict(smiles_perceiver_state)
        self.cross_attn.load_state_dict(cross_attn_state)
        self.polybert_encoder.load_state_dict(polybert_encoder_state)

In [15]:
import random

### training

In [16]:
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import json
import pandas as pd
import os
import matplotlib.pyplot as plt

In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
import random
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
from nltk.translate.bleu_score import sentence_bleu
import Levenshtein

# Add this function to model class to get outputs without computing loss internally
def add_forward_without_loss_to_model(model):
    """
    Adds a new method to the model to get outputs without computing loss.
    Call this function before starting training.
    """
    def forward_without_loss(self, smiles_strings, targets=None):
        """Get model outputs without computing loss internally for SigmaProtFlamingo"""
        device = next(self.parameters()).device

        # Get SMILES embeddings through PolyBERT
        smiles_embeddings = self.polybert_encoder(smiles_strings)
        processed_smiles = self.smiles_perceiver(smiles_embeddings)

        # Initialize with start token
        gpt_input = self.protGPT2_tokenizer(
            "<|endoftext|>",
            return_tensors="pt",
            padding='max_length',
            max_length=self.max_len,
            truncation=True
        ).to(device)

        input_ids = gpt_input.input_ids.long()
        seq_length = input_ids.size(1)
        batch_size = 1 if isinstance(smiles_strings, str) else len(smiles_strings)

        hidden_states = self.protGPT2_model.transformer.wte(input_ids)

        # Use left-to-right order for training with AAR focus
        order = torch.arange(seq_length, device=device).unsqueeze(0).repeat(batch_size, 1)

        # Apply double positional encoding
        input_positions = order
        output_positions = torch.roll(order, -1, dims=1)
        output_positions[:, -1] = order[:, 0]

        hidden_states = self.positional_encoding(hidden_states, input_positions, output_positions)

        # Create attention mask
        attention_mask = gpt_input.attention_mask
        num_heads = self.protGPT2_model.config.n_head

        # Create 4D attention mask [batch_size, num_heads, seq_length, seq_length]
        attention_mask = attention_mask.view(batch_size, 1, 1, seq_length)
        attention_mask = attention_mask.expand(batch_size, num_heads, seq_length, seq_length)
        attention_mask = attention_mask.to(dtype=hidden_states.dtype)

        # Create causal mask based on left-to-right order
        seq_indices = torch.arange(seq_length, device=device)
        expanded_seq_indices_i = seq_indices.unsqueeze(1).expand(seq_length, seq_length)
        expanded_seq_indices_j = seq_indices.unsqueeze(0).expand(seq_length, seq_length)

        causal_mask = torch.zeros((batch_size, seq_length, seq_length), device=device)
        for b in range(batch_size):
            order_b = order[b]
            order_i = order_b[expanded_seq_indices_i]
            order_j = order_b[expanded_seq_indices_j]
            causal_mask[b] = (order_j <= order_i).float()

        # Reshape causal_mask to match attention_mask and combine them
        causal_mask = causal_mask.unsqueeze(1)  # [batch_size, 1, seq_length, seq_length]
        combined_mask = attention_mask * causal_mask

        # Process through all layers
        for i, layer in enumerate(self.layers):
            if isinstance(layer, GatedCrossAttentionBlock):
                hidden_states = layer(hidden_states, processed_smiles)
            else:
                hidden_states = layer(hidden_states, attention_mask=combined_mask)[0]

        # Get logits without computing loss
        logits = self.protGPT2_model.lm_head(hidden_states)

        return logits

    # Add the method to the model
    model.forward_without_loss = forward_without_loss.__get__(model, type(model))
    return model

def repetition_penalty_loss(predicted_token_ids, target_token_ids, pad_token_id):
    """
    Penalizes repeated amino acid tokens in a sequence.
    Returns a penalty value for consecutive repeated tokens.
    """
    # Ignore padding tokens
    mask = target_token_ids != pad_token_id

    # Shift the sequence by one token to compare
    prev_tokens = predicted_token_ids[:, :-1]
    curr_tokens = predicted_token_ids[:, 1:]

    # Compute a repetition mask (1 if consecutive tokens are the same, 0 otherwise)
    repetition_mask = (prev_tokens == curr_tokens).float()

    # Apply the mask to only count valid (non-pad) regions
    valid_mask = mask[:, 1:].float()
    repetition_penalty = (repetition_mask * valid_mask).sum() / (valid_mask.sum() + 1e-8)

    return repetition_penalty

def sequence_diversity_loss(predicted_logits, pad_token_id, vocab_size):
    """
    Encourages diversity in token distribution across the sequence.
    Penalizes sequences that use a limited set of amino acids.
    """
    # Get predicted token ids
    predicted_token_ids = torch.argmax(predicted_logits, dim=-1)

    # Create mask to ignore padding
    mask = predicted_token_ids != pad_token_id

    # Count frequency of each token
    batch_size = predicted_token_ids.size(0)
    token_counts = torch.zeros(batch_size, vocab_size, device=predicted_logits.device)

    # For each sequence in the batch
    diversity_loss = 0.0
    for b in range(batch_size):
        # Count tokens in this sequence (excluding padding)
        seq_tokens = predicted_token_ids[b][mask[b]]
        if len(seq_tokens) == 0:
            continue

        # Count each token
        for t in range(vocab_size):
            token_counts[b, t] = (seq_tokens == t).sum()

        # Normalize to get probability distribution
        token_probs = token_counts[b] / (len(seq_tokens) + 1e-8)

        # Calculate entropy (higher is more diverse)
        # We negate entropy to make it a loss (lower is better)
        non_zero_probs = token_probs[token_probs > 0]
        entropy = -torch.sum(non_zero_probs * torch.log(non_zero_probs + 1e-8))

        # Add to batch loss with negation (we want to maximize entropy = diversity)
        diversity_loss += -entropy

    # Average over batch
    return diversity_loss / batch_size

def sequence_similarity_loss(predicted_tokens, target_tokens, tokenizer):
    """
    Calculate a loss based on sequence-level similarity.
    Uses normalized Levenshtein distance.
    """
    batch_size = predicted_tokens.shape[0]
    total_distance = 0.0

    for i in range(batch_size):
        # Convert token IDs to strings, skipping padding and special tokens
        pred_seq = tokenizer.decode(predicted_tokens[i], skip_special_tokens=True)
        target_seq = tokenizer.decode(target_tokens[i], skip_special_tokens=True)

        # Calculate normalized Levenshtein distance
        if len(target_seq) > 0:
            distance = Levenshtein.distance(pred_seq, target_seq) / max(len(target_seq), 1)
            total_distance += distance

    return total_distance / batch_size

def train_with_improved_aar_objective(model, train_loader, val_loader, num_epochs, device,
                           curriculum_steps=0, l2_reg=1e-5, sample_smiles=None):
    # Add the forward_without_loss method to the model
    model = add_forward_without_loss_to_model(model)

    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=l2_reg)

    # Use label smoothing to prevent overconfident predictions
    criterion = nn.CrossEntropyLoss(
        ignore_index=model.protGPT2_tokenizer.pad_token_id,
        reduction='none',
        label_smoothing=0.1  # Add label smoothing to prevent overfitting to common tokens
    )

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    # Hyperparameters for different loss components
    lambda_rep = 0.2       # Weight for repetition penalty
    lambda_div = 0.1       # Weight for diversity loss
    lambda_seq = 0.15      # Weight for sequence-level similarity loss

    loss_log = []
    initial_checkpoint_path = "/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/sigma_ckpt_ca1_full_dataset/sigma_epoch_2.pth"
    new_checkpoint_dir = "/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/sigma_ckpt_full_dataset_enhanced_metrics"
    os.makedirs(new_checkpoint_dir, exist_ok=True)

    best_val_aar = 0.0  # Track best validation AAR instead of loss

    if os.path.exists(initial_checkpoint_path):
        print("Loading initial checkpoint...")
        model.load_state_dict(torch.load(initial_checkpoint_path))

    vocab_size = model.protGPT2_model.config.vocab_size

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_correct = 0
        total_tokens = 0

        # Calculate curriculum ratio (if using curriculum learning)
        curriculum_ratio = min(1.0, epoch / (num_epochs / 2)) if curriculum_steps > 0 else 1.0
        print(f"Curriculum ratio: {curriculum_ratio:.2f}")

        for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            smiles_strings = batch['smiles']
            proteins = batch['proteins']
            optimizer.zero_grad()

            target_encoding = model.protGPT2_tokenizer(
                proteins,
                return_tensors="pt",
                padding='max_length',
                max_length=model.max_len,
                truncation=True
            ).to(device)

            # Get model outputs without computing loss yet
            outputs = model.forward_without_loss(smiles_strings, target_encoding.input_ids)

            # Calculate token-level loss
            token_loss = criterion(outputs.view(-1, outputs.size(-1)), target_encoding.input_ids.view(-1))
            token_loss = token_loss.view(outputs.size(0), -1)

            # Calculate token-level accuracy to create a weighting mask
            predicted_token_ids = torch.argmax(outputs, dim=-1)

            # Ensure predicted_token_ids has the same shape as target_encoding.input_ids
            if predicted_token_ids.shape[1] < model.max_len:
                predicted_token_ids = torch.nn.functional.pad(
                    predicted_token_ids,
                    (0, model.max_len - predicted_token_ids.shape[1]),
                    value=model.protGPT2_tokenizer.pad_token_id
                )

            # Create a mask for non-padding tokens
            pad_mask = target_encoding.input_ids != model.protGPT2_tokenizer.pad_token_id

            # Compute per-token accuracy (1 for correct predictions, 0 for incorrect)
            token_correct = (predicted_token_ids == target_encoding.input_ids) & pad_mask

            # Create a weighting mask that balances correct and incorrect predictions
            # Less aggressive than before - don't overly focus on incorrect tokens
            incorrect_weight = 1.5 + curriculum_ratio * 0.5  # Weight increases from 1.5 to 2.0 over training
            weight_mask = (~token_correct).float() * incorrect_weight + 1.0
            weight_mask = weight_mask * pad_mask.float()  # Zero out padding tokens

            # Apply the weighting mask to the token losses
            weighted_loss = (token_loss * weight_mask).sum() / (weight_mask.sum() + 1e-8)

            # Calculate repetition penalty to discourage repeating the same amino acid
            rep_penalty = repetition_penalty_loss(
                predicted_token_ids,
                target_encoding.input_ids,
                model.protGPT2_tokenizer.pad_token_id
            )

            # Calculate diversity loss to encourage using a wide range of amino acids
            div_loss = sequence_diversity_loss(
                outputs,
                model.protGPT2_tokenizer.pad_token_id,
                vocab_size
            )

            # Calculate sequence-level similarity loss
            seq_loss = sequence_similarity_loss(
                predicted_token_ids,
                target_encoding.input_ids,
                model.protGPT2_tokenizer
            )

            # Combine all losses with appropriate weights
            total_loss_val = (
                weighted_loss +
                lambda_rep * rep_penalty +
                lambda_div * div_loss +
                lambda_seq * seq_loss
            )

            # Print all loss components every 50 batches for monitoring
            if batch_idx % 50 == 0:
                print(f"  Loss components: CE={weighted_loss:.4f}, Rep={rep_penalty:.4f}, Div={div_loss:.4f}, Seq={seq_loss:.4f}")

            # Use the combined loss for backpropagation
            total_loss_val.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # For logging purposes, calculate the standard loss
            standard_loss = token_loss[pad_mask].mean().item() if pad_mask.sum() > 0 else 0
            total_loss += standard_loss

            # Calculate AAR metrics for logging
            correct = token_correct.sum().item()
            total = pad_mask.sum().item()
            total_correct += correct
            total_tokens += total

            # Print batch statistics occasionally
            if batch_idx % 10 == 0:
                batch_aar = (correct / total * 100) if total > 0 else 0
                print(f"  Batch {batch_idx}: Loss={standard_loss:.4f}, AAR={batch_aar:.2f}%")

        avg_loss = total_loss / len(train_loader)
        amino_acid_recovery = total_correct / total_tokens * 100

        # Prevent overflow when calculating perplexity
        try:
            perplexity = math.exp(avg_loss)
        except OverflowError:
            perplexity = float('inf')  # Return infinity if the loss is too high

        print(f"\nEpoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
        print(f"Perplexity: {perplexity}")
        print(f"Amino Acid Recovery: {amino_acid_recovery:.2f}%")

        val_loss, val_perplexity, val_aar, val_results = validate_with_metrics(model, val_loader, nn.CrossEntropyLoss(ignore_index=model.protGPT2_tokenizer.pad_token_id), device)
        print(f"Validation Loss: {val_loss:.4f}, Perplexity: {val_perplexity:.4f}, Amino Acid Recovery: {val_aar:.2f}%")

        # Save validation results
        import json; json.dump(val_results, open("validation_results.json", "w"), indent=4)

        loss_log.append({
            'epoch': epoch+1,
            'train_loss': avg_loss,
            'train_perplexity': perplexity,
            'train_accuracy': amino_acid_recovery,
            'val_loss': val_loss,
            'val_perplexity': val_perplexity,
            'val_accuracy': val_aar
        })

        checkpoint_path = os.path.join(new_checkpoint_dir, f"sigma_epoch_{epoch+1}.pth")

        # Save checkpoint based on better AAR, but also consider perplexity
        # This encourages both accuracy and sequence diversity
        current_score = val_aar - 0.5 * val_perplexity  # Balance AAR and perplexity

        if val_aar > best_val_aar:
            best_val_aar = val_aar
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path} (Validation AAR improved to {best_val_aar:.2f}%)")

        scheduler.step()

    loss_df = pd.DataFrame(loss_log)
    loss_df.to_csv("/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/sigma_improved_aar_log.csv", index=False)

    # Plot training metrics
    plt.figure(figsize=(15, 10))

    # Filter out invalid values for plotting
    loss_df['train_perplexity_plot'] = loss_df['train_perplexity'].apply(lambda x: x if x != -1 and x < 1000 else None)
    loss_df['val_perplexity_plot'] = loss_df['val_perplexity'].apply(lambda x: x if x != -1 and x < 1000 else None)

    plt.subplot(2, 2, 1)
    plt.plot(loss_df['epoch'], loss_df['train_loss'], label='Train')
    plt.plot(loss_df['epoch'], loss_df['val_loss'], label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss vs. Epoch')

    plt.subplot(2, 2, 2)
    plt.plot(loss_df['epoch'], loss_df['train_perplexity'], label='Train')
    plt.plot(loss_df['epoch'], loss_df['val_perplexity'], label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Perplexity')
    plt.legend()
    plt.title('Perplexity vs. Epoch')

    plt.subplot(2, 2, 3)
    plt.plot(loss_df['epoch'], loss_df['train_accuracy'], label='Train')
    plt.plot(loss_df['epoch'], loss_df['val_accuracy'], label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Amino Acid Recovery (%)')
    plt.legend()
    plt.title('AAR vs. Epoch')

    plt.tight_layout()
    plt.savefig("/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/sigma_improved_aar_training_metrics.png")
    plt.close()

In [18]:
def validate_with_enhanced_metrics(model, val_loader, criterion, device):
    """
    Enhanced validation function that evaluates sequence fidelity in addition to AAR.
    """
    model.eval()
    total_loss = 0
    total_correct = 0
    total_tokens = 0

    # Track advanced metrics
    total_levenshtein = 0
    total_diversity = 0
    total_repetition = 0
    total_sequences = 0

    saved_results = []  # Store ground truth vs predicted sequences and SMILES

    with torch.no_grad():
        sampled_batches = random.sample(range(len(val_loader)), min(50, len(val_loader)))

        for i, batch in enumerate(val_loader):
            smiles_strings = batch['smiles']
            proteins = batch['proteins']
            protein_masks = batch['protein_masks'].to(device)

            target_encoding = model.protGPT2_tokenizer(
                proteins,
                return_tensors="pt",
                padding='max_length',
                max_length=model.max_len,
                truncation=True
            ).to(device)

            outputs, loss = model(
                smiles_strings,
                targets=target_encoding.input_ids
            )

            total_loss += loss.item()

            # Amino Acid Recovery Calculation (excluding padding tokens)
            predicted_token_ids = torch.argmax(outputs, dim=-1)
            predicted_token_ids = torch.nn.functional.pad(
                predicted_token_ids, (0, model.max_len - predicted_token_ids.shape[1]),
                value=model.protGPT2_tokenizer.pad_token_id
            )

            mask = target_encoding.input_ids != model.protGPT2_tokenizer.pad_token_id
            correct = (predicted_token_ids[mask] == target_encoding.input_ids[mask]).sum().item()
            total = mask.sum().item()
            total_correct += correct
            total_tokens += total

            # Calculate diversity for each sequence
            for b in range(len(predicted_token_ids)):
                pred_seq_tokens = predicted_token_ids[b][mask[b]]
                if len(pred_seq_tokens) == 0:
                    continue

                # Count unique tokens
                unique_tokens = torch.unique(pred_seq_tokens).size(0)
                seq_diversity = unique_tokens / len(pred_seq_tokens)

                # Measure repetition - look for consecutive repeated tokens
                consecutive_repeats = 0
                for t in range(1, len(pred_seq_tokens)):
                    if pred_seq_tokens[t] == pred_seq_tokens[t-1]:
                        consecutive_repeats += 1

                normalized_repeats = consecutive_repeats / max(1, len(pred_seq_tokens) - 1)

                # Decode sequences for Levenshtein distance
                pred_seq = model.protGPT2_tokenizer.decode(predicted_token_ids[b], skip_special_tokens=True)
                true_seq = model.protGPT2_tokenizer.decode(target_encoding.input_ids[b], skip_special_tokens=True)

                # Calculate normalized Levenshtein distance
                if len(true_seq) > 0:
                    levenshtein_dist = Levenshtein.distance(pred_seq, true_seq) / len(true_seq)
                    total_levenshtein += levenshtein_dist

                total_diversity += seq_diversity
                total_repetition += normalized_repeats
                total_sequences += 1

            # Save randomly selected ground truth vs predicted sequences and SMILES
            if i in sampled_batches:
                ground_truth = model.protGPT2_tokenizer.decode(target_encoding.input_ids[0], skip_special_tokens=True)
                predicted = model.protGPT2_tokenizer.decode(predicted_token_ids[0], skip_special_tokens=True)

                # Calculate per-sequence AAR for this example
                seq_mask = target_encoding.input_ids[0] != model.protGPT2_tokenizer.pad_token_id
                seq_correct = (predicted_token_ids[0][seq_mask] == target_encoding.input_ids[0][seq_mask]).sum().item()
                seq_total = seq_mask.sum().item()
                seq_aar = (seq_correct / seq_total * 100) if seq_total > 0 else 0

                # Calculate additional metrics for this sample
                # 1. Count longest repeated segment
                pred_seq = model.protGPT2_tokenizer.decode(predicted_token_ids[0], skip_special_tokens=True)
                longest_repeat = find_longest_repeat(pred_seq)

                # 2. Calculate amino acid composition similarity
                aa_comp_similarity = amino_acid_composition_similarity(ground_truth, predicted)

                saved_results.append({
                    'SMILES': smiles_strings[0],
                    'Ground Truth': ground_truth,
                    'Predicted': predicted,
                    'Sequence AAR': f"{seq_aar:.2f}%",
                    'Levenshtein Distance': f"{Levenshtein.distance(predicted, ground_truth)}",
                    'Normalized Levenshtein': f"{Levenshtein.distance(predicted, ground_truth) / max(1, len(ground_truth)):.4f}",
                    'Longest Repeat': longest_repeat,
                    'AA Composition Similarity': f"{aa_comp_similarity:.4f}"
                })

    avg_loss = total_loss / len(val_loader)
    amino_acid_recovery = total_correct / total_tokens * 100

    # Calculate average advanced metrics
    avg_levenshtein = total_levenshtein / total_sequences if total_sequences > 0 else 0
    avg_diversity = total_diversity / total_sequences if total_sequences > 0 else 0
    avg_repetition = total_repetition / total_sequences if total_sequences > 0 else 0

    # Prevent overflow when calculating perplexity
    try:
        perplexity = math.exp(avg_loss)
    except OverflowError:
        perplexity = float('inf')  # Return infinity if the loss is too high

    # Add advanced metrics to the results
    evaluation_metrics = {
        'loss': avg_loss,
        'perplexity': perplexity,
        'aar': amino_acid_recovery,
        'levenshtein': avg_levenshtein,
        'diversity': avg_diversity,
        'repetition': avg_repetition
    }

    return avg_loss, perplexity, amino_acid_recovery, saved_results, evaluation_metrics

def find_longest_repeat(sequence):
    """Find the longest repeated substring in the sequence."""
    if not sequence:
        return 0

    longest = 0
    current = 1

    for i in range(1, len(sequence)):
        if sequence[i] == sequence[i-1]:
            current += 1
        else:
            longest = max(longest, current)
            current = 1

    longest = max(longest, current)
    return longest

def amino_acid_composition_similarity(seq1, seq2):
    """
    Calculate the similarity between the amino acid compositions of two sequences.
    Returns a value between 0 and 1, where 1 means identical composition.
    """
    if not seq1 or not seq2:
        return 0

    # Count amino acids in each sequence
    aa_count1 = {}
    aa_count2 = {}

    for aa in seq1:
        aa_count1[aa] = aa_count1.get(aa, 0) + 1

    for aa in seq2:
        aa_count2[aa] = aa_count2.get(aa, 0) + 1

    # Get the union of all amino acids
    all_aas = set(aa_count1.keys()) | set(aa_count2.keys())

    # Calculate cosine similarity
    dot_product = sum(aa_count1.get(aa, 0) * aa_count2.get(aa, 0) for aa in all_aas)

    norm1 = math.sqrt(sum(count**2 for count in aa_count1.values()))
    norm2 = math.sqrt(sum(count**2 for count in aa_count2.values()))

    if norm1 == 0 or norm2 == 0:
        return 0

    return dot_product / (norm1 * norm2)

### inference + training

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load and preprocess data
train_data = preprocess_snp_data('/content/train_data.csv')
val_data = preprocess_snp_data('/content/val_data.csv')

# train_data = train_data.sample(frac=0.01, random_state=42)
# val_data = val_data.sample(frac=0.01, random_state=42)

train_data = filter_datasets(train_data)
val_data = filter_datasets(val_data)

# Calculate max sequence length
max_length = max(
    train_data['protein_length'].max(),
    val_data['protein_length'].max()
)
max_length = min(max_length, 1024)  # Cap at 1024 or your desired maximum
print(f"Max sequence length: {max_length}")

# Create datasets
train_dataset = ProteinGenerationDataset(train_data, max_length)
val_dataset = ProteinGenerationDataset(val_data, max_length)

# Create dataloaders with custom collate function
train_loader = DataLoader(
    train_dataset,
    batch_size=1,  # Adjust based on your GPU memory
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)

# Initialize model with sigma-gpt capabilities
model = SigmaProtFlamingo(
    model_path='nferruz/ProtGPT2',
    max_len=max_length,
    cross_attn_every=1,
    dim_head=64,
    heads=8,
    perceiver_depth=2,
    perceiver_num_latents=64
).to(device)

num_epochs = 10

# Training loop with curriculum learning
# Start with 50% of sequences in left-to-right order and gradually increase to 100% random
curriculum_steps = int(0.5 * num_epochs * len(train_loader))  # Curriculum over first half of training
print("Starting training with sigma-gpt capabilities...")
train_with_improved_aar_objective(
    model,
    train_loader,
    val_loader,
    num_epochs,
    device,
    curriculum_steps=curriculum_steps
)

# # Generate and evaluate
# print("Generating proteins for test set...")
# test_results = generate_and_evaluate(model, test_loader, device)

# # Save results
# print("Saving results...")
# results_path = '/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/test_results.json'
# with open(results_path, 'w') as f:
#     json.dump(test_results, f, indent=2)

# print(f"Results saved to {results_path}")


Using device: cuda
Max sequence length: 914


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Starting training with sigma-gpt capabilities...
Curriculum ratio: 0.00


Epoch 1/10:   0%|          | 0/7956 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


  Loss components: CE=3952.0249, Rep=0.3056, Div=-1.7139, Seq=5.1408


Epoch 1/10:   0%|          | 1/7956 [00:04<10:19:52,  4.68s/it]

  Batch 0: Loss=3952.0249, AAR=0.00%


Epoch 1/10:   0%|          | 11/7956 [00:38<7:33:24,  3.42s/it]

  Batch 10: Loss=1212.5408, AAR=0.00%


Epoch 1/10:   0%|          | 21/7956 [01:12<7:26:10,  3.37s/it]

  Batch 20: Loss=847.3379, AAR=0.00%


Epoch 1/10:   0%|          | 31/7956 [01:46<7:26:49,  3.38s/it]

  Batch 30: Loss=798.9221, AAR=0.00%


Epoch 1/10:   1%|          | 41/7956 [02:19<7:28:36,  3.40s/it]

  Batch 40: Loss=613.4681, AAR=0.00%


Epoch 1/10:   1%|          | 50/7956 [02:50<7:24:16,  3.37s/it]

  Loss components: CE=553.9661, Rep=0.0000, Div=-3.3589, Seq=7.7932


Epoch 1/10:   1%|          | 51/7956 [02:53<7:25:13,  3.38s/it]

  Batch 50: Loss=553.9661, AAR=0.00%


Epoch 1/10:   1%|          | 61/7956 [03:27<7:24:33,  3.38s/it]

  Batch 60: Loss=421.7709, AAR=0.00%


Epoch 1/10:   1%|          | 71/7956 [04:01<7:23:47,  3.38s/it]

  Batch 70: Loss=309.9572, AAR=0.00%


Epoch 1/10:   1%|          | 81/7956 [04:35<7:24:35,  3.39s/it]

  Batch 80: Loss=310.3925, AAR=0.00%


Epoch 1/10:   1%|          | 91/7956 [05:08<7:20:16,  3.36s/it]

  Batch 90: Loss=310.7725, AAR=0.00%


Epoch 1/10:   1%|▏         | 100/7956 [05:38<7:21:03,  3.37s/it]

  Loss components: CE=267.5225, Rep=0.0000, Div=-3.2896, Seq=6.8078


Epoch 1/10:   1%|▏         | 101/7956 [05:42<7:19:05,  3.35s/it]

  Batch 100: Loss=267.5225, AAR=0.00%


Epoch 1/10:   1%|▏         | 111/7956 [06:16<7:22:22,  3.38s/it]

  Batch 110: Loss=258.4881, AAR=0.00%


Epoch 1/10:   2%|▏         | 121/7956 [06:50<7:28:10,  3.43s/it]

  Batch 120: Loss=233.2827, AAR=0.00%


Epoch 1/10:   2%|▏         | 131/7956 [07:23<7:16:48,  3.35s/it]

  Batch 130: Loss=234.5899, AAR=0.00%


Epoch 1/10:   2%|▏         | 141/7956 [07:57<7:15:10,  3.34s/it]

  Batch 140: Loss=228.8950, AAR=0.00%


Epoch 1/10:   2%|▏         | 150/7956 [08:27<7:20:14,  3.38s/it]

  Loss components: CE=199.8562, Rep=0.0000, Div=-3.9665, Seq=5.9921


Epoch 1/10:   2%|▏         | 151/7956 [08:31<7:19:27,  3.38s/it]

  Batch 150: Loss=199.8562, AAR=0.00%


Epoch 1/10:   2%|▏         | 161/7956 [09:04<7:17:38,  3.37s/it]

  Batch 160: Loss=221.5343, AAR=0.00%


Epoch 1/10:   2%|▏         | 171/7956 [09:38<7:18:53,  3.38s/it]

  Batch 170: Loss=201.7257, AAR=0.00%


Epoch 1/10:   2%|▏         | 181/7956 [10:12<7:15:48,  3.36s/it]

  Batch 180: Loss=179.8303, AAR=0.00%


Epoch 1/10:   2%|▏         | 187/7956 [10:32<7:25:02,  3.44s/it]

### generation

In [None]:
import torch
import json
from tqdm import tqdm

# Load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Path to model checkpoint
checkpoint_path = "/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/sigma_checkpoint.pth"

# Load
model = SigmaProtFlamingo(
    model_path='nferruz/ProtGPT2',
    max_len=914,  # Ensure this matches the training max_len
    cross_attn_every=3,
    dim_head=64,
    heads=8,
    perceiver_depth=2,
    perceiver_num_latents=64
).to(device)



In [None]:
ProteinGenerationDataset

In [None]:
# Load trained weights
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

# Load test data
test_data = preprocess_snp_data('/content/augmented_test.csv')
test_data = filter_datasets(test_data)

# Create test dataset and dataloader
test_dataset = ProteinGenerationDataset(test_data,max_length = 914 )
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)



In [None]:
def generate_autoregressively(model, smiles_string, max_length=914, temperature=1.0, random_order=False):
    """Generate protein autoregressively, with option to use random order"""
    device = next(model.parameters()).device

    # Get SMILES embeddings
    smiles_embeddings = model.polybert_encoder([smiles_string])
    processed_smiles = model.smiles_perceiver(smiles_embeddings)

    # Initialize with start token
    input_ids = torch.tensor([[model.protGPT2_tokenizer.bos_token_id]], device=device)

    # If using random order, generate a random permutation
    if random_order:
        order = torch.randperm(max_length, device=device).unsqueeze(0)
    else:
        order = torch.arange(max_length, device=device).unsqueeze(0)

    # Track the current positions in the order
    current_pos = 0

    # Generated sequence in order's positions
    generated_sequence = torch.full((1, max_length), model.protGPT2_tokenizer.pad_token_id, device=device)
    generated_sequence[0, 0] = model.protGPT2_tokenizer.bos_token_id  # Start token

    while current_pos < max_length - 1:
        # Get the next position in the order
        next_pos = current_pos + 1

        # Forward pass to get next token prediction
        with torch.no_grad():
            # Use only the sequence up to the current position
            current_order = order[:, :next_pos]
            current_sequence = generated_sequence[:, current_order[0]]

            # Get logits for the next token
            logits, _ = model(
                smiles_string,
                order=current_order,
                optimize=True
            )

            # Apply temperature and sample
            logits = logits[0, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()

            # Add the token to the generated sequence at the next position in the order
            generated_sequence[0, order[0, next_pos]] = next_token

            # Check for EOS token
            if next_token == model.protGPT2_tokenizer.eos_token_id:
                break

            current_pos = next_pos

    # Decode the generated sequence
    generated_ids = generated_sequence[0].tolist()
    print('generated_ids',generated_ids)
    # Remove padding tokens
    generated_ids = [id for id in generated_ids if id != model.protGPT2_tokenizer.pad_token_id]
    seq = model.protGPT2_tokenizer.decode(generated_ids, skip_special_tokens=True)
    print('seq',seq)
    print("autoregressive gen done...")
    return seq


In [None]:
def generate_with_rejection_sampling(model, smiles_string, max_length=914, num_orders=5, temperature=1.0):
    """Generate protein using token-based rejection sampling with proper MH acceptance ratio"""
    device = next(model.parameters()).device

    # Get SMILES embeddings
    smiles_embeddings = model.polybert_encoder([smiles_string])
    processed_smiles = model.smiles_perceiver(smiles_embeddings)

    # Initialize with start token
    prompt = torch.tensor([[model.protGPT2_tokenizer.bos_token_id]], device=device)

    # Initialize full sequence with padding
    full_seq = torch.full((1, max_length), model.protGPT2_tokenizer.pad_token_id, device=device)
    full_seq[:, 0] = model.protGPT2_tokenizer.bos_token_id  # Start token

    # Track positions that have been filled
    filled_positions = {0}  # Start with position 0 filled

    while len(filled_positions) < max_length:
        remaining_positions = [i for i in range(max_length) if i not in filled_positions]
        if not remaining_positions:
            break

        # Step 1: Sample tokens at all remaining positions from marginal distribution
        # This is our proposal distribution p(x̃)
        candidate_tokens = {}
        proposal_probs = {}  # Store the probability of each proposal

        for pos in remaining_positions:
            # Create current filled sequence context
            current_context = torch.ones((1, max_length), device=device) * model.protGPT2_tokenizer.pad_token_id
            for filled_pos in filled_positions:
                current_context[0, filled_pos] = full_seq[0, filled_pos]

            # Get logits for this position given current context
            with torch.no_grad():
                # Order that puts this position last
                context_order = torch.tensor([list(filled_positions) + [pos]], device=device)

                logits = get_logits_for_position(model, current_context, context_order, smiles_string, pos)

                # Sample a token and record its probability
                logits = logits / temperature
                probs = F.softmax(logits, dim=-1)
                token_dist = torch.distributions.Categorical(probs)
                token = token_dist.sample().item()

                candidate_tokens[pos] = token
                proposal_probs[pos] = probs[0, token].item()

        # Step 2: Evaluate acceptance under different orders
        best_order_acceptances = []

        for _ in range(num_orders):
            # Create a random permutation of remaining positions
            eval_order = random.sample(remaining_positions, len(remaining_positions))

            accepted_tokens = []
            accepted_positions = []
            acceptance_ratios = []

            # Try to accept tokens in this order
            for pos in eval_order:
                # Create sequence with previously accepted tokens
                temp_seq = full_seq.clone()
                for acc_pos in accepted_positions:
                    temp_seq[0, acc_pos] = candidate_tokens[acc_pos]

                # Get conditional probability q(x̃|X,x̃σ<i)
                filled_plus_accepted = list(filled_positions) + accepted_positions
                context_order = torch.tensor([filled_plus_accepted + [pos]], device=device)

                with torch.no_grad():
                    cond_logits = get_logits_for_position(
                        model, temp_seq, context_order, processed_smiles, pos
                    )

                    cond_probs = F.softmax(cond_logits / temperature, dim=-1)
                    cond_prob = cond_probs[0, candidate_tokens[pos]].item()

                # Compute acceptance ratio r = q(x̃i|X,x̃σ<i) / p(x̃i|X)
                # Where p(x̃i|X) is the proposal probability
                acceptance_ratio = min(1.0, cond_prob / proposal_probs[pos])

                # Decide whether to accept
                if random.random() < acceptance_ratio:
                    accepted_tokens.append(candidate_tokens[pos])
                    accepted_positions.append(pos)
                    acceptance_ratios.append(acceptance_ratio)
                else:
                    # Stop at first rejection
                    break

            best_order_acceptances.append((accepted_positions, accepted_tokens, acceptance_ratios))

        # Step 3: Dynamic token acceptance
        best_order_idx = -1
        max_accepted = -1
        min_sequence_idx = -1

        for idx, (accepted_positions, _, acceptance_ratios) in enumerate(best_order_acceptances):
            if len(accepted_positions) > max_accepted:
                max_accepted = len(accepted_positions)
                best_order_idx = idx
                # Find the minimum position in the sequence where we see a rejection
                if len(accepted_positions) < len(remaining_positions):
                    min_sequence_idx = len(accepted_positions)
                else:
                    min_sequence_idx = len(remaining_positions)

        # No need to calculate min across orders if all orders accept all tokens
        if min_sequence_idx == -1:
            min_sequence_idx = len(remaining_positions)

        # Get the best order
        best_order = best_order_acceptances[best_order_idx]
        accepted_positions, accepted_tokens, _ = best_order

        # Limit acceptance to positions before the minimum rejection
        accepted_positions = accepted_positions[:min_sequence_idx]
        accepted_tokens = accepted_tokens[:min_sequence_idx]

        # Update the sequence with accepted tokens
        for pos, token in zip(accepted_positions, accepted_tokens):
            full_seq[0, pos] = token
            filled_positions.add(pos)

            # Check for EOS token
            if token == model.protGPT2_tokenizer.eos_token_id:
                break

    # Decode the generated sequence
    result = model.protGPT2_tokenizer.decode(
        [t for t in full_seq[0].tolist() if t != model.protGPT2_tokenizer.pad_token_id],
        skip_special_tokens=True
    )
    return result

def get_logits_for_position(model, sequence, order, smiles_string, target_position):
    """Helper function to get logits for a specific position"""
    # Run model forward pass
    logits, _ = model(
        smiles_string,  # Pass the SMILES string
        order=order,
        optimize=True
    )

    # Return logits for target position (last position in the order)
    return logits[:, -1, :]

In [None]:
import time
import numpy as np

In [None]:
def evaluate_on_unique_smiles(model, test_loader, device, output_file="generated_proteins_comparison.json"):
    """Generate proteins using both methods on unique SMILES from test set"""
    model.eval()

    # Collect unique SMILES from the test loader
    unique_smiles = set()
    for batch in test_loader:
        unique_smiles.update(batch['smiles'])

    unique_smiles = list(unique_smiles)  # Convert to list
    print(f"Found {len(unique_smiles)} unique SMILES in test set")

    results = []

    # Generate proteins using both methods and time each generation
    for i, smiles in enumerate(tqdm(unique_smiles, desc="Generating proteins")):
        # Track time for autoregressive generation
        start_time = time.time()
        print('autoregressive generations...')
        print(f"Generating protein for SMILES: {smiles}")
        ar_protein = generate_autoregressively(model, smiles, max_length=914, temperature=1.0, random_order=False)
        print(ar_protein)
        ar_time = time.time() - start_time

        # Track time for rejection sampling
        start_time = time.time()
        print('rejection sampling generations...')
        print(f"Generating protein for SMILES: {smiles}")
        rs_protein = generate_with_rejection_sampling(model, smiles, max_length=914, num_orders=5, temperature=1.0)
        print(rs_protein)
        rs_time = time.time() - start_time

        results.append({
            'SMILES': smiles,
            'Autoregressive': {
                'protein': ar_protein,
                'time_seconds': ar_time
            },
            'Rejection_Sampling': {
                'protein': rs_protein,
                'time_seconds': rs_time
            }
        })

        # Print progress occasionally
        if (i + 1) % 5 == 0:
            print(f"\nCompleted {i+1}/{len(unique_smiles)}")
            print(f"Example - SMILES: {smiles}")
            print(f"Autoregressive: {ar_protein[:50]}... ({ar_time:.2f}s)")
            print(f"Rejection Sampling: {rs_protein[:50]}... ({rs_time:.2f}s)")

    # Save results to JSON
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)

    # Calculate and print average times
    ar_times = [r['Autoregressive']['time_seconds'] for r in results]
    rs_times = [r['Rejection_Sampling']['time_seconds'] for r in results]

    print(f"\nGeneration complete!")
    print(f"Average autoregressive generation time: {np.mean(ar_times):.2f}s")
    print(f"Average rejection sampling generation time: {np.mean(rs_times):.2f}s")
    print(f"Speed improvement: {np.mean(ar_times)/np.mean(rs_times):.2f}x")

    return results

In [None]:
results = evaluate_on_unique_smiles(model, test_loader, device, output_file="sigma_gpt_comparison_results.json")