### morgan fingerprints

In [3]:
!pip install rdkit



In [4]:
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from sklearn.cluster import KMeans
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances

df_combined = pd.read_csv('sigma_data.csv')
df_combined = df_combined.drop(columns=[col for col in ['cluster', 'Cluster'] if col in df_combined.columns], errors='ignore')

def smiles_to_morgan_fp(smiles, radius=2, nBits=2048):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    return np.array(AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits))

fingerprints = []
valid_indices = []

for i, smiles in enumerate(df_combined['smiles']):
    try:
        fp = smiles_to_morgan_fp(smiles)
        if fp is not None:
            fingerprints.append(fp)
            valid_indices.append(i)
    except Exception as e:
        print(f"Error processing SMILES {smiles}: {e}")

X = np.array(fingerprints)
n_clusters = min(40, len(X))

kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=15)
cluster_labels = kmeans.fit_predict(X)

df_valid = df_combined.iloc[valid_indices].copy()
df_valid['cluster'] = cluster_labels

def subset_clusters(df, subset_fraction=0.5):
    return df.groupby('cluster', group_keys=False).apply(lambda x: x.sample(frac=subset_fraction, random_state=42))

df_subset = subset_clusters(df_valid, subset_fraction=0.5)
cluster_centers = kmeans.cluster_centers_

def split_by_cluster_dissimilarity(df, cluster_centers, val_size=0.3):
    unique_clusters = df['cluster'].unique()
    n_clusters = len(unique_clusters)
    n_val = int(n_clusters * val_size)
    distances = euclidean_distances(cluster_centers)
    np.random.seed(42)
    current_cluster = np.random.choice(unique_clusters)
    val_clusters = [current_cluster]
    remaining_clusters = set(unique_clusters) - {current_cluster}

    while len(val_clusters) < n_val:
        avg_distances = [(cluster, np.mean([distances[cluster, val_cluster] for val_cluster in val_clusters])) for cluster in remaining_clusters]
        next_cluster = max(avg_distances, key=lambda x: x[1])[0]
        val_clusters.append(next_cluster)
        remaining_clusters.remove(next_cluster)

    train_clusters = list(set(unique_clusters) - set(val_clusters))
    return df[df['cluster'].isin(train_clusters)], df[df['cluster'].isin(val_clusters)]

train_df, val_df = split_by_cluster_dissimilarity(df_subset, cluster_centers)

print(f"Training set: {len(train_df)} samples, {train_df['cluster'].nunique()} clusters")
print(f"Validation set: {len(val_df)} samples, {val_df['cluster'].nunique()} clusters")

train_clusters = set(train_df['cluster'].unique())
val_clusters = set(val_df['cluster'].unique())
print(f"Cluster overlap between train and val: {len(train_clusters.intersection(val_clusters))}")

train_df.to_csv('train_data.csv', index=False)
val_df.to_csv('val_data.csv', index=False)

train_fps = np.array([fingerprints[valid_indices.index(i)] for i in train_df.index if i in valid_indices])
val_fps = np.array([fingerprints[valid_indices.index(i)] for i in val_df.index if i in valid_indices])

def calculate_avg_tanimoto(fps):
    sum_sim, count = 0, 0
    for i in range(len(fps)):
        for j in range(i+1, len(fps)):
            intersection, union = np.sum(fps[i] & fps[j]), np.sum(fps[i] | fps[j])
            if union > 0:
                sum_sim += intersection / union
                count += 1
    return sum_sim / count if count > 0 else 0

if len(train_fps) > 1 and len(val_fps) > 1:
    train_diversity = calculate_avg_tanimoto(train_fps)
    val_diversity = calculate_avg_tanimoto(val_fps)
    print(f"Average train set similarity: {train_diversity:.4f}")
    print(f"Average validation set similarity: {val_diversity:.4f}")
    print(f"Validation set is {('more diverse' if val_diversity < train_diversity else 'less diverse')} than training set")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  return fit_method(estimator, *args, **kwargs)
  return df.groupby('cluster', group_keys=False).apply(lambda x: x.sample(frac=subset_fraction, random_state=42))


Training set: 3980 samples, 17 clusters
Validation set: 536 samples, 6 clusters
Cluster overlap between train and val: 0
Average train set similarity: 0.4520
Average validation set similarity: 0.3789
Validation set is more diverse than training set


In [5]:
val_df

Unnamed: 0,Plastic Type,Enzyme Name,protein_sequence,smiles,protein_length,synthetic,cluster
2673,Nylon,Polyamidase,MTVAEYAAHDATGLAELVRRGQVSAAEVATAARTALEAVNPELCAV...,NCCCCNCCCC(=O)O,479,True,3
2642,Nylon,Nylon_hydrolase,MNTTPVHALTLITGGPAVDPAPRPAGEPAAGGPGKAAEDLVPLRSD...,NCCCCNCCCC(=O)O,355,True,3
2668,Nylon,Polyamidase,MDVAEYAAHDATGLADLIRAGQVSAAEVATAAKTALAAVEPELAAV...,NCCCCNCCCC(=O)O,479,True,3
2627,Nylon,Nylon_hydrolase,MNATPAHALTGIDSGIAVTPAPRLGGDEVFGGSGNAAFDLVPVAST...,NCCCCNCCCC(=O)O,355,True,3
2561,Nylon,Hydrolase,MNTTPVHALTGIDSGIAVDPAPRLAGPPVPGGPGDDAFDLAPGRST...,NCCCCNCCCC(=O)O,355,True,3
...,...,...,...,...,...,...,...
876,PBAT_PBS_PBSA_PCL_PET_PHB_PLA_PHA,Cutinase,MRRRRQAGTGARAGRARAIGVAVLALAVLVGAVGGVAGAEVSTAQD...,[*]OCCCCOC(=O)CC(=O)O[*].[*]OCCCCOC(=O)CC(=O)O...,304,True,16
883,PBAT_PBS_PBSA_PCL_PET_PHB_PLA_PHA,Cutinase,MRIRRSAGAGARARGRRAIVVMTTALAVLVGAVGGVAGAEVATAPD...,[*]OCCCCOC(=O)CC(=O)O[*].[*]OCCCCOC(=O)CC(=O)O...,304,True,16
893,PBAT_PBS_PBSA_PCL_PET_PHB_PLA_PHA,Cutinase,MRIRRQAGTGARRRMARAIGVYTTALAVLTGAVGGVAGAEVATAQD...,[*]OCCCCOC(=O)CC(=O)O[*].[*]OCCCCOC(=O)CC(=O)O...,304,True,16
911,PBAT_PBS_PBSA_PCL_PET_PHB_PLA_PHA,Cutinase,MRIRRSAETGARASRARRITVVTTAVAVLVGAVGGVAGAEVSDAAD...,[*]OCCCCOC(=O)CC(=O)O[*].[*]OCCCCOC(=O)CC(=O)O...,304,True,16


In [6]:
train_df

Unnamed: 0,Plastic Type,Enzyme Name,protein_sequence,smiles,protein_length,synthetic,cluster
3126,PBS,PBS_depolymerase,MHLSRGACDRPFKKETTMTHTFSVRALLAAGALLASAAVSAQTNPY...,[*]OCCCCOC(=O)CC(=O)O[*],304,True,0
3120,PBS,PBS_depolymerase,MTLPLTRPEIPFKEETTMRVHFSVRALLAAGALLASAAVSAQTNPY...,[*]OCCCCOC(=O)CC(=O)O[*],304,True,0
3430,PBSA,Lipase,MVRSMRSRVVAAAVALAMSGAALAGTTAATTAATATAATAATAATA...,[*]OCCCCOC(=O)CC(=O)O[*],370,True,0
3393,PBSA,Lipase,MNLVGHSQGGLTSRYVAAVAPDLVASVTTIGTPHRGSEAADFVQSV...,[*]OCCCCOC(=O)CC(=O)O[*],240,True,0
3381,PBSA,Lipase,MNLVGHSQGGLTSRYVAAVAPDLVASVTTIGTPHRGSEFADFVQSI...,[*]OCCCCOC(=O)CC(=O)O[*],240,True,0
...,...,...,...,...,...,...,...
3648,PLA_PHA_PES_PCL,Lipase,MTETLLYRDMNRAQLDAAYNNTAAVPDFPGIYAAYQARSAAFYASA...,[*]OC(C)C(=O)[*].[*]OC(C)CC(=O)[*].[*]OCCOC(=O...,275,True,22
3640,PLA_PHA_PES_PCL,Lipase,MTMTLLYRDMNQAQLDAAYNNTQAVPDFPGIYAALQARSASFYASA...,[*]OC(C)C(=O)[*].[*]OC(C)CC(=O)[*].[*]OCCOC(=O...,275,True,22
3704,PLA_PHA_PES_PCL,Lipase,MSTLSWVRTVNRTLGWVAPGLVARKMRALFMTPRKRLPRDWELPLL...,[*]OC(C)C(=O)[*].[*]OC(C)CC(=O)[*].[*]OCCOC(=O...,277,True,22
3646,PLA_PHA_PES_PCL,Lipase,MTPTLLYRDMNQAQLDAAYNNTQAVPDFPGIYAAFQARSASFYASA...,[*]OC(C)C(=O)[*].[*]OC(C)CC(=O)[*].[*]OCCOC(=O...,275,True,22


### setup

In [7]:
!pip install Levenshtein
!pip install einops
!pip install einops_exts
!pip install torch
!pip install transformers
!pip install tqdm
!pip install sentencepiece
# !pip install fair-esm

Collecting Levenshtein
  Downloading levenshtein-0.27.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.6 kB)
Collecting rapidfuzz<4.0.0,>=3.9.0 (from Levenshtein)
  Downloading rapidfuzz-3.12.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading levenshtein-0.27.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (161 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m161.7/161.7 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rapidfuzz-3.12.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m3.1/3.1 MB[0m [31m43.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, Levenshtein
Successfully installed Levenshtein-0.27.1 rapid

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import pandas as pd
import re
import math
import json
from tqdm import tqdm
from einops import rearrange, repeat
# import esm

# Set up GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# # Load ESM-2 model
# esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
# batch_converter = alphabet.get_batch_converter()
# esm_model = esm_model.to(device)  # Move to GPU if available
# esm_model.eval()  # Set to evaluation mode


Using device: cuda


### data

In [9]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [10]:
def preprocess_snp_data(file_path):
    snp_df = pd.read_csv(file_path)

    # Basic preprocessing and length calculations
    snp_df['smiles_length'] = snp_df['smiles'].apply(len)
    snp_df['protein_length'] = snp_df['protein_sequence'].apply(len)

    return snp_df

def filter_datasets(dataset):
    return dataset[
        (dataset['smiles'].notna()) &
        (dataset['protein_sequence'].notna()) &
        (dataset['smiles_length'] > 0) &
        (dataset['protein_length'] > 0)
    ]

class ProteinGenerationDataset(Dataset):
    def __init__(self, dataframe, max_length):
        self.dataframe = dataframe
        self.max_length = max_length

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        return row['smiles'], row['protein_sequence']

def collate_fn(batch):
    """
    Custom collate function to handle padding within batches.
    Args:
        batch: List of tuples (smiles, protein)
    Returns:
        Padded and batched tensors
    """
    smiles, proteins = zip(*batch)

    # SMILES strings don't need padding as PolyBERT handles that internally
    smiles = list(smiles)

    # Get max length in this batch for proteins (not exceeding dataset max_length)
    max_protein_len = min(max(len(p) for p in proteins), max_length)

    # Pad proteins to max length in batch
    padded_proteins = []
    protein_masks = []

    for protein in proteins:
        if len(protein) > max_protein_len:
            padded = protein[:max_protein_len]
            mask = [1] * max_protein_len
        else:
            padded = protein + ' ' * (max_protein_len - len(protein))
            mask = [1] * len(protein) + [0] * (max_protein_len - len(protein))

        padded_proteins.append(padded)
        protein_masks.append(mask)

    return {
        'smiles': smiles,
        'proteins': padded_proteins,
        'protein_masks': torch.tensor(protein_masks, dtype=torch.bool)
    }

### utils

In [11]:
# Model Components
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0)]

class DoublePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        # Use the full embedding dimension divided into two halves
        self.d_model = d_model
        half_dim = d_model // 2

        # Create position encodings for both input and output positions
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, half_dim, 2) * (-math.log(10000.0) / half_dim))

        # Input position encodings
        pe_input = torch.zeros(max_len, half_dim)
        pe_input[:, 0::2] = torch.sin(position * div_term)
        pe_input[:, 1::2] = torch.cos(position * div_term)

        # Output position encodings
        pe_output = torch.zeros(max_len, half_dim)
        pe_output[:, 0::2] = torch.sin(position * div_term)
        pe_output[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe_input', pe_input)
        self.register_buffer('pe_output', pe_output)

    def forward(self, x, input_positions, output_positions):
        batch_size, seq_length, _ = x.shape

        # Create a tensor of zeros with the same shape as the input
        pos_encoding = torch.zeros_like(x)

        # For each item in the batch
        for b in range(batch_size):
            for t in range(seq_length):
                # Get the input and output positions for this token
                input_pos = input_positions[b, t] if input_positions is not None else t
                output_pos = output_positions[b, t] if output_positions is not None else t

                if input_pos < self.pe_input.size(0) and output_pos < self.pe_output.size(0):
                    # Fill the first half with input position encoding
                    pos_encoding[b, t, :self.d_model//2] = self.pe_input[input_pos]
                    # Fill the second half with output position encoding
                    pos_encoding[b, t, self.d_model//2:] = self.pe_output[output_pos]

        return x + pos_encoding

class PerceiverAttention(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm_media = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x, latents):
        """
        x: [batch_size, seq_len_x, dim]
        latents: [batch_size, seq_len_l, dim]
        """
        batch_size = x.shape[0]

        x = self.norm_media(x)
        latents = self.norm_latents(latents)

        # Ensure latents has correct batch size
        if latents.size(0) != batch_size:
            latents = latents.expand(batch_size, -1, -1)

        q = self.to_q(latents)
        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
        q = q * self.scale

        # Ensure proper concatenation
        kv_input = torch.cat((x, latents), dim=1)  # concatenate along sequence dimension
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)

        sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
        attn = sim.softmax(dim=-1)
        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

class GatedCrossAttentionBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        self.attn = PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads)
        self.attn_gate = nn.Parameter(torch.tensor([0.]))
        self.ff = FeedForward(dim, mult=ff_mult)
        self.ff_gate = nn.Parameter(torch.tensor([0.]))

    def forward(self, x, media):
        """
        x: [batch_size, seq_len_x, dim]
        media: [batch_size, seq_len_m, dim]
        """
        batch_size = x.shape[0]
        target_batch_size = media.size(0)

        # Handle batch size mismatch
        if batch_size > target_batch_size:
            media = media.expand(batch_size, -1, -1)
        elif batch_size < target_batch_size:
            x = x.expand(target_batch_size, -1, -1)

        gate = self.attn_gate.tanh()
        x = self.attn(media, x) * gate + x
        x = self.ff(x) * self.ff_gate.tanh() + x
        return x

class PerceiverResampler(nn.Module):
    def __init__(self, dim, depth, dim_head=64, heads=8, num_latents=64):
        super().__init__()
        # Initialize latents without batch dimension
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                FeedForward(dim=dim)
            ]))

    def forward(self, x):
        batch_size = x.shape[0]
        # Expand latents to batch size
        latents = repeat(self.latents, 'n d -> b n d', b=batch_size)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        return latents

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * mult, bias=False),
            nn.GELU(),
            nn.Linear(dim * mult, dim, bias=False)
        )

    def forward(self, x):
        return self.net(x)

