In [20]:
!pip install biopython  # install Biopython for BLAST access and parsing
!pip install omegaconf
!pip install py3Dmol

Collecting biopython
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m3.3/3.3 MB[0m [31m190.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m91.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.85
Collecting omegaconf
  Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)
Collecting antlr4-python3-runtime==4.9.* (from omegaconf)
  Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.0/117.0 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Prepar

Collecting py3Dmol
  Downloading py3Dmol-2.4.2-py2.py3-none-any.whl.metadata (1.9 kB)
Downloading py3Dmol-2.4.2-py2.py3-none-any.whl (7.0 kB)
Installing collected packages: py3Dmol
Successfully installed py3Dmol-2.4.2


In [45]:
esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model_fold = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)

model_fold = model_fold.cuda()
model_fold.esm = model_fold.esm.half()
torch.backends.cuda.matmul.allow_tf32 = True

Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [48]:

"""
Credit to https://github.com/karpathy/nanoGPT
"""
import torch
import torch.nn as nn
from torch.nn import functional as F

from Bio import SeqIO
from Bio.Blast import NCBIWWW, NCBIXML

from transformers import AutoTokenizer, EsmForProteinFolding


import argparse
import os
import time
from contextlib import nullcontext
from functools import partial
import random
import math
import inspect
from dataclasses import dataclass


# -----------------------------------------------------------------------------
# Utility Functions and Global Dicts
# -----------------------------------------------------------------------------

def exists(val):
    """Check if a value is not None."""
    return val is not None

def print0(*args, **kwargs):
    """
    Print only on the primary (rank=0) process in a distributed run.
    If not distributed, it prints normally.
    """
    if int(os.environ.get("RANK", 0)) == 0:
        print(*args, **kwargs)


# -----------------------------------------------------------------------------
# Token Dictionaries
# - token_dict: maps tokens (e.g., A, T, etc.) to integer IDs
# - token_dict_inv: inverse mapping from IDs back to tokens
# -----------------------------------------------------------------------------
token_dict = {
    "<pad>": 0, "<bos>": 1, "<eos>": 2, "<unk>": 3, "A": 4, "B": 5, "C": 6,
    "D": 7, "E": 8, "F": 9, "G": 10, "H": 11, "I": 12, "J": 13, "K": 14,
    "L": 15, "M": 16, "N": 17, "O": 18, "P": 19, "Q": 20, "R": 21, "S": 22,
    "T": 23, "U": 24, "V": 25, "W": 26, "X": 27, "Y": 28, "Z": 29, "1": 30,
    "2": 31
}
token_dict_inv = {v: k for k, v in token_dict.items()}


# -----------------------------------------------------------------------------
# ProteinTokenizer
# -----------------------------------------------------------------------------
class ProteinTokenizer:
    """
    Converts raw protein sequences (strings) into lists of token IDs, and vice versa.
    This tokenizer can also add special tokens <bos>, <eos> for sequence boundaries,
    and it provides a way to pad sequences to a fixed length.
    """
    def __init__(self, token_dict):
        self.token_dict = token_dict
        self.inv_token_dict = {v: k for k, v in token_dict.items()}
        self.unk_token = "<unk>"
        self.pad_token = "<pad>"
        self.bos_token = "<bos>"
        self.eos_token = "<eos>"
        self.pad_id = token_dict[self.pad_token]
        self.bos_id = token_dict[self.bos_token]
        self.eos_id = token_dict[self.eos_token]
        self.stop_tokens = [token_dict[self.eos_token]]

    def tokenize(self, sequence):
        """
        Splits a sequence string into individual characters (tokens).
        e.g., 'ABC' -> ['A', 'B', 'C'].
        """
        return list(sequence)

    def convert_tokens_to_ids(self, tokens):
        """
        Converts a list of tokens (e.g., ['A','B','C']) into their corresponding IDs,
        using the provided token_dict. Unknown tokens default to <unk>.
        """
        return [
            self.token_dict.get(token, self.token_dict[self.unk_token]) for token in tokens
        ]

    def encode(self, sequence, add_special_tokens=True):
        """
        Goes from raw sequence string -> tokens -> token IDs.
        Optionally adds a <bos> token ID at the start.
        """
        tokens = self.tokenize(sequence)
        if add_special_tokens:
            tokens = [self.bos_token] + tokens
        return self.convert_tokens_to_ids(tokens)

    def decode(self, token_ids):
        """
        Converts token IDs back to a string representation (the inverse of `encode`).
        """
        tokens = [self.inv_token_dict[token_id] for token_id in token_ids]
        return "".join(tokens)

    def pad_sequences(self, sequences, padding_value=None, block_size=None):
        """
        Pads a list of sequences (each a list of token IDs) to a fixed block_size.
        Any extra space is filled with padding_value (default = <pad> ID).
        """
        if block_size is None:
            block_size = max(len(seq) for seq in sequences)
        if padding_value is None:
            padding_value = self.token_dict[self.pad_token]

        padded = []
        for seq in sequences:
            seq = list(seq)[:block_size]  # truncate if longer than block_size
            padding_needed = max(0, block_size - len(seq))
            seq += [padding_value] * padding_needed
            padded.append(seq)

        return padded


