### setup

In [1]:
!pip install Levenshtein
!pip install einops
!pip install einops_exts
!pip install torch
!pip install transformers
!pip install tqdm
!pip install sentencepiece
# !pip install fair-esm



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import pandas as pd
import re
import math
import json
from tqdm import tqdm
from einops import rearrange, repeat
# import esm

# Set up GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# # Load ESM-2 model
# esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
# batch_converter = alphabet.get_batch_converter()
# esm_model = esm_model.to(device)  # Move to GPU if available
# esm_model.eval()  # Set to evaluation mode


Using device: cuda


### data

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
def preprocess_snp_data(file_path):
    snp_df = pd.read_csv(file_path)

    # Basic preprocessing and length calculations
    snp_df['smiles_length'] = snp_df['smiles'].apply(len)
    snp_df['protein_length'] = snp_df['protein_sequence'].apply(len)

    return snp_df

def filter_datasets(dataset):
    return dataset[
        (dataset['smiles'].notna()) &
        (dataset['protein_sequence'].notna()) &
        (dataset['smiles_length'] > 0) &
        (dataset['protein_length'] > 0)
    ]

class ProteinGenerationDataset(Dataset):
    def __init__(self, dataframe, max_length):
        self.dataframe = dataframe
        self.max_length = max_length

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        return row['smiles'], row['protein_sequence']

def collate_fn(batch):
    """
    Custom collate function to handle padding within batches.
    Args:
        batch: List of tuples (smiles, protein)
    Returns:
        Padded and batched tensors
    """
    smiles, proteins = zip(*batch)

    # SMILES strings don't need padding as PolyBERT handles that internally
    smiles = list(smiles)

    # Get max length in this batch for proteins (not exceeding dataset max_length)
    max_protein_len = min(max(len(p) for p in proteins), max_length)

    # Pad proteins to max length in batch
    padded_proteins = []
    protein_masks = []

    for protein in proteins:
        if len(protein) > max_protein_len:
            padded = protein[:max_protein_len]
            mask = [1] * max_protein_len
        else:
            padded = protein + ' ' * (max_protein_len - len(protein))
            mask = [1] * len(protein) + [0] * (max_protein_len - len(protein))

        padded_proteins.append(padded)
        protein_masks.append(mask)

    return {
        'smiles': smiles,
        'proteins': padded_proteins,
        'protein_masks': torch.tensor(protein_masks, dtype=torch.bool)
    }

### utils

In [5]:
# Model Components
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0)]

class DoublePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        # Use the full embedding dimension divided into two halves
        self.d_model = d_model
        half_dim = d_model // 2

        # Create position encodings for both input and output positions
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, half_dim, 2) * (-math.log(10000.0) / half_dim))

        # Input position encodings
        pe_input = torch.zeros(max_len, half_dim)
        pe_input[:, 0::2] = torch.sin(position * div_term)
        pe_input[:, 1::2] = torch.cos(position * div_term)

        # Output position encodings
        pe_output = torch.zeros(max_len, half_dim)
        pe_output[:, 0::2] = torch.sin(position * div_term)
        pe_output[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe_input', pe_input)
        self.register_buffer('pe_output', pe_output)

    def forward(self, x, input_positions, output_positions):
        batch_size, seq_length, _ = x.shape

        # Create a tensor of zeros with the same shape as the input
        pos_encoding = torch.zeros_like(x)

        # For each item in the batch
        for b in range(batch_size):
            for t in range(seq_length):
                # Get the input and output positions for this token
                input_pos = input_positions[b, t] if input_positions is not None else t
                output_pos = output_positions[b, t] if output_positions is not None else t

                if input_pos < self.pe_input.size(0) and output_pos < self.pe_output.size(0):
                    # Fill the first half with input position encoding
                    pos_encoding[b, t, :self.d_model//2] = self.pe_input[input_pos]
                    # Fill the second half with output position encoding
                    pos_encoding[b, t, self.d_model//2:] = self.pe_output[output_pos]

        return x + pos_encoding

class PerceiverAttention(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm_media = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x, latents):
        """
        x: [batch_size, seq_len_x, dim]
        latents: [batch_size, seq_len_l, dim]
        """
        batch_size = x.shape[0]

        x = self.norm_media(x)
        latents = self.norm_latents(latents)

        # Ensure latents has correct batch size
        if latents.size(0) != batch_size:
            latents = latents.expand(batch_size, -1, -1)

        q = self.to_q(latents)
        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
        q = q * self.scale

        # Ensure proper concatenation
        kv_input = torch.cat((x, latents), dim=1)  # concatenate along sequence dimension
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)

        sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
        attn = sim.softmax(dim=-1)
        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

class GatedCrossAttentionBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        self.attn = PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads)
        self.attn_gate = nn.Parameter(torch.tensor([0.]))
        self.ff = FeedForward(dim, mult=ff_mult)
        self.ff_gate = nn.Parameter(torch.tensor([0.]))

    def forward(self, x, media):
        """
        x: [batch_size, seq_len_x, dim]
        media: [batch_size, seq_len_m, dim]
        """
        batch_size = x.shape[0]
        target_batch_size = media.size(0)

        # Handle batch size mismatch
        if batch_size > target_batch_size:
            media = media.expand(batch_size, -1, -1)
        elif batch_size < target_batch_size:
            x = x.expand(target_batch_size, -1, -1)

        gate = self.attn_gate.tanh()
        x = self.attn(media, x) * gate + x
        x = self.ff(x) * self.ff_gate.tanh() + x
        return x

class PerceiverResampler(nn.Module):
    def __init__(self, dim, depth, dim_head=64, heads=8, num_latents=64):
        super().__init__()
        # Initialize latents without batch dimension
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                FeedForward(dim=dim)
            ]))

    def forward(self, x):
        batch_size = x.shape[0]
        # Expand latents to batch size
        latents = repeat(self.latents, 'n d -> b n d', b=batch_size)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        return latents

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * mult, bias=False),
            nn.GELU(),
            nn.Linear(dim * mult, dim, bias=False)
        )

    def forward(self, x):
        return self.net(x)