# class PerceiverResampler(nn.Module):
#     def __init__(self, dim, depth, dim_head=64, heads=8, num_latents=64):
#         super().__init__()
#         self.latents = nn.Parameter(torch.randn(num_latents, dim))
#         self.layers = nn.ModuleList([])

#         for _ in range(depth):
#             self.layers.append(nn.ModuleList([
#                 PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
#                 FeedForward(dim=dim)
#             ]))

#     def forward(self, x):
#         latents = repeat(self.latents, 'n d -> b n d', b=x.shape[0])

#         for attn, ff in self.layers:
#             latents = attn(x, latents) + latents
#             latents = ff(latents) + latents

#         return latents

# class GatedCrossAttentionBlock(nn.Module):
#     def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
#         super().__init__()
#         self.attn = PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads)
#         self.attn_gate = nn.Parameter(torch.tensor([0.]))
#         self.ff = FeedForward(dim, mult=ff_mult)
#         self.ff_gate = nn.Parameter(torch.tensor([0.]))

#     def forward(self, x, media):
#         gate = self.attn_gate.tanh()
#         x = self.attn(media, x) * gate + x
#         x = self.ff(x) * self.ff_gate.tanh() + x
#         return x

### PolyBert Encoder

In [12]:
from transformers import AutoTokenizer, AutoModel
import torch


In [13]:
# class PolyBERTEncoder(nn.Module):
#     def __init__(self, output_dim):
#         super().__init__()
#         self.polybert = AutoModel.from_pretrained('kuelumbus/polyBERT')
#         self.tokenizer = AutoTokenizer.from_pretrained('kuelumbus/polyBERT')
#         self.output_dim = output_dim
#         # Add a projection layer to match the required dimension
#         self.projection = nn.Linear(self.polybert.config.hidden_size, output_dim)

#     def mean_pooling(self, model_output, attention_mask):
#         token_embeddings = model_output[0]
#         input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
#         return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

#     def forward(self, smiles_strings):
#         # Tokenize the SMILES strings
#         encoded_input = self.tokenizer(smiles_strings,
#                                      padding=True,
#                                      truncation=True,
#                                      return_tensors='pt').to(next(self.polybert.parameters()).device)

#         # Get PolyBERT embeddings
#         with torch.no_grad():
#             model_output = self.polybert(**encoded_input)

#         # Debug prints
#         print("Model Output Keys:", model_output.keys())  # Check available keys
#         # print("Last Hidden State:", model_output.last_hidden_state)
#         print("Last Hidden State Shape:", model_output.last_hidden_state.shape)

#         # Pool the embeddings
#         pooled_output = self.mean_pooling(model_output, encoded_input['attention_mask'])

#         # print("Pooled Output:", pooled_output)
#         print("Pooled Output Shape:", pooled_output.shape)

#         # Project to required dimension
#         projected_output = self.projection(pooled_output)

#         return projected_output

In [14]:
class PolyBERTEncoder(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.polybert = AutoModel.from_pretrained('kuelumbus/polyBERT')
        self.tokenizer = AutoTokenizer.from_pretrained('kuelumbus/polyBERT')
        self.output_dim = output_dim
        # Project each token embedding to required dimension
        self.projection = nn.Linear(self.polybert.config.hidden_size, output_dim)

    def forward(self, smiles_strings):
        # Tokenize the SMILES strings
        encoded_input = self.tokenizer(smiles_strings,
                                     padding=True,
                                     truncation=True,
                                     return_tensors='pt').to(next(self.polybert.parameters()).device)

        # Get PolyBERT embeddings
        with torch.no_grad():
            model_output = self.polybert(**encoded_input)

        # Debug prints
        # print("Model Output Keys:", model_output.keys())
        # print("Last Hidden State Shape:", model_output.last_hidden_state.shape)  # [batch_size, seq_len, hidden_size]

        # Get sequence embeddings
        sequence_embeddings = model_output.last_hidden_state

        # Project each token embedding to required dimension
        projected_output = self.projection(sequence_embeddings)  # [batch_size, seq_len, output_dim]
        # print("Projected Output Shape:", projected_output.shape)

        return projected_output

### ProtFlamingo

In [15]:
import torch.nn.functional as F


In [16]:
class SigmaProtFlamingo(nn.Module):
    def __init__(self, model_path, max_len, cross_attn_every=3, dim_head=64, heads=8, perceiver_depth=2, perceiver_num_latents=64):
        super().__init__()

        self.protGPT2_model = GPT2LMHeadModel.from_pretrained(model_path)
        self.protGPT2_tokenizer = GPT2Tokenizer.from_pretrained(model_path)
        self.max_len = max_len

        if self.protGPT2_tokenizer.pad_token is None:
            self.protGPT2_tokenizer.pad_token = self.protGPT2_tokenizer.eos_token
            self.protGPT2_model.config.pad_token_id = self.protGPT2_model.config.eos_token_id

        self.cross_attn_every = cross_attn_every

        # PolyBERT encoder for SMILES strings
        self.polybert_encoder = PolyBERTEncoder(self.protGPT2_model.config.n_embd)

        # Replace single positional encoding with double positional encoding
        self.positional_encoding = DoublePositionalEncoding(self.protGPT2_model.config.n_embd, max_len=max_len)

        # Single perceiver resampler for SMILES embeddings
        self.smiles_perceiver = PerceiverResampler(
            dim=self.protGPT2_model.config.n_embd,
            depth=perceiver_depth,
            dim_head=dim_head,
            heads=heads,
            num_latents=perceiver_num_latents
        )

        # Cross attention layers
        num_gpt_layers = len(self.protGPT2_model.transformer.h)
        self.cross_attn = nn.ModuleList([
            GatedCrossAttentionBlock(dim=self.protGPT2_model.config.n_embd, dim_head=dim_head, heads=heads)
            for _ in range(num_gpt_layers)
        ])

        # Combine GPT layers with cross attention
        self.layers = nn.ModuleList()
        for i, block in enumerate(self.protGPT2_model.transformer.h):
            self.layers.append(block)
            if i % cross_attn_every == 0 and i != 0:
                self.layers.append(GatedCrossAttentionBlock(dim=self.protGPT2_model.config.n_embd, dim_head=dim_head, heads=heads))

    def forward(self, smiles_strings, order=None, targets=None, optimize=False, kv_cache=None, burst=False):
        device = next(self.parameters()).device

        # Get SMILES embeddings through PolyBERT
        smiles_embeddings = self.polybert_encoder(smiles_strings)
        processed_smiles = self.smiles_perceiver(smiles_embeddings)

        # Initialize with start token
        gpt_input = self.protGPT2_tokenizer.encode_plus(
            "<|endoftext|>",
            return_tensors="pt",
            padding='max_length',
            max_length=self.max_len,
            truncation=True
        ).to(device)

        input_ids = gpt_input.input_ids.long()
        seq_length = input_ids.size(1)
        batch_size = 1 if isinstance(smiles_strings, str) else len(smiles_strings)

        hidden_states = self.protGPT2_model.transformer.wte(input_ids)

        # If no order is provided, use left-to-right
        if order is None:
            order = torch.arange(seq_length, device=device).unsqueeze(0).repeat(batch_size, 1)

        # Make sure order is the right length
        if order.size(1) > seq_length:
            order = order[:, :seq_length]
        elif order.size(1) < seq_length:
            # Pad order if needed
            padding = torch.arange(order.size(1), seq_length, device=device).unsqueeze(0).repeat(batch_size, 1)
            order = torch.cat([order, padding], dim=1)

        # Map the input tokens according to the order
        # When using random order, we need to reshuffle the input tokens
        if not optimize and not burst:  # Only shuffle during training
            reordered_input_ids = torch.zeros_like(input_ids)
            for b in range(batch_size):
                # Reorder the input tokens according to the order
                reordered_input_ids[b] = input_ids[b, order[b]]

            # Re-embed with reordered tokens
            hidden_states = self.protGPT2_model.transformer.wte(reordered_input_ids)

        # Get input and output positions from the order
        # Input positions: the current position in the order
        # Output positions: the next position in the order
        input_positions = order
        # Shift the order by 1 to get output positions (target positions)
        output_positions = torch.roll(order, -1, dims=1)
        # The last position wraps to the first position
        output_positions[:, -1] = order[:, 0]

        # Apply double positional encoding
        hidden_states = self.positional_encoding(hidden_states, input_positions, output_positions)

        # Create attention mask based on the order
        attention_mask = gpt_input.attention_mask
        num_heads = self.protGPT2_model.config.n_head

        # Create 4D attention mask [batch_size, num_heads, seq_length, seq_length]
        attention_mask = attention_mask.view(batch_size, 1, 1, seq_length)
        attention_mask = attention_mask.expand(batch_size, num_heads, seq_length, seq_length)
        attention_mask = attention_mask.to(dtype=hidden_states.dtype)

        # Create causal mask based on the order
        # A token at position i can attend to tokens at positions j where order[j] <= order[i]
        # Vectorized causal mask creation
        seq_indices = torch.arange(seq_length, device=device)
        expanded_seq_indices_i = seq_indices.unsqueeze(1).expand(seq_length, seq_length)
        expanded_seq_indices_j = seq_indices.unsqueeze(0).expand(seq_length, seq_length)

        causal_mask = torch.zeros((batch_size, seq_length, seq_length), device=device)
        for b in range(batch_size):
            # Get order for this batch
            order_b = order[b]
            # Get order values at positions i and j
            order_i = order_b[expanded_seq_indices_i]
            order_j = order_b[expanded_seq_indices_j]
            # Create mask where order_j <= order_i
            causal_mask[b] = (order_j <= order_i).float()

        # Reshape causal_mask to match attention_mask and combine them
        causal_mask = causal_mask.unsqueeze(1)  # [batch_size, 1, seq_length, seq_length]
        combined_mask = attention_mask * causal_mask

        for i, layer in enumerate(self.layers):
            if isinstance(layer, GatedCrossAttentionBlock):
                hidden_states = layer(hidden_states, processed_smiles)
            else:
                hidden_states = layer(hidden_states, attention_mask=combined_mask)[0]

        # Get logits
        logits = self.protGPT2_model.lm_head(hidden_states)

        if targets is None:
            if optimize:
                # inference-time mini-optimization: only forward the lm_head on the very last position
                return logits[:, [-1], :], None
            return logits, None

        # Compute loss against the targets
        # If targets are provided in original order, we need to shuffle them to match our order
        if targets is not None:
            shuffled_targets = torch.zeros_like(targets)
            for b in range(batch_size):
                # Reorder the targets according to the order
                shuffled_targets[b] = targets[b, order[b]]

            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                shuffled_targets.view(-1),
                ignore_index=-1
            )
        else:
            loss = None

        return logits, loss

    def custom_generate(self, smiles_string, max_length=200):
        device = next(self.parameters()).device

        # Get SMILES embeddings
        smiles_embeddings = self.polybert_encoder(smiles_string)
        processed_smiles = self.smiles_perceiver(smiles_embeddings)

        # Initialize with start token
        input_ids = torch.tensor([[self.protGPT2_tokenizer.bos_token_id]]).to(device)

        # Autoregressive generation
        for _ in range(max_length):
            inputs_embeds = self.protGPT2_model.transformer.wte(input_ids)
            inputs_embeds = self.positional_encoding(inputs_embeds)

            hidden_states = inputs_embeds
            cross_attn_idx = 0

            for i, layer in enumerate(self.layers):
                if isinstance(layer, GatedCrossAttentionBlock):
                    hidden_states = layer(hidden_states, processed_smiles)
                    cross_attn_idx += 1
                else:
                    hidden_states = layer(hidden_states, attention_mask=None)[0]

            next_token_logits = self.protGPT2_model.lm_head(hidden_states[:, -1, :])
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)

            input_ids = torch.cat([input_ids, next_token], dim=-1)

            if next_token.item() == self.protGPT2_tokenizer.eos_token_id:
                break

        return self.protGPT2_tokenizer.decode(input_ids[0], skip_special_tokens=True)

    def generate(self, smiles_string, max_length=50):
        return self.custom_generate(smiles_string, max_length)

    def state_dict(self):
        state_dict = super().state_dict()
        state_dict['smiles_perceiver'] = self.smiles_perceiver.state_dict()
        state_dict['cross_attn'] = self.cross_attn.state_dict()
        state_dict['polybert_encoder'] = self.polybert_encoder.state_dict()
        return state_dict

    def load_state_dict(self, state_dict):
        smiles_perceiver_state = state_dict.pop('smiles_perceiver')
        cross_attn_state = state_dict.pop('cross_attn')
        polybert_encoder_state = state_dict.pop('polybert_encoder')

        super().load_state_dict(state_dict)

        self.smiles_perceiver.load_state_dict(smiles_perceiver_state)
        self.cross_attn.load_state_dict(cross_attn_state)
        self.polybert_encoder.load_state_dict(polybert_encoder_state)