# -----------------------------------------------------------------------------
# Layer Normalization
# -----------------------------------------------------------------------------
class LayerNorm(nn.Module):
    """
    Custom layer norm: simpler version without bias if not needed.
    """
    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)


# -----------------------------------------------------------------------------
# Self Attention
# -----------------------------------------------------------------------------
class CausalSelfAttention(nn.Module):
    """
    A standard multi-head masked self-attention mechanism with optional LoRA dimension.
    """
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0, \
            "Embedding dimension must be divisible by number of heads."

        # Optional LoRA: if lora_dim > 0, it modifies the attention projection
        if config.lora_dim == 0:
            self.c_attn = nn.Linear(config.n_embd, config.n_embd * 3, bias=False)
        else:
            self.c_attn_a = nn.Linear(config.n_embd, config.lora_dim, bias=False)
            self.c_attn_norm = LayerNorm(config.lora_dim, bias=False)
            self.c_attn_b = nn.Linear(config.lora_dim, config.n_embd * 3, bias=False)

        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        self.lora_dim = config.lora_dim
        self.is_causal = config.is_causal

    def forward(self, x):
        """
        x: (batch_size, sequence_length, embedding_size)
        We compute Q,K,V and then perform scaled dot product attention.
        """
        B, T, C = x.size()

        # If no LoRA, normal projection. Else, apply LoRA transformations
        if self.lora_dim == 0:
            q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        else:
            q, k, v = self.c_attn_b(self.c_attn_norm(self.c_attn_a(x))).split(self.n_embd, dim=2)

        # Reshape Q,K,V to (batch_size, heads, sequence_length, head_dim)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # PyTorch >= 2.0 provides scaled_dot_product_attention
        # which can apply the causal mask for us.
        y = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=self.dropout if self.training else 0,
            is_causal=self.is_causal
        )
        # Bring y back to (batch_size, sequence_length, embedding_size)
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # A final linear projection + dropout
        y = self.resid_dropout(self.c_proj(y))
        return y