# class PerceiverResampler(nn.Module):
#     def __init__(self, dim, depth, dim_head=64, heads=8, num_latents=64):
#         super().__init__()
#         self.latents = nn.Parameter(torch.randn(num_latents, dim))
#         self.layers = nn.ModuleList([])

#         for _ in range(depth):
#             self.layers.append(nn.ModuleList([
#                 PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
#                 FeedForward(dim=dim)
#             ]))

#     def forward(self, x):
#         latents = repeat(self.latents, 'n d -> b n d', b=x.shape[0])

#         for attn, ff in self.layers:
#             latents = attn(x, latents) + latents
#             latents = ff(latents) + latents

#         return latents

# class GatedCrossAttentionBlock(nn.Module):
#     def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
#         super().__init__()
#         self.attn = PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads)
#         self.attn_gate = nn.Parameter(torch.tensor([0.]))
#         self.ff = FeedForward(dim, mult=ff_mult)
#         self.ff_gate = nn.Parameter(torch.tensor([0.]))

#     def forward(self, x, media):
#         gate = self.attn_gate.tanh()
#         x = self.attn(media, x) * gate + x
#         x = self.ff(x) * self.ff_gate.tanh() + x
#         return x

### PolyBert Encoder

In [6]:
from transformers import AutoTokenizer, AutoModel
import torch


In [7]:
# class PolyBERTEncoder(nn.Module):
#     def __init__(self, output_dim):
#         super().__init__()
#         self.polybert = AutoModel.from_pretrained('kuelumbus/polyBERT')
#         self.tokenizer = AutoTokenizer.from_pretrained('kuelumbus/polyBERT')
#         self.output_dim = output_dim
#         # Add a projection layer to match the required dimension
#         self.projection = nn.Linear(self.polybert.config.hidden_size, output_dim)

#     def mean_pooling(self, model_output, attention_mask):
#         token_embeddings = model_output[0]
#         input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
#         return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

#     def forward(self, smiles_strings):
#         # Tokenize the SMILES strings
#         encoded_input = self.tokenizer(smiles_strings,
#                                      padding=True,
#                                      truncation=True,
#                                      return_tensors='pt').to(next(self.polybert.parameters()).device)

#         # Get PolyBERT embeddings
#         with torch.no_grad():
#             model_output = self.polybert(**encoded_input)

#         # Debug prints
#         print("Model Output Keys:", model_output.keys())  # Check available keys
#         # print("Last Hidden State:", model_output.last_hidden_state)
#         print("Last Hidden State Shape:", model_output.last_hidden_state.shape)