In [17]:
import random

### training

In [18]:
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import json
import pandas as pd
import os
import matplotlib.pyplot as plt

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
import random
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
from nltk.translate.bleu_score import sentence_bleu
import Levenshtein



def print_model_structure(model):
    print("\n===== MODEL STRUCTURE ANALYSIS =====")

    # 1. Check which layers have cross-attention
    cross_attn_locations = []
    for i, layer in enumerate(model.layers):
        if isinstance(layer, GatedCrossAttentionBlock):
            cross_attn_locations.append(i)

    print(f"\nüìå CROSS-ATTENTION LAYERS:")
    print(f"  Total cross-attention blocks: {len(cross_attn_locations)}")
    print(f"  Located at positions: {cross_attn_locations}")

    # 2. Check parameter freezing by group
    frozen_info = {}
    trainable_info = {}
    total_frozen = 0
    total_trainable = 0

    # Get lm_head status if it exists
    lm_head_status = "NOT FOUND"
    if hasattr(model.protGPT2_model, 'lm_head'):
        if isinstance(model.protGPT2_model.lm_head, GatedCrossAttentionBlock):
            lm_head_status = "REPLACED with GatedCrossAttentionBlock"
        else:
            lm_head_trainable = all(p.requires_grad for p in model.protGPT2_model.lm_head.parameters())
            lm_head_status = "TRAINABLE" if lm_head_trainable else "FROZEN"

            # Count parameters
            lm_head_params = sum(p.numel() for p in model.protGPT2_model.lm_head.parameters())
            if lm_head_trainable:
                total_trainable += lm_head_params
                trainable_info["lm_head"] = lm_head_params
            else:
                total_frozen += lm_head_params
                frozen_info["lm_head"] = lm_head_params

    # Check transformer layer status
    for i in range(36):  # Assuming 36 transformer layers
        layer_name = f"transformer.h.{i}"

        # Find parameters for this layer
        layer_params = []
        for name, param in model.protGPT2_model.named_parameters():
            if f"transformer.h.{i}." in name:
                layer_params.append(param)

        if layer_params:
            layer_trainable = all(p.requires_grad for p in layer_params)
            layer_status = "TRAINABLE" if layer_trainable else "FROZEN"

            # Count parameters
            layer_param_count = sum(p.numel() for p in layer_params)
            if layer_trainable:
                total_trainable += layer_param_count
                trainable_info[layer_name] = layer_param_count
            else:
                total_frozen += layer_param_count
                frozen_info[layer_name] = layer_param_count

    # Print layer freezing status
    print(f"\nüìå LAYER FREEZING STATUS:")
    for i in range(36):
        layer_name = f"transformer.h.{i}"
        if layer_name in trainable_info:
            print(f"  Layer {i:2d}: ‚úÖ TRAINABLE ({trainable_info[layer_name]:,} params)")
        elif layer_name in frozen_info:
            print(f"  Layer {i:2d}: ‚ùÑÔ∏è FROZEN ({frozen_info[layer_name]:,} params)")
        else:
            print(f"  Layer {i:2d}: ‚ö†Ô∏è NOT FOUND")

    print(f"\n  LM Head: {lm_head_status}")

    # Print overall stats
    total_params = total_trainable + total_frozen
    print(f"\nüìå PARAMETER SUMMARY:")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {total_trainable:,} ({total_trainable/total_params:.2%})")
    print(f"  Frozen parameters: {total_frozen:,} ({total_frozen/total_params:.2%})")

    # Cross-attention specific info
    print(f"\nüìå CROSS-ATTENTION DETAILS:")
    cross_attn_count = 0
    for i, layer in enumerate(model.layers):
        if isinstance(layer, GatedCrossAttentionBlock):
            cross_attn_count += 1
            cross_attn_params = sum(p.numel() for p in layer.parameters())
            trainable = all(p.requires_grad for p in layer.parameters())
            print(f"  Cross-Attention #{cross_attn_count} (index {i}): {'‚úÖ TRAINABLE' if trainable else '‚ùÑÔ∏è FROZEN'} ({cross_attn_params:,} params)")

    print("\n===================================\n")

# Add this function to model class to get outputs without computing loss internally
def add_forward_without_loss_to_model(model):
    """
    Adds a new method to the model to get outputs without computing loss.
    Call this function before starting training.
    """
    def forward_without_loss(self, smiles_strings, targets=None):
        """Get model outputs without computing loss internally for SigmaProtFlamingo"""
        device = next(self.parameters()).device

        # Get SMILES embeddings through PolyBERT
        smiles_embeddings = self.polybert_encoder(smiles_strings)
        processed_smiles = self.smiles_perceiver(smiles_embeddings)

        # Initialize with start token
        gpt_input = self.protGPT2_tokenizer(
            "<|endoftext|>",
            return_tensors="pt",
            padding='max_length',
            max_length=self.max_len,
            truncation=True
        ).to(device)

        input_ids = gpt_input.input_ids.long()
        seq_length = input_ids.size(1)
        batch_size = 1 if isinstance(smiles_strings, str) else len(smiles_strings)

        hidden_states = self.protGPT2_model.transformer.wte(input_ids)

        # Use left-to-right order for training with AAR focus
        order = torch.arange(seq_length, device=device).unsqueeze(0).repeat(batch_size, 1)

        # Apply double positional encoding
        input_positions = order
        output_positions = torch.roll(order, -1, dims=1)
        output_positions[:, -1] = order[:, 0]

        hidden_states = self.positional_encoding(hidden_states, input_positions, output_positions)

        # Create attention mask
        attention_mask = gpt_input.attention_mask
        num_heads = self.protGPT2_model.config.n_head

        # Create 4D attention mask [batch_size, num_heads, seq_length, seq_length]
        attention_mask = attention_mask.view(batch_size, 1, 1, seq_length)
        attention_mask = attention_mask.expand(batch_size, num_heads, seq_length, seq_length)
        attention_mask = attention_mask.to(dtype=hidden_states.dtype)

        # Create causal mask based on left-to-right order
        seq_indices = torch.arange(seq_length, device=device)
        expanded_seq_indices_i = seq_indices.unsqueeze(1).expand(seq_length, seq_length)
        expanded_seq_indices_j = seq_indices.unsqueeze(0).expand(seq_length, seq_length)

        causal_mask = torch.zeros((batch_size, seq_length, seq_length), device=device)
        for b in range(batch_size):
            order_b = order[b]
            order_i = order_b[expanded_seq_indices_i]
            order_j = order_b[expanded_seq_indices_j]
            causal_mask[b] = (order_j <= order_i).float()

        # Reshape causal_mask to match attention_mask and combine them
        causal_mask = causal_mask.unsqueeze(1)  # [batch_size, 1, seq_length, seq_length]
        combined_mask = attention_mask * causal_mask

        # Process through all layers
        for i, layer in enumerate(self.layers):
            if isinstance(layer, GatedCrossAttentionBlock):
                hidden_states = layer(hidden_states, processed_smiles)
            else:
                hidden_states = layer(hidden_states, attention_mask=combined_mask)[0]

        # Get logits without computing loss
        logits = self.protGPT2_model.lm_head(hidden_states)

        return logits

    # Add the method to the model
    model.forward_without_loss = forward_without_loss.__get__(model, type(model))
    return model

def repetition_penalty_loss(predicted_token_ids, target_token_ids, pad_token_id):
    """
    Penalizes repeated amino acid tokens in a sequence.
    Returns a penalty value for consecutive repeated tokens.
    """
    # Ignore padding tokens
    mask = target_token_ids != pad_token_id

    # Shift the sequence by one token to compare
    prev_tokens = predicted_token_ids[:, :-1]
    curr_tokens = predicted_token_ids[:, 1:]

    # Compute a repetition mask (1 if consecutive tokens are the same, 0 otherwise)
    repetition_mask = (prev_tokens == curr_tokens).float()

    # Apply the mask to only count valid (non-pad) regions
    valid_mask = mask[:, 1:].float()
    repetition_penalty = (repetition_mask * valid_mask).sum() / (valid_mask.sum() + 1e-8)

    return repetition_penalty

def sequence_diversity_loss(predicted_logits, pad_token_id, vocab_size):
    """
    Encourages diversity in token distribution across the sequence.
    Penalizes sequences that use a limited set of amino acids.
    """
    # Get predicted token ids
    predicted_token_ids = torch.argmax(predicted_logits, dim=-1)

    # Create mask to ignore padding
    mask = predicted_token_ids != pad_token_id

    # Count frequency of each token
    batch_size = predicted_token_ids.size(0)
    token_counts = torch.zeros(batch_size, vocab_size, device=predicted_logits.device)

    # For each sequence in the batch
    diversity_loss = 0.0
    for b in range(batch_size):
        # Count tokens in this sequence (excluding padding)
        seq_tokens = predicted_token_ids[b][mask[b]]
        if len(seq_tokens) == 0:
            continue

        # Count each token
        for t in range(vocab_size):
            token_counts[b, t] = (seq_tokens == t).sum()

        # Normalize to get probability distribution
        token_probs = token_counts[b] / (len(seq_tokens) + 1e-8)

        # Calculate entropy (higher is more diverse)
        # We negate entropy to make it a loss (lower is better)
        non_zero_probs = token_probs[token_probs > 0]
        entropy = -torch.sum(non_zero_probs * torch.log(non_zero_probs + 1e-8))

        # Add to batch loss with negation (we want to maximize entropy = diversity)
        diversity_loss += -entropy

    # Average over batch
    return diversity_loss / batch_size

