<a href="https://colab.research.google.com/github/LigandBindingDomain/Decoding_Bias/blob/main/Calculate_All_models_likelihoods.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## ProteinMPNN


In [None]:
#@title Setup

import os
try:
  import colabdesign
except:
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
  os.system("ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign")

from colabdesign.mpnn import mk_mpnn_model, clear_mem
from colabdesign.shared.protein import pdb_to_string

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML
import pandas as pd
import tqdm.notebook
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

from google.colab import files
from google.colab import data_table
data_table.enable_dataframe_formatter()
# Import necessary libraries
import os
import json
import pandas as pd
import numpy as np
import requests
import logging
import time
from tqdm import tqdm
from colabdesign.mpnn import mk_mpnn_model, clear_mem
from pathlib import Path
from scipy.special import softmax
from typing import Optional, Tuple, Dict, Any

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Constants
AA_TO_INDEX = {
    'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9,
    'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14, 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19
}
PDB_DIR = "pdb_files"
MODEL_NAME = "v_48_020"

In [None]:
#@title Likelihood Predictor (Alphafold PDBs)

class ProteinMPNNProcessor:
    """
    A class to handle protein sequence analysis using ProteinMPNN.
    Processes proteins in manageable chunks and provides detailed scoring metrics.
    """
    def __init__(self, model_name="v_48_020"):
        self.model_name = model_name
        self.setup_logging()

    def setup_logging(self):
        """Configure logging to track processing progress and errors."""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s'
        )

    def create_model(self):
        """Create and return a ProteinMPNN model with the new weights."""
        try:
            mpnn_model = mk_mpnn_model(self.model_name)

            #load custom weights if provided
            # Corrected syntax: Use hasattr and check for weights_path existence
            if hasattr(self, 'customs') and hasattr(self.customs, 'weights_path') and self.customs.weights_path and os.path.exists(self.customs.weights_path):
                logging.info(f"Loading custom weights from {self.customs.weights_path}")

                # Make sure joblib is imported if needed for loading custom weights
                # import joblib

                try:
                    custom_weights = joblib.load(self.customs.weights_path)
                    if 'model_state_dict' in custom_weights:
                        weights_to_load = custom_weights['model_state_dict']
                    else:
                        weights_to_load = custom_weights

                    if hasattr(mpnn_model,'_model')and hasattr(mpnn_model._model,'params'):
                        try:
                            jax_weights = {}
                            for key, value in weights_to_load.items():
                                if isinstance(value, dict):
                                    jax_weights[key] = {k: jnp.array(v) for k, v in value.items()}
                                else:
                                    jax_weights[key] = jnp.array(value)

                            mpnn_model._model.params = jax.tree_map(np.array, jax_weights)
                            logging.info("Custom weights loaded successfully.")

                        except Exception as e:
                            logging.error(f"Error loading custom weights into model: {e}") # More specific error message
                            logging.info("Using vanilla model weights")
                    else:
                        logging.error("Model structure not compatible for weight loading")
                        logging.info("Using vanilla model weights")

                except Exception as e:
                     logging.error(f"Error loading custom weights file: {e}") # More specific error message
                     logging.info("Using vanilla model weights")


            else:
                 # Corrected syntax: Use hasattr and check for weights_path existence
                 if hasattr(self, 'customs') and hasattr(self.customs, 'weights_path') and self.customs.weights_path:
                     logging.warning(f"Custom weights file not found: {self.customs.weights_path}")
                 logging.info("Using vanilla model weights")


            return mpnn_model

        except Exception as e:
            logging.error(f"Error creating model: {e}")
            return None


    def get_pdb(self, accession, output_dir=PDB_DIR):
        """
        Fetch PDB file from AlphaFold database and cache locally.
        Returns tuple of (path_to_pdb, success_boolean).
        """
        # Corrected PDB filename format to match AlphaFold
        pdb_path = os.path.join(output_dir, f"AF-{accession}-F1-model_v4.pdb")
        base_url = f"https://alphafold.ebi.ac.uk/files/AF-{accession}-F1-model_v4.pdb"

        os.makedirs(output_dir, exist_ok=True)

        if os.path.exists(pdb_path):
            return pdb_path, True

        try:
            response = requests.get(base_url)
            if response.status_code == 200:
                with open(pdb_path, 'wb') as f:
                    f.write(response.content)
                return pdb_path, True
        except Exception as e:
            logging.error(f"Failed to download {accession}: {e}")
        return None, False

    def calculate_model_outputs(self, mpnn_model, sequence, pdb_path=None, chain="A", mode="unconditional"):
        """
        Compute model predictions and transformations for a protein sequence.
        Returns dictionary with logits, probabilities, and their transformations.
        """
        try:
            if pdb_path:
                mpnn_model.prep_inputs(pdb_filename=pdb_path, chain=chain)
            else:
                mpnn_model.prep_inputs(sequence=sequence)

            # Ensure _lengths is iterable or handle appropriately
            try:
                 L = sum(mpnn_model._lengths)
            except TypeError:
                 logging.error("Model _lengths attribute is not iterable.")
                 return None # Or handle based on expected model output

            ar_mask = np.zeros((L, L)) if mode == "unconditional" else 1 - np.eye(L)

            result = mpnn_model.score(ar_mask=ar_mask)

            if result.get("logits") is not None:
                logits = np.array(result["logits"])
                result["log_probs"] = np.array(jax.nn.log_softmax(logits, axis=-1)[...,:20])
                result["probs"] = np.array(jax.nn.softmax(logits, axis=-1)[...,:20])
                # The original logits are already stored in result["logits"]
                # result["logits"] = logits # This line is redundant if logits is already in result

            return result
        except Exception as e:
            logging.error(f"Error in model computation: {e}")
            return None

    def process_single_protein(self, sequence, accession, chain="A"):

        pdb_path, success = self.get_pdb(accession)
        if not success:
            logging.warning(f"Skipping {accession} due to failed PDB download.")
            return None

        try:
            # Create model inside the processing method if it's not created globally or passed in
            # Consider if creating a new model instance for every protein is intended/necessary
            # or if the model should be created once in __init__ or create_model and reused.
            # Assuming for now it's intended to create it here based on the original code structure.
            mpnn_model = mk_mpnn_model(self.model_name) # Ensure mk_mpnn_model is accessible (imported)


            model_output = self.calculate_model_outputs(mpnn_model, sequence, pdb_path, chain)

            # Calculate metrics before clearing memory
            metrics = {
                "Entry": accession,
                "sequence_score": self.calculate_sequence_score(model_output, sequence),
                "entropy": self.calculate_entropy(model_output),
                "sequence_length": len(sequence),
                # Added check to ensure "probs" key exists in model_output
                "mean_confidence": float(model_output["probs"].max(axis=-1).mean()) if model_output and "probs" in model_output else None
            }

            # Ensure clear_mem is accessible (imported)
            clear_mem()
            return metrics

        except Exception as e:
            logging.error(f"Error processing {accession}: {e}")
            return None

    def process_proteins(self, csv_file, start_idx=0, end_idx=None, chain="A"):

        proteins_data = pd.read_csv(csv_file)
        if end_idx is None:
            end_idx = len(proteins_data)

        proteins_data = proteins_data.iloc[start_idx:end_idx].reset_index(drop=True) # Reset index after slicing

        results = []

        for _, row in tqdm(proteins_data.iterrows(), total=len(proteins_data)):
            # Ensure 'sequence' and 'Entry' columns exist in the DataFrame
            if 'sequence' in row and 'Entry' in row:
                 metrics = self.process_single_protein(row['sequence'], row['Entry'], chain)
                 if metrics:
                     results.append(metrics)
            else:
                 logging.warning(f"Skipping row due to missing 'sequence' or 'Entry' column.")


        return pd.DataFrame(results)

    def process_proteins_in_chunks(self, csv_file: str, chunk_size: int = 500, chain: str = "A"):

        # Read the full CSV once to get total proteins for chunking
        try:
            proteins_data_full = pd.read_csv(csv_file)
            total_proteins = len(proteins_data_full)
        except FileNotFoundError:
            logging.error(f"CSV file not found at {csv_file}")
            return pd.DataFrame() # Return empty DataFrame if file not found


        total_chunks = (total_proteins + chunk_size - 1) // chunk_size

        logging.info(f"Starting processing of {total_proteins} proteins in {total_chunks} chunks from {csv_file}")

        all_results_df = pd.DataFrame() # Initialize an empty DataFrame to accumulate results

        for i in range(0, total_proteins, chunk_size):
            start_idx = i
            end_idx = min(i + chunk_size, total_proteins)
            chunk_number = (i // chunk_size) + 1

            logging.info(f"Processing chunk {chunk_number}/{total_chunks}: proteins {start_idx} to {end_idx-1}")

            # Process current chunk by passing the full dataframe and indices
            # This avoids re-reading the CSV in process_proteins for each chunk
            # Alternatively, slice the dataframe here and pass the slice:
            # chunk_data = proteins_data_full.iloc[start_idx:end_idx].reset_index(drop=True)
            # results_df = self._process_proteins_from_df(chunk_data, chain) # Need a new internal method

            # Using the existing process_proteins structure which reads the file internally:
            # Note: This is less efficient as process_proteins reads the CSV every time.
            # A better approach would be to modify process_proteins to accept a DataFrame.
            results_df = self.process_proteins(csv_file, start_idx, end_idx, chain)

            # Append chunk results to the overall results DataFrame
            all_results_df = pd.concat([all_results_df, results_df], ignore_index=True)

            # Save and download chunk results (optional, can be commented out if only final save is needed)
            chunk_filename = f"proteinMPNN_results_chunk_{start_idx}-{end_idx-1}.csv"
            results_df.to_csv(chunk_filename, index=False)

            logging.info(f"Completed chunk {chunk_number}/{total_chunks}. Downloading results...")
            # Ensure files.download is available (e.g., in Colab)
            try:
                files.download(chunk_filename)
            except NameError:
                 logging.warning("Could not download chunk file. 'files.download' not found. Are you in a Colab environment?")


            logging.info(f"Processed {end_idx}/{total_proteins} proteins ({(end_idx/total_proteins)*100:.1f}% complete)")

        logging.info("Processing complete!")

        # Optional: Save the complete results file at the end
        final_output_filename = "proteinMPNN_results_all_chunks.csv"
        all_results_df.to_csv(final_output_filename, index=False)
        logging.info(f"All results saved to {final_output_filename}")
        # Optional: Download the final combined file
        try:
            files.download(final_output_filename)
        except NameError:
             logging.warning("Could not download final results file. 'files.download' not found. Are you in a Colab environment?")


        return all_results_df # Return the combined results DataFrame


    @staticmethod
    def calculate_sequence_score(model_output, sequence):
        """Compute normalized log-likelihood score for the sequence."""
        if model_output is None or "log_probs" not in model_output:
            return None

        log_probs = model_output["log_probs"]
        score = 0.0
        valid_positions = 0

        for pos, aa in enumerate(sequence):
            aa_index = AA_TO_INDEX.get(aa)
            # Ensure position is within the bounds of log_probs and aa_index is valid
            if aa_index is not None and pos < log_probs.shape[0]:
                score += log_probs[pos, aa_index]
                valid_positions += 1
            # Optional: Log if an amino acid is skipped
            # elif aa_index is None:
            #     logging.warning(f"Skipping unknown amino acid '{aa}' at position {pos}")
            # elif pos >= log_probs.shape[0]:
            #      logging.warning(f"Skipping position {pos} - out of bounds for log_probs array.")


        return score / valid_positions if valid_positions > 0 else None
        # Added return None if no valid positions to avoid division by zero
        # if valid_positions == 0:
        #     logging.warning("No valid positions found in sequence for score calculation.")
        #     return None


    @staticmethod
    def calculate_entropy(model_output):
        """Calculate average positional entropy of predictions."""
        if model_output is None or "probs" not in model_output:
            return None

        try:
            probs = model_output["probs"]
            # Added check to ensure probs is not empty
            if probs.shape[0] == 0:
                logging.warning("Probability matrix is empty.")
                return None

            # Ensure probabilities are within (0, 1) range for log calculation
            probs = np.clip(probs, 1e-10, 1 - 1e-10) # Clip values to avoid log(0) or log(1) boundary issues

            # Calculate entropy at each position
            positional_entropy = -np.sum(probs * np.log(probs), axis=1) # Removed (1-probs) part for standard categorical entropy

            # Calculate mean entropy, handling potential NaNs or Infs if any
            return float(np.nanmean(positional_entropy)) if np.isfinite(positional_entropy).any() else None # Use nanmean and check for finite values


        except Exception as e:
            logging.error(f"Error calculating entropy: {e}")
            return None

#@title Run
processor = ProteinMPNNProcessor()

# Option to choose between processing CSV or single sequence
process_mode = input("Enter mode ('csv' or 'single'): ").strip().lower()

if process_mode == 'csv':
    csv_path = "/content/uniprotkb_PETase_AND_reviewed_true_2025_10_30_enhanced.csv" # or your CSV file path
    chunk_size = 1000
    processor.process_proteins_in_chunks(csv_path, chunk_size=chunk_size)
elif process_mode == 'single':
    accession = input("Enter AlphaFold PDB accession code (e.g., P0DTD1): ").strip()
    sequence = input("Enter the protein sequence: ").strip().upper()

    if not accession or not sequence:
        print("PDB accession code and sequence are required for single processing.")
    else:
        print(f"Processing single protein: {accession}")
        metrics = processor.process_single_protein(sequence, accession)

        if metrics:
            print("\nProteinMPNN Results:")
            for key, value in metrics.items():
                print(f"{key}: {value}")
        else:
            print(f"Failed to process protein {accession}.")
else:
    print("Invalid mode selected. Please enter 'csv' or 'single'.")

Enter mode ('csv' or 'single'): csv


ERROR:root:CSV file not found at /content/uniprotkb_PETase_AND_reviewed_true_2025_10_30_enhanced.csv


In [None]:
#@title Run
processor = ProteinMPNNProcessor()

# Option to choose between processing CSV or single sequence
process_mode = input("Enter mode ('csv' or 'single'): ").strip().lower()

if process_mode == 'csv':
    # Prompt for CSV file upload
    print("Please upload a CSV file containing protein data:")
    uploaded = files.upload()
    if not uploaded:
        print("No file uploaded. Exiting.")
    else:
        # Get the filename of the uploaded file
        csv_path = list(uploaded.keys())[0]
        chunk_size = 1000
        processor.process_proteins_in_chunks(csv_path, chunk_size=chunk_size)
elif process_mode == 'single':
    accession = input("Enter AlphaFold PDB accession code (e.g., P0DTD1): ").strip()
    sequence = input("Enter the protein sequence: ").strip().upper()

    if not accession or not sequence:
        print("PDB accession code and sequence are required for single processing.")
    else:
        print(f"Processing single protein: {accession}")
        metrics = processor.process_single_single_protein(sequence, accession) # Corrected method call

        if metrics:
            print("\nProteinMPNN Results:")
            for key, value in metrics.items():
                print(f"{key}: {value}")
        else:
            print(f"Failed to process protein {accession}.")
else:
    print("Invalid mode selected. Please enter 'csv' or 'single'.")

Enter mode ('csv' or 'single'): csv
Please upload a CSV file containing protein data:


Saving proteins_full_phylum_fixed_with_TM.csv to proteins_full_phylum_fixed_with_TM.csv


100%|██████████| 1000/1000 [04:55<00:00,  3.38it/s]


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

100%|██████████| 1000/1000 [04:53<00:00,  3.40it/s]


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

 74%|███████▎  | 735/1000 [03:35<01:17,  3.41it/s]


KeyboardInterrupt: 

## ESMIF

In [None]:
#@title Setup pt 1


!pip install numpy==1.26.4

!pip install torch==2.3.0+cu121 torchvision torchaudio -f https://download.pytorch.org/whl/cu121/torch_stable.html

!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv torch-geometric -f https://data.pyg.org/whl/torch-2.3.0+cu121.html

!pip install biotite==0.39.0

!pip install git+https://github.com/facebookresearch/esm.git

import os
os.kill(os.getpid(), 9)


Collecting numpy==1.26.4
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/61.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.3/18.3 MB[0m [31m100.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but y

In [None]:
#@title Setup pt 2
import json
import pandas as pd
import numpy as np
import requests
import torch
import torch.nn.functional as F
import esm
from tqdm import tqdm
from pathlib import Path
import logging
from typing import Optional, Tuple, List
from biotite.sequence.io.fasta import FastaFile, get_sequences
from datetime import datetime
from google.colab import drive
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os
import esm

# Mount Google Drive
drive.mount('/content/drive')

# Constants
PDB_DIR = "pdb_files"
MODEL_NAME = "esm_if1_gvp4_t16_142M_UR50"

Mounted at /content/drive


In [None]:
#@title Likelihood Predictor

class ESMIFProcessor:
    def __init__(self, output_file: str = "esmif_results.csv", batch_processing: bool = True):
        self.output_file = output_file
        self.batch_processing = batch_processing
        # Use GPU if available
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        os.makedirs(PDB_DIR, exist_ok=True)
        self.model, self.alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
        self.model = self.model.eval().to(self.device)

    def download_pdb(self, accession: str) -> Tuple[Optional[str], bool]:
        pdb_path = os.path.join(PDB_DIR, f"AF-{accession}-F1-model_v4.pdb")
        url = f"https://alphafold.ebi.ac.uk/files/AF-{accession}-F1-model_v4.pdb"

        if os.path.exists(pdb_path):
            return pdb_path, True

        try:
            response = requests.get(url, timeout=10)
            if response.status_code == 200:
                with open(pdb_path, 'wb') as f:
                    f.write(response.content)
                return pdb_path, True
            else:
                logging.warning(f"Failed to download {accession}")
        except requests.RequestException as e:
            logging.error(f"Request failed for {accession}: {e}")

        return None, False

    def score_protein(self, sequence: str, pdb_path: str, chain: str = "A") -> Optional[dict]:
        """
        Score a protein sequence given its structure.
        Returns detailed scoring information including per-position scores and structure validity.
        """
        try:
            # Get backbone coordinates and native sequence
            coords, native_seq = esm.inverse_folding.util.load_coords(pdb_path, chain)

            # Get positional losses comparing to WT sequence
            loss, target_padding_mask = esm.inverse_folding.util.get_sequence_loss(
                self.model, self.alphabet, coords, sequence
            )

            # Calculate sequence scores (log likelihoods)
            ll_fullseq, ll_withcoord = esm.inverse_folding.util.score_sequence(
                self.model, self.alphabet, coords, sequence
            )

            # Calculate valid coordinate mask
            coord_mask = np.all(np.isfinite(coords), axis=(-1, -2))
            valid_positions = np.sum(coord_mask)

            # Calculate valid position scores
            valid_pos_losses = loss[coord_mask]

            return {
                'sequence_length': len(sequence),
                'valid_positions': int(valid_positions),
                'percent_valid': float(valid_positions/len(sequence) * 100),

                # Main scoring metrics
                'total_score': float(ll_fullseq),  # Average over all positions
                'valid_pos_score': float(ll_withcoord),  # Average over valid positions only

                # Per-position score statistics
                'mean_pos_score': float(np.mean(valid_pos_losses)),
                'min_pos_score': float(np.min(valid_pos_losses)),
                'max_pos_score': float(np.max(valid_pos_losses)),
                'std_pos_score': float(np.std(valid_pos_losses)),

                # Structure information
                'has_missing_coords': bool(valid_positions < len(sequence)),
                'num_missing_coords': int(len(sequence) - valid_positions)
            }

        except Exception as e:
            logging.error(f"Error scoring {pdb_path}: {e}")
            return None

    def process_proteins(self, csv_file: str, start_idx: int = 0, end_idx: Optional[int] = None, chain: str = "A") -> pd.DataFrame:
        """Process a range of proteins from the input CSV file."""
        df = pd.read_csv(csv_file)

        if end_idx is None:
            end_idx = len(df)
        df_subset = df.iloc[start_idx:end_idx]

        results = []
        for _, row in tqdm(df_subset.iterrows(), total=len(df_subset),
                          desc=f"Processing proteins {start_idx} to {end_idx-1}"):
            accession = row['Entry']
            sequence = row['sequence']

            # Download structure
            pdb_path, success = self.download_pdb(accession)
            if not success:
                continue

            # Score protein
            score_dict = self.score_protein(sequence, pdb_path, chain)
            if score_dict is not None:
                score_dict['Entry'] = accession  # Add entry ID to results
                results.append(score_dict)

        # Create DataFrame with organized columns
        column_order = [
            'Entry', 'sequence_length', 'valid_positions', 'percent_valid',
            'total_score', 'valid_pos_score',
            'mean_pos_score', 'min_pos_score', 'max_pos_score', 'std_pos_score',
            'has_missing_coords', 'num_missing_coords'
        ]

        results_df = pd.DataFrame(results)
        results_df = results_df[column_order]  # Reorder columns
        return results_df

    def process_all(self, csv_file: str, start_row: int = 0):  # Add start_row parameter
        """Process all proteins, either in batches or all at once, starting from a specified row."""
        if self.batch_processing:
            self.process_in_batches(csv_file, start_row=start_row)  # Pass start_row
        else:
            results_df = self.process_proteins(csv_file, start_idx=start_row)  # Use start_row


            # Print summary statistics
            print("\nProcessing Summary:")
            print(f"Total proteins processed: {len(results_df)}")
            print(f"Mean total score: {results_df['total_score'].mean():.3f}")
            print(f"Mean valid position score: {results_df['valid_pos_score'].mean():.3f}")
            print(f"Average % valid positions: {results_df['percent_valid'].mean():.1f}%")

            print(f"\nResults saved to {self.output_file}")

    def process_in_batches(self, csv_file: str, batch_size: int = 50, start_row: int = 0):  # Add start_row
        """Process proteins in batches, starting from a specified row."""
        df = pd.read_csv(csv_file)
        total_proteins = len(df)
        num_batches = ((total_proteins - start_row) // batch_size) + 1  # Adjust for start_row

        all_results = []
        for batch_idx in range(num_batches):
            start_idx = start_row + batch_idx * batch_size  # Adjust start_idx
            end_idx = min(start_idx + batch_size, total_proteins)

            print(f"\nProcessing batch {batch_idx + 1}/{num_batches}")

            results_df = self.process_proteins(csv_file, start_idx, end_idx)
            all_results.append(results_df)

            # Save intermediate results
            batch_results = pd.concat(all_results)
            batch_results.to_csv(self.output_file, index=False)

            print(f"Processed {len(results_df)} proteins in current batch")
            print(f"Total proteins processed so far: {len(batch_results)}")



In [None]:
# @title Run
if __name__ == "__main__":
    processor = ESMIFProcessor(batch_processing=True)

    # Option to choose between processing CSV or single sequence
    process_mode = input("Enter mode ('csv' or 'single'): ").strip().lower()

    if process_mode == 'csv':
        # Existing CSV processing logic
        processor.process_all("/content/ACID_output_with_properties.csv", start_row=0)
    elif process_mode == 'single':
        accession = input("Enter AlphaFold PDB accession code (e.g., P0DTD1): ").strip()
        sequence = input("Enter the protein sequence: ").strip().upper()

        if not accession or not sequence:
            print("PDB accession code and sequence are required for single processing.")
        else:
            print(f"Processing single protein: {accession}")
            score_dict = processor.process_single_protein(accession, sequence) # Call the new method

            if score_dict:
                print("\nESMIF Results:")
                for key, value in score_dict.items():
                    print(f"{key}: {value}")
            else:
                print(f"Failed to process protein {accession}.")
    else:
        print("Invalid mode selected. Please enter 'csv' or 'single'.")

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm_if1_gvp4_t16_142M_UR50.pt" to /root/.cache/torch/hub/checkpoints/esm_if1_gvp4_t16_142M_UR50.pt


Enter mode ('csv' or 'single'): csv

Processing batch 1/3


Processing proteins 0 to 49:  54%|█████▍    | 27/50 [05:17<06:16, 16.35s/it]ERROR:root:Error scoring pdb_files/AF-P0CP74-F1-model_v4.pdb: operands could not be broadcast together with shapes (634,) (637,) 
Processing proteins 0 to 49: 100%|██████████| 50/50 [09:58<00:00, 11.96s/it]


Processed 49 proteins in current batch
Total proteins processed so far: 49

Processing batch 2/3


Processing proteins 50 to 99: 100%|██████████| 50/50 [09:07<00:00, 10.94s/it]


Processed 50 proteins in current batch
Total proteins processed so far: 99

Processing batch 3/3


Processing proteins 100 to 101: 100%|██████████| 2/2 [00:11<00:00,  5.88s/it]

Processed 2 proteins in current batch
Total proteins processed so far: 101





## MIF and MIFst

In [None]:
#@title Setup

!pip install git+https://github.com/microsoft/protein-sequence-models.git
!git clone https://github.com/microsoft/protein-sequence-models.git
!pip install -q torch biotite

import os
import torch
import pandas as pd
import numpy as np
import requests
import glob
from tqdm import tqdm
from sequence_models.pretrained import load_model_and_alphabet
from sequence_models.pdb_utils import parse_PDB, process_coords
from sequence_models.collaters import StructureCollater
from google.colab import files

Collecting git+https://github.com/microsoft/protein-sequence-models.git
  Cloning https://github.com/microsoft/protein-sequence-models.git to /tmp/pip-req-build-rl9r1wrk
  Running command git clone --filter=blob:none --quiet https://github.com/microsoft/protein-sequence-models.git /tmp/pip-req-build-rl9r1wrk
  Resolved https://github.com/microsoft/protein-sequence-models.git to commit af695772c4a1c056d930c95ec7e6428aa042f5cd
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: sequence-models
  Building wheel for sequence-models (setup.py) ... [?25l[?25hdone
  Created wheel for sequence-models: filename=sequence_models-1.8.0-py3-none-any.whl size=59506 sha256=c8e8dc39ffda1c309dedb3d2205be209d159bf308d849d7c654b5dea8482179a
  Stored in directory: /tmp/pip-ephem-wheel-cache-09ixoi1_/wheels/47/f9/4c/7668455fae11af20d53510b1e3e971b9e01b49ca3bd7f868a0
Successfully built sequence-models
Installing collected packages: sequence-models
Successfully inst

In [None]:
#@title Likelihood Predictor


def load_model(model_name="mif"):
    """Load the MIF or MIF-ST model using the correct function."""
    model, collater = load_model_and_alphabet(model_name)
    model.eval()
    return model, collater

def get_pdb(accession, output_dir="pdb_files"):
    """Fetch PDB file from AlphaFold and save locally."""
    pdb_path = os.path.join(output_dir, f"{accession}.pdb")
    base_url = f"https://alphafold.ebi.ac.uk/files/AF-{accession}-F1-model_v4.pdb"
    os.makedirs(output_dir, exist_ok=True)
    if os.path.exists(pdb_path):
        return pdb_path, True
    try:
        response = requests.get(base_url)
        if response.status_code == 200:
            with open(pdb_path, 'wb') as f:
                f.write(response.content)
            return pdb_path, True
    except Exception as e:
        print(f"Failed to download {accession}: {e}")
    return None, False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def calculate_mif_log_probs(model, collater, sequence, pdb_path):
    """Compute MIF likelihoods for a protein using structure-based inputs."""
    coords, wt, _ = parse_PDB(pdb_path)
    coords_dict = {'N': coords[:, 0], 'CA': coords[:, 1], 'C': coords[:, 2]}
    dist, omega, theta, phi = process_coords(coords_dict)

    batch = [[wt, torch.tensor(dist, dtype=torch.float, device=device),
              torch.tensor(omega, dtype=torch.float, device=device),
              torch.tensor(theta, dtype=torch.float, device=device),
              torch.tensor(phi, dtype=torch.float, device=device)]]

    src, nodes, edges, connections, edge_mask = collater(batch)
    src, nodes, edges, connections, edge_mask = (
        src.to(device), nodes.to(device), edges.to(device), connections.to(device), edge_mask.to(device)
    )

    with torch.no_grad():
        log_probs = model(src, nodes, edges, connections, edge_mask, result='logits')
        avg_logp = torch.mean(log_probs).item()

    return avg_logp

def process_proteins_in_chunks(csv_file, model_name="mif", chunk_size=50, output_dir="results"):
    """Process proteins in chunks, saving results incrementally."""
    model, collater = load_model_and_alphabet(model_name)
    model.to(device)
    os.makedirs(output_dir, exist_ok=True)

    proteins_data = pd.read_csv(csv_file)
    total_proteins = len(proteins_data)
    results = []

    for i in range(0, total_proteins, chunk_size):
        chunk = proteins_data.iloc[i:i + chunk_size]
        chunk_results = []

        with tqdm(total=len(chunk), desc=f"Processing chunk {i//chunk_size + 1}", unit="protein") as pbar:
            for _, row in chunk.iterrows():
                accession = row['Entry']
                sequence = row['sequence']
                pdb_path, success = get_pdb(accession)
                if not success:
                    pbar.update(1)
                    continue

                log_likelihood = calculate_mif_log_probs(model, collater, sequence, pdb_path)
                chunk_results.append({"Entry": accession, "MIF_Likelihood": log_likelihood})
                pbar.update(1)

        # Convert chunk results to DataFrame and save
        chunk_df = pd.DataFrame(chunk_results)
        chunk_file = os.path.join(output_dir, f"mif_likelihoods_chunk_{i//chunk_size + 1}.csv")
        chunk_df.to_csv(chunk_file, index=False)
        torch.save(chunk_df, chunk_file.replace(".csv", ".pt"))

        results.extend(chunk_results)  # Append to full results list

    # Save final combined results
    final_results_df = pd.DataFrame(results)
    final_results_file = os.path.join(output_dir, "mif_likelihoods_final.csv")
    final_results_df.to_csv(final_results_file, index=False)
    torch.save(final_results_df, final_results_file.replace(".csv", ".pt"))
    files.download(final_results_file)



In [None]:
#@title Run

csv_path = "/content/ACID_output_with_properties.csv"

# Option to choose between processing CSV or single sequence
process_mode = input("Enter mode ('csv' or 'single'): ").strip().lower()

if process_mode == 'csv':
    # Existing CSV processing logic
    process_proteins_in_chunks(csv_path, model_name="mif", chunk_size=500)
elif process_mode == 'single':
    accession = input("Enter AlphaFold PDB accession code (e.g., P0DTD1): ").strip()
    sequence = input("Enter the protein sequence: ").strip().upper()

    if not accession or not sequence:
        print("PDB accession code and sequence are required for single processing.")
    else:
        print(f"Processing single protein: {accession}")
        mif_likelihood = process_single_mif_protein(accession, sequence, model_name="mif") # Call the new function

        if mif_likelihood is not None:
            print("\nMIF Results:")
            print(f"Entry: {accession}")
            print(f"MIF_Likelihood: {mif_likelihood}")
        else:
            print(f"Failed to process protein {accession}.")
else:
    print("Invalid mode selected. Please enter 'csv' or 'single'.")

Enter mode ('csv' or 'single'): csv


Downloading: "https://zenodo.org/record/6573779/files/mif.pt?download=1" to /root/.cache/torch/hub/checkpoints/mif.pt
Processing chunk 1: 100%|██████████| 102/102 [01:31<00:00,  1.11protein/s]


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# ESM2-pppl


In [None]:
# ESM2 Protein Sequence Likelihood Calculator for Google Colab
# =====================================================

# Install required packages
!pip install -q fair-esm

import torch
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import os
from google.colab import files
import matplotlib.pyplot as plt
import time

# Check if GPU is available
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hPyTorch version: 2.8.0+cu126
CUDA available: True
CUDA device: Tesla T4


In [None]:


# Function to calculate ESM2 likelihoods
def calculate_esm2_likelihoods(sequences, model_size="650M", batch_size=1,
                               max_length=1022, show_progress=True,
                               optimize_memory=True):
    """
    Calculate ESM2 pseudo-log-likelihoods for protein sequences in Google Colab.

    Parameters:
    -----------
    sequences : list
        List of protein sequences as strings
    model_size : str, default="650M"
        Size of ESM2 model to use ("650M", "3B", or "15B")
    batch_size : int, default=1
        Number of sequences to process in parallel
    max_length : int, default=1022
        Maximum sequence length to process
    show_progress : bool, default=True
        Whether to show progress bars
    optimize_memory : bool, default=True
        Use memory optimization for Colab environment

    Returns:
    --------
    list
        List of PPPL scores for each sequence
    """
    # Determine model based on size
    model_map = {
        "650M": "esm2_t33_650M_UR50D",
        "3B": "esm2_t36_3B_UR50D",
        "15B": "esm2_t48_15B_UR50D"
    }
    model_name = model_map.get(model_size)
    if not model_name:
        raise ValueError(f"Unknown model size: {model_size}. Choose from {list(model_map.keys())}")

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model
    print(f"Loading ESM2-{model_size} model on {device}...")
    model, alphabet = torch.hub.load("facebookresearch/esm:main", model_name)
    model = model.to(device)
    model.eval()
    batch_converter = alphabet.get_batch_converter()

    # Trim sequences if needed
    sequences = [seq[:max_length] for seq in sequences]

    # Process sequences
    results = []
    invalid_indices = []

    # Progress bar configuration
    progress_bar = tqdm(range(0, len(sequences), batch_size), desc="Processing sequences") if show_progress else range(0, len(sequences), batch_size)

    for i in progress_bar:
        batch_sequences = sequences[i:min(i+batch_size, len(sequences))]
        batch_results = []

        for seq_idx, sequence in enumerate(batch_sequences):
            # Skip sequences with non-standard amino acids or that are too long
            if '*' in sequence or len(sequence) == 0:
                batch_results.append(float('nan'))
                invalid_indices.append(i + seq_idx)
                continue

            log_probs = []

            # Process each position
            with torch.no_grad():
                for pos in range(1, len(sequence) + 1):
                    # Create data with a masked position
                    data = [(f"seq_{seq_idx}", sequence)]
                    _, _, batch_tokens = batch_converter(data)
                    batch_tokens = batch_tokens.to(device)

                    # Mask the current position
                    batch_tokens[0, pos] = alphabet.mask_idx

                    # Get model predictions
                    outputs = model(batch_tokens, repr_layers=[33], return_contacts=False)
                    token_probs = torch.log_softmax(outputs["logits"], dim=-1)

                    # Get probability of the true amino acid at the masked position
                    true_aa_idx = alphabet.get_idx(sequence[pos-1])
                    log_prob = token_probs[0, pos, true_aa_idx].item()
                    log_probs.append(log_prob)

                    # Clear CUDA cache periodically if optimizing memory
                    if optimize_memory and torch.cuda.is_available() and pos % 50 == 0:
                        torch.cuda.empty_cache()

            # Calculate average log probability (PPPL)
            if log_probs:
                pppl = sum(log_probs) / len(log_probs)
                batch_results.append(pppl)
            else:
                batch_results.append(float('nan'))

        results.extend(batch_results)

        # Clear CUDA cache after batch if optimizing memory
        if optimize_memory and torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Report invalid sequences
    if invalid_indices and show_progress:
        print(f"Warning: {len(invalid_indices)} sequences were skipped due to invalid characters or empty sequences.")

    return results

# Function to process a CSV file
def process_csv_file(input_file=None, output_file=None, model_size="650M",
                     max_sequences=None, sequence_column="sequence"):
    """
    Process protein sequences from a CSV file and calculate their ESM2 likelihoods.
    For Google Colab environment with file upload/download support.

    Parameters:
    -----------
    input_file : str, default=None
        Path to input CSV file. If None, prompts for upload.
    output_file : str, default=None
        Path to output CSV file. If None, uses input filename with _results suffix.
    model_size : str, default="650M"
        Size of ESM2 model to use
    max_sequences : int, default=None
        Maximum number of sequences to process
    sequence_column : str, default="sequence"
        Name of the column containing protein sequences
    """
    # Handle file upload if needed
    if input_file is None:
        print("Please upload a CSV file containing protein sequences:")
        uploaded = files.upload()
        if not uploaded:
            print("No file uploaded. Exiting.")
            return
        input_file = list(uploaded.keys())[0]

    # Set default output file if needed
    if output_file is None:
        base_name = os.path.splitext(input_file)[0]
        output_file = f"{base_name}_ESM2_results.csv"

    # Read input file
    print(f"Reading {input_file}...")
    df = pd.read_csv(input_file)

    # Check for required column
    if sequence_column not in df.columns:
        print(f"Available columns: {', '.join(df.columns)}")
        raise ValueError(f"Input CSV must contain a '{sequence_column}' column")

    # Limit number of sequences if specified
    original_count = len(df)
    if max_sequences is not None and max_sequences < original_count:
        df = df.head(max_sequences)
        print(f"Processing {max_sequences} sequences out of {original_count} total.")
    else:
        print(f"Processing all {original_count} sequences.")

    # Track time
    start_time = time.time()

    # Extract sequences
    sequences = df[sequence_column].tolist()

    # Calculate likelihoods
    likelihoods = calculate_esm2_likelihoods(sequences, model_size=model_size)

    # Add results to dataframe
    df[f"ESM2_{model_size}_pppl"] = likelihoods

    # Save results
    df.to_csv(output_file, index=False)
    print(f"Results saved to {output_file}")

    # Download results
    files.download(output_file)

    # Report time taken
    elapsed_time = time.time() - start_time
    print(f"Processing completed in {elapsed_time:.2f} seconds.")

    # Visualize results if available
    valid_scores = [score for score in likelihoods if not np.isnan(score)]
    if valid_scores:
        plt.figure(figsize=(10, 6))
        plt.hist(valid_scores, bins=30, alpha=0.7)
        plt.title(f"Distribution of ESM2-{model_size} PPPL Scores")
        plt.xlabel("PPPL Score")
        plt.ylabel("Frequency")
        plt.grid(alpha=0.3)
        plt.show()

    return df

# Function to process a single sequence
def process_single_sequence(sequence=None, model_size="650M"):
    """
    Calculate ESM2 likelihood for a single protein sequence.

    Parameters:
    -----------
    sequence : str, default=None
        Protein sequence to analyze. If None, prompts for input.
    model_size : str, default="650M"
        Size of ESM2 model to use
    """
    if sequence is None:
        sequence = input("Enter a protein sequence: ")

    sequence = sequence.strip().upper()

    # Validate sequence
    valid_aa = set("ACDEFGHIKLMNPQRSTVWY")
    if not all(aa in valid_aa for aa in sequence):
        invalid_aa = [aa for aa in sequence if aa not in valid_aa]
        print(f"Warning: Sequence contains invalid amino acids: {', '.join(set(invalid_aa))}")

    # Calculate likelihood
    print(f"Calculating ESM2-{model_size} likelihood for sequence (length {len(sequence)})...")

    start_time = time.time()
    likelihoods = calculate_esm2_likelihoods([sequence], model_size=model_size, show_progress=False)
    elapsed_time = time.time() - start_time

    pppl = likelihoods[0]

    # Display results
    print("\nResults:")
    print(f"Sequence: {sequence[:20]}... (length {len(sequence)})")
    print(f"ESM2-{model_size} PPPL Score: {pppl:.6f}")
    print(f"Calculation Time: {elapsed_time:.2f} seconds")

    return pppl

# Function to validate implementation
def validation_test(sequences, expected_scores, model_size="650M"):
    """
    Validate the implementation by comparing with expected scores.

    Parameters:
    -----------
    sequences : list
        List of protein sequences to test
    expected_scores : list
        List of expected PPPL scores for each sequence
    model_size : str, default="650M"
        Size of ESM2 model to use
    """
    calculated_scores = calculate_esm2_likelihoods(sequences, model_size=model_size, show_progress=True)

    print("\nValidation Results:")
    print("------------------")
    all_match = True

    for i, (seq, expected, calculated) in enumerate(zip(sequences, expected_scores, calculated_scores)):
        difference = abs(expected - calculated)
        matches = difference < 0.01

        print(f"Sequence {i+1} (length {len(seq)}):")
        print(f"  Calculated PPPL: {calculated:.6f}")
        print(f"  Expected PPPL:   {expected:.6f}")
        print(f"  Difference:      {difference:.6f}")
        print(f"  Within tolerance: {'✓' if matches else '✗'}")
        print()

        if not matches:
            all_match = False

    if all_match:
        print("✅ All scores match within tolerance!")
    else:
        print("⚠️ Some scores do not match the expected values.")

In [None]:
# Example 1: Process a single sequence
# Uncomment and run this cell to process a single sequence

sequence = "MLIVINYKTYNESIGNRGLEIAKIAEKVSEESGITIGVAPQFVDLRMIVENVNIPVYAQHIDNINPGSHTGHILAEAIKDCGCKGTLINHSEKRMLLADIEAVINKCKNLGLETIVCTNNINTSKAVAALSPDYIAVEPPELIGTGIPVSKANPEVVEGTVRAVKEINKDVKVLCGAGISKGEDVKAALDLGAEGVLLASGVVKAKNVEEAIRELIKF"
pppl = process_single_sequence(sequence, model_size="650M")
"""

# Example 2: Process sequences from a CSV file
# Uncomment and run this cell to process sequences from a CSV file

# Upload a CSV file with a 'sequence' column

df_results = process_csv_file(input_file='/content/Archaea_only.csv',model_size="650M", max_sequences=300)


# Example 3: Validation with known scores
# If you have sequences with known PPPL scores, you can validate the implementation

# Replace with your sequences and their known scores
test_sequences = [
    "MENDKGQLVELYVPRKCSATNRIIKAKDHASVQISIAKVDEDGRAIAGENITYALSGYVRGRGEADDSLNRLAQQDGLLKNVWSYSR",
    "MARGPKKHLKRLAAPSHWMLDKLSGTYAPRPSAGPHKLRESLPLVVFLRNRLKYALNGREVKAIMMQQHVQVDGKVRTDTTYPAGFMDVITLEATNEHFRLVYDVKGKFAVHRISAEEAAYKLGKVKKVQLGKKGVPYVVTHDGRTIRYPDPLIRANDTVKIDLATGKIDDFIKFDTGRLVMVTGGRNLGRVGVIVHREKHEGGFDLVHIKDALENTFVTRLSNVFVIGTEAGKPWVSLPKGKGIKLSISEERDRRRAQQGL"
]
expected_scores = [-0.466516291, -0.992957033]  # Replace with your expected scores

validation_test(test_sequences, expected_scores, model_size="15B")
"""

Calculating ESM2-650M likelihood for sequence (length 218)...
Loading ESM2-650M model on cuda...
Downloading: "https://github.com/facebookresearch/esm/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt

Results:
Sequence: MLIVINYKTYNESIGNRGLE... (length 218)
ESM2-650M PPPL Score: -1.053744
Calculation Time: 43.16 seconds


'\n\n# Example 2: Process sequences from a CSV file\n# Uncomment and run this cell to process sequences from a CSV file\n\n# Upload a CSV file with a \'sequence\' column\n\ndf_results = process_csv_file(input_file=\'/content/Archaea_only.csv\',model_size="650M", max_sequences=300)\n\n\n# Example 3: Validation with known scores\n# If you have sequences with known PPPL scores, you can validate the implementation\n\n# Replace with your sequences and their known scores\ntest_sequences = [\n    "MENDKGQLVELYVPRKCSATNRIIKAKDHASVQISIAKVDEDGRAIAGENITYALSGYVRGRGEADDSLNRLAQQDGLLKNVWSYSR",\n    "MARGPKKHLKRLAAPSHWMLDKLSGTYAPRPSAGPHKLRESLPLVVFLRNRLKYALNGREVKAIMMQQHVQVDGKVRTDTTYPAGFMDVITLEATNEHFRLVYDVKGKFAVHRISAEEAAYKLGKVKKVQLGKKGVPYVVTHDGRTIRYPDPLIRANDTVKIDLATGKIDDFIKFDTGRLVMVTGGRNLGRVGVIVHREKHEGGFDLVHIKDALENTFVTRLSNVFVIGTEAGKPWVSLPKGKGIKLSISEERDRRRAQQGL"\n]\nexpected_scores = [-0.466516291, -0.992957033]  # Replace with your expected scores\n\nvalidation_test(test_sequences, expected_scores, mode

In [None]:
# ESM2 Protein Sequence Likelihood Calculator for Google Colab
# Optimized for ESM2-15B model with limited memory

# Install required packages
!pip install -q fair-esm

import torch
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import os
from google.colab import files
import matplotlib.pyplot as plt
import time
import gc  # For aggressive garbage collection

# Check if GPU is available
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Function to calculate ESM2 likelihoods with extreme memory optimization
def calculate_esm2_likelihoods_memory_efficient(sequences, model_size="650M",
                                               max_length=1022, show_progress=True):
    """
    Calculate ESM2 pseudo-log-likelihoods for protein sequences with extreme memory optimization.

    Parameters:
    -----------
    sequences : list
        List of protein sequences as strings
    model_size : str, default="650M"
        Size of ESM2 model to use ("650M", "3B", or "15B")
    max_length : int, default=1022
        Maximum sequence length to process
    show_progress : bool, default=True
        Whether to show progress bars

    Returns:
    --------
    list
        List of PPPL scores for each sequence
    """
    # Determine model based on size
    model_map = {
        "650M": "esm2_t33_650M_UR50D",
        "3B": "esm2_t36_3B_UR50D",
        "15B": "esm2_t48_15B_UR50D"
    }
    model_name = model_map.get(model_size)
    if not model_name:
        raise ValueError(f"Unknown model size: {model_size}. Choose from {list(model_map.keys())}")

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Track memory usage
    if torch.cuda.is_available():
        initial_memory = torch.cuda.memory_allocated() / 1e9
        print(f"Initial GPU memory usage: {initial_memory:.2f} GB")

    # Aggressive memory clearing
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

    # Load model
    print(f"Loading ESM2-{model_size} model on {device}...")
    model, alphabet = torch.hub.load("facebookresearch/esm:main", model_name)
    model = model.to(device)
    model.eval()
    batch_converter = alphabet.get_batch_converter()

    # Report memory after model load
    if torch.cuda.is_available():
        model_memory = (torch.cuda.memory_allocated() / 1e9) - initial_memory
        print(f"Model loaded. Using {model_memory:.2f} GB of GPU memory")
        print(f"Remaining GPU memory: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / 1e9:.2f} GB")

    # Trim sequences if needed
    sequences = [seq[:max_length] for seq in sequences]

    # Process each sequence
    results = []
    invalid_indices = []

    # Progress bar configuration
    progress_bar = tqdm(range(len(sequences)), desc="Processing sequences") if show_progress else range(len(sequences))

    for i in progress_bar:
        sequence = sequences[i]

        # Skip sequences with non-standard amino acids or that are too short
        if '*' in sequence or len(sequence) < 2:
            results.append(float('nan'))
            invalid_indices.append(i)
            continue

        log_probs = []

        # Process each position one by one (super memory efficient)
        try:
            for pos in range(1, len(sequence) + 1):
                # Clear everything possible
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    gc.collect()

                # Create data with a masked position
                data = [(f"seq_{i}", sequence)]
                _, _, batch_tokens = batch_converter(data)
                batch_tokens = batch_tokens.to(device)

                # Mask the current position
                batch_tokens[0, pos] = alphabet.mask_idx

                # Get model predictions with reduced memory usage
                with torch.no_grad():
                    outputs = model(batch_tokens, repr_layers=[model.num_layers], return_contacts=False)
                    token_probs = torch.log_softmax(outputs["logits"], dim=-1)

                    # Get probability of the true amino acid at the masked position
                    true_aa_idx = alphabet.get_idx(sequence[pos-1])
                    log_prob = token_probs[0, pos, true_aa_idx].item()
                    log_probs.append(log_prob)

                # Delete variables explicitly
                del outputs, token_probs, batch_tokens

                # Clear memory again
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    gc.collect()

                # For 15B model, show memory usage every 10 positions
                if model_size == "15B" and torch.cuda.is_available() and pos % 10 == 0:
                    current_memory = torch.cuda.memory_allocated() / 1e9
                    print(f"  Position {pos}/{len(sequence)}, Memory: {current_memory:.2f} GB")

            # Calculate average log probability (PPPL)
            if log_probs:
                pppl = sum(log_probs) / len(log_probs)
                results.append(pppl)
            else:
                results.append(float('nan'))

        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"Out of memory at sequence {i}, position {len(log_probs) + 1}")
                # Try to recover and continue with next sequence
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    gc.collect()

                # Record partial result if possible
                if log_probs:
                    partial_pppl = sum(log_probs) / len(log_probs)
                    print(f"  Partial PPPL from {len(log_probs)}/{len(sequence)} positions: {partial_pppl:.6f}")
                    results.append(partial_pppl)
                else:
                    results.append(float('nan'))
            else:
                print(f"Error processing sequence {i}: {e}")
                results.append(float('nan'))

    # Report memory usage after processing
    if torch.cuda.is_available():
        final_memory = torch.cuda.memory_allocated() / 1e9
        print(f"Final GPU memory usage: {final_memory:.2f} GB")

    # Report invalid sequences
    if invalid_indices and show_progress:
        print(f"Warning: {len(invalid_indices)} sequences were skipped due to invalid characters or empty sequences.")

    # Final cleanup
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

    return results

# Function to process a single sequence with extreme memory optimization
def process_single_sequence_memory_efficient(sequence=None, model_size="650M"):
    """
    Calculate ESM2 likelihood for a single protein sequence with minimal memory usage.

    Parameters:
    -----------
    sequence : str, default=None
        Protein sequence to analyze. If None, prompts for input.
    model_size : str, default="650M"
        Size of ESM2 model to use
    """
    if sequence is None:
        sequence = input("Enter a protein sequence: ")

    sequence = sequence.strip().upper()

    # Validate sequence
    valid_aa = set("ACDEFGHIKLMNPQRSTVWY")
    if not all(aa in valid_aa for aa in sequence):
        invalid_aa = [aa for aa in sequence if aa not in valid_aa]
        print(f"Warning: Sequence contains invalid amino acids: {', '.join(set(invalid_aa))}")

    # Calculate likelihood
    print(f"Calculating ESM2-{model_size} likelihood for sequence (length {len(sequence)})...")

    start_time = time.time()
    likelihoods = calculate_esm2_likelihoods_memory_efficient([sequence], model_size=model_size, show_progress=False)
    elapsed_time = time.time() - start_time

    pppl = likelihoods[0]

    # Display results
    print("\nResults:")
    print(f"Sequence: {sequence[:20]}... (length {len(sequence)})")
    print(f"ESM2-{model_size} PPPL Score: {pppl:.6f}")
    print(f"Calculation Time: {elapsed_time:.2f} seconds")

    return pppl

# Function to validate implementation with extreme memory optimization
def validation_test_memory_efficient(sequences, expected_scores, model_size="650M"):
    """
    Validate the implementation by comparing with expected scores.
    Optimized for minimal memory usage.

    Parameters:
    -----------
    sequences : list
        List of protein sequences to test
    expected_scores : list
        List of expected PPPL scores for each sequence
    model_size : str, default="650M"
        Size of ESM2 model to use
    """
    print(f"Running validation with ESM2-{model_size} model for {len(sequences)} sequences...")

    # Process one sequence at a time to minimize memory usage
    calculated_scores = []

    for i, seq in enumerate(sequences):
        print(f"\nProcessing sequence {i+1}/{len(sequences)} (length: {len(seq)})")

        start_time = time.time()
        score = calculate_esm2_likelihoods_memory_efficient([seq], model_size=model_size, show_progress=False)[0]
        elapsed_time = time.time() - start_time

        calculated_scores.append(score)
        print(f"  PPPL Score: {score:.6f}")
        print(f"  Time taken: {elapsed_time:.2f} seconds")

    print("\nValidation Results:")
    print("------------------")
    all_match = True

    for i, (seq, expected, calculated) in enumerate(zip(sequences, expected_scores, calculated_scores)):
        difference = abs(expected - calculated)
        matches = difference < 0.01

        print(f"Sequence {i+1} (length {len(seq)}):")
        print(f"  Calculated PPPL: {calculated:.6f}")
        print(f"  Expected PPPL:   {expected:.6f}")
        print(f"  Difference:      {difference:.6f}")
        print(f"  Within tolerance: {'✓' if matches else '✗'}")
        print()

        if not matches:
            all_match = False

    if all_match:
        print("✅ All scores match within tolerance!")
    else:
        print("⚠️ Some scores do not match the expected values.")

# Optimize memory before running the model
def optimize_memory_for_15B():
    """
    Perform memory optimization steps specifically for the ESM2-15B model
    """
    print("Optimizing memory for ESM2-15B model...")

    # Force clear CUDA cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

    # Check available memory
    if torch.cuda.is_available():
        total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        allocated_memory = torch.cuda.memory_allocated() / 1e9
        free_memory = (total_memory - allocated_memory)

        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Total GPU memory: {total_memory:.2f} GB")
        print(f"Currently allocated: {allocated_memory:.2f} GB")
        print(f"Free memory: {free_memory:.2f} GB")

        if free_memory < 12:
            print("⚠️ Warning: Less than 12GB of free GPU memory.")
            print("   ESM2-15B requires at least 12GB of GPU memory.")
            print("   Consider using ESM2-650M or ESM2-3B instead.")
            print("   Or restart your Colab runtime to free memory.")

            # Suggest restart if memory is really low
            if free_memory < 8:
                print("\n⚠️ CRITICAL: Insufficient memory available.")
                print("   Please restart your Colab runtime before running.")
                print("   Runtime → Restart runtime")

    # Return True if we have enough memory, False otherwise
    return torch.cuda.is_available() and (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / 1e9 >= 12

PyTorch version: 2.8.0+cu126
CUDA available: True
CUDA device: Tesla T4
Available GPU memory: 15.83 GB


In [None]:

# Example 1: Process a single sequence with memory efficiency
sequence = "MLIVINYKTYNESIGNRGLEIAKIAEKVSEESGITIGVAPQFVDLRMIVENVNIPVYAQHIDNINPGSHTGHILAEAIKDCGCKGTLINHSEKRMLLADIEAVINKCKNLGLETIVCTNNINTSKAVAALSPDYIAVEPPELIGTGIPVSKANPEVVEGTVRAVKEINKDVKVLCGAGISKGEDVKAALDLGAEGVLLASGVVKAKNVEEAIRELIKF"
pppl = process_single_sequence_memory_efficient(sequence, model_size="650M")

# Example 2: Validation with known scores - memory efficient version
test_sequences = [
    "MGGLEKKKYERGSATNYITRNKARKKLQLSLADFRRLCILKGIYPHEPKHKKKVNKGSTAARTFYLIKDIRFLLHEPIVNKFREYKVFVRKLRKAYGKSEWNTVERLKDNKPNYKLDHIIKERYPTFIDALRDLDDALSMCFLFSTFPRTGKCHVQTIQLCRRLTVEFMHYIIAARALRKVFLSIKGIYYQAEVLGQPIVWITPYAFSHDHPTDVDYRVMATFTEFYTTLLGFVNFRLYQLLNLHYPPKLEGQAQAEAKAGEGTYALDSESCMEKLAALSASLARVVVPATEEEAEVDEFPTDGEMSAQEEDRRKELEAQEKHKKLFEGLKFFLNREVPREALAFIIRSFGGEVSWDKSLCIGATYDVTDSRITHQIVDRPGQQTSVIGRCYVQPQWVFDSVNARLLLPVAEYFSGVQLPPHLSPFVTEKEGDYVPPEKLKLLALQRGEDPGNLNESEEEEEEDDNNEGDGDEEGENEEEEEDAEAGSEKEEEARLAALEEQRMEGKKPRVMAGTLKLEDKQRLAQEEESEAKRLAIMMMKKREKYLYQKIMFGKRRKIREANKLAEKRKAHDEAVRSEKKAKKARPE",
    "MDSSVIQRKKVAVIGGGLVGSLQACFLAKRNFQIDVYEAREDTRVATFTRGRSINLALSHRGRQALKAVGLEDQIVSQGIPMRARMIHSLSGKKSAIPYGTKSQYILSVSRENLNKDLLTAAEKYPNVKMHFNHRLLKCNPEEGMITVLGSDKVPKDVTCDLIVGCDGAYSTVRSHLMKKPRFDYSQQYIPHGYMELTIPPKNGDYAMEPNYLHIWPRNTFMMIALPNMNKSFTCTLFMPFEEFEKLLTSNDVVDFFQKYFPDAIPLIGEKLLVQDFFLLPAQPMISVKCSSFHFKSHCVLLGDAAHAIVPFFGQGMNAGFEDCLVFDELMDKFSNDLSLCLPVFSRLRIPDDHAISDLSMYNYIEMRAHVNSSWFIFQKNMERFLHAIMPSTFIPLYTMVTFSRIRYHEAVQRWHWQKKVINKGLFFLGSLIAISSTYLLIHYMSPRSFLRLRRPWNWIAHFRNTTCFPAKAVDSLEQISNLISR"
]
expected_scores = [-1.376412861, -1.078029289]


Calculating ESM2-650M likelihood for sequence (length 218)...
Initial GPU memory usage: 0.01 GB
Loading ESM2-650M model on cuda...


Using cache found in /root/.cache/torch/hub/facebookresearch_esm_main


Model loaded. Using 2.67 GB of GPU memory
Remaining GPU memory: 13.15 GB
Final GPU memory usage: 2.69 GB

Results:
Sequence: MLIVINYKTYNESIGNRGLE... (length 218)
ESM2-650M PPPL Score: -1.053744
Calculation Time: 77.09 seconds