#         # Pool the embeddings
#         pooled_output = self.mean_pooling(model_output, encoded_input['attention_mask'])

#         # print("Pooled Output:", pooled_output)
#         print("Pooled Output Shape:", pooled_output.shape)

#         # Project to required dimension
#         projected_output = self.projection(pooled_output)

#         return projected_output

In [8]:
class PolyBERTEncoder(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.polybert = AutoModel.from_pretrained('kuelumbus/polyBERT')
        self.tokenizer = AutoTokenizer.from_pretrained('kuelumbus/polyBERT')
        self.output_dim = output_dim
        # Project each token embedding to required dimension
        self.projection = nn.Linear(self.polybert.config.hidden_size, output_dim)

    def forward(self, smiles_strings):
        # Tokenize the SMILES strings
        encoded_input = self.tokenizer(smiles_strings,
                                     padding=True,
                                     truncation=True,
                                     return_tensors='pt').to(next(self.polybert.parameters()).device)

        # Get PolyBERT embeddings
        with torch.no_grad():
            model_output = self.polybert(**encoded_input)

        # Debug prints
        # print("Model Output Keys:", model_output.keys())
        # print("Last Hidden State Shape:", model_output.last_hidden_state.shape)  # [batch_size, seq_len, hidden_size]

        # Get sequence embeddings
        sequence_embeddings = model_output.last_hidden_state

        # Project each token embedding to required dimension
        projected_output = self.projection(sequence_embeddings)  # [batch_size, seq_len, output_dim]
        # print("Projected Output Shape:", projected_output.shape)

        return projected_output

### ProtFlamingo

In [9]:
import torch.nn.functional as F


In [10]:
class SigmaProtFlamingo(nn.Module):
    def __init__(self, model_path, max_len, cross_attn_every=3, dim_head=64, heads=8, perceiver_depth=2, perceiver_num_latents=64):
        super().__init__()

        self.protGPT2_model = GPT2LMHeadModel.from_pretrained(model_path)
        self.protGPT2_tokenizer = GPT2Tokenizer.from_pretrained(model_path)
        self.max_len = max_len

        if self.protGPT2_tokenizer.pad_token is None:
            self.protGPT2_tokenizer.pad_token = self.protGPT2_tokenizer.eos_token
            self.protGPT2_model.config.pad_token_id = self.protGPT2_model.config.eos_token_id

        self.cross_attn_every = cross_attn_every

        # PolyBERT encoder for SMILES strings
        self.polybert_encoder = PolyBERTEncoder(self.protGPT2_model.config.n_embd)

        # Replace single positional encoding with double positional encoding
        self.positional_encoding = DoublePositionalEncoding(self.protGPT2_model.config.n_embd, max_len=max_len)

        # Single perceiver resampler for SMILES embeddings
        self.smiles_perceiver = PerceiverResampler(
            dim=self.protGPT2_model.config.n_embd,
            depth=perceiver_depth,
            dim_head=dim_head,
            heads=heads,
            num_latents=perceiver_num_latents
        )

        # Cross attention layers
        num_gpt_layers = len(self.protGPT2_model.transformer.h)
        self.cross_attn = nn.ModuleList([
            GatedCrossAttentionBlock(dim=self.protGPT2_model.config.n_embd, dim_head=dim_head, heads=heads)
            for _ in range(num_gpt_layers)
        ])

        # Combine GPT layers with cross attention
        self.layers = nn.ModuleList()
        for i, block in enumerate(self.protGPT2_model.transformer.h):
            self.layers.append(block)
            if i % cross_attn_every == 0 and i != 0:
                self.layers.append(GatedCrossAttentionBlock(dim=self.protGPT2_model.config.n_embd, dim_head=dim_head, heads=heads))

    def forward(self, smiles_strings, order=None, targets=None, optimize=False, kv_cache=None, burst=False):
        device = next(self.parameters()).device

        # Get SMILES embeddings through PolyBERT
        smiles_embeddings = self.polybert_encoder(smiles_strings)
        processed_smiles = self.smiles_perceiver(smiles_embeddings)

        # Initialize with start token
        gpt_input = self.protGPT2_tokenizer.encode_plus(
            "<|endoftext|>",
            return_tensors="pt",
            padding='max_length',
            max_length=self.max_len,
            truncation=True
        ).to(device)

        input_ids = gpt_input.input_ids.long()
        seq_length = input_ids.size(1)
        batch_size = 1 if isinstance(smiles_strings, str) else len(smiles_strings)

        hidden_states = self.protGPT2_model.transformer.wte(input_ids)

        # If no order is provided, use left-to-right
        if order is None:
            order = torch.arange(seq_length, device=device).unsqueeze(0).repeat(batch_size, 1)

        # Make sure order is the right length
        if order.size(1) > seq_length:
            order = order[:, :seq_length]
        elif order.size(1) < seq_length:
            # Pad order if needed
            padding = torch.arange(order.size(1), seq_length, device=device).unsqueeze(0).repeat(batch_size, 1)
            order = torch.cat([order, padding], dim=1)

        # Map the input tokens according to the order
        # When using random order, we need to reshuffle the input tokens
        if not optimize and not burst:  # Only shuffle during training
            reordered_input_ids = torch.zeros_like(input_ids)
            for b in range(batch_size):
                # Reorder the input tokens according to the order
                reordered_input_ids[b] = input_ids[b, order[b]]

            # Re-embed with reordered tokens
            hidden_states = self.protGPT2_model.transformer.wte(reordered_input_ids)

        # Get input and output positions from the order
        # Input positions: the current position in the order
        # Output positions: the next position in the order
        input_positions = order
        # Shift the order by 1 to get output positions (target positions)
        output_positions = torch.roll(order, -1, dims=1)
        # The last position wraps to the first position
        output_positions[:, -1] = order[:, 0]

        # Apply double positional encoding
        hidden_states = self.positional_encoding(hidden_states, input_positions, output_positions)

        # Create attention mask based on the order
        attention_mask = gpt_input.attention_mask
        num_heads = self.protGPT2_model.config.n_head

        # Create 4D attention mask [batch_size, num_heads, seq_length, seq_length]
        attention_mask = attention_mask.view(batch_size, 1, 1, seq_length)
        attention_mask = attention_mask.expand(batch_size, num_heads, seq_length, seq_length)
        attention_mask = attention_mask.to(dtype=hidden_states.dtype)

        # Create causal mask based on the order
        # A token at position i can attend to tokens at positions j where order[j] <= order[i]
        causal_mask = torch.zeros((batch_size, seq_length, seq_length), device=device)
        for b in range(batch_size):
            for i in range(seq_length):
                for j in range(seq_length):
                    # Token i can attend to token j if order[j] <= order[i]
                    # This means j comes before or at the same position as i in the order
                    if order[b, j] <= order[b, i]:
                        causal_mask[b, i, j] = 1.0

        # Reshape causal_mask to match attention_mask and combine them
        causal_mask = causal_mask.unsqueeze(1)  # [batch_size, 1, seq_length, seq_length]
        combined_mask = attention_mask * causal_mask

        for i, layer in enumerate(self.layers):
            if isinstance(layer, GatedCrossAttentionBlock):
                hidden_states = layer(hidden_states, processed_smiles)
            else:
                hidden_states = layer(hidden_states, attention_mask=combined_mask)[0]

        # Get logits
        logits = self.protGPT2_model.lm_head(hidden_states)

        if targets is None:
            if optimize:
                # inference-time mini-optimization: only forward the lm_head on the very last position
                return logits[:, [-1], :], None
            return logits, None

        # Compute loss against the targets
        # If targets are provided in original order, we need to shuffle them to match our order
        if targets is not None:
            shuffled_targets = torch.zeros_like(targets)
            for b in range(batch_size):
                # Reorder the targets according to the order
                shuffled_targets[b] = targets[b, order[b]]

            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                shuffled_targets.view(-1),
                ignore_index=-1
            )
        else:
            loss = None

        return logits, loss

    def custom_generate(self, smiles_string, max_length=200):
        device = next(self.parameters()).device

        # Get SMILES embeddings
        smiles_embeddings = self.polybert_encoder(smiles_string)
        processed_smiles = self.smiles_perceiver(smiles_embeddings)

        # Initialize with start token
        input_ids = torch.tensor([[self.protGPT2_tokenizer.bos_token_id]]).to(device)

        # Autoregressive generation
        for _ in range(max_length):
            inputs_embeds = self.protGPT2_model.transformer.wte(input_ids)
            inputs_embeds = self.positional_encoding(inputs_embeds)

            hidden_states = inputs_embeds
            cross_attn_idx = 0

            for i, layer in enumerate(self.layers):
                if isinstance(layer, GatedCrossAttentionBlock):
                    hidden_states = layer(hidden_states, processed_smiles)
                    cross_attn_idx += 1
                else:
                    hidden_states = layer(hidden_states, attention_mask=None)[0]

            next_token_logits = self.protGPT2_model.lm_head(hidden_states[:, -1, :])
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)

            input_ids = torch.cat([input_ids, next_token], dim=-1)

            if next_token.item() == self.protGPT2_tokenizer.eos_token_id:
                break

        return self.protGPT2_tokenizer.decode(input_ids[0], skip_special_tokens=True)

    def generate(self, smiles_string, max_length=50):
        return self.custom_generate(smiles_string, max_length)

    def state_dict(self):
        state_dict = super().state_dict()
        state_dict['smiles_perceiver'] = self.smiles_perceiver.state_dict()
        state_dict['cross_attn'] = self.cross_attn.state_dict()
        state_dict['polybert_encoder'] = self.polybert_encoder.state_dict()
        return state_dict

    def load_state_dict(self, state_dict):
        smiles_perceiver_state = state_dict.pop('smiles_perceiver')
        cross_attn_state = state_dict.pop('cross_attn')
        polybert_encoder_state = state_dict.pop('polybert_encoder')

        super().load_state_dict(state_dict)

        self.smiles_perceiver.load_state_dict(smiles_perceiver_state)
        self.cross_attn.load_state_dict(cross_attn_state)
        self.polybert_encoder.load_state_dict(polybert_encoder_state)