def sequence_similarity_loss(predicted_tokens, target_tokens, tokenizer):
    """
    Calculate a loss based on sequence-level similarity.
    Uses normalized Levenshtein distance.
    """
    batch_size = predicted_tokens.shape[0]
    total_distance = 0.0

    for i in range(batch_size):
        # Convert token IDs to strings, skipping padding and special tokens
        pred_seq = tokenizer.decode(predicted_tokens[i], skip_special_tokens=True)
        target_seq = tokenizer.decode(target_tokens[i], skip_special_tokens=True)

        # Calculate normalized Levenshtein distance
        if len(target_seq) > 0:
            distance = Levenshtein.distance(pred_seq, target_seq) / max(len(target_seq), 1)
            total_distance += distance

    return total_distance / batch_size


def apply_token_masking(input_ids, tokenizer, mask_prob=0.15):
    masked_input_ids = input_ids.clone()
    labels = input_ids.clone()

    # Create mask for tokens that can be masked (exclude padding)
    padding_mask = input_ids != tokenizer.pad_token_id

    # Generate random mask with specified probability
    random_mask = torch.rand(input_ids.shape, device=input_ids.device) < mask_prob

    # Only apply masking to non-padding tokens
    mask_indices = padding_mask & random_mask

    # Replace masked tokens with mask token
    if tokenizer.mask_token_id is not None:
        mask_token_id = tokenizer.mask_token_id
    else:
        # If model doesn't have a mask token, use a special token or UNK token
        mask_token_id = tokenizer.unk_token_id

    masked_input_ids[mask_indices] = mask_token_id

    # For the loss computation, we only want to predict the masked tokens
    # Set labels to -100 for non-masked tokens (CrossEntropyLoss will ignore these)
    labels[~mask_indices] = -100

    return masked_input_ids, labels


