### setup

In [1]:
!pip install Levenshtein
!pip install einops
!pip install einops_exts
!pip install torch
!pip install transformers
!pip install tqdm
!pip install sentencepiece
# !pip install fair-esm



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import pandas as pd
import re
import math
import json
from tqdm import tqdm
from einops import rearrange, repeat
# import esm

# Set up GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# # Load ESM-2 model
# esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
# batch_converter = alphabet.get_batch_converter()
# esm_model = esm_model.to(device)  # Move to GPU if available
# esm_model.eval()  # Set to evaluation mode


Using device: cuda


### data

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
def preprocess_snp_data(file_path):
    snp_df = pd.read_csv(file_path)

    # Basic preprocessing and length calculations
    snp_df['smiles_length'] = snp_df['smiles'].apply(len)
    snp_df['protein_length'] = snp_df['protein_sequence'].apply(len)

    return snp_df

def filter_datasets(dataset):
    return dataset[
        (dataset['smiles'].notna()) &
        (dataset['protein_sequence'].notna()) &
        (dataset['smiles_length'] > 0) &
        (dataset['protein_length'] > 0)
    ]

class ProteinGenerationDataset(Dataset):
    def __init__(self, dataframe, max_length):
        self.dataframe = dataframe
        self.max_length = max_length

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        return row['smiles'], row['protein_sequence']

def collate_fn(batch):
    """
    Custom collate function to handle padding within batches.
    Args:
        batch: List of tuples (smiles, protein)
    Returns:
        Padded and batched tensors
    """
    smiles, proteins = zip(*batch)

    # SMILES strings don't need padding as PolyBERT handles that internally
    smiles = list(smiles)

    # Get max length in this batch for proteins (not exceeding dataset max_length)
    max_protein_len = min(max(len(p) for p in proteins), max_length)

    # Pad proteins to max length in batch
    padded_proteins = []
    protein_masks = []

    for protein in proteins:
        if len(protein) > max_protein_len:
            padded = protein[:max_protein_len]
            mask = [1] * max_protein_len
        else:
            padded = protein + ' ' * (max_protein_len - len(protein))
            mask = [1] * len(protein) + [0] * (max_protein_len - len(protein))

        padded_proteins.append(padded)
        protein_masks.append(mask)

    return {
        'smiles': smiles,
        'proteins': padded_proteins,
        'protein_masks': torch.tensor(protein_masks, dtype=torch.bool)
    }

### utils

In [5]:
# Model Components
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0)]

class DoublePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        # Use the full embedding dimension divided into two halves
        self.d_model = d_model
        half_dim = d_model // 2

        # Create position encodings for both input and output positions
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, half_dim, 2) * (-math.log(10000.0) / half_dim))

        # Input position encodings
        pe_input = torch.zeros(max_len, half_dim)
        pe_input[:, 0::2] = torch.sin(position * div_term)
        pe_input[:, 1::2] = torch.cos(position * div_term)

        # Output position encodings
        pe_output = torch.zeros(max_len, half_dim)
        pe_output[:, 0::2] = torch.sin(position * div_term)
        pe_output[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe_input', pe_input)
        self.register_buffer('pe_output', pe_output)

    def forward(self, x, input_positions, output_positions):
        batch_size, seq_length, _ = x.shape

        # Create a tensor of zeros with the same shape as the input
        pos_encoding = torch.zeros_like(x)

        # For each item in the batch
        for b in range(batch_size):
            for t in range(seq_length):
                # Get the input and output positions for this token
                input_pos = input_positions[b, t] if input_positions is not None else t
                output_pos = output_positions[b, t] if output_positions is not None else t

                if input_pos < self.pe_input.size(0) and output_pos < self.pe_output.size(0):
                    # Fill the first half with input position encoding
                    pos_encoding[b, t, :self.d_model//2] = self.pe_input[input_pos]
                    # Fill the second half with output position encoding
                    pos_encoding[b, t, self.d_model//2:] = self.pe_output[output_pos]

        return x + pos_encoding

class PerceiverAttention(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm_media = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x, latents):
        """
        x: [batch_size, seq_len_x, dim]
        latents: [batch_size, seq_len_l, dim]
        """
        batch_size = x.shape[0]

        x = self.norm_media(x)
        latents = self.norm_latents(latents)

        # Ensure latents has correct batch size
        if latents.size(0) != batch_size:
            latents = latents.expand(batch_size, -1, -1)

        q = self.to_q(latents)
        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
        q = q * self.scale

        # Ensure proper concatenation
        kv_input = torch.cat((x, latents), dim=1)  # concatenate along sequence dimension
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)

        sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
        attn = sim.softmax(dim=-1)
        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

class GatedCrossAttentionBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        self.attn = PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads)
        self.attn_gate = nn.Parameter(torch.tensor([0.]))
        self.ff = FeedForward(dim, mult=ff_mult)
        self.ff_gate = nn.Parameter(torch.tensor([0.]))

    def forward(self, x, media):
        """
        x: [batch_size, seq_len_x, dim]
        media: [batch_size, seq_len_m, dim]
        """
        batch_size = x.shape[0]
        target_batch_size = media.size(0)

        # Handle batch size mismatch
        if batch_size > target_batch_size:
            media = media.expand(batch_size, -1, -1)
        elif batch_size < target_batch_size:
            x = x.expand(target_batch_size, -1, -1)

        gate = self.attn_gate.tanh()
        x = self.attn(media, x) * gate + x
        x = self.ff(x) * self.ff_gate.tanh() + x
        return x

class PerceiverResampler(nn.Module):
    def __init__(self, dim, depth, dim_head=64, heads=8, num_latents=64):
        super().__init__()
        # Initialize latents without batch dimension
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                FeedForward(dim=dim)
            ]))

    def forward(self, x):
        batch_size = x.shape[0]
        # Expand latents to batch size
        latents = repeat(self.latents, 'n d -> b n d', b=batch_size)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        return latents

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * mult, bias=False),
            nn.GELU(),
            nn.Linear(dim * mult, dim, bias=False)
        )

    def forward(self, x):
        return self.net(x)

# class PerceiverResampler(nn.Module):
#     def __init__(self, dim, depth, dim_head=64, heads=8, num_latents=64):
#         super().__init__()
#         self.latents = nn.Parameter(torch.randn(num_latents, dim))
#         self.layers = nn.ModuleList([])

#         for _ in range(depth):
#             self.layers.append(nn.ModuleList([
#                 PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
#                 FeedForward(dim=dim)
#             ]))

#     def forward(self, x):
#         latents = repeat(self.latents, 'n d -> b n d', b=x.shape[0])