# -----------------------------------------------------------------------------
# MLP
# -----------------------------------------------------------------------------
class MLP(nn.Module):
    """
    A standard MLP block used after self-attention:
    linear -> GELU -> linear -> dropout
    """
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, config.mlp_hidden_dim, bias=False)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(config.mlp_hidden_dim, config.n_embd, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


# -----------------------------------------------------------------------------
# Transformer Block
# -----------------------------------------------------------------------------
class Block(nn.Module):
    """
    A single Transformer block: LN -> Self-Attn -> LN -> MLP
    (with residual connections around each part).
    """
    def __init__(self, config):
        super().__init__()
        self.ln1 = LayerNorm(config.n_embd, bias=False)
        self.sa = CausalSelfAttention(config)
        self.ln2 = LayerNorm(config.n_embd, bias=False)
        self.mlp = MLP(config)

    def forward(self, x):
        # First residual block: self-attention
        x = x + self.sa(self.ln1(x))
        # Second residual block: MLP
        x = x + self.mlp(self.ln2(x))
        return x


# -----------------------------------------------------------------------------
# GPT Configuration
# -----------------------------------------------------------------------------
@dataclass
class GPT_Config:
    """
    Holds key hyperparameters for building the GPT model.
    These will be passed to the GPT constructor.
    """
    n_embd: int = 1024      # Embedding dimension
    lora_dim: int = 256     # Optional LoRA dimension for parameter-efficient training
    max_seq_len: int = 1024 # Max sequence length (context window)
    n_head: int = 32        # Number of attention heads
    n_layer: int = 32       # Number of Transformer blocks
    dropout: float = 0.1    # Dropout rate
    vocab_size: int = 32    # Size of the vocabulary (number of tokens)
    ignore_index: int = -100 # Index to ignore in loss (e.g. for padding)
    block_size: int = max_seq_len # Typically the same as max_seq_len
    seq_padding_idx: int = 0
    is_causal: bool = True  # Whether attention is causal (mask future tokens)
    mlp_hidden_dim: int = 2 * n_embd #MLP hidden dimension


# -----------------------------------------------------------------------------
# GPT Model
# -----------------------------------------------------------------------------
class GPT(nn.Module):
    """
    A GPT-style language model for protein sequences, with positional embeddings,
    multiple Transformer blocks, and a final projection to predict next token.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.tokenizer = ProteinTokenizer(token_dict)
        # Positional embeddings map positions 0..(max_seq_len-1) to embedding vectors
        self.pos_embedding = nn.Embedding(config.max_seq_len, config.n_embd)
        # seq_embedding maps token IDs in [0..vocab_size-1] to embedding vectors
        self.seq_embedding = nn.Embedding(config.vocab_size, config.n_embd, padding_idx=0)
        # Final LayerNorm after stacking Transformer blocks
        self.ln = LayerNorm(config.n_embd, bias=False)

        # Create a stack of Transformer blocks
        self.transformer = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        # Final linear layer to map hidden states to vocab logits
        self.project = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Initialize weights for the entire model
        self.apply(self._init_weights)

        # A special initialization for c_proj weights
        for pn, p in self.named_parameters():
            if pn.endswith("c_proj.weight"):
                torch.nn.init.normal_(
                    p, mean=0.0,
                    std=0.02 / math.sqrt(2 * config.n_layer)
                )

        # Print the number of parameters if rank=0
        print0("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))

    def get_num_params(self):
        """Calculate total number of trainable parameters."""
        return sum(p.numel() for p in self.parameters())

    def _init_weights(self, module):
        """Default initialization: normal(0, 0.02) for linear weights."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, seq, targets=None):
        """
        Forward pass:
        seq is (batch_size, sequence_length) of token IDs.
        If targets is provided, we compute cross-entropy loss.
        If not, we only return logits (useful during inference/generation).
        """
        device = seq.device
        B, T = seq.size()

        # Build position IDs and gather their embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=device)
        pos = self.pos_embedding(pos)

        # Look up token embeddings
        seq = self.seq_embedding(seq)

        # Sum positional + token embeddings
        x = seq + pos

        # Pass through each Transformer block in turn
        for layer in self.transformer:
            x = layer(x)

        # Final layer norm
        x = self.ln(x)

        # If we have targets, compute the language modeling loss
        if exists(targets):
            logits = self.project(x)  # (B, T, vocab_size)
            # Flatten the batch and sequence dims for cross_entropy
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=self.config.ignore_index
            )
        else:
            # If no targets, return the logits for the last token only
            # (useful for token-by-token generation)
            logits = self.project(x[:, [-1], :])
            loss = None

        return logits, loss


    @torch.inference_mode()
    def generate(self, prefix, max_size, temperature=1.0, top_k=7, rep_penalty=5, ngram_block=4):
        """
        Generates protein sequences, starting from a prefix, up to max_size tokens.
          - temperature controls randomness
          - top_k controls sampling from top k candidates
          - rep_penalty blocks repeated tokens
          - ngram_block blocks repeated n-grams
        """
        generated = prefix.clone()  # shape [B, prefix_len]
        tokens_to_generate = max_size - prefix.size(1)
        if tokens_to_generate <= 0:
            raise ValueError(
                f"Desired size {max_size} <= prefix length {prefix.size(1)}."
            )

        for _ in range(tokens_to_generate):
            # Forward pass the entire current sequence to get logits
            logits, _ = self.forward(generated, targets=None)

            # Extract logits for the last token in the sequence
            next_token_logits = logits[:, -1, :]  # shape [B, vocab_size]

            # 1) repetition penalty
            next_token_logits = adjust_logits_for_repetition(
                next_token_logits, generated, rep_penalty=rep_penalty
            )

            # 2) n-gram blocking
            next_token_logits = adjust_logits_for_ngram_blocking(
                next_token_logits, generated, n=ngram_block
            )

            # Sample the next token from the adjusted logits
            next_token = self._sample_next_token(next_token_logits, temperature, top_k)

            # Append the sampled token
            generated = torch.cat([generated, next_token], dim=1)

            # If we hit <eos>, stop generating
            if next_token.item() == self.tokenizer.eos_id:
                break

        return generated

    def _sample_next_token(self, logits, temperature=1.0, top_k=7):
        """
        Takes the logits for a single step and chooses a token from the distribution.
        - If top_k > 0, we only consider the top k tokens by logit.
        """
        # Scale by temperature
        logits = logits / temperature

        if top_k > 0:
            # Get top_k probabilities
            top_logits, top_indices = torch.topk(logits, top_k, dim=-1)
            probs = F.softmax(top_logits, dim=-1)
            # Sample from those top_k
            indices = torch.multinomial(probs, num_samples=1)
            next_tokens = top_indices.gather(-1, indices)
        else:
            # Full distribution
            probs = F.softmax(logits, dim=-1)
            next_tokens = torch.multinomial(probs, num_samples=1)

        return next_tokens