def train_with_improved_aar_objective(model, train_loader, val_loader, num_epochs, device,
                           curriculum_steps=0, l2_reg=1e-5, sample_smiles=None):
    # Add the forward_without_loss method to the model
    model = add_forward_without_loss_to_model(model)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5, weight_decay=l2_reg)

    # Use label smoothing to prevent overconfident predictions
    criterion = nn.CrossEntropyLoss(
        ignore_index=-100,  # Changed from pad_token_id to -100 for MLM
        reduction='none',
        label_smoothing=0.1
    )

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    # Hyperparameters for different loss components
    lambda_rep = 0.2       # Weight for repetition penalty
    # lambda_div = 0.1       # Weight for diversity loss
    lambda_seq = 0.15      # Weight for sequence-level similarity loss

    # MLM parameters
    use_mlm = True
    mlm_prob = 0.15

    loss_log = []
    new_checkpoint_dir = "/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/sigma_ckpt_full_dataset_enhanced_metrics"
    os.makedirs(new_checkpoint_dir, exist_ok=True)

    best_val_aar = 0.0  # Track best validation AAR instead of loss

    vocab_size = model.protGPT2_model.config.vocab_size

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_correct = 0
        total_tokens = 0

        # Calculate curriculum ratio (if using curriculum learning)
        curriculum_ratio = min(1.0, epoch / (num_epochs / 2)) if curriculum_steps > 0 else 1.0
        print(f"Curriculum ratio: {curriculum_ratio:.2f}")

        for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            smiles_strings = batch['smiles']
            proteins = batch['proteins']
            optimizer.zero_grad()

            target_encoding = model.protGPT2_tokenizer(
                proteins,
                return_tensors="pt",
                padding='max_length',
                max_length=model.max_len,
                truncation=True
            ).to(device)

            # Alternate between MLM and autoregressive training
            use_masking_this_batch = use_mlm and (epoch % 2 == 0 or batch_idx % 2 == 0)
            if use_masking_this_batch:
                masked_input_ids, mlm_labels = apply_token_masking(
                    target_encoding.input_ids,
                    model.protGPT2_tokenizer,
                    mask_prob=mlm_prob
                )

                # Get model outputs with masked input
                outputs = model.forward_without_loss(smiles_strings, masked_input_ids)

                # Calculate loss only on masked positions
                token_loss = criterion(outputs.view(-1, outputs.size(-1)), mlm_labels.view(-1))
                token_loss = token_loss.view(outputs.size(0), -1)

                # Create mask for valid tokens (not -100)
                valid_mask = mlm_labels != -100
            else:
                # Original approach without masking
                outputs = model.forward_without_loss(smiles_strings, target_encoding.input_ids)
                token_loss = criterion(outputs.view(-1, outputs.size(-1)), target_encoding.input_ids.view(-1))
                token_loss = token_loss.view(outputs.size(0), -1)
                valid_mask = target_encoding.input_ids != model.protGPT2_tokenizer.pad_token_id

            # Calculate token-level accuracy
            predicted_token_ids = torch.argmax(outputs, dim=-1)

            # Ensure predicted_token_ids has the same shape
            if predicted_token_ids.shape[1] < model.max_len:
                predicted_token_ids = torch.nn.functional.pad(
                    predicted_token_ids,
                    (0, model.max_len - predicted_token_ids.shape[1]),
                    value=model.protGPT2_tokenizer.pad_token_id
                )

            # For MLM, we only check accuracy on masked positions
            if use_mlm:
                original_ids = target_encoding.input_ids.clone()
                mask_positions = mlm_labels != -100

                # Compute accuracy only on masked positions
                token_correct = (predicted_token_ids == original_ids) & mask_positions
                pad_mask = mask_positions
            else:
                # Original accuracy calculation
                pad_mask = target_encoding.input_ids != model.protGPT2_tokenizer.pad_token_id
                token_correct = (predicted_token_ids == target_encoding.input_ids) & pad_mask

            # Create a weighting mask that balances correct and incorrect predictions
            incorrect_weight = 1.5 + curriculum_ratio * 0.5  # Weight increases from 1.5 to 2.0 over training
            weight_mask = (~token_correct).float() * incorrect_weight + 1.0
            weight_mask = weight_mask * valid_mask.float()  # Zero out padding or non-masked tokens

            # Apply the weighting mask to the token losses
            weighted_loss = (token_loss * weight_mask).sum() / (weight_mask.sum() + 1e-8)

            # Calculate repetition penalty
            rep_penalty = repetition_penalty_loss(
                predicted_token_ids,
                target_encoding.input_ids,
                model.protGPT2_tokenizer.pad_token_id
            )

            # # Calculate diversity loss
            # div_loss = sequence_diversity_loss(
            #     outputs,
            #     model.protGPT2_tokenizer.pad_token_id,
            #     vocab_size
            # )

            # Calculate sequence-level similarity loss
            seq_loss = sequence_similarity_loss(
                predicted_token_ids,
                target_encoding.input_ids,
                model.protGPT2_tokenizer
            )

            # Combine all losses with appropriate weights
            total_loss_val = (
                weighted_loss +
                lambda_rep * rep_penalty +
                # lambda_div * div_loss +
                lambda_seq * seq_loss
            )

            # Print all loss components every 50 batches for monitoring
            if batch_idx % 50 == 0:
                print(f"  Loss components: CE={weighted_loss:.4f}, Rep={rep_penalty:.4f}, Seq={seq_loss:.4f}")

            # Use the combined loss for backpropagation
            total_loss_val.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # For logging purposes, calculate the standard loss
            standard_loss = token_loss[valid_mask].mean().item() if valid_mask.sum() > 0 else 0
            total_loss += standard_loss

            # Calculate AAR metrics for logging
            correct = token_correct.sum().item()
            total = pad_mask.sum().item()
            total_correct += correct
            total_tokens += total

            # Print batch statistics occasionally
            if batch_idx % 10 == 0:
                batch_aar = (correct / total * 100) if total > 0 else 0
                print(f"  Batch {batch_idx}: Loss={standard_loss:.4f}, AAR={batch_aar:.2f}%")

        avg_loss = total_loss / len(train_loader)
        amino_acid_recovery = total_correct / total_tokens * 100

        # Prevent overflow when calculating perplexity
        try:
            perplexity = math.exp(avg_loss)
        except OverflowError:
            perplexity = float('inf')  # Return infinity if the loss is too high

        print(f"\nEpoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
        print(f"Perplexity: {perplexity}")
        print(f"Amino Acid Recovery: {amino_acid_recovery:.2f}%")

        val_loss, val_perplexity, val_aar, val_results = validate_with_enhanced_metrics(model, val_loader, nn.CrossEntropyLoss(ignore_index=model.protGPT2_tokenizer.pad_token_id), device)
        print(f"Validation Loss: {val_loss:.4f}, Perplexity: {val_perplexity:.4f}, Amino Acid Recovery: {val_aar:.2f}%")

        # Save validation results
        import json
        json.dump(val_results, open("validation_results.json", "w"), indent=4)

        loss_log.append({
            'epoch': epoch+1,
            'train_loss': avg_loss,
            'train_perplexity': perplexity,
            'train_accuracy': amino_acid_recovery,
            'val_loss': val_loss,
            'val_perplexity': val_perplexity,
            'val_accuracy': val_aar
        })

        checkpoint_path = os.path.join(new_checkpoint_dir, f"sigma_epoch_{epoch+1}.pth")

        # Save checkpoint based on better AAR, but also consider perplexity
        if val_aar > best_val_aar:
            best_val_aar = val_aar
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path} (Validation AAR improved to {best_val_aar:.2f}%)")

        scheduler.step()

    loss_df = pd.DataFrame(loss_log)
    loss_df.to_csv("/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/sigma_improved_aar_log.csv", index=False)

    # Plot training metrics
    plt.figure(figsize=(15, 10))

    # Filter out invalid values for plotting
    loss_df['train_perplexity_plot'] = loss_df['train_perplexity'].apply(lambda x: x if x != -1 and x < 1000 else None)
    loss_df['val_perplexity_plot'] = loss_df['val_perplexity'].apply(lambda x: x if x != -1 and x < 1000 else None)

    plt.subplot(2, 2, 1)
    plt.plot(loss_df['epoch'], loss_df['train_loss'], label='Train')
    plt.plot(loss_df['epoch'], loss_df['val_loss'], label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss vs. Epoch')

    plt.subplot(2, 2, 2)
    plt.plot(loss_df['epoch'], loss_df['train_perplexity'], label='Train')
    plt.plot(loss_df['epoch'], loss_df['val_perplexity'], label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Perplexity')
    plt.legend()
    plt.title('Perplexity vs. Epoch')

    plt.subplot(2, 2, 3)
    plt.plot(loss_df['epoch'], loss_df['train_accuracy'], label='Train')
    plt.plot(loss_df['epoch'], loss_df['val_accuracy'], label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Amino Acid Recovery (%)')
    plt.legend()
    plt.title('AAR vs. Epoch')

    plt.tight_layout()
    plt.savefig("/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/sigma_improved_aar_training_metrics.png")
    plt.close()

In [20]:
def validate_with_enhanced_metrics(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_tokens = 0

    # Track advanced metrics
    total_levenshtein = 0
    total_diversity = 0
    total_repetition = 0
    total_sequences = 0

    saved_results = []  # Store ground truth vs predicted sequences and SMILES

    with torch.no_grad():
        sampled_batches = random.sample(range(len(val_loader)), min(50, len(val_loader)))

        for i, batch in enumerate(val_loader):
            smiles_strings = batch['smiles']
            proteins = batch['proteins']
            protein_masks = batch['protein_masks'].to(device)

            target_encoding = model.protGPT2_tokenizer(
                proteins,
                return_tensors="pt",
                padding='max_length',
                max_length=model.max_len,
                truncation=True
            ).to(device)

            # For validation, we don't use MLM - we want to evaluate on full sequence prediction
            outputs = model.forward_without_loss(smiles_strings, target_encoding.input_ids)

            # Calculate loss on full sequence prediction
            loss_fn = nn.CrossEntropyLoss(ignore_index=model.protGPT2_tokenizer.pad_token_id)
            loss = loss_fn(outputs.view(-1, outputs.size(-1)), target_encoding.input_ids.view(-1))

            total_loss += loss.item()

            # Amino Acid Recovery Calculation (excluding padding tokens)
            predicted_token_ids = torch.argmax(outputs, dim=-1)
            predicted_token_ids = torch.nn.functional.pad(
                predicted_token_ids, (0, model.max_len - predicted_token_ids.shape[1]),
                value=model.protGPT2_tokenizer.pad_token_id
            )

            mask = target_encoding.input_ids != model.protGPT2_tokenizer.pad_token_id
            correct = (predicted_token_ids[mask] == target_encoding.input_ids[mask]).sum().item()
            total = mask.sum().item()
            total_correct += correct
            total_tokens += total

            # Calculate diversity for each sequence
            for b in range(len(predicted_token_ids)):
                pred_seq_tokens = predicted_token_ids[b][mask[b]]
                if len(pred_seq_tokens) == 0:
                    continue

                # Count unique tokens
                unique_tokens = torch.unique(pred_seq_tokens).size(0)
                seq_diversity = unique_tokens / len(pred_seq_tokens)

                # Measure repetition - look for consecutive repeated tokens
                consecutive_repeats = 0
                for t in range(1, len(pred_seq_tokens)):
                    if pred_seq_tokens[t] == pred_seq_tokens[t-1]:
                        consecutive_repeats += 1

                normalized_repeats = consecutive_repeats / max(1, len(pred_seq_tokens) - 1)

                # Decode sequences for Levenshtein distance
                pred_seq = model.protGPT2_tokenizer.decode(predicted_token_ids[b], skip_special_tokens=True)
                true_seq = model.protGPT2_tokenizer.decode(target_encoding.input_ids[b], skip_special_tokens=True)

                # Calculate normalized Levenshtein distance
                if len(true_seq) > 0:
                    levenshtein_dist = Levenshtein.distance(pred_seq, true_seq) / len(true_seq)
                    total_levenshtein += levenshtein_dist

                total_diversity += seq_diversity
                total_repetition += normalized_repeats
                total_sequences += 1

            # Save randomly selected ground truth vs predicted sequences and SMILES
            if i in sampled_batches:
                ground_truth = model.protGPT2_tokenizer.decode(target_encoding.input_ids[0], skip_special_tokens=True)
                predicted = model.protGPT2_tokenizer.decode(predicted_token_ids[0], skip_special_tokens=True)

                # Calculate per-sequence AAR for this example
                seq_mask = target_encoding.input_ids[0] != model.protGPT2_tokenizer.pad_token_id
                seq_correct = (predicted_token_ids[0][seq_mask] == target_encoding.input_ids[0][seq_mask]).sum().item()
                seq_total = seq_mask.sum().item()
                seq_aar = (seq_correct / seq_total * 100) if seq_total > 0 else 0

                # Calculate additional metrics for this sample
                # 1. Count longest repeated segment
                pred_seq = model.protGPT2_tokenizer.decode(predicted_token_ids[0], skip_special_tokens=True)
                longest_repeat = find_longest_repeat(pred_seq)

                # 2. Calculate amino acid composition similarity
                aa_comp_similarity = amino_acid_composition_similarity(ground_truth, predicted)

                saved_results.append({
                    'SMILES': smiles_strings[0],
                    'Ground Truth': ground_truth,
                    'Predicted': predicted,
                    'Sequence AAR': f"{seq_aar:.2f}%",
                    'Levenshtein Distance': f"{Levenshtein.distance(predicted, ground_truth)}",
                    'Normalized Levenshtein': f"{Levenshtein.distance(predicted, ground_truth) / max(1, len(ground_truth)):.4f}",
                    'Longest Repeat': longest_repeat,
                    'AA Composition Similarity': f"{aa_comp_similarity:.4f}"
                })

    avg_loss = total_loss / len(val_loader)
    amino_acid_recovery = total_correct / total_tokens * 100

    # Calculate average advanced metrics
    avg_levenshtein = total_levenshtein / total_sequences if total_sequences > 0 else 0
    avg_diversity = total_diversity / total_sequences if total_sequences > 0 else 0
    avg_repetition = total_repetition / total_sequences if total_sequences > 0 else 0

    # Prevent overflow when calculating perplexity
    try:
        perplexity = math.exp(avg_loss)
    except OverflowError:
        perplexity = float('inf')  # Return infinity if the loss is too high

    # Add advanced metrics to the results
    evaluation_metrics = {
        'loss': avg_loss,
        'perplexity': perplexity,
        'aar': amino_acid_recovery,
        'levenshtein': avg_levenshtein,
        'diversity': avg_diversity,
        'repetition': avg_repetition
    }

    return avg_loss, perplexity, amino_acid_recovery, saved_results

def find_longest_repeat(sequence):
    """Find the longest repeated substring in the sequence."""
    if not sequence:
        return 0

    longest = 0
    current = 1

    for i in range(1, len(sequence)):
        if sequence[i] == sequence[i-1]:
            current += 1
        else:
            longest = max(longest, current)
            current = 1

    longest = max(longest, current)
    return longest

def amino_acid_composition_similarity(seq1, seq2):
    """
    Calculate the similarity between the amino acid compositions of two sequences.
    Returns a value between 0 and 1, where 1 means identical composition.
    """
    if not seq1 or not seq2:
        return 0

    # Count amino acids in each sequence
    aa_count1 = {}
    aa_count2 = {}

    for aa in seq1:
        aa_count1[aa] = aa_count1.get(aa, 0) + 1

    for aa in seq2:
        aa_count2[aa] = aa_count2.get(aa, 0) + 1

    # Get the union of all amino acids
    all_aas = set(aa_count1.keys()) | set(aa_count2.keys())

    # Calculate cosine similarity
    dot_product = sum(aa_count1.get(aa, 0) * aa_count2.get(aa, 0) for aa in all_aas)

    norm1 = math.sqrt(sum(count**2 for count in aa_count1.values()))
    norm2 = math.sqrt(sum(count**2 for count in aa_count2.values()))

    if norm1 == 0 or norm2 == 0:
        return 0

    return dot_product / (norm1 * norm2)

### inference + training

#### all frozen _ 34+35 has cross_attn

In [29]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load and preprocess data
train_data = preprocess_snp_data('/content/train_data.csv')
val_data = preprocess_snp_data('/content/val_data.csv')

# train_data = train_data.sample(frac=0.01, random_state=42)
# val_data = val_data.sample(frac=0.01, random_state=42)

train_data = filter_datasets(train_data)
val_data = filter_datasets(val_data)

# Calculate max sequence length
max_length = max(
    train_data['protein_length'].max(),
    val_data['protein_length'].max()
)
max_length = min(max_length, 1024)  # Cap at 1024 or your desired maximum
print(f"Max sequence length: {max_length}")

# Create datasets
train_dataset = ProteinGenerationDataset(train_data, max_length)
val_dataset = ProteinGenerationDataset(val_data, max_length)

# Create dataloaders with custom collate function
train_loader = DataLoader(
    train_dataset,
    batch_size=1,  # Adjust based on your GPU memory
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)

# # Initialize model with sigma-gpt capabilities
# model = SigmaProtFlamingo(
#     model_path='nferruz/ProtGPT2',
#     max_len=max_length,
#     cross_attn_every=2,
#     dim_head=64,
#     heads=8,
#     perceiver_depth=2,
#     perceiver_num_latents=64
# ).to(device)


# Initialize model with sigma-gpt capabilities but without any cross-attention initially
model = SigmaProtFlamingo(
    model_path='nferruz/ProtGPT2',
    max_len=max_length,
    cross_attn_every=999,
    dim_head=64,
    heads=8,
    perceiver_depth=2,
    perceiver_num_latents=64
)  # Don't move to device yet

# The model structure shows that model.layers contains the transformer blocks
# model.protGPT2_model.transformer.h contains the GPT2Blocks

# Let's recreate the layers list with cross-attention only after the last two blocks
new_layers = []

# First, get all the original transformer blocks
transformer_blocks = model.protGPT2_model.transformer.h

# Total number of transformer blocks
num_blocks = len(transformer_blocks)
print(f"Total transformer blocks: {num_blocks}")

# Add each transformer block, with cross-attention after the last two blocks
for i, block in enumerate(transformer_blocks):
    # Add the transformer block
    new_layers.append(block)

    # Add cross-attention after the last two blocks
    if i == num_blocks - 2 or i == num_blocks - 1:
        print(f"Adding cross-attention after block {i}")
        new_layers.append(GatedCrossAttentionBlock(
            dim=model.protGPT2_model.config.n_embd,
            dim_head=64,
            heads=8
        ))

# Replace the model's layers with our new sequence
model.layers = nn.ModuleList(new_layers)


# Now move the entire model to the device after modifying it
model = model.to(device)

# Count how many cross-attention blocks were added
cross_attn_count = sum(1 for layer in model.layers if isinstance(layer, GatedCrossAttentionBlock))
print(f"Added {cross_attn_count} cross-attention blocks")


# Add this line after replacing model.layers with new_layers and moving to device
model = model.to(device)

# Print more detailed layer structure first
for name, _ in model.protGPT2_model.named_parameters():
    if 'transformer.h' in name:
        print(name)
        break  # Just print one example to see the structure


# Check the highest layer index in the model
max_layer_idx = -1
for name, _ in model.protGPT2_model.named_parameters():
    if 'transformer.h.' in name:
        # Extract the layer index which comes after 'transformer.h.'
        parts = name.split('.')
        if len(parts) > 2:
            try:
                layer_idx = int(parts[2])
                max_layer_idx = max(max_layer_idx, layer_idx)
            except ValueError:
                continue

print(f"Total number of transformer layers: {max_layer_idx + 1}")

# Then modify the freezing code to match the actual structure
# This assumes the layer indexing is inside the parameter names
for name, param in model.protGPT2_model.named_parameters():
  param.requires_grad = False  # Freeze everything else


print_model_structure(model)


###___________________________________________________________________________________

# # Directly check if 'lm_head' exists as an attribute
# if hasattr(model.protGPT2_model, 'lm_head'):
#     print("lm_head exists as an attribute!")
#     print(model.protGPT2_model.lm_head)

#     # Check if it has parameters
#     if hasattr(model.protGPT2_model.lm_head, 'parameters'):
#         print("lm_head has parameters!")

#         # Check requires_grad for lm_head manually
#         for param in model.protGPT2_model.lm_head.parameters():
#             print(f"lm_head requires_grad: {param.requires_grad}")
#     else:
#         print("WARNING: lm_head has no registered parameters!")
# else:
#     print("WARNING: lm_head does not exist as an attribute!")

# # Unfreeze lm_head manually
# if hasattr(model.protGPT2_model, 'lm_head'):
#     for param in model.protGPT2_model.lm_head.parameters():
#         param.requires_grad = True
#     print("lm_head manually unfrozen!")

# # Verify if lm_head is now trainable
# for param in model.protGPT2_model.lm_head.parameters():
#     print(f"lm_head requires_grad: {param.requires_grad}")

# Verify which parameters are trainable
total_params = 0
trainable_params = 0
for name, param in model.named_parameters():
    total_params += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%})")

num_epochs = 10

# Training loop with curriculum learning
# Start with 50% of sequences in left-to-right order and gradually increase to 100% random
curriculum_steps = int(0.5 * num_epochs * len(train_loader))  # Curriculum over first half of training
# print("Starting training with sigma-gpt capabilities...")
# train_with_improved_aar_objective(
#     model,
#     train_loader,
#     val_loader,
#     num_epochs,
#     device,
#     curriculum_steps=curriculum_steps
# )

###___________________________________________________________________________________

# # Generate and evaluate
# print("Generating proteins for test set...")
# test_results = generate_and_evaluate(model, test_loader, device)

# # Save results
# print("Saving results...")
# results_path = '/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/test_results.json'
# with open(results_path, 'w') as f:
#     json.dump(test_results, f, indent=2)

# print(f"Results saved to {results_path}")


Using device: cuda
Max sequence length: 914
Total transformer blocks: 36
Adding cross-attention after block 34
Adding cross-attention after block 35
Added 2 cross-attention blocks
transformer.h.0.ln_1.weight
Total number of transformer layers: 36

===== MODEL STRUCTURE ANALYSIS =====

üìå CROSS-ATTENTION LAYERS:
  Total cross-attention blocks: 2
  Located at positions: [35, 37]

üìå LAYER FREEZING STATUS:
  Layer  0: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  1: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  2: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  3: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  4: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  5: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  6: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  7: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  8: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  9: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer 10: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer 11: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer 12: ‚ùÑÔ∏è FROZEN (19,677,440

#### all but lm head is frozen _ cross attn 34+35+lmhead

In [31]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load and preprocess data
train_data = preprocess_snp_data('/content/train_data.csv')
val_data = preprocess_snp_data('/content/val_data.csv')

# train_data = train_data.sample(frac=0.01, random_state=42)
# val_data = val_data.sample(frac=0.01, random_state=42)

train_data = filter_datasets(train_data)
val_data = filter_datasets(val_data)

# Calculate max sequence length
max_length = max(
    train_data['protein_length'].max(),
    val_data['protein_length'].max()
)
max_length = min(max_length, 1024)  # Cap at 1024 or your desired maximum
print(f"Max sequence length: {max_length}")

# Create datasets
train_dataset = ProteinGenerationDataset(train_data, max_length)
val_dataset = ProteinGenerationDataset(val_data, max_length)

# Create dataloaders with custom collate function
train_loader = DataLoader(
    train_dataset,
    batch_size=1,  # Adjust based on your GPU memory
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)

# # Initialize model with sigma-gpt capabilities
# model = SigmaProtFlamingo(
#     model_path='nferruz/ProtGPT2',
#     max_len=max_length,
#     cross_attn_every=2,
#     dim_head=64,
#     heads=8,
#     perceiver_depth=2,
#     perceiver_num_latents=64
# ).to(device)


# Initialize model with sigma-gpt capabilities but without any cross-attention initially
model = SigmaProtFlamingo(
    model_path='nferruz/ProtGPT2',
    max_len=max_length,
    cross_attn_every=999,
    dim_head=64,
    heads=8,
    perceiver_depth=2,
    perceiver_num_latents=64
)  # Don't move to device yet

# The model structure shows that model.layers contains the transformer blocks
# model.protGPT2_model.transformer.h contains the GPT2Blocks

# Let's recreate the layers list with cross-attention only after the last two blocks
new_layers = []

# First, get all the original transformer blocks
transformer_blocks = model.protGPT2_model.transformer.h

# Total number of transformer blocks
num_blocks = len(transformer_blocks)
print(f"Total transformer blocks: {num_blocks}")

# Add each transformer block, with cross-attention after the last two blocks
for i, block in enumerate(transformer_blocks):
    # Add the transformer block
    new_layers.append(block)

    # Add cross-attention after the last two blocks
    if i == num_blocks - 2 or i == num_blocks - 1:
        print(f"Adding cross-attention after block {i}")
        new_layers.append(GatedCrossAttentionBlock(
            dim=model.protGPT2_model.config.n_embd,
            dim_head=64,
            heads=8
        ))

# Replace the model's layers with our new sequence
model.layers = nn.ModuleList(new_layers)

# Now move the entire model to the device after modifying it
model = model.to(device)

# Count how many cross-attention blocks were added
cross_attn_count = sum(1 for layer in model.layers if isinstance(layer, GatedCrossAttentionBlock))
print(f"Added {cross_attn_count} cross-attention blocks")


# Print more detailed layer structure first
for name, _ in model.protGPT2_model.named_parameters():
    if 'transformer.h' in name:
        print(name)
        break  # Just print one example to see the structure


# Check the highest layer index in the model
max_layer_idx = -1
for name, _ in model.protGPT2_model.named_parameters():
    if 'transformer.h.' in name:
        # Extract the layer index which comes after 'transformer.h.'
        parts = name.split('.')
        if len(parts) > 2:
            try:
                layer_idx = int(parts[2])
                max_layer_idx = max(max_layer_idx, layer_idx)
            except ValueError:
                continue

print(f"Total number of transformer layers: {max_layer_idx + 1}")

# Then modify the freezing code to match the actual structure
# This assumes the layer indexing is inside the parameter names
for name, param in model.protGPT2_model.named_parameters():
    param.requires_grad = False  # Freeze everything else






###___________________________________________________________________________________

# # Directly check if 'lm_head' exists as an attribute
# if hasattr(model.protGPT2_model, 'lm_head'):
#     print("lm_head exists as an attribute!")
#     print(model.protGPT2_model.lm_head)

#     # Check if it has parameters
#     if hasattr(model.protGPT2_model.lm_head, 'parameters'):
#         print("lm_head has parameters!")

#         # Check requires_grad for lm_head manually
#         for param in model.protGPT2_model.lm_head.parameters():
#             print(f"lm_head requires_grad: {param.requires_grad}")
#     else:
#         print("WARNING: lm_head has no registered parameters!")
# else:
#     print("WARNING: lm_head does not exist as an attribute!")

# Unfreeze lm_head manually
if hasattr(model.protGPT2_model, 'lm_head'):
    for param in model.protGPT2_model.lm_head.parameters():
        param.requires_grad = True
    print("lm_head manually unfrozen!")

# Verify if lm_head is now trainable
for param in model.protGPT2_model.lm_head.parameters():
    print(f"lm_head requires_grad: {param.requires_grad}")


print_model_structure(model)

# Verify which parameters are trainable
total_params = 0
trainable_params = 0
for name, param in model.named_parameters():
    total_params += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%})")

num_epochs = 10

# Training loop with curriculum learning
# Start with 50% of sequences in left-to-right order and gradually increase to 100% random
curriculum_steps = int(0.5 * num_epochs * len(train_loader))  # Curriculum over first half of training
# print("Starting training with sigma-gpt capabilities...")
# train_with_improved_aar_objective(
#     model,
#     train_loader,
#     val_loader,
#     num_epochs,
#     device,
#     curriculum_steps=curriculum_steps
# )

###___________________________________________________________________________________

# # Generate and evaluate
# print("Generating proteins for test set...")
# test_results = generate_and_evaluate(model, test_loader, device)

# # Save results
# print("Saving results...")
# results_path = '/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/test_results.json'
# with open(results_path, 'w') as f:
#     json.dump(test_results, f, indent=2)

# print(f"Results saved to {results_path}")


Using device: cuda
Max sequence length: 914
Total transformer blocks: 36
Adding cross-attention after block 34
Adding cross-attention after block 35
Added 2 cross-attention blocks
transformer.h.0.ln_1.weight
Total number of transformer layers: 36
lm_head manually unfrozen!
lm_head requires_grad: True

===== MODEL STRUCTURE ANALYSIS =====

üìå CROSS-ATTENTION LAYERS:
  Total cross-attention blocks: 2
  Located at positions: [35, 37]

üìå LAYER FREEZING STATUS:
  Layer  0: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  1: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  2: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  3: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  4: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  5: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  6: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  7: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  8: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  9: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer 10: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer 11: ‚ùÑÔ∏è FROZEN (1

#### unfreeze 34+35 and do cross attn with 34+35

In [27]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load and preprocess data
train_data = preprocess_snp_data('/content/train_data.csv')
val_data = preprocess_snp_data('/content/val_data.csv')

# train_data = train_data.sample(frac=0.01, random_state=42)
# val_data = val_data.sample(frac=0.01, random_state=42)

train_data = filter_datasets(train_data)
val_data = filter_datasets(val_data)

# Calculate max sequence length
max_length = max(
    train_data['protein_length'].max(),
    val_data['protein_length'].max()
)
max_length = min(max_length, 1024)  # Cap at 1024 or your desired maximum
print(f"Max sequence length: {max_length}")

# Create datasets
train_dataset = ProteinGenerationDataset(train_data, max_length)
val_dataset = ProteinGenerationDataset(val_data, max_length)

# Create dataloaders with custom collate function
train_loader = DataLoader(
    train_dataset,
    batch_size=1,  # Adjust based on your GPU memory
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)

# # Initialize model with sigma-gpt capabilities
# model = SigmaProtFlamingo(
#     model_path='nferruz/ProtGPT2',
#     max_len=max_length,
#     cross_attn_every=2,
#     dim_head=64,
#     heads=8,
#     perceiver_depth=2,
#     perceiver_num_latents=64
# ).to(device)


# Initialize model with sigma-gpt capabilities but without any cross-attention initially
model = SigmaProtFlamingo(
    model_path='nferruz/ProtGPT2',
    max_len=max_length,
    cross_attn_every=999,
    dim_head=64,
    heads=8,
    perceiver_depth=2,
    perceiver_num_latents=64
)  # Don't move to device yet

# The model structure shows that model.layers contains the transformer blocks
# model.protGPT2_model.transformer.h contains the GPT2Blocks

# Let's recreate the layers list with cross-attention only after the last two blocks
new_layers = []

# First, get all the original transformer blocks
transformer_blocks = model.protGPT2_model.transformer.h

# Total number of transformer blocks
num_blocks = len(transformer_blocks)
print(f"Total transformer blocks: {num_blocks}")

# Add each transformer block, with cross-attention after the last two blocks
for i, block in enumerate(transformer_blocks):
    # Add the transformer block
    new_layers.append(block)

    # Add cross-attention after the last two blocks
    if i == num_blocks - 2 or i == num_blocks - 1:
        print(f"Adding cross-attention after block {i}")
        new_layers.append(GatedCrossAttentionBlock(
            dim=model.protGPT2_model.config.n_embd,
            dim_head=64,
            heads=8
        ))

# Replace the model's layers with our new sequence
model.layers = nn.ModuleList(new_layers)

# Now move the entire model to the device after modifying it
model = model.to(device)

# Count how many cross-attention blocks were added
cross_attn_count = sum(1 for layer in model.layers if isinstance(layer, GatedCrossAttentionBlock))
print(f"Added {cross_attn_count} cross-attention blocks")

# Print more detailed layer structure first
for name, _ in model.protGPT2_model.named_parameters():
    if 'transformer.h' in name:
        print(name)
        break  # Just print one example to see the structure


# Check the highest layer index in the model
max_layer_idx = -1
for name, _ in model.protGPT2_model.named_parameters():
    if 'transformer.h.' in name:
        # Extract the layer index which comes after 'transformer.h.'
        parts = name.split('.')
        if len(parts) > 2:
            try:
                layer_idx = int(parts[2])
                max_layer_idx = max(max_layer_idx, layer_idx)
            except ValueError:
                continue

print(f"Total number of transformer layers: {max_layer_idx + 1}")

# Then modify the freezing code to match the actual structure
# This assumes the layer indexing is inside the parameter names
for name, param in model.protGPT2_model.named_parameters():
    if 'lm_head' in name or 'transformer.h.34' in name or 'transformer.h.35' in name:
        param.requires_grad = True  # Unfreeze
    else:
        param.requires_grad = False  # Freeze everything else



print_model_structure(model)


###___________________________________________________________________________________

# # Directly check if 'lm_head' exists as an attribute
# if hasattr(model.protGPT2_model, 'lm_head'):
#     print("lm_head exists as an attribute!")
#     print(model.protGPT2_model.lm_head)

#     # Check if it has parameters
#     if hasattr(model.protGPT2_model.lm_head, 'parameters'):
#         print("lm_head has parameters!")

#         # Check requires_grad for lm_head manually
#         for param in model.protGPT2_model.lm_head.parameters():
#             print(f"lm_head requires_grad: {param.requires_grad}")
#     else:
#         print("WARNING: lm_head has no registered parameters!")
# else:
#     print("WARNING: lm_head does not exist as an attribute!")

# # Unfreeze lm_head manually
# if hasattr(model.protGPT2_model, 'lm_head'):
#     for param in model.protGPT2_model.lm_head.parameters():
#         param.requires_grad = True
#     print("lm_head manually unfrozen!")

# # Verify if lm_head is now trainable
# for param in model.protGPT2_model.lm_head.parameters():
#     print(f"lm_head requires_grad: {param.requires_grad}")

# Verify which parameters are trainable
total_params = 0
trainable_params = 0
for name, param in model.named_parameters():
    total_params += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%})")