#         for attn, ff in self.layers:
#             latents = attn(x, latents) + latents
#             latents = ff(latents) + latents

#         return latents

# class GatedCrossAttentionBlock(nn.Module):
#     def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
#         super().__init__()
#         self.attn = PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads)
#         self.attn_gate = nn.Parameter(torch.tensor([0.]))
#         self.ff = FeedForward(dim, mult=ff_mult)
#         self.ff_gate = nn.Parameter(torch.tensor([0.]))

#     def forward(self, x, media):
#         gate = self.attn_gate.tanh()
#         x = self.attn(media, x) * gate + x
#         x = self.ff(x) * self.ff_gate.tanh() + x
#         return x

### PolyBert Encoder

In [6]:
from transformers import AutoTokenizer, AutoModel
import torch


In [7]:
# class PolyBERTEncoder(nn.Module):
#     def __init__(self, output_dim):
#         super().__init__()
#         self.polybert = AutoModel.from_pretrained('kuelumbus/polyBERT')
#         self.tokenizer = AutoTokenizer.from_pretrained('kuelumbus/polyBERT')
#         self.output_dim = output_dim
#         # Add a projection layer to match the required dimension
#         self.projection = nn.Linear(self.polybert.config.hidden_size, output_dim)

#     def mean_pooling(self, model_output, attention_mask):
#         token_embeddings = model_output[0]
#         input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
#         return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

#     def forward(self, smiles_strings):
#         # Tokenize the SMILES strings
#         encoded_input = self.tokenizer(smiles_strings,
#                                      padding=True,
#                                      truncation=True,
#                                      return_tensors='pt').to(next(self.polybert.parameters()).device)

#         # Get PolyBERT embeddings
#         with torch.no_grad():
#             model_output = self.polybert(**encoded_input)

#         # Debug prints
#         print("Model Output Keys:", model_output.keys())  # Check available keys
#         # print("Last Hidden State:", model_output.last_hidden_state)
#         print("Last Hidden State Shape:", model_output.last_hidden_state.shape)

#         # Pool the embeddings
#         pooled_output = self.mean_pooling(model_output, encoded_input['attention_mask'])

#         # print("Pooled Output:", pooled_output)
#         print("Pooled Output Shape:", pooled_output.shape)

#         # Project to required dimension
#         projected_output = self.projection(pooled_output)

#         return projected_output