# -----------------------------------------------------------------------------
# Generation Helpers (block repeated tokens, repeated n-grams, etc.)
# -----------------------------------------------------------------------------
def adjust_logits_for_repetition(logits, generated_seq, rep_penalty=4):
    """
    If the last token in generated_seq is repeated rep_penalty times consecutively,
    we set its logit to -∞ to block it.
    """
    logits = logits.clone()
    B = logits.size(0)

    for i in range(B):
        seq = generated_seq[i]
        if seq.numel() == 0:
            continue
        last_token = seq[-1].item()
        count = 1
        j = seq.size(0) - 2
        # Count how many times the last token repeats at the end
        while j >= 0 and seq[j].item() == last_token:
            count += 1
            j -= 1
        if count >= rep_penalty:
            logits[i, last_token] = float('-inf')
    return logits

def adjust_logits_for_ngram_blocking(logits, generated_seq, n=3):
    """
    If appending a candidate token would form an n-gram that already appeared in generated_seq,
    we set its logit to -∞ to block it.
    """
    logits = logits.clone()
    B = logits.size(0)
    for i in range(B):
        seq = generated_seq[i]
        if seq.size(0) < n - 1:
            continue

        # Get the last (n - 1) tokens
        context = tuple(seq[-(n - 1):].tolist())
        banned_tokens = set()

        # Scan through the sequence to find repeating n-grams
        for start_idx in range(seq.size(0) - (n - 1)):
            window = seq[start_idx:start_idx + (n - 1)]
            if tuple(window.tolist()) == context:
                # The token that followed that n-1 context is the banned token
                if start_idx + (n - 1) < seq.size(0):
                    banned_tokens.add(seq[start_idx + (n - 1)].item())

        # Block all banned tokens
        for token in banned_tokens:
            logits[i, token] = float('-inf')
    return logits




In [114]:

# -----------------------------------------------------------------------------
# Main Inference Function
# -----------------------------------------------------------------------------
parser = argparse.ArgumentParser(description="Batched Inference for protein GPT model")
parser.add_argument(
    "--ckpt", type=str, default="/content/drive/MyDrive/Language Model/ckpt_lora2.pt",
    help="Path to the pretrained model checkpoint"
)
parser.add_argument(
    "--sampling_method", type=str, choices=["top_k", "greedy", "top_p"],
    default="top_k", help="Sampling method to use"
)
parser.add_argument(
    "--temperature", type=float, default=0.9,
    help="Sampling temperature"
)
parser.add_argument(
    "--top_k", type=int, default=4,
    help="Top-k value for top-k sampling"
)
parser.add_argument(
    "--top_p", type=float, default=0.5,
    help="Top-p threshold for top-p (nucleus) sampling"
)
parser.add_argument(
    "--num_samples", type=int, default=4,
    help="Number of sequences to generate per prefix"
)
parser.add_argument(
    "--ngram", type=int, default=4,
    help="N-gram size for blocking repetitive patterns"
)
parser.add_argument(
    "--rep_penalty", type=int, default=4,
    help="Repetition penalty for blocking repetitive patterns"
)
parser.add_argument(
    "--max_size", type=int, default=1024,
    help="Maximum sequence length to generate"
)
parser.add_argument(
    "--dtype", type=str, default="bfloat16",
    help="Data type to use for computation"
)
parser.add_argument(
    "--amp", default = True,
    help="Whether to use automatic mixed precision"
)
args = parser.parse_args('')  # For interactive testing, you can override with a custom arg string.

# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Set random seeds
torch.manual_seed(2001)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Choose device and precision

ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
ctx = nullcontext() if device == "cpu" else torch.autocast(
    enabled=args.amp, dtype=torch.bfloat16, device_type="cuda"
)



In [50]:
# -----------------------------------------------------------------------------
# Sampling Functions
# -----------------------------------------------------------------------------
def sample_top_k(logits, top_k):
    """
    Sample the next token using top-k sampling.
      1. Select the top_k highest-scoring tokens in 'logits'.
      2. Convert these logits to probabilities via softmax.
      3. Sample a single token index from these top_k candidates.
    """
    top_logits, top_indices = torch.topk(logits, top_k, dim=-1)
    probs = F.softmax(top_logits, dim=-1)
    indices = torch.multinomial(probs, num_samples=1)
    next_tokens = top_indices.gather(-1, indices)
    return next_tokens

def sample_greedy(logits):
    """
    Sample the next token using greedy sampling.
    This simply takes the argmax of 'logits' along the last dimension.
    """
    return torch.argmax(logits, dim=-1, keepdim=True)

def sample_top_p(logits, top_p):
    """
    Sample the next token using top-p (nucleus) sampling.
      1. Compute the probability distribution via softmax.
      2. Sort tokens by their probability, then take a cumulative sum.
      3. Zero out probabilities once the sum exceeds top_p.
      4. Renormalize the remaining probabilities and sample from them.
    """
    batch_size = logits.size(0)
    next_tokens = []

    for i in range(batch_size):
        logits_i = logits[i]
        probs_i = F.softmax(logits_i, dim=-1)
        # Sort probabilities in descending order
        sorted_probs, sorted_indices = torch.sort(probs_i, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # Find where cumulative_probs exceeds top_p
        mask = cumulative_probs > top_p
        sorted_probs[mask] = 0.0
        sorted_probs = sorted_probs / sorted_probs.sum()

        # Sample a token from the "nucleus"
        token = torch.multinomial(sorted_probs, num_samples=1)
        token = sorted_indices[token]
        next_tokens.append(token)

    # Stack results into a tensor with shape (batch_size, 1)
    next_tokens = torch.stack(next_tokens, dim=0).unsqueeze(1)
    return next_tokens


# -----------------------------------------------------------------------------
# Batched Generation Function
# -----------------------------------------------------------------------------
def generate_batch(
    model,
    tokenizer,
    prefixes,
    max_sizes,
    sampling_method="top_k",
    sampling_args=None,
    rep_penalty=4,
    ngram_block=3
):
    """
    Generate sequences in batch for a list of prefixes using various sampling methods.

    Args:
        model (nn.Module):
            The GPT model (already loaded).
        tokenizer (ProteinTokenizer):
            The tokenizer instance (for encoding/decoding).
        prefixes (list of str):
            List of prefix strings to start generation from.
        max_sizes (list of int):
            List of maximum total lengths to generate for each prefix.
        sampling_method (str):
            One of "top_k", "greedy", or "top_p".
        sampling_args (dict, optional):
            Dictionary with keys like "temperature", "top_k", "top_p" to control sampling.
            Defaults to {"temperature": 1.0, "top_k": 7, "top_p": 0.5}.
        rep_penalty (int):
            Threshold for consecutive repetition penalty (blocks repeated tokens).
        ngram_block (int):
            N value for n-gram blocking (to block repeated n-grams).

    Returns:
        list of torch.Tensor:
            Each element is a single generated sequence tensor (shape [1, L_generated]).
    """
    # Default sampling arguments
    if sampling_args is None:
        sampling_args = {"temperature": 1.0, "top_k": 7, "top_p": 0.5}

    temperature = sampling_args.get("temperature", 1.0)
    top_k = sampling_args.get("top_k", 7)
    top_p = sampling_args.get("top_p", 0.5)

    device = next(model.parameters()).device

    # Encode each prefix into a tensor of token IDs. (Plus some example suffix?)
    # If you don't want to add [2, 30, 1, 16], you can remove that part.
    batch = [
        torch.tensor(tokenizer.encode(prefix), dtype=torch.long, device=device).unsqueeze(0)
        for prefix in prefixes
    ]

    # Track finished sequences
    finished = [False] * len(batch)

    while not all(finished):
        # Pad each sequence in the batch to the same length for a single forward pass
        max_len = max(t.size(1) for t in batch)
        padded_batch = []
        for t in batch:
            pad_len = max_len - t.size(1)
            if pad_len > 0:
                t = F.pad(t, (0, pad_len), value=tokenizer.token_dict[tokenizer.pad_token])
            padded_batch.append(t)
        padded_batch = torch.cat(padded_batch, dim=0)  # shape (B, max_len)

        # Forward pass to get logits for the "last token" in each sequence
        # Many GPT implementations will return shape (B, 1, vocab_size) when targets=None.
        logits, _ = model(padded_batch, targets=None)
        next_logits = logits.squeeze(1)  # shape (B, vocab_size)

        # Apply repetition and n-gram penalties
        next_logits = adjust_logits_for_repetition(next_logits, padded_batch, rep_penalty)
        next_logits = adjust_logits_for_ngram_blocking(next_logits, padded_batch, n=ngram_block)

        # Now sample the next token from these adjusted logits
        if sampling_method == "top_k":
            next_tokens = sample_top_k(next_logits / temperature, top_k)
        elif sampling_method == "greedy":
            next_tokens = sample_greedy(next_logits)
        elif sampling_method == "top_p":
            next_tokens = sample_top_p(next_logits / temperature, top_p)
        else:
            raise ValueError(f"Unknown sampling method: {sampling_method}")

        # Append the sampled token to each sequence if it's not finished
        new_batch = []
        for i, t in enumerate(batch):
            if finished[i]:
                new_batch.append(t)
                continue
            token = next_tokens[i]
            t = torch.cat([t, token.unsqueeze(0)], dim=1)
            new_batch.append(t)

            # Check stopping conditions: if <eos> is generated or max length reached
            if token.item() == tokenizer.eos_id or t.size(1) >= max_sizes[i]:
                finished[i] = True
        batch = new_batch

    return batch


# -----------------------------------------------------------------------------
# Load Model Checkpoint
# -----------------------------------------------------------------------------
checkpoint = torch.load(args.ckpt, map_location=device)
if "model_args" not in checkpoint:
    raise ValueError("Checkpoint does not contain model_args")
model_args = checkpoint["model_args"]


# Instantiate the GPT model config and create the model
cfg = GPT_Config(**model_args)
model = GPT(cfg)
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()

# Create tokenizer
tokenizer = ProteinTokenizer(token_dict)

# Sampling arguments (temperature, top_k, top_p, etc.)
sampling_args = {
    "temperature": args.temperature,
    "top_k": args.top_k,
    "top_p": args.top_p
}



  checkpoint = torch.load(args.ckpt, map_location=device)


number of parameters: 202.52M


In [134]:
# -----------------------------------------------------------------------------
# Sample Generation
# -----------------------------------------------------------------------------
# Example list of prefix strings (just 1 long prefix here) and their maximum lengths
sample_prefixes = [
    "MNVLIIGSGGREHALAWKVAQSPLA"
]
sizes = [400]

# We replicate each prefix multiple times for multi-sample generation
batched_prefixes = []
batched_sizes = []
for prefix, size in zip(sample_prefixes, sizes):
    for _ in range(args.num_samples):
        batched_prefixes.append(prefix)
        batched_sizes.append(size)

# Perform batched generation
generated_batch = generate_batch(
    model,
    tokenizer,
    batched_prefixes,
    batched_sizes,
    sampling_method=args.sampling_method,
    sampling_args=sampling_args,
    rep_penalty=args.rep_penalty,
    ngram_block=args.ngram
)

# -----------------------------------------------------------------------------
# Print Results
# -----------------------------------------------------------------------------
num_prefixes = len(sample_prefixes)
for i in range(num_prefixes):
    print(f"Generating samples for prefix: {sample_prefixes[i]}")
    for j in range(args.num_samples):
        idx = i * args.num_samples + j
        gen_tensor = generated_batch[idx]
        # gen_tensor is shape [1, L], we decode the single row
        generated_text = tokenizer.decode(gen_tensor[0].tolist())
        if ("<eos>" in generated_text) and ("<bos>" in generated_text):
            n = 10
        else:
            n = 5
        print(f"len: {len(generated_text) - n} - Sample {j+1}:\n{generated_text}\n{'-'*80}")

Generating samples for prefix: MNVLIIGSGGREHALAWKVAQSPLA
len: 399 - Sample 1:
<bos>MNVLIIGSGGREHALAWKVAQSPLAQKIFIAPGNAGTALEPTLQNVAIDVSDHQALVDFALKNNVDLTVVGPEAPLVIGVVDAFRAAGLAIFGPSKAAAQLEGSKAFTKDFLARHNIPTGKYQNFTEADAALAYVREQGAPIVIKADGLAAGKGVTVAMTLAEAEAAIKDMLAGNAFGDAGSRVVIEEFLDGEEASFFVLCDGKNVLPMATSQDHKRVGDADTGPNTGGMGAYSPAPVVTPQVHARVMREVIQPTVQGMAQDGTTYTGFLYAGLMITPDGPKVIEYNCRFGDPETQVVLPRLKSDLVELLEASAQGKLGDVSIEWDARAAVTVVMAAGGYPGKYETGKVISGLDEAAKLDGVQVFHAGTKLDEQGNVVTNGGRVLCVTALGDTVKQAQD
--------------------------------------------------------------------------------
len: 399 - Sample 2:
<bos>MNVLIIGSGGREHALAWKVAQSPLAQKIFIAPGNAGTAQVAENVAIAADDVPGLVRFAKAEAVDFTVVGPEAPLVAGVVDAFRAAGLRIFGPTQAAAQLEGSKAFTKDFLARHKIPTAAYQNFTEIEPALAYVRERGAPIVVKADGLAAGKGVIVAMTLEEAHAAVDDMLGGNFGAAGAEVVIEEFLDGEEASFIVMVDGENVLPMATSQDHKRVGDGDTGPNTGGMGAYSPAPIITEQVHARVMKEVILPTVKGMAADGSPYTGFLYAGLMIAPDGPQVIEFNCRMGDPETQPIMMRLKSDLVELCLAACNGKLADAAIEWSEQAALTVVMAAKGYPGSYAKGKPISGLDDAARMEGVEVFHAGTKREGDKLVTNGGRVLCVTSLGATVAEAQKRAYQ
----------

In [135]:
#choose one of the generated sequences above:
sequence = "MNVLIIGSGGREHALAWKVAQSPLAQKIFIAPGNAGTAQVAENVAIAADDVPGLVRFAKAEAVDFTVVGPEAPLVAGVVDAFRAAGLRIFGPTQAAAQLEGSKAFTKDFLARHKIPTAAYQNFTEIEPALAYVRERGAPIVVKADGLAAGKGVIVAMTLEEAHAAVDDMLGGNFGAAGAEVVIEEFLDGEEASFIVMVDGENVLPMATSQDHKRVGDGDTGPNTGGMGAYSPAPIITEQVHARVMKEVILPTVKGMAADGSPYTGFLYAGLMIAPDGPQVIEFNCRMGDPETQPIMMRLKSDLVELCLAACNGKLADAAIEWSEQAALTVVMAAKGYPGSYAKGKPISGLDDAARMEGVEVFHAGTKREGDKLVTNGGRVLCVTSLGATVAEAQKRAYQ"


In [136]:
# Define the sequence in FASTA format for BLAST
query_seq = sequence  # from previous step
blast_program = "blastp"
database = "swissprot"  # using Swiss-Prot database for curated sequences

print("Submitting BLASTp search...")
result_handle = NCBIWWW.qblast(blast_program, database, query_seq)

# Parse the BLAST results (XML format by default)
blast_record = NCBIXML.read(result_handle)

# Check if there are any hits
if not blast_record.alignments:
    print("No hits found: the sequence appears to be novel (no similar sequences in database).")
else:
    top_hit = blast_record.alignments[0]
    e_val = top_hit.hsps[0].expect
    print(f"Top hit: {top_hit.title}")
    print(f"E-value: {e_val:.2e}")
    if e_val < 1e-3:
        print("Significant hit found - the sequence has similarity to a known protein.")
    else:
        print("No significant similarities found (sequence is likely novel or very divergent).")


Submitting BLASTp search...
Top hit: sp|Q8Z334.1| RecName: Full=Phosphoribosylamine--glycine ligase; AltName: Full=GARS; AltName: Full=Glycinamide ribonucleotide synthetase; AltName: Full=Phosphoribosylglycinamide synthetase [Salmonella enterica subsp. enterica serovar Typhi]
E-value: 0.00e+00
Significant hit found - the sequence has similarity to a known protein.


In [137]:
from Bio import SeqIO

max_hits = 5
if blast_record.alignments:
    print(f"\nTop {min(max_hits, len(blast_record.alignments))} BLAST hits:")
    for i, alignment in enumerate(blast_record.alignments[:max_hits], start=1):
        title = alignment.title
        length = alignment.length
        hsp = alignment.hsps[0]  # the first HSP (highest scoring segment)
        identity = hsp.identities
        align_len = hsp.align_length
        e_val = hsp.expect
        perc_id = (identity / align_len) * 100
        print(f"Hit {i}: {title}")
        print(f" - Length: {length}, Identity: {identity}/{align_len} (~{perc_id:.1f}%), E-value: {e_val:.1e}")
        print(f" - Alignment coverage: Query {hsp.query_start}-{hsp.query_end} aligned to Subject {hsp.sbjct_start}-{hsp.sbjct_end}")
        print()



Top 5 BLAST hits:
Hit 1: sp|Q8Z334.1| RecName: Full=Phosphoribosylamine--glycine ligase; AltName: Full=GARS; AltName: Full=Glycinamide ribonucleotide synthetase; AltName: Full=Phosphoribosylglycinamide synthetase [Salmonella enterica subsp. enterica serovar Typhi]
 - Length: 429, Identity: 297/403 (~73.7%), E-value: 0.0e+00
 - Alignment coverage: Query 1-398 aligned to Subject 1-402

Hit 2: sp|P26977.1| RecName: Full=Phosphoribosylamine--glycine ligase; AltName: Full=GARS; AltName: Full=Glycinamide ribonucleotide synthetase; AltName: Full=Phosphoribosylglycinamide synthetase [Salmonella enterica subsp. enterica serovar Typhimurium str. LT2]
 - Length: 429, Identity: 297/403 (~73.7%), E-value: 0.0e+00
 - Alignment coverage: Query 1-398 aligned to Subject 1-402

Hit 3: sp|P57829.1| RecName: Full=Phosphoribosylamine--glycine ligase; AltName: Full=GARS; AltName: Full=Glycinamide ribonucleotide synthetase; AltName: Full=Phosphoribosylglycinamide synthetase [Pasteurella multocida subsp. mul

In [138]:
tokenized_input = esm_tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids']
tokenized_input = tokenized_input.cuda()

with torch.no_grad():
    output = model_fold(tokenized_input)

In [139]:
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37

def convert_outputs_to_pdb(outputs):
    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs["atom37_atom_exists"]
    pdbs = []
    for i in range(outputs["aatype"].shape[0]):
        aa = outputs["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = outputs["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=outputs["plddt"][i],
            chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))
    return pdbs

pdb = convert_outputs_to_pdb(output)

In [140]:
import py3Dmol

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=800, height=400)
view.addModel("".join(pdb), 'pdb')
view.setStyle({'model': -1}, {"cartoon": {'color': 'spectrum'}})

<py3Dmol.view at 0x79553a576950>

In [141]:
# The plddt field is scaled from 0-1 on earlier versions of ESMFold but will be updated
# to match AlphaFold's scale of 0-100 in future versions.
# We check here so that this code will work on either:

if torch.max(output['plddt']) <= 1.0:
    vmin = 0.5
    vmax = 0.95
else:
    vmin = 50
    vmax = 95

view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min': vmin,'max': vmax}}})

<py3Dmol.view at 0x79553a576950>

In [143]:
output['plddt'].mean()

tensor(0.8872, device='cuda:0')