num_epochs = 10

# Training loop with curriculum learning
# Start with 50% of sequences in left-to-right order and gradually increase to 100% random
curriculum_steps = int(0.5 * num_epochs * len(train_loader))  # Curriculum over first half of training
# print("Starting training with sigma-gpt capabilities...")
# train_with_improved_aar_objective(
#     model,
#     train_loader,
#     val_loader,
#     num_epochs,
#     device,
#     curriculum_steps=curriculum_steps
# )

###___________________________________________________________________________________

# # Generate and evaluate
# print("Generating proteins for test set...")
# test_results = generate_and_evaluate(model, test_loader, device)

# # Save results
# print("Saving results...")
# results_path = '/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/test_results.json'
# with open(results_path, 'w') as f:
#     json.dump(test_results, f, indent=2)

# print(f"Results saved to {results_path}")


Using device: cuda
Max sequence length: 914
Total transformer blocks: 36
Adding cross-attention after block 34
Adding cross-attention after block 35
Added 2 cross-attention blocks
transformer.h.0.ln_1.weight
Total number of transformer layers: 36

===== MODEL STRUCTURE ANALYSIS =====

üìå CROSS-ATTENTION LAYERS:
  Total cross-attention blocks: 2
  Located at positions: [35, 37]

üìå LAYER FREEZING STATUS:
  Layer  0: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  1: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  2: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  3: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  4: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  5: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  6: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  7: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  8: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  9: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer 10: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer 11: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer 12: ‚ùÑÔ∏è FROZEN (19,677,440

#### unfreeze 34+35+lm head and add cross attn to those

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load and preprocess data
train_data = preprocess_snp_data('/content/train_data.csv')
val_data = preprocess_snp_data('/content/val_data.csv')

# train_data = train_data.sample(frac=0.01, random_state=42)
# val_data = val_data.sample(frac=0.01, random_state=42)

train_data = filter_datasets(train_data)
val_data = filter_datasets(val_data)

# Calculate max sequence length
max_length = max(
    train_data['protein_length'].max(),
    val_data['protein_length'].max()
)
max_length = min(max_length, 1024)  # Cap at 1024 or your desired maximum
print(f"Max sequence length: {max_length}")

# Create datasets
train_dataset = ProteinGenerationDataset(train_data, max_length)
val_dataset = ProteinGenerationDataset(val_data, max_length)

# Create dataloaders with custom collate function
train_loader = DataLoader(
    train_dataset,
    batch_size=1,  # Adjust based on your GPU memory
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)

# # Initialize model with sigma-gpt capabilities
# model = SigmaProtFlamingo(
#     model_path='nferruz/ProtGPT2',
#     max_len=max_length,
#     cross_attn_every=2,
#     dim_head=64,
#     heads=8,
#     perceiver_depth=2,
#     perceiver_num_latents=64
# ).to(device)


# Initialize model with sigma-gpt capabilities but without any cross-attention initially
model = SigmaProtFlamingo(
    model_path='nferruz/ProtGPT2',
    max_len=max_length,
    cross_attn_every=999,
    dim_head=64,
    heads=8,
    perceiver_depth=2,
    perceiver_num_latents=64
)  # Don't move to device yet

# The model structure shows that model.layers contains the transformer blocks
# model.protGPT2_model.transformer.h contains the GPT2Blocks

# Let's recreate the layers list with cross-attention only after the last two blocks
new_layers = []

# First, get all the original transformer blocks
transformer_blocks = model.protGPT2_model.transformer.h

# Total number of transformer blocks
num_blocks = len(transformer_blocks)
print(f"Total transformer blocks: {num_blocks}")

# Add each transformer block, with cross-attention after the last two blocks
for i, block in enumerate(transformer_blocks):
    # Add the transformer block
    new_layers.append(block)

    # Add cross-attention after the last two blocks
    if i == num_blocks - 2 or i == num_blocks - 1:
        print(f"Adding cross-attention after block {i}")
        new_layers.append(GatedCrossAttentionBlock(
            dim=model.protGPT2_model.config.n_embd,
            dim_head=64,
            heads=8
        ))

# Replace the model's layers with our new sequence
model.layers = nn.ModuleList(new_layers)

# Now move the entire model to the device after modifying it
model = model.to(device)

# Count how many cross-attention blocks were added
cross_attn_count = sum(1 for layer in model.layers if isinstance(layer, GatedCrossAttentionBlock))
print(f"Added {cross_attn_count} cross-attention blocks")

# Print more detailed layer structure first
for name, _ in model.protGPT2_model.named_parameters():
    if 'transformer.h' in name:
        print(name)
        break  # Just print one example to see the structure


# Check the highest layer index in the model
max_layer_idx = -1
for name, _ in model.protGPT2_model.named_parameters():
    if 'transformer.h.' in name:
        # Extract the layer index which comes after 'transformer.h.'
        parts = name.split('.')
        if len(parts) > 2:
            try:
                layer_idx = int(parts[2])
                max_layer_idx = max(max_layer_idx, layer_idx)
            except ValueError:
                continue

print(f"Total number of transformer layers: {max_layer_idx + 1}")

# Then modify the freezing code to match the actual structure
# This assumes the layer indexing is inside the parameter names
for name, param in model.protGPT2_model.named_parameters():
    if 'lm_head' in name or 'transformer.h.34' in name or 'transformer.h.35' in name:
        param.requires_grad = True  # Unfreeze
    else:
        param.requires_grad = False  # Freeze everything else

##___________________________________________________________________________________

# Directly check if 'lm_head' exists as an attribute
if hasattr(model.protGPT2_model, 'lm_head'):
    print("lm_head exists as an attribute!")
    print(model.protGPT2_model.lm_head)

    # Check if it has parameters
    if hasattr(model.protGPT2_model.lm_head, 'parameters'):
        print("lm_head has parameters!")

        # Check requires_grad for lm_head manually
        for param in model.protGPT2_model.lm_head.parameters():
            print(f"lm_head requires_grad: {param.requires_grad}")
    else:
        print("WARNING: lm_head has no registered parameters!")
else:
    print("WARNING: lm_head does not exist as an attribute!")

# Unfreeze lm_head manually
if hasattr(model.protGPT2_model, 'lm_head'):
    for param in model.protGPT2_model.lm_head.parameters():
        param.requires_grad = True
    print("lm_head manually unfrozen!")

# Verify if lm_head is now trainable
for param in model.protGPT2_model.lm_head.parameters():
    print(f"lm_head requires_grad: {param.requires_grad}")

# Verify which parameters are trainable
total_params = 0
trainable_params = 0
for name, param in model.named_parameters():
    total_params += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%})")