In [8]:
class PolyBERTEncoder(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.polybert = AutoModel.from_pretrained('kuelumbus/polyBERT')
        self.tokenizer = AutoTokenizer.from_pretrained('kuelumbus/polyBERT')
        self.output_dim = output_dim
        # Project each token embedding to required dimension
        self.projection = nn.Linear(self.polybert.config.hidden_size, output_dim)

    def forward(self, smiles_strings):
        # Tokenize the SMILES strings
        encoded_input = self.tokenizer(smiles_strings,
                                     padding=True,
                                     truncation=True,
                                     return_tensors='pt').to(next(self.polybert.parameters()).device)

        # Get PolyBERT embeddings
        with torch.no_grad():
            model_output = self.polybert(**encoded_input)

        # Debug prints
        # print("Model Output Keys:", model_output.keys())
        # print("Last Hidden State Shape:", model_output.last_hidden_state.shape)  # [batch_size, seq_len, hidden_size]

        # Get sequence embeddings
        sequence_embeddings = model_output.last_hidden_state

        # Project each token embedding to required dimension
        projected_output = self.projection(sequence_embeddings)  # [batch_size, seq_len, output_dim]
        # print("Projected Output Shape:", projected_output.shape)

        return projected_output

### ProtFlamingo

In [9]:
import torch.nn.functional as F


In [10]:
class SigmaProtFlamingo(nn.Module):
    def __init__(self, model_path, max_len, cross_attn_every=3, dim_head=64, heads=8, perceiver_depth=2, perceiver_num_latents=64):
        super().__init__()

        self.protGPT2_model = GPT2LMHeadModel.from_pretrained(model_path)
        self.protGPT2_tokenizer = GPT2Tokenizer.from_pretrained(model_path)
        self.max_len = max_len

        if self.protGPT2_tokenizer.pad_token is None:
            self.protGPT2_tokenizer.pad_token = self.protGPT2_tokenizer.eos_token
            self.protGPT2_model.config.pad_token_id = self.protGPT2_model.config.eos_token_id

        self.cross_attn_every = cross_attn_every

        # PolyBERT encoder for SMILES strings
        self.polybert_encoder = PolyBERTEncoder(self.protGPT2_model.config.n_embd)

        # Replace single positional encoding with double positional encoding
        self.positional_encoding = DoublePositionalEncoding(self.protGPT2_model.config.n_embd, max_len=max_len)

        # Single perceiver resampler for SMILES embeddings
        self.smiles_perceiver = PerceiverResampler(
            dim=self.protGPT2_model.config.n_embd,
            depth=perceiver_depth,
            dim_head=dim_head,
            heads=heads,
            num_latents=perceiver_num_latents
        )

        # Cross attention layers
        num_gpt_layers = len(self.protGPT2_model.transformer.h)
        self.cross_attn = nn.ModuleList([
            GatedCrossAttentionBlock(dim=self.protGPT2_model.config.n_embd, dim_head=dim_head, heads=heads)
            for _ in range(num_gpt_layers)
        ])

        # Combine GPT layers with cross attention
        self.layers = nn.ModuleList()
        for i, block in enumerate(self.protGPT2_model.transformer.h):
            self.layers.append(block)
            if i % cross_attn_every == 0 and i != 0:
                self.layers.append(GatedCrossAttentionBlock(dim=self.protGPT2_model.config.n_embd, dim_head=dim_head, heads=heads))

    def forward(self, smiles_strings, order=None, targets=None, optimize=False, kv_cache=None, burst=False):
        device = next(self.parameters()).device

        # Get SMILES embeddings through PolyBERT
        smiles_embeddings = self.polybert_encoder(smiles_strings)
        processed_smiles = self.smiles_perceiver(smiles_embeddings)

        # Initialize with start token
        gpt_input = self.protGPT2_tokenizer.encode_plus(
            "<|endoftext|>",
            return_tensors="pt",
            padding='max_length',
            max_length=self.max_len,
            truncation=True
        ).to(device)

        input_ids = gpt_input.input_ids.long()
        seq_length = input_ids.size(1)
        batch_size = 1 if isinstance(smiles_strings, str) else len(smiles_strings)

        hidden_states = self.protGPT2_model.transformer.wte(input_ids)

        # If no order is provided, use left-to-right
        if order is None:
            order = torch.arange(seq_length, device=device).unsqueeze(0).repeat(batch_size, 1)

        # Make sure order is the right length
        if order.size(1) > seq_length:
            order = order[:, :seq_length]
        elif order.size(1) < seq_length:
            # Pad order if needed
            padding = torch.arange(order.size(1), seq_length, device=device).unsqueeze(0).repeat(batch_size, 1)
            order = torch.cat([order, padding], dim=1)

        # Map the input tokens according to the order
        # When using random order, we need to reshuffle the input tokens
        if not optimize and not burst:  # Only shuffle during training
            reordered_input_ids = torch.zeros_like(input_ids)
            for b in range(batch_size):
                # Reorder the input tokens according to the order
                reordered_input_ids[b] = input_ids[b, order[b]]

            # Re-embed with reordered tokens
            hidden_states = self.protGPT2_model.transformer.wte(reordered_input_ids)

        # Get input and output positions from the order
        # Input positions: the current position in the order
        # Output positions: the next position in the order
        input_positions = order
        # Shift the order by 1 to get output positions (target positions)
        output_positions = torch.roll(order, -1, dims=1)
        # The last position wraps to the first position
        output_positions[:, -1] = order[:, 0]

        # Apply double positional encoding
        hidden_states = self.positional_encoding(hidden_states, input_positions, output_positions)

        # Create attention mask based on the order
        attention_mask = gpt_input.attention_mask
        num_heads = self.protGPT2_model.config.n_head

        # Create 4D attention mask [batch_size, num_heads, seq_length, seq_length]
        attention_mask = attention_mask.view(batch_size, 1, 1, seq_length)
        attention_mask = attention_mask.expand(batch_size, num_heads, seq_length, seq_length)
        attention_mask = attention_mask.to(dtype=hidden_states.dtype)

        # Create causal mask based on the order
        # A token at position i can attend to tokens at positions j where order[j] <= order[i]
        # Vectorized causal mask creation
        seq_indices = torch.arange(seq_length, device=device)
        expanded_seq_indices_i = seq_indices.unsqueeze(1).expand(seq_length, seq_length)
        expanded_seq_indices_j = seq_indices.unsqueeze(0).expand(seq_length, seq_length)

        causal_mask = torch.zeros((batch_size, seq_length, seq_length), device=device)
        for b in range(batch_size):
            # Get order for this batch
            order_b = order[b]
            # Get order values at positions i and j
            order_i = order_b[expanded_seq_indices_i]
            order_j = order_b[expanded_seq_indices_j]
            # Create mask where order_j <= order_i
            causal_mask[b] = (order_j <= order_i).float()

        # Reshape causal_mask to match attention_mask and combine them
        causal_mask = causal_mask.unsqueeze(1)  # [batch_size, 1, seq_length, seq_length]
        combined_mask = attention_mask * causal_mask

        for i, layer in enumerate(self.layers):
            if isinstance(layer, GatedCrossAttentionBlock):
                hidden_states = layer(hidden_states, processed_smiles)
            else:
                hidden_states = layer(hidden_states, attention_mask=combined_mask)[0]

        # Get logits
        logits = self.protGPT2_model.lm_head(hidden_states)

        if targets is None:
            if optimize:
                # inference-time mini-optimization: only forward the lm_head on the very last position
                return logits[:, [-1], :], None
            return logits, None

        # Compute loss against the targets
        # If targets are provided in original order, we need to shuffle them to match our order
        if targets is not None:
            shuffled_targets = torch.zeros_like(targets)
            for b in range(batch_size):
                # Reorder the targets according to the order
                shuffled_targets[b] = targets[b, order[b]]

            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                shuffled_targets.view(-1),
                ignore_index=-1
            )
        else:
            loss = None

        return logits, loss

    def custom_generate(self, smiles_string, max_length=200):
        device = next(self.parameters()).device

        # Get SMILES embeddings
        smiles_embeddings = self.polybert_encoder(smiles_string)
        processed_smiles = self.smiles_perceiver(smiles_embeddings)

        # Initialize with start token
        input_ids = torch.tensor([[self.protGPT2_tokenizer.bos_token_id]]).to(device)

        # Autoregressive generation
        for _ in range(max_length):
            inputs_embeds = self.protGPT2_model.transformer.wte(input_ids)
            inputs_embeds = self.positional_encoding(inputs_embeds)

            hidden_states = inputs_embeds
            cross_attn_idx = 0

            for i, layer in enumerate(self.layers):
                if isinstance(layer, GatedCrossAttentionBlock):
                    hidden_states = layer(hidden_states, processed_smiles)
                    cross_attn_idx += 1
                else:
                    hidden_states = layer(hidden_states, attention_mask=None)[0]

            next_token_logits = self.protGPT2_model.lm_head(hidden_states[:, -1, :])
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)

            input_ids = torch.cat([input_ids, next_token], dim=-1)

            if next_token.item() == self.protGPT2_tokenizer.eos_token_id:
                break

        return self.protGPT2_tokenizer.decode(input_ids[0], skip_special_tokens=True)

    def generate(self, smiles_string, max_length=50):
        return self.custom_generate(smiles_string, max_length)

    def state_dict(self):
        state_dict = super().state_dict()
        state_dict['smiles_perceiver'] = self.smiles_perceiver.state_dict()
        state_dict['cross_attn'] = self.cross_attn.state_dict()
        state_dict['polybert_encoder'] = self.polybert_encoder.state_dict()
        return state_dict

    def load_state_dict(self, state_dict):
        smiles_perceiver_state = state_dict.pop('smiles_perceiver')
        cross_attn_state = state_dict.pop('cross_attn')
        polybert_encoder_state = state_dict.pop('polybert_encoder')

        super().load_state_dict(state_dict)

        self.smiles_perceiver.load_state_dict(smiles_perceiver_state)
        self.cross_attn.load_state_dict(cross_attn_state)
        self.polybert_encoder.load_state_dict(polybert_encoder_state)

In [11]:
import random

### training

In [12]:
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import json
import pandas as pd
import os
import matplotlib.pyplot as plt