In [11]:
import random

### training

In [12]:
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import json
import pandas as pd
import os
import matplotlib.pyplot as plt

In [13]:
def train_with_random_order(model, train_loader, val_loader, num_epochs, device,
                           curriculum_steps=0, l2_reg=1e-5, sample_smiles=None):
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=l2_reg)
    criterion = nn.CrossEntropyLoss(ignore_index=model.protGPT2_tokenizer.pad_token_id)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    loss_log = []
    checkpoint_path = "/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/sigma_checkpoint.pth"

    if os.path.exists(checkpoint_path):
        print("Loading checkpoint...")
        model.load_state_dict(torch.load(checkpoint_path))

    # Total training steps for curriculum
    total_steps = num_epochs * len(train_loader)
    step_counter = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        batch_losses = []

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            step_counter += 1

            smiles_strings = batch['smiles']
            proteins = batch['proteins']
            protein_masks = batch['protein_masks'].to(device)

            # Determine whether to use random or left-to-right order based on curriculum
            if curriculum_steps > 0:
                # Current percentage of random ordering
                random_prob = min(1.0, step_counter / curriculum_steps)
                use_random = random.random() < random_prob
            else:
                use_random = True

            batch_size = len(smiles_strings)
            seq_length = model.max_len

            # Generate order
            if use_random:
                # Create random permutation for each batch item
                order = torch.stack([torch.randperm(seq_length) for _ in range(batch_size)]).to(device)
            else:
                # Left-to-right order
                order = torch.arange(seq_length).unsqueeze(0).repeat(batch_size, 1).to(device)

            optimizer.zero_grad()

            # Encode the proteins
            target_encoding = model.protGPT2_tokenizer(
                proteins,
                return_tensors="pt",
                padding='max_length',
                max_length=model.max_len,
                truncation=True
            ).to(device)

            # Forward pass with order - the targets are already in original order
            # The model will handle shuffling the targets according to the order
            outputs, loss = model(
                smiles_strings,
                order=order,
                targets=target_encoding.input_ids
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            batch_losses.append(loss.item())

        avg_loss = total_loss / len(train_loader)
        print(f"\nEpoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
        print(f"Per-batch Losses: {batch_losses[:5]} ...")

        val_loss = validate_with_random_order(model, val_loader, criterion, device)
        print(f"Validation Loss: {val_loss:.4f}")

        loss_log.append({
            'epoch': epoch+1,
            'train_loss': avg_loss,
            'val_loss': val_loss
        })

        scheduler.step()
        torch.save(model.state_dict(), checkpoint_path)

        # Generate sample proteins after each epoch
        if sample_smiles:
            print("\nSample Generated Proteins:")
            for smiles in sample_smiles:
                # Try both random and left-to-right generation
                gen_protein_lr = generate_autoregressively(model, smiles, max_length=100, random_order=False)
                gen_protein_random = generate_autoregressively(model, smiles, max_length=100, random_order=True)
                print(f"SMILES: {smiles}")
                print(f"Generated Protein (L2R): {gen_protein_lr}")
                print(f"Generated Protein (Random): {gen_protein_random}\n")

    # Save loss log
    loss_df = pd.DataFrame(loss_log)
    loss_df.to_csv("/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/sigma_loss_log.csv", index=False)

    # Plot loss
    plt.figure(figsize=(8, 5))
    plt.plot(loss_df['epoch'], loss_df['train_loss'], label='Train Loss', marker='o')
    plt.plot(loss_df['epoch'], loss_df['val_loss'], label='Validation Loss', marker='s')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Training and Validation Loss")
    plt.savefig("sigma_loss_plot.png")
    plt.show()

def validate_with_random_order(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in val_loader:
            smiles_strings = batch['smiles']
            proteins = batch['proteins']
            protein_masks = batch['protein_masks'].to(device)

            batch_size = len(smiles_strings)
            seq_length = model.max_len

            # Create random order for each item in the batch
            order = torch.stack([torch.randperm(seq_length) for _ in range(batch_size)]).to(device)

            # Encode the proteins
            target_encoding = model.protGPT2_tokenizer(
                proteins,
                return_tensors="pt",
                padding='max_length',
                max_length=model.max_len,
                truncation=True
            ).to(device)

            # Forward pass with order
            outputs, loss = model(
                smiles_strings,
                order=order,
                targets=target_encoding.input_ids
            )

            total_loss += loss.item()

    return total_loss / len(val_loader)

### inference + training

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load and preprocess data
train_data = preprocess_snp_data('/content/augmented_train.csv')
val_data = preprocess_snp_data('/content/augmented_val.csv')
test_data = preprocess_snp_data('/content/augmented_test.csv')

train_data = filter_datasets(train_data)
val_data = filter_datasets(val_data)
test_data = filter_datasets(test_data)

# Calculate max sequence length
max_length = max(
    train_data['protein_length'].max(),
    val_data['protein_length'].max(),
    test_data['protein_length'].max()
)
max_length = min(max_length, 1024)  # Cap at 1024 or your desired maximum
print(f"Max sequence length: {max_length}")

# Create datasets
train_dataset = ProteinGenerationDataset(train_data, max_length)
val_dataset = ProteinGenerationDataset(val_data, max_length)
test_dataset = ProteinGenerationDataset(test_data, max_length)

# Create dataloaders with custom collate function
train_loader = DataLoader(
    train_dataset,
    batch_size=1,  # Adjust based on your GPU memory
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)

# Initialize model with sigma-gpt capabilities
model = SigmaProtFlamingo(
    model_path='nferruz/ProtGPT2',
    max_len=max_length,
    cross_attn_every=3,
    dim_head=64,
    heads=8,
    perceiver_depth=2,
    perceiver_num_latents=64
).to(device)

num_epochs = 10

# Training loop with curriculum learning
# Start with 50% of sequences in left-to-right order and gradually increase to 100% random
curriculum_steps = int(0.5 * num_epochs * len(train_loader))  # Curriculum over first half of training
print("Starting training with sigma-gpt capabilities...")
train_with_random_order(
    model,
    train_loader,
    val_loader,
    num_epochs,
    device,
    curriculum_steps=curriculum_steps
)

# # Generate and evaluate
# print("Generating proteins for test set...")
# test_results = generate_and_evaluate(model, test_loader, device)

# # Save results
# print("Saving results...")
# results_path = '/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/test_results.json'
# with open(results_path, 'w') as f:
#     json.dump(test_results, f, indent=2)

# print(f"Results saved to {results_path}")


Using device: cuda
Max sequence length: 914


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Starting training with sigma-gpt capabilities...


Epoch 1/10:   0%|          | 0/3111 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


### generation

In [None]:
import torch
import json
from tqdm import tqdm

# Load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Path to model checkpoint
checkpoint_path = "/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/latest_checkpoint.pth"

# Load model
model = ProtFlamingo(
    model_path='nferruz/ProtGPT2',
    max_len=914,  # Ensure this matches the training max_len
    cross_attn_every=3,
    dim_head=64,
    heads=8,
    perceiver_depth=2,
    perceiver_num_latents=64
).to(device)



In [None]:
ProteinGenerationDataset

In [None]:
# Load trained weights
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

# Load test data
test_data = preprocess_snp_data('/content/augmented_test.csv')
test_data = filter_datasets(test_data)

# Create test dataset and dataloader
test_dataset = ProteinGenerationDataset(test_data,max_length = 914 )
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)



In [None]:
# # Function to compare generated vs. ground truth proteins for unique SMILES
# def evaluate_generation(model, test_loader, device, output_file="generated_vs_ground_truth.json"):
#     model.eval()
#     results = []

#     with torch.no_grad():
#         # Collect unique SMILES from test_loader
#         unique_smiles = set()
#         smiles_to_target = {}  # Map unique SMILES to ground truth

#         for batch in test_loader:
#             for smiles, target_protein in zip(batch['smiles'], batch['proteins']):
#                 unique_smiles.add(smiles)
#                 smiles_to_target[smiles] = target_protein.strip()  # Store ground truth

#         unique_smiles = list(unique_smiles)  # Convert set to list
#         print(f"Total Unique SMILES: {len(unique_smiles)}")

#         # Generate proteins for unique SMILES
#         for smiles in tqdm(unique_smiles, desc="Generating Proteins for Unique SMILES"):
#             generated_proteins = [model.generate(smiles, max_length=914) for _ in range(1)]  # Generate 1 sequences

#             print(f"\nSMILES: {smiles}")
#             for i, gen_protein in enumerate(generated_proteins):
#                 print(f"Generated Protein {i+1}: {gen_protein}")
#             print(f"Ground Truth Protein: {smiles_to_target[smiles]}")

#             results.append({
#                 'SMILES': smiles,
#                 'Generated Proteins': generated_proteins,
#                 'Ground Truth Protein': smiles_to_target[smiles]
#             })

#     # Save results to a JSON file
#     with open(output_file, 'w') as f:
#         json.dump(results, f, indent=2)

#     print(f"\nResults saved to {output_file}")
#     return results

# # Run evaluation
# output_path = "/content/drive/MyDrive/classes+projects/plastic_enzyme_project/2024/codes/generated_vs_ground_truth.json"
# test_results = evaluate_generation(model, test_loader, device, output_file=output_path)

# # Print first few examples
# print("\nExample Comparisons:")
# for i, result in enumerate(test_results[:5]):  # Show first 5 examples
#     print(f"\nExample {i+1}:")
#     print(f"SMILES: {result['SMILES']}")
#     for j, gen_protein in enumerate(result['Generated Proteins']):
#         print(f"Generated Protein {j+1}: {gen_protein}")
#     print(f"Ground Truth Protein: {result['Ground Truth Protein']}")


In [None]:
def generate_autoregressively(model, smiles_string, max_length=100, temperature=1.0, random_order=False):
    """Generate protein autoregressively, with option to use random order"""
    device = next(model.parameters()).device

    # Get SMILES embeddings
    smiles_embeddings = model.polybert_encoder([smiles_string])
    processed_smiles = model.smiles_perceiver(smiles_embeddings)

    # Initialize with start token
    input_ids = torch.tensor([[model.protGPT2_tokenizer.bos_token_id]], device=device)

    # If using random order, generate a random permutation
    if random_order:
        order = torch.randperm(max_length, device=device).unsqueeze(0)
    else:
        order = torch.arange(max_length, device=device).unsqueeze(0)

    # Track the current positions in the order
    current_pos = 0

    # Generated sequence in order's positions
    generated_sequence = torch.full((1, max_length), model.protGPT2_tokenizer.pad_token_id, device=device)
    generated_sequence[0, 0] = model.protGPT2_tokenizer.bos_token_id  # Start token

    while current_pos < max_length - 1:
        # Get the next position in the order
        next_pos = current_pos + 1

        # Forward pass to get next token prediction
        with torch.no_grad():
            # Use only the sequence up to the current position
            current_order = order[:, :next_pos]
            current_sequence = generated_sequence[:, current_order[0]]

            # Get logits for the next token
            logits, _ = model(
                smiles_string,
                order=current_order,
                optimize=True
            )

            # Apply temperature and sample
            logits = logits[0, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()

            # Add the token to the generated sequence at the next position in the order
            generated_sequence[0, order[0, next_pos]] = next_token

            # Check for EOS token
            if next_token == model.protGPT2_tokenizer.eos_token_id:
                break

            current_pos = next_pos

    # Decode the generated sequence
    generated_ids = generated_sequence[0].tolist()
    # Remove padding tokens
    generated_ids = [id for id in generated_ids if id != model.protGPT2_tokenizer.pad_token_id]
    return model.protGPT2_tokenizer.decode(generated_ids, skip_special_tokens=True)

In [None]:
def generate_with_rejection_sampling(model, smiles_string, max_length=100, num_orders=5, temperature=1.0):
    """Generate protein using token-based rejection sampling (burst mode)"""
    device = next(model.parameters()).device

    # Get SMILES embeddings
    smiles_embeddings = model.polybert_encoder([smiles_string])
    processed_smiles = model.smiles_perceiver(smiles_embeddings)

    # Initialize with start token
    generated_sequence = torch.full((1, max_length), model.protGPT2_tokenizer.pad_token_id, device=device)
    generated_sequence[0, 0] = model.protGPT2_tokenizer.bos_token_id  # Start token

    # Track filled positions
    filled_positions = {0}  # Start with position 0 filled

    while len(filled_positions) < max_length:
        remaining_positions = [i for i in range(max_length) if i not in filled_positions]
        if not remaining_positions:
            break

        # Sample tokens for all remaining positions
        candidate_tokens = {}
        for pos in remaining_positions:
            # Create a temporary sequence with this position as the target
            temp_order = torch.tensor([[pos]], device=device)

            # Get logits for this position
            with torch.no_grad():
                logits, _ = model(
                    smiles_string,
                    order=temp_order,
                    optimize=True
                )

                # Sample a token
                logits = logits[0, -1, :] / temperature
                probs = F.softmax(logits, dim=-1)
                token = torch.multinomial(probs, num_samples=1).item()
                candidate_tokens[pos] = token

        # Now evaluate acceptance for these candidates under different orders
        best_accepted = []

        for _ in range(num_orders):
            # Create a random order of the remaining positions
            rand_order = random.sample(remaining_positions, len(remaining_positions))

            # Try to accept tokens in this order
            accepted_positions = []

            for pos in rand_order:
                # Create a sequence with all previously accepted tokens
                temp_sequence = generated_sequence.clone()
                for accepted_pos in accepted_positions:
                    temp_sequence[0, accepted_pos] = candidate_tokens[accepted_pos]

                # Evaluate probability under this context
                filled_pos_list = sorted(list(filled_positions) + accepted_positions)
                context_order = torch.tensor([filled_pos_list + [pos]], device=device)

                with torch.no_grad():
                    logits, _ = model(
                        smiles_string,
                        order=context_order,
                        optimize=True
                    )

                    # Get probability of the candidate token
                    logits = logits[0, -1, :] / temperature
                    probs = F.softmax(logits, dim=-1)
                    prob = probs[candidate_tokens[pos]].item()

                    # Decision to accept - simplified for now
                    if random.random() < min(1.0, prob * len(probs)):
                        accepted_positions.append(pos)
                    else:
                        # Stop at first rejection
                        break

            # Keep track of the best set of accepted positions
            if len(accepted_positions) > len(best_accepted):
                best_accepted = accepted_positions

        # Update the generated sequence with the best accepted tokens
        for pos in best_accepted:
            generated_sequence[0, pos] = candidate_tokens[pos]
            filled_positions.add(pos)

            # Check for EOS token
            if candidate_tokens[pos] == model.protGPT2_tokenizer.eos_token_id:
                break

        # If nothing was accepted, just accept one token to avoid getting stuck
        if not best_accepted:
            pos = remaining_positions[0]
            generated_sequence[0, pos] = candidate_tokens[pos]
            filled_positions.add(pos)

    # Reorder the sequence to standard order
    ordered_sequence = generated_sequence[0].tolist()

    # Decode and return
    # Remove padding tokens
    ordered_sequence = [id for id in ordered_sequence if id != model.protGPT2_tokenizer.pad_token_id]
    return model.protGPT2_tokenizer.decode(ordered_sequence, skip_special_tokens=True)