num_epochs = 10


print_model_structure(model)


# Training loop with curriculum learning
# Start with 50% of sequences in left-to-right order and gradually increase to 100% random
curriculum_steps = int(0.5 * num_epochs * len(train_loader))  # Curriculum over first half of training
print("Starting training with sigma-gpt capabilities...")
train_with_improved_aar_objective(
    model,
    train_loader,
    val_loader,
    num_epochs,
    device,
    curriculum_steps=curriculum_steps
)

###___________________________________________________________________________________

# # Generate and evaluate
# print("Generating proteins for test set...")
# test_results = generate_and_evaluate(model, test_loader, device)

# # Save results
# print("Saving results...")
# results_path = '/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/test_results.json'
# with open(results_path, 'w') as f:
#     json.dump(test_results, f, indent=2)

# print(f"Results saved to {results_path}")


Using device: cuda
Max sequence length: 914
Total transformer blocks: 36
Adding cross-attention after block 34
Adding cross-attention after block 35
Added 2 cross-attention blocks
transformer.h.0.ln_1.weight
Total number of transformer layers: 36
lm_head exists as an attribute!
Linear(in_features=1280, out_features=50257, bias=False)
lm_head has parameters!
lm_head requires_grad: False
lm_head manually unfrozen!
lm_head requires_grad: True
Total parameters: 1,429,527,700
Trainable parameters: 759,181,460 (53.11%)

===== MODEL STRUCTURE ANALYSIS =====

üìå CROSS-ATTENTION LAYERS:
  Total cross-attention blocks: 2
  Located at positions: [35, 37]

üìå LAYER FREEZING STATUS:
  Layer  0: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  1: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  2: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  3: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  4: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  5: ‚ùÑÔ∏è FROZEN (19,677,440 params)
  Layer  6: ‚ùÑÔ∏è FROZEN (19,677,440 par

Epoch 1/10:   0%|          | 0/3980 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


  Loss components: CE=3724.5242, Rep=0.2105, Seq=7.0833


Epoch 1/10:   0%|          | 1/3980 [00:02<2:38:17,  2.39s/it]

  Batch 0: Loss=3724.5242, AAR=0.00%


Epoch 1/10:   0%|          | 11/3980 [00:07<33:49,  1.96it/s]

  Batch 10: Loss=4470.9326, AAR=0.00%


Epoch 1/10:   1%|          | 21/3980 [00:12<32:03,  2.06it/s]

  Batch 20: Loss=2587.9871, AAR=0.00%


Epoch 1/10:   1%|          | 31/3980 [00:17<32:59,  2.00it/s]

  Batch 30: Loss=3171.5122, AAR=0.00%


Epoch 1/10:   1%|          | 41/3980 [00:22<32:14,  2.04it/s]

  Batch 40: Loss=2025.2548, AAR=0.00%


Epoch 1/10:   1%|‚ñè         | 51/3980 [00:26<31:59,  2.05it/s]

  Loss components: CE=1903.2087, Rep=0.0430, Seq=8.9424
  Batch 50: Loss=1903.2090, AAR=0.00%


Epoch 1/10:   2%|‚ñè         | 61/3980 [00:31<31:47,  2.05it/s]

  Batch 60: Loss=1626.1965, AAR=0.00%


Epoch 1/10:   2%|‚ñè         | 71/3980 [00:36<31:38,  2.06it/s]

  Batch 70: Loss=1828.9250, AAR=0.00%


Epoch 1/10:   2%|‚ñè         | 81/3980 [00:41<31:38,  2.05it/s]

  Batch 80: Loss=2046.9462, AAR=0.00%


Epoch 1/10:   2%|‚ñè         | 91/3980 [00:46<31:52,  2.03it/s]

  Batch 90: Loss=1727.7098, AAR=0.00%


Epoch 1/10:   3%|‚ñé         | 101/3980 [00:51<31:28,  2.05it/s]

  Loss components: CE=697.2571, Rep=0.0102, Seq=8.3209
  Batch 100: Loss=697.2571, AAR=0.00%


Epoch 1/10:   3%|‚ñé         | 111/3980 [00:56<31:50,  2.03it/s]

  Batch 110: Loss=1760.5056, AAR=0.00%


Epoch 1/10:   3%|‚ñé         | 121/3980 [01:01<31:52,  2.02it/s]

  Batch 120: Loss=1324.5081, AAR=0.00%


Epoch 1/10:   3%|‚ñé         | 131/3980 [01:06<31:07,  2.06it/s]

  Batch 130: Loss=1711.7574, AAR=0.00%


Epoch 1/10:   4%|‚ñé         | 141/3980 [01:10<30:47,  2.08it/s]

  Batch 140: Loss=1618.8890, AAR=0.00%


Epoch 1/10:   4%|‚ñç         | 151/3980 [01:15<31:35,  2.02it/s]

  Loss components: CE=894.1496, Rep=0.0206, Seq=8.4175
  Batch 150: Loss=894.1496, AAR=0.00%


Epoch 1/10:   4%|‚ñç         | 161/3980 [01:20<30:47,  2.07it/s]

  Batch 160: Loss=1078.1924, AAR=0.00%


Epoch 1/10:   4%|‚ñç         | 171/3980 [01:25<31:39,  2.01it/s]

  Batch 170: Loss=1226.4070, AAR=0.00%


Epoch 1/10:   5%|‚ñç         | 181/3980 [01:30<30:58,  2.04it/s]

  Batch 180: Loss=821.2664, AAR=0.00%


Epoch 1/10:   5%|‚ñç         | 191/3980 [01:35<30:34,  2.07it/s]

  Batch 190: Loss=794.2172, AAR=0.00%


Epoch 1/10:   5%|‚ñå         | 201/3980 [01:40<30:37,  2.06it/s]

  Loss components: CE=1057.6624, Rep=0.0333, Seq=10.5076
  Batch 200: Loss=1057.6625, AAR=0.00%


Epoch 1/10:   5%|‚ñå         | 211/3980 [01:45<30:14,  2.08it/s]

  Batch 210: Loss=912.4845, AAR=0.00%


Epoch 1/10:   6%|‚ñå         | 221/3980 [01:50<30:44,  2.04it/s]

  Batch 220: Loss=1064.3438, AAR=0.00%


Epoch 1/10:   6%|‚ñå         | 231/3980 [01:55<31:03,  2.01it/s]

  Batch 230: Loss=896.9280, AAR=0.00%


Epoch 1/10:   6%|‚ñå         | 241/3980 [02:00<30:28,  2.05it/s]

  Batch 240: Loss=590.0178, AAR=0.00%


Epoch 1/10:   6%|‚ñã         | 251/3980 [02:04<30:24,  2.04it/s]

  Loss components: CE=1153.0062, Rep=0.0127, Seq=10.0645
  Batch 250: Loss=1153.0063, AAR=0.00%


Epoch 1/10:   7%|‚ñã         | 261/3980 [02:09<30:18,  2.04it/s]

  Batch 260: Loss=1551.5060, AAR=0.00%


Epoch 1/10:   7%|‚ñã         | 271/3980 [02:14<30:05,  2.05it/s]

  Batch 270: Loss=682.8907, AAR=0.00%


Epoch 1/10:   7%|‚ñã         | 281/3980 [02:19<30:10,  2.04it/s]

  Batch 280: Loss=1056.3319, AAR=0.00%


Epoch 1/10:   7%|‚ñã         | 291/3980 [02:24<30:07,  2.04it/s]

  Batch 290: Loss=1675.0182, AAR=0.00%


Epoch 1/10:   8%|‚ñä         | 301/3980 [02:29<29:52,  2.05it/s]

  Loss components: CE=957.1619, Rep=0.0061, Seq=5.0447
  Batch 300: Loss=957.1618, AAR=0.00%


Epoch 1/10:   8%|‚ñä         | 311/3980 [02:34<29:38,  2.06it/s]

  Batch 310: Loss=597.6630, AAR=0.00%


Epoch 1/10:   8%|‚ñä         | 321/3980 [02:39<29:39,  2.06it/s]

  Batch 320: Loss=968.7035, AAR=0.00%


Epoch 1/10:   8%|‚ñä         | 331/3980 [02:44<30:03,  2.02it/s]

  Batch 330: Loss=505.2824, AAR=0.00%


Epoch 1/10:   9%|‚ñä         | 341/3980 [02:48<29:27,  2.06it/s]

  Batch 340: Loss=920.2732, AAR=0.00%


Epoch 1/10:   9%|‚ñâ         | 351/3980 [02:53<29:21,  2.06it/s]

  Loss components: CE=833.8425, Rep=0.0000, Seq=9.2194
  Batch 350: Loss=833.8425, AAR=0.00%


Epoch 1/10:   9%|‚ñâ         | 361/3980 [02:58<29:34,  2.04it/s]

  Batch 360: Loss=845.9545, AAR=0.00%


Epoch 1/10:   9%|‚ñâ         | 371/3980 [03:03<29:18,  2.05it/s]

  Batch 370: Loss=973.6693, AAR=0.00%


Epoch 1/10:  10%|‚ñâ         | 381/3980 [03:08<29:40,  2.02it/s]

  Batch 380: Loss=899.8007, AAR=0.00%


Epoch 1/10:  10%|‚ñâ         | 391/3980 [03:13<28:59,  2.06it/s]

  Batch 390: Loss=801.7186, AAR=0.00%


Epoch 1/10:  10%|‚ñà         | 401/3980 [03:18<29:14,  2.04it/s]

  Loss components: CE=1251.2941, Rep=0.0102, Seq=8.1574
  Batch 400: Loss=1251.2941, AAR=0.00%


Epoch 1/10:  10%|‚ñà         | 411/3980 [03:23<28:44,  2.07it/s]

  Batch 410: Loss=586.0203, AAR=0.00%


Epoch 1/10:  11%|‚ñà         | 421/3980 [03:27<28:39,  2.07it/s]

  Batch 420: Loss=658.6631, AAR=0.00%


Epoch 1/10:  11%|‚ñà         | 431/3980 [03:32<28:29,  2.08it/s]

  Batch 430: Loss=591.3684, AAR=0.00%


Epoch 1/10:  11%|‚ñà         | 441/3980 [03:37<29:09,  2.02it/s]

  Batch 440: Loss=679.3245, AAR=0.00%


Epoch 1/10:  11%|‚ñà‚ñè        | 451/3980 [03:42<28:54,  2.03it/s]

  Loss components: CE=921.2969, Rep=0.0112, Seq=9.3034
  Batch 450: Loss=921.2969, AAR=0.00%


### generation

In [None]:
import torch
import json
from tqdm import tqdm

# Load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Path to model checkpoint
checkpoint_path = "/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/sigma_checkpoint.pth"

# Load
model = SigmaProtFlamingo(
    model_path='nferruz/ProtGPT2',
    max_len=914,  # Ensure this matches the training max_len
    cross_attn_every=3,
    dim_head=64,
    heads=8,
    perceiver_depth=2,
    perceiver_num_latents=64
).to(device)



In [None]:
ProteinGenerationDataset

In [None]:
# Load trained weights
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

# Load test data
test_data = preprocess_snp_data('/content/augmented_test.csv')
test_data = filter_datasets(test_data)

# Create test dataset and dataloader
test_dataset = ProteinGenerationDataset(test_data,max_length = 914 )
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)