In [15]:
def train_with_random_order(model, train_loader, val_loader, num_epochs, device,
                           curriculum_steps=0, l2_reg=1e-5, sample_smiles=None):
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=l2_reg)
    criterion = nn.CrossEntropyLoss(ignore_index=model.protGPT2_tokenizer.pad_token_id)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    loss_log = []
    checkpoint_path = "/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/sigma_checkpoint.pth"

    if os.path.exists(checkpoint_path):
        print("Loading checkpoint...")
        model.load_state_dict(torch.load(checkpoint_path))

    # Total training steps for curriculum
    total_steps = num_epochs * len(train_loader)
    step_counter = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        batch_losses = []

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            step_counter += 1

            smiles_strings = batch['smiles']
            proteins = batch['proteins']
            protein_masks = batch['protein_masks'].to(device)

            # Determine whether to use random or left-to-right order based on curriculum
            if curriculum_steps > 0:
                # Current percentage of random ordering
                random_prob = min(1.0, step_counter / curriculum_steps)
                use_random = random.random() < random_prob
            else:
                use_random = True

            batch_size = len(smiles_strings)
            seq_length = model.max_len

            # Generate order
            if use_random:
                # Create random permutation for each batch item
                order = torch.stack([torch.randperm(seq_length) for _ in range(batch_size)]).to(device)
            else:
                # Left-to-right order
                order = torch.arange(seq_length).unsqueeze(0).repeat(batch_size, 1).to(device)

            optimizer.zero_grad()

            # Encode the proteins
            target_encoding = model.protGPT2_tokenizer(
                proteins,
                return_tensors="pt",
                padding='max_length',
                max_length=model.max_len,
                truncation=True
            ).to(device)

            # Forward pass with order - the targets are already in original order
            # The model will handle shuffling the targets according to the order
            outputs, loss = model(
                smiles_strings,
                order=order,
                targets=target_encoding.input_ids
            )

            # Print the raw token IDs (before decoding)
            print(f"Ground Truth Token IDs: {target_encoding.input_ids[0]}")
            print(f"Generated Protein Token IDs: {outputs[0]}")
            print(outputs[0].shape)  # Check the shape of the output tensor
            print(target_encoding.input_ids[0].shape)  # Check the shape of the target tensor


            # Get the token IDs from the logits (choose the token with the highest probability)
            predicted_token_ids = torch.argmax(outputs, dim=-1)

            # Now decode the token IDs to a protein sequence
            generated_proteins = model.protGPT2_tokenizer.decode(predicted_token_ids[0], skip_special_tokens=True)
            ground_truth_proteins = model.protGPT2_tokenizer.decode(target_encoding.input_ids[0], skip_special_tokens=True)

            # Print the ground truth and generated sequences
            print(f"Ground Truth: {ground_truth_proteins}")
            print(f"Generated Protein: {generated_proteins}")

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            batch_losses.append(loss.item())

        avg_loss = total_loss / len(train_loader)
        print(f"\nEpoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
        print(f"Per-batch Losses: {batch_losses[:5]} ...")

        val_loss = validate_with_random_order(model, val_loader, criterion, device)
        print(f"Validation Loss: {val_loss:.4f}")

        loss_log.append({
            'epoch': epoch+1,
            'train_loss': avg_loss,
            'val_loss': val_loss
        })

        scheduler.step()
        torch.save(model.state_dict(), checkpoint_path)

        # Generate sample proteins after each epoch
        if sample_smiles:
            print("\nSample Generated Proteins:")
            for smiles in sample_smiles:
                # Try both random and left-to-right generation
                gen_protein_lr = generate_autoregressively(model, smiles, max_length=100, random_order=False)
                gen_protein_random = generate_autoregressively(model, smiles, max_length=100, random_order=True)
                print(f"SMILES: {smiles}")
                print(f"Generated Protein (L2R): {gen_protein_lr}")
                print(f"Generated Protein (Random): {gen_protein_random}\n")

    # Save loss log
    loss_df = pd.DataFrame(loss_log)
    loss_df.to_csv("/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/sigma_loss_log.csv", index=False)

    # Plot loss
    plt.figure(figsize=(8, 5))
    plt.plot(loss_df['epoch'], loss_df['train_loss'], label='Train Loss', marker='o')
    plt.plot(loss_df['epoch'], loss_df['val_loss'], label='Validation Loss', marker='s')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Training and Validation Loss")
    plt.savefig("sigma_loss_plot.png")
    plt.show()

def validate_with_random_order(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in val_loader:
            smiles_strings = batch['smiles']
            proteins = batch['proteins']
            protein_masks = batch['protein_masks'].to(device)

            batch_size = len(smiles_strings)
            seq_length = model.max_len

            # Create random order for each item in the batch
            order = torch.stack([torch.randperm(seq_length) for _ in range(batch_size)]).to(device)

            # Encode the proteins
            target_encoding = model.protGPT2_tokenizer(
                proteins,
                return_tensors="pt",
                padding='max_length',
                max_length=model.max_len,
                truncation=True
            ).to(device)

            # Forward pass with order
            outputs, loss = model(
                smiles_strings,
                order=order,
                targets=target_encoding.input_ids
            )

            total_loss += loss.item()

    return total_loss / len(val_loader)

### inference + training

In [16]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load and preprocess data
train_data = preprocess_snp_data('/content/augmented_train.csv')
val_data = preprocess_snp_data('/content/augmented_val.csv')
test_data = preprocess_snp_data('/content/augmented_test.csv')

train_data = filter_datasets(train_data)
val_data = filter_datasets(val_data)
test_data = filter_datasets(test_data)

# Calculate max sequence length
max_length = max(
    train_data['protein_length'].max(),
    val_data['protein_length'].max(),
    test_data['protein_length'].max()
)
max_length = min(max_length, 1024)  # Cap at 1024 or your desired maximum
print(f"Max sequence length: {max_length}")

# Create datasets
train_dataset = ProteinGenerationDataset(train_data, max_length)
val_dataset = ProteinGenerationDataset(val_data, max_length)
test_dataset = ProteinGenerationDataset(test_data, max_length)

# Create dataloaders with custom collate function
train_loader = DataLoader(
    train_dataset,
    batch_size=1,  # Adjust based on your GPU memory
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)

# Initialize model with sigma-gpt capabilities
model = SigmaProtFlamingo(
    model_path='nferruz/ProtGPT2',
    max_len=max_length,
    cross_attn_every=3,
    dim_head=64,
    heads=8,
    perceiver_depth=2,
    perceiver_num_latents=64
).to(device)

num_epochs = 10

# Training loop with curriculum learning
# Start with 50% of sequences in left-to-right order and gradually increase to 100% random
curriculum_steps = int(0.5 * num_epochs * len(train_loader))  # Curriculum over first half of training
print("Starting training with sigma-gpt capabilities...")
train_with_random_order(
    model,
    train_loader,
    val_loader,
    num_epochs,
    device,
    curriculum_steps=curriculum_steps
)

# # Generate and evaluate
# print("Generating proteins for test set...")
# test_results = generate_and_evaluate(model, test_loader, device)

# # Save results
# print("Saving results...")
# results_path = '/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/test_results.json'
# with open(results_path, 'w') as f:
#     json.dump(test_results, f, indent=2)

# print(f"Results saved to {results_path}")


Using device: cuda
Max sequence length: 914
Starting training with sigma-gpt capabilities...
Loading checkpoint...


  model.load_state_dict(torch.load(checkpoint_path))
Epoch 1/10:   0%|          | 0/3111 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Ground Truth Token IDs: tensor([ 1811,  1039,   483, 12115,   333,  1598, 12126,   693,   280,  1431,
         2607,  4162, 25305,   331,  1817,  3734,   410,  4857,   279,  1768,
        29364,  3533, 19003,  1945,   342,  1227, 18786,   288,  1382,   410,
          732,   403,  1566,  7401,  9512,  3189,   285,   424,   621,   372,
         1382, 19839,   367,  3857,   691,   326,   500,  7062, 13528,   410,
        27266,  3393,  5663,  1549,   363,   721, 22684,   511, 11843,  1107,
        12129,   542,  2708,   325,  2734, 13889,   434,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     

Epoch 1/10:   0%|          | 1/3111 [00:01<1:08:31,  1.32s/it]

Ground Truth Token IDs: tensor([  468,  2613,   498,   619,  1747,  3179,   523,  3426, 11581,   463,
         3464,   483, 11144,   542,  1747,  1717,  1056,   361,  1321,   427,
         1856,  1090,   282,   761, 41525,  3939,   315,  4580,   468,  7057,
        46998,  1277,  1978,  7922,   805,   334,  1345, 42689, 14852,   841,
        46088,   556, 11852,   597, 16334,  1480,   473,  1615,  1345, 17922,
         5704, 29017,   722, 23267,   355,  3069,  1878,   284, 17498,   435,
        16338,  4569,  5182, 24261,   734,   969, 22028,   464,  1465,  1822,
         1520,  4320,   334,   706,   291,   851,    54,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     

Epoch 1/10:   0%|          | 2/3111 [00:02<49:15,  1.05it/s]  

Ground Truth Token IDs: tensor([43534,   368,  1969, 44486,  2212,   280,  1623,   332,  4060,  1201,
         1287,   631,  4818,  2156,   441, 29108,  1746,   271,   820,   296,
         1724, 22261, 11540,   428,  9272,   366,  1324,  6179,   407,   850,
        19785,   335,   866,  5007,   430,   466,    45,  1470,   537,  8676,
         1495,   265,  4523,   375,   400,  1298,  4324,  3568,   378,  1362,
          717,   916,   350,  4035,  1667,  2346, 14482,   267,   845, 14083,
          412,  8207,   458,  1819, 21243,   704,   327,  4202,   681,  6611,
          291,   532,   325,  1502,  1834,   299,  1877, 11540,   351,  1297,
          391,   377,  3367, 22445,  1567,   361,  5579, 10074,  2077,   568,
          507,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     

Epoch 1/10:   0%|          | 3/3111 [00:02<44:21,  1.17it/s]

Ground Truth Token IDs: tensor([14709, 31184,   588,   442, 31233,   789,   335,   608,   358,   866,
         3089,   296,  1492,   405,   296,  1645,  1032,  2597,   280,  1291,
          535,  4817,  1645,   487,  8238, 24085,  1614,   411, 18237,  2743,
        24782,   507,  6850,   665,  6602,  2504,   323,  1666,  1311,  1008,
          281,   457,   643,   407,  3757,   330,  3075,  1847,   393,   700,
           46,  1326, 35962,  1555,  7189,   378,   446,   579,   779, 27683,
        24775, 23920,   296,   457,   716,   278,  1431,   542,   291,  1856,
          287,  2816,   298,  1332,   281,  2280,   425, 12114, 24210,  1305,
         1407,   419,  4128,   623,   770,  3747,   692,   340,  7466,   422,
          712,  1422,   296,   753,   435,   477,   746,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     

Epoch 1/10:   0%|          | 4/3111 [00:03<40:58,  1.26it/s]

Ground Truth Token IDs: tensor([45460, 11577,   635,   789,  2593,  4987,   742, 15682,  3439,   331,
          684, 32886,  1191, 26838,  1564, 10755,   638,  1856,  1808,   325,
          778,   260,  1049,  3939, 47004,   445,  4976,   555,  1121, 13980,
         4316,   341,   734,  3142,   449,   420, 15562,   715,   331,   516,
        13184, 30813,  9147,  3027,   349,   473,   401,  2463,   356,  5480,
          357, 22453,   722,    57,  1847,  4250,  1035,  5849, 41129,   435,
        19996,  1646,   283,  2165,  7737,   625,   776, 19591, 16212,   413,
          493,   282,  3502,   361,   706,   334,   913,    54,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     

Epoch 1/10:   0%|          | 5/3111 [00:04<39:09,  1.32it/s]

Ground Truth Token IDs: tensor([  382, 25221,   774,   410,  1657,  1274,   262,  1174,   314,  1837,
        10524,   430,  3008,  1144,   488,  3148,  2201,  1045,   317,  1653,
          485,  2708,  8727,  1447,  5102,   731, 22605,   341,  2353,   408,
          794,   584,   275,   485,   298,  1679,  4197,   476,  8580, 24021,
          334,   611, 23802, 41838,  6331,  5448,   382, 12892,  1805, 15866,
         5460,   903,  4308,   301,  8396,  3516, 50167,  1682,    51,  1902,
        40707,  5982,  4817,   419,   496,  6797,   478,  3012,   744,  3321,
          319,  1431,   912,  5357,   978,   375,  1302,   261,  1114,   769,
        17642,   574,  7653,  1262,  3209,   275,  1253,   673,   366,  1450,
         3462,   446,   709,   373,  4510, 16564,  3350,   262,   642, 45802,
        17783,   365,   401,  6512,   377,  5867,    38,  1143, 17330,   510,
        37296,   513,   769,  2030,   523,   556,  6228, 19229,  7227,  1981,
          529,   375,   540,   257,  236

Epoch 1/10:   0%|          | 6/3111 [00:04<38:06,  1.36it/s]

Ground Truth Token IDs: tensor([21107,   623,  2369,  1747,  3179,   523,   291,  4029, 22696,  3464,
          483,   664,   327,   706, 31771,  1425, 27692,  1818,   332,  1856,
         1090, 39710, 41525,  3939,   467,  4580,   615,  7057,   464,   556,
          277,  1978,   835, 36696,   418, 48763,   420, 14852,   841, 48467,
          609, 11852,  6092,  1617,  1480,   473,  1814,  1345, 17922,  5704,
        29017,   722, 23267, 21685,  1103,   371, 16262,  1096, 12082, 12734,
          540,  2760, 15294,   734,  1617,   529,   708, 31700,   407, 40834,
          550,  1900,   706,   334,   639,    54,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     

Epoch 1/10:   0%|          | 7/3111 [00:05<37:13,  1.39it/s]

Ground Truth Token IDs: tensor([  339,  1782,  1366,  6443,   582,    57,   680,  1345,   604,   289,
          553,   696,  6241,   835,   280,  4056,  2912,  1069,   355,  1839,
          270,  3229,   355,   676,  2321,   614,  1957,  5061,   579,  8403,
          279,   684,   271,  1907, 22343, 43213,  1561,  5073,  3962,   640,
         5080,   567, 11942,  2321,   382,  4857,   753,   461,  1583,   747,
         3487,   462,  8533,   362,   674,  3566,   386, 10063,  4935,   728,
         7209,  2375,  5082,   807,   418, 24372,   345,  6344,  1733,  1594,
          487,  1612,   674,  1107,  1871,   662, 14890,  1369,   285,  2739,
          265, 36437,  2547,   496,   390, 30149,   267,  1160,   331,  5658,
          806,   331,  4587,   319,   690, 19271,   504,   278,   799,   260,
          573,  2955,  4036,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     

Epoch 1/10:   0%|          | 8/3111 [00:06<36:50,  1.40it/s]

Ground Truth Token IDs: tensor([19576,   329,   527,   430,   923,  1598,  7463,   693,   280,  1431,
         2607,  5305, 25305, 12599,   464,  1111,   841, 17243,  1768, 17895,
         3694, 13693,  1945,   342,  1227,   555, 21782,  1382,   410,   732,
          403,  1566, 12777,  9512,  3189,   285,   424,   551,   372,  1382,
        19839, 17088,   817,  3770,   500,  5981,   289,  1389,   370,   391,
          919,   284,  5663,  1549,   363,   721, 22684,  3494,   264, 21425,
          515,   936,  2189,  5168,  2734, 32167,   272,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     

Epoch 1/10:   0%|          | 9/3111 [00:06<36:20,  1.42it/s]

Ground Truth Token IDs: tensor([ 2110,   884,  1751,   570, 20654,  1385, 14609, 21055,   966,  6655,
          346,   296,  9172,   652,  2113,   429,  2147,   275,  1623,   372,
        14574,  1260, 41662,  4818,  2156,   441,   284,  1076, 23403,   931,
         1088,   544,   875,  4624,  9604,  3418, 25573,  1324,   468, 17522,
         9484, 21415,   866,  1985, 33091, 13411,  2034,   755,   426,   873,
         1925,   369,   587,  1515,   552,  1298,  4324,   733, 32207,  4918,
          916,   325,   655, 18670,   933, 41566, 14543,   415,  1496,   363,
         9556,   392,  1819,   446,   827,  1497,   938,   355, 11321,   688,
          532,  3224,   510, 17638,   625,   587,   295,   995,   351,  1124,
          391, 45419,  1499,   395,   463,    50,  1193, 10941, 39168,   748,
          279,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     

Epoch 1/10:   0%|          | 10/3111 [00:07<36:07,  1.43it/s]

Ground Truth Token IDs: tensor([  445,   536,   507, 13322, 22548,   875, 31653,  2015, 31253,  9341,
         3646,  4189,   295,   888, 29712, 12982,   791, 35169,  1105,  1490,
          353,  1880, 14524, 17464,   321, 29164,  2427,    57,   680,   394,
          817,   317,   795,   559,  9065,  1505,   434,   573,  6646,   317,
         1857,   333,   520,   389,   468, 21733,   458, 16048,  1746,   638,
          324,   418, 22503,  2031,   319,   851,   330, 17623,  1622,   414,
         6228,  3476,   343,  7084,   296,   413,   552,   372, 11936,  1705,
         3969,  3001,   925,   317, 30671, 19476,   731,  1629,  7591,  1134,
          353,  6878,   978,  2348,  2067, 10352,   326,   589, 10314, 30933,
          282,  1241,   798,   510,  4543, 24211,   838,   394, 15613,   859,
         4510, 13789, 11023, 13383,   726,   267,  1950,  1095,   357, 27393,
         1864,  9329,  1696,   379,   284,  5867, 11872,   740,   285,  3691,
          327,   954,   372,  2201,  985

Epoch 1/10:   0%|          | 11/3111 [00:08<36:23,  1.42it/s]

Ground Truth Token IDs: tensor([32578,  1076,  1733,   280,   854,  2321,   481, 14899,    55, 12971,
         5914,  3618,   746,  5004,   471,  6838, 38711,  1062,  1722,   333,
         1242,  1957,   955, 15731,  2302, 38153,   420,  5143,   358,  3052,
          416,  1658,  1535,   346,   440, 10980,   295,   759,  5307,   614,
         4118, 25332,  1434,   280,  1594,   591,  4471,   732,   355,  2199,
         7559,   678,   526,  7832,  1752,   425,  1488,   799,   485,   548,
         4029,  4236,   292, 11171,   281,  1976,   288,   850,  7326,  1345,
          567,   470,   437,   330,  8522,   435,   864,  1309,  7555,  2896,
          483,   320,    55,  3497,  8669,  6760,   294,  1766,   329,  1252,
        10252,  4575,  3012, 21457,  1486,   264,  4341,   292,   739,   348,
        42063,   260,  3229, 36552,  4626,  1751,   322,  1231,   542,   292,
          778,   388, 29572,   319,  3854,   299,  1480,  1357, 29420,  1983,
         1090,   780,   265,  2806, 4880

Epoch 1/10:   0%|          | 12/3111 [00:09<36:15,  1.42it/s]

Ground Truth Token IDs: tensor([ 1570,  6016,   586,   689,  1494,  2948,  2408,   426,  3542,  1505,
        30397,  1352,   331,  6538,  3887, 31304,  5486,   295,  1514,   682,
         4818,   376,  3548,   369,  8667,  1436,   278,   820,   296, 14529,
         2346,   382,  1506,  3128, 25573,  1324, 16828,   407,  1011,   531,
          370,   652, 20213,  3059, 17492,   334, 11429,  1403,  5413,  3766,
          400,  1298,   264,   669,   258,  1060,   830,   717,   916,   350,
           55,  1224,   357,  2836, 17956, 29188,   950,   391,  9058,   752,
          400, 18839,   926,   691,  5107,   681,  4773,   282, 43459,  2306,
        36896,  8075,   502,   435,  1570,  1288, 35703,  3367,   288,   549,
         1964,   284, 18925,  2393,   280,   471,  2373,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     

Epoch 1/10:   0%|          | 13/3111 [00:09<36:01,  1.43it/s]

Ground Truth Token IDs: tensor([ 1416, 15846,   709,  3461,  2848,   270,   830,   394,  1632,   389,
         7081,   383, 17990,  3383,   387,   846,   271,   818,   479, 12626,
          732,  8513,  3381,  1285,   499,   470,  1576,   566,  9464,  1375,
         1891,   407,  1650,  1777, 38389,  1195,  1139,   643, 14690, 29317,
          412,  1952,  8624,   362, 45717,   607, 48783,  1546,   400,   824,
          271,   548,   683,   378,  1362,  1640, 19884,  1102,   325,   730,
          397,   350, 42582,  1284,   280,   710,   325,  1779,  1094,  9365,
         1479, 18252,  1039,  1055,   386,   416,   723,   321,  1644,   359,
        48032,  1963,  3670,  4657,   517,   361,  2327,   369,   399,   474,
        41350,   525,   930,   910,   270,  1295,   430, 26915,  1027,  4021,
          963,   665,   395, 22518, 14364,  3924,   584,   296,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     

Epoch 1/10:   0%|          | 14/3111 [00:10<35:58,  1.43it/s]

Ground Truth Token IDs: tensor([ 2364,  1057,   331,  2408,   927,   808, 17426,   362,   598,  3755,
          510, 32528,   268,  1536, 17088,  2607,   289,  2349,  5896, 11965,
          743,   609,   318,   839,  4929,  9976,   315,   485,  1648,   637,
         8140,  3117,   452,   808, 14021,   802,   949, 24600,  2171,   330,
          413,  3270,   271,   424,   523, 25863,  2692,   619,  3457,   335,
         7162, 11966,   633,   329,  3306, 13396,   264,  1740,  2573, 33231,
         1521,   626,  4371, 36028,  6176,   923,   336,  1746,   417,   616,
        29553, 15640,  1361,   282,  1116,    33,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     

Epoch 1/10:   0%|          | 15/3111 [00:11<39:09,  1.32it/s]


KeyboardInterrupt: 

### generation

In [None]:
import torch
import json
from tqdm import tqdm

# Load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Path to model checkpoint
checkpoint_path = "/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/sigma_checkpoint.pth"

# Load model
model = SigmaProtFlamingo(
    model_path='nferruz/ProtGPT2',
    max_len=914,  # Ensure this matches the training max_len
    cross_attn_every=3,
    dim_head=64,
    heads=8,
    perceiver_depth=2,
    perceiver_num_latents=64
).to(device)



In [None]:
ProteinGenerationDataset

In [None]:
# Load trained weights
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

# Load test data
test_data = preprocess_snp_data('/content/augmented_test.csv')
test_data = filter_datasets(test_data)

# Create test dataset and dataloader
test_dataset = ProteinGenerationDataset(test_data,max_length = 914 )
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)



In [None]:
def generate_autoregressively(model, smiles_string, max_length=914, temperature=1.0, random_order=False):
    """Generate protein autoregressively, with option to use random order"""
    device = next(model.parameters()).device

    # Get SMILES embeddings
    smiles_embeddings = model.polybert_encoder([smiles_string])
    processed_smiles = model.smiles_perceiver(smiles_embeddings)

    # Initialize with start token
    input_ids = torch.tensor([[model.protGPT2_tokenizer.bos_token_id]], device=device)

    # If using random order, generate a random permutation
    if random_order:
        order = torch.randperm(max_length, device=device).unsqueeze(0)
    else:
        order = torch.arange(max_length, device=device).unsqueeze(0)

    # Track the current positions in the order
    current_pos = 0

    # Generated sequence in order's positions
    generated_sequence = torch.full((1, max_length), model.protGPT2_tokenizer.pad_token_id, device=device)
    generated_sequence[0, 0] = model.protGPT2_tokenizer.bos_token_id  # Start token

    while current_pos < max_length - 1:
        # Get the next position in the order
        next_pos = current_pos + 1

        # Forward pass to get next token prediction
        with torch.no_grad():
            # Use only the sequence up to the current position
            current_order = order[:, :next_pos]
            current_sequence = generated_sequence[:, current_order[0]]

            # Get logits for the next token
            logits, _ = model(
                smiles_string,
                order=current_order,
                optimize=True
            )

            # Apply temperature and sample
            logits = logits[0, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()

            # Add the token to the generated sequence at the next position in the order
            generated_sequence[0, order[0, next_pos]] = next_token

            # Check for EOS token
            if next_token == model.protGPT2_tokenizer.eos_token_id:
                break

            current_pos = next_pos

    # Decode the generated sequence
    generated_ids = generated_sequence[0].tolist()
    print('generated_ids',generated_ids)
    # Remove padding tokens
    generated_ids = [id for id in generated_ids if id != model.protGPT2_tokenizer.pad_token_id]
    seq = model.protGPT2_tokenizer.decode(generated_ids, skip_special_tokens=True)
    print('seq',seq)
    print("autoregressive gen done...")
    return seq


In [None]:
def generate_with_rejection_sampling(model, smiles_string, max_length=914, num_orders=5, temperature=1.0):
    """Generate protein using token-based rejection sampling with proper MH acceptance ratio"""
    device = next(model.parameters()).device

    # Get SMILES embeddings
    smiles_embeddings = model.polybert_encoder([smiles_string])
    processed_smiles = model.smiles_perceiver(smiles_embeddings)

    # Initialize with start token
    prompt = torch.tensor([[model.protGPT2_tokenizer.bos_token_id]], device=device)

    # Initialize full sequence with padding
    full_seq = torch.full((1, max_length), model.protGPT2_tokenizer.pad_token_id, device=device)
    full_seq[:, 0] = model.protGPT2_tokenizer.bos_token_id  # Start token

    # Track positions that have been filled
    filled_positions = {0}  # Start with position 0 filled

    while len(filled_positions) < max_length:
        remaining_positions = [i for i in range(max_length) if i not in filled_positions]
        if not remaining_positions:
            break

        # Step 1: Sample tokens at all remaining positions from marginal distribution
        # This is our proposal distribution p(x̃)
        candidate_tokens = {}
        proposal_probs = {}  # Store the probability of each proposal

        for pos in remaining_positions:
            # Create current filled sequence context
            current_context = torch.ones((1, max_length), device=device) * model.protGPT2_tokenizer.pad_token_id
            for filled_pos in filled_positions:
                current_context[0, filled_pos] = full_seq[0, filled_pos]

            # Get logits for this position given current context
            with torch.no_grad():
                # Order that puts this position last
                context_order = torch.tensor([list(filled_positions) + [pos]], device=device)

                logits = get_logits_for_position(model, current_context, context_order, smiles_string, pos)

                # Sample a token and record its probability
                logits = logits / temperature
                probs = F.softmax(logits, dim=-1)
                token_dist = torch.distributions.Categorical(probs)
                token = token_dist.sample().item()

                candidate_tokens[pos] = token
                proposal_probs[pos] = probs[0, token].item()

        # Step 2: Evaluate acceptance under different orders
        best_order_acceptances = []

        for _ in range(num_orders):
            # Create a random permutation of remaining positions
            eval_order = random.sample(remaining_positions, len(remaining_positions))

            accepted_tokens = []
            accepted_positions = []
            acceptance_ratios = []

            # Try to accept tokens in this order
            for pos in eval_order:
                # Create sequence with previously accepted tokens
                temp_seq = full_seq.clone()
                for acc_pos in accepted_positions:
                    temp_seq[0, acc_pos] = candidate_tokens[acc_pos]

                # Get conditional probability q(x̃|X,x̃σ<i)
                filled_plus_accepted = list(filled_positions) + accepted_positions
                context_order = torch.tensor([filled_plus_accepted + [pos]], device=device)

                with torch.no_grad():
                    cond_logits = get_logits_for_position(
                        model, temp_seq, context_order, processed_smiles, pos
                    )

                    cond_probs = F.softmax(cond_logits / temperature, dim=-1)
                    cond_prob = cond_probs[0, candidate_tokens[pos]].item()

                # Compute acceptance ratio r = q(x̃i|X,x̃σ<i) / p(x̃i|X)
                # Where p(x̃i|X) is the proposal probability
                acceptance_ratio = min(1.0, cond_prob / proposal_probs[pos])

                # Decide whether to accept
                if random.random() < acceptance_ratio:
                    accepted_tokens.append(candidate_tokens[pos])
                    accepted_positions.append(pos)
                    acceptance_ratios.append(acceptance_ratio)
                else:
                    # Stop at first rejection
                    break

            best_order_acceptances.append((accepted_positions, accepted_tokens, acceptance_ratios))

        # Step 3: Dynamic token acceptance
        best_order_idx = -1
        max_accepted = -1
        min_sequence_idx = -1

        for idx, (accepted_positions, _, acceptance_ratios) in enumerate(best_order_acceptances):
            if len(accepted_positions) > max_accepted:
                max_accepted = len(accepted_positions)
                best_order_idx = idx
                # Find the minimum position in the sequence where we see a rejection
                if len(accepted_positions) < len(remaining_positions):
                    min_sequence_idx = len(accepted_positions)
                else:
                    min_sequence_idx = len(remaining_positions)

        # No need to calculate min across orders if all orders accept all tokens
        if min_sequence_idx == -1:
            min_sequence_idx = len(remaining_positions)

        # Get the best order
        best_order = best_order_acceptances[best_order_idx]
        accepted_positions, accepted_tokens, _ = best_order

        # Limit acceptance to positions before the minimum rejection
        accepted_positions = accepted_positions[:min_sequence_idx]
        accepted_tokens = accepted_tokens[:min_sequence_idx]

        # Update the sequence with accepted tokens
        for pos, token in zip(accepted_positions, accepted_tokens):
            full_seq[0, pos] = token
            filled_positions.add(pos)

            # Check for EOS token
            if token == model.protGPT2_tokenizer.eos_token_id:
                break

    # Decode the generated sequence
    result = model.protGPT2_tokenizer.decode(
        [t for t in full_seq[0].tolist() if t != model.protGPT2_tokenizer.pad_token_id],
        skip_special_tokens=True
    )
    return result

def get_logits_for_position(model, sequence, order, smiles_string, target_position):
    """Helper function to get logits for a specific position"""
    # Run model forward pass
    logits, _ = model(
        smiles_string,  # Pass the SMILES string
        order=order,
        optimize=True
    )

    # Return logits for target position (last position in the order)
    return logits[:, -1, :]

In [None]:
import time
import numpy as np

In [None]:
def evaluate_on_unique_smiles(model, test_loader, device, output_file="generated_proteins_comparison.json"):
    """Generate proteins using both methods on unique SMILES from test set"""
    model.eval()

    # Collect unique SMILES from the test loader
    unique_smiles = set()
    for batch in test_loader:
        unique_smiles.update(batch['smiles'])

    unique_smiles = list(unique_smiles)  # Convert to list
    print(f"Found {len(unique_smiles)} unique SMILES in test set")

    results = []

    # Generate proteins using both methods and time each generation
    for i, smiles in enumerate(tqdm(unique_smiles, desc="Generating proteins")):
        # Track time for autoregressive generation
        start_time = time.time()
        print('autoregressive generations...')
        print(f"Generating protein for SMILES: {smiles}")
        ar_protein = generate_autoregressively(model, smiles, max_length=914, temperature=1.0, random_order=False)
        print(ar_protein)
        ar_time = time.time() - start_time

        # Track time for rejection sampling
        start_time = time.time()
        print('rejection sampling generations...')
        print(f"Generating protein for SMILES: {smiles}")
        rs_protein = generate_with_rejection_sampling(model, smiles, max_length=914, num_orders=5, temperature=1.0)
        print(rs_protein)
        rs_time = time.time() - start_time

        results.append({
            'SMILES': smiles,
            'Autoregressive': {
                'protein': ar_protein,
                'time_seconds': ar_time
            },
            'Rejection_Sampling': {
                'protein': rs_protein,
                'time_seconds': rs_time
            }
        })

        # Print progress occasionally
        if (i + 1) % 5 == 0:
            print(f"\nCompleted {i+1}/{len(unique_smiles)}")
            print(f"Example - SMILES: {smiles}")
            print(f"Autoregressive: {ar_protein[:50]}... ({ar_time:.2f}s)")
            print(f"Rejection Sampling: {rs_protein[:50]}... ({rs_time:.2f}s)")

    # Save results to JSON
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)

    # Calculate and print average times
    ar_times = [r['Autoregressive']['time_seconds'] for r in results]
    rs_times = [r['Rejection_Sampling']['time_seconds'] for r in results]

    print(f"\nGeneration complete!")
    print(f"Average autoregressive generation time: {np.mean(ar_times):.2f}s")
    print(f"Average rejection sampling generation time: {np.mean(rs_times):.2f}s")
    print(f"Speed improvement: {np.mean(ar_times)/np.mean(rs_times):.2f}x")

    return results

In [None]:
results = evaluate_on_unique_smiles(model, test_loader, device, output_file="sigma_gpt_comparison_results.json")