In [None]:
def generate_autoregressively(model, smiles_string, max_length=914, temperature=1.0, random_order=False):
    """Generate protein autoregressively, with option to use random order"""
    device = next(model.parameters()).device

    # Get SMILES embeddings
    smiles_embeddings = model.polybert_encoder([smiles_string])
    processed_smiles = model.smiles_perceiver(smiles_embeddings)

    # Initialize with start token
    input_ids = torch.tensor([[model.protGPT2_tokenizer.bos_token_id]], device=device)

    # If using random order, generate a random permutation
    if random_order:
        order = torch.randperm(max_length, device=device).unsqueeze(0)
    else:
        order = torch.arange(max_length, device=device).unsqueeze(0)

    # Track the current positions in the order
    current_pos = 0

    # Generated sequence in order's positions
    generated_sequence = torch.full((1, max_length), model.protGPT2_tokenizer.pad_token_id, device=device)
    generated_sequence[0, 0] = model.protGPT2_tokenizer.bos_token_id  # Start token

    while current_pos < max_length - 1:
        # Get the next position in the order
        next_pos = current_pos + 1

        # Forward pass to get next token prediction
        with torch.no_grad():
            # Use only the sequence up to the current position
            current_order = order[:, :next_pos]
            current_sequence = generated_sequence[:, current_order[0]]

            # Get logits for the next token
            logits, _ = model(
                smiles_string,
                order=current_order,
                optimize=True
            )

            # Apply temperature and sample
            logits = logits[0, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()

            # Add the token to the generated sequence at the next position in the order
            generated_sequence[0, order[0, next_pos]] = next_token

            # Check for EOS token
            if next_token == model.protGPT2_tokenizer.eos_token_id:
                break

            current_pos = next_pos

    # Decode the generated sequence
    generated_ids = generated_sequence[0].tolist()
    print('generated_ids',generated_ids)
    # Remove padding tokens
    generated_ids = [id for id in generated_ids if id != model.protGPT2_tokenizer.pad_token_id]
    seq = model.protGPT2_tokenizer.decode(generated_ids, skip_special_tokens=True)
    print('seq',seq)
    print("autoregressive gen done...")
    return seq


In [None]:
def generate_with_rejection_sampling(model, smiles_string, max_length=914, num_orders=5, temperature=1.0):
    """Generate protein using token-based rejection sampling with proper MH acceptance ratio"""
    device = next(model.parameters()).device

    # Get SMILES embeddings
    smiles_embeddings = model.polybert_encoder([smiles_string])
    processed_smiles = model.smiles_perceiver(smiles_embeddings)

    # Initialize with start token
    prompt = torch.tensor([[model.protGPT2_tokenizer.bos_token_id]], device=device)

    # Initialize full sequence with padding
    full_seq = torch.full((1, max_length), model.protGPT2_tokenizer.pad_token_id, device=device)
    full_seq[:, 0] = model.protGPT2_tokenizer.bos_token_id  # Start token

    # Track positions that have been filled
    filled_positions = {0}  # Start with position 0 filled

    while len(filled_positions) < max_length:
        remaining_positions = [i for i in range(max_length) if i not in filled_positions]
        if not remaining_positions:
            break

        # Step 1: Sample tokens at all remaining positions from marginal distribution
        # This is our proposal distribution p(xÃÉ)
        candidate_tokens = {}
        proposal_probs = {}  # Store the probability of each proposal

        for pos in remaining_positions:
            # Create current filled sequence context
            current_context = torch.ones((1, max_length), device=device) * model.protGPT2_tokenizer.pad_token_id
            for filled_pos in filled_positions:
                current_context[0, filled_pos] = full_seq[0, filled_pos]

            # Get logits for this position given current context
            with torch.no_grad():
                # Order that puts this position last
                context_order = torch.tensor([list(filled_positions) + [pos]], device=device)

                logits = get_logits_for_position(model, current_context, context_order, smiles_string, pos)

                # Sample a token and record its probability
                logits = logits / temperature
                probs = F.softmax(logits, dim=-1)
                token_dist = torch.distributions.Categorical(probs)
                token = token_dist.sample().item()

                candidate_tokens[pos] = token
                proposal_probs[pos] = probs[0, token].item()

        # Step 2: Evaluate acceptance under different orders
        best_order_acceptances = []

        for _ in range(num_orders):
            # Create a random permutation of remaining positions
            eval_order = random.sample(remaining_positions, len(remaining_positions))

            accepted_tokens = []
            accepted_positions = []
            acceptance_ratios = []

            # Try to accept tokens in this order
            for pos in eval_order:
                # Create sequence with previously accepted tokens
                temp_seq = full_seq.clone()
                for acc_pos in accepted_positions:
                    temp_seq[0, acc_pos] = candidate_tokens[acc_pos]

                # Get conditional probability q(xÃÉ|X,xÃÉœÉ<i)
                filled_plus_accepted = list(filled_positions) + accepted_positions
                context_order = torch.tensor([filled_plus_accepted + [pos]], device=device)

                with torch.no_grad():
                    cond_logits = get_logits_for_position(
                        model, temp_seq, context_order, processed_smiles, pos
                    )

                    cond_probs = F.softmax(cond_logits / temperature, dim=-1)
                    cond_prob = cond_probs[0, candidate_tokens[pos]].item()

                # Compute acceptance ratio r = q(xÃÉi|X,xÃÉœÉ<i) / p(xÃÉi|X)
                # Where p(xÃÉi|X) is the proposal probability
                acceptance_ratio = min(1.0, cond_prob / proposal_probs[pos])

                # Decide whether to accept
                if random.random() < acceptance_ratio:
                    accepted_tokens.append(candidate_tokens[pos])
                    accepted_positions.append(pos)
                    acceptance_ratios.append(acceptance_ratio)
                else:
                    # Stop at first rejection
                    break

            best_order_acceptances.append((accepted_positions, accepted_tokens, acceptance_ratios))

        # Step 3: Dynamic token acceptance
        best_order_idx = -1
        max_accepted = -1
        min_sequence_idx = -1

        for idx, (accepted_positions, _, acceptance_ratios) in enumerate(best_order_acceptances):
            if len(accepted_positions) > max_accepted:
                max_accepted = len(accepted_positions)
                best_order_idx = idx
                # Find the minimum position in the sequence where we see a rejection
                if len(accepted_positions) < len(remaining_positions):
                    min_sequence_idx = len(accepted_positions)
                else:
                    min_sequence_idx = len(remaining_positions)

        # No need to calculate min across orders if all orders accept all tokens
        if min_sequence_idx == -1:
            min_sequence_idx = len(remaining_positions)

        # Get the best order
        best_order = best_order_acceptances[best_order_idx]
        accepted_positions, accepted_tokens, _ = best_order

        # Limit acceptance to positions before the minimum rejection
        accepted_positions = accepted_positions[:min_sequence_idx]
        accepted_tokens = accepted_tokens[:min_sequence_idx]

        # Update the sequence with accepted tokens
        for pos, token in zip(accepted_positions, accepted_tokens):
            full_seq[0, pos] = token
            filled_positions.add(pos)

            # Check for EOS token
            if token == model.protGPT2_tokenizer.eos_token_id:
                break

    # Decode the generated sequence
    result = model.protGPT2_tokenizer.decode(
        [t for t in full_seq[0].tolist() if t != model.protGPT2_tokenizer.pad_token_id],
        skip_special_tokens=True
    )
    return result

def get_logits_for_position(model, sequence, order, smiles_string, target_position):
    """Helper function to get logits for a specific position"""
    # Run model forward pass
    logits, _ = model(
        smiles_string,  # Pass the SMILES string
        order=order,
        optimize=True
    )

    # Return logits for target position (last position in the order)
    return logits[:, -1, :]

In [None]:
import time
import numpy as np

In [None]:
def evaluate_on_unique_smiles(model, test_loader, device, output_file="generated_proteins_comparison.json"):
    """Generate proteins using both methods on unique SMILES from test set"""
    model.eval()

    # Collect unique SMILES from the test loader
    unique_smiles = set()
    for batch in test_loader:
        unique_smiles.update(batch['smiles'])

    unique_smiles = list(unique_smiles)  # Convert to list
    print(f"Found {len(unique_smiles)} unique SMILES in test set")

    results = []

    # Generate proteins using both methods and time each generation
    for i, smiles in enumerate(tqdm(unique_smiles, desc="Generating proteins")):
        # Track time for autoregressive generation
        start_time = time.time()
        print('autoregressive generations...')
        print(f"Generating protein for SMILES: {smiles}")
        ar_protein = generate_autoregressively(model, smiles, max_length=914, temperature=1.0, random_order=False)
        print(ar_protein)
        ar_time = time.time() - start_time

        # Track time for rejection sampling
        start_time = time.time()
        print('rejection sampling generations...')
        print(f"Generating protein for SMILES: {smiles}")
        rs_protein = generate_with_rejection_sampling(model, smiles, max_length=914, num_orders=5, temperature=1.0)
        print(rs_protein)
        rs_time = time.time() - start_time

        results.append({
            'SMILES': smiles,
            'Autoregressive': {
                'protein': ar_protein,
                'time_seconds': ar_time
            },
            'Rejection_Sampling': {
                'protein': rs_protein,
                'time_seconds': rs_time
            }
        })

        # Print progress occasionally
        if (i + 1) % 5 == 0:
            print(f"\nCompleted {i+1}/{len(unique_smiles)}")
            print(f"Example - SMILES: {smiles}")
            print(f"Autoregressive: {ar_protein[:50]}... ({ar_time:.2f}s)")
            print(f"Rejection Sampling: {rs_protein[:50]}... ({rs_time:.2f}s)")

    # Save results to JSON
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)

    # Calculate and print average times
    ar_times = [r['Autoregressive']['time_seconds'] for r in results]
    rs_times = [r['Rejection_Sampling']['time_seconds'] for r in results]

    print(f"\nGeneration complete!")
    print(f"Average autoregressive generation time: {np.mean(ar_times):.2f}s")
    print(f"Average rejection sampling generation time: {np.mean(rs_times):.2f}s")
    print(f"Speed improvement: {np.mean(ar_times)/np.mean(rs_times):.2f}x")

    return results

In [None]:
results = evaluate_on_unique_smiles(model, test_loader, device, output_file="sigma_gpt_comparison_results.json")