In [None]:
# Core imports
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
import random
import time
import math
import warnings
from typing import List, Tuple, Dict, Union
import matplotlib.pyplot as plt
import statistics as stats
import faiss #type: ignore
import numpy as np
import torch
from tqdm import tqdm
from typing import List
from typing import Union, Tuple
from pathlib import Path
import torch
import torch.nn.functional as F
import numpy as np
import sys
import os
from pathlib import Path
from typing import Tuple, Optional

# Scientific computing
from scipy.spatial import ConvexHull
from scipy.linalg import eigvals, eig
from scipy.stats import entropy
from scipy.linalg import sqrtm
from scipy.stats import wasserstein_distance
from sklearn.decomposition import PCA
from tqdm.auto import tqdm
from sklearn.neighbors import KernelDensity

from top_pr import compute_top_pr as TopPR #type: ignore

# Hugging Face transformers
from transformers import EsmTokenizer, EsmForMaskedLM, EsmForProteinFolding #type: ignore

# External soft alignment
import sys, pathlib, os
project_root = pathlib.Path.home() / "projets" / "protein-generation"
sys.path.append(str(project_root))

from external.protein_embed_softalign.soft_align import soft_align

def load_perplexity_model(
    ppl_model_name: str, 
    device: str = "cuda"
) -> Tuple[EsmForMaskedLM, EsmTokenizer]:
    print("Loading perplexity model...")
    
    # Load tokenizer and model from Hugging Face
    ppl_tokenizer = EsmTokenizer.from_pretrained(ppl_model_name)
    ppl_model = EsmForMaskedLM.from_pretrained(ppl_model_name)
    
    # Set model to evaluation mode and move to specified device
    ppl_model.eval()
    ppl_model.to(device)
    
    print("✓ Perplexity model loaded")
    return ppl_model, ppl_tokenizer


def load_folding_model(
    fold_model_name: str, 
    device: str
) -> Tuple[EsmForProteinFolding, EsmTokenizer]:
    print("Loading folding model...")
    
    # Load tokenizer and folding model from Hugging Face
    fold_tokenizer = EsmTokenizer.from_pretrained(fold_model_name)
    fold_model = EsmForProteinFolding.from_pretrained(fold_model_name)
    
    # Set model to evaluation mode and move to specified device
    fold_model.eval()
    fold_model.to(device)
    
    print("✓ Folding model loaded")
    return fold_model, fold_tokenizer


def get_sequence_embeddings(
    sequences: List[str],
    ppl_model: EsmForMaskedLM,
    ppl_tokenizer: EsmTokenizer,
    device: Union[str, torch.device],
    show_progress: bool = True
) -> np.ndarray:
    embeddings = []
    
    # Setup iterator with or without progress bar
    iterator = tqdm(sequences, desc="Extracting embeddings") if show_progress else sequences
    
    for seq in iterator:
        # Tokenize the protein sequence
        inputs = ppl_tokenizer(seq, return_tensors="pt", padding=True, truncation=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            # Get model outputs with hidden states from all layers
            outputs = ppl_model(**inputs, output_hidden_states=True)
            
            # Extract layer 6 hidden states (0-indexed, so layer 6 is index 6)
            # Shape: [batch_size, seq_len, hidden_dim]
            hidden_states = outputs.hidden_states[6]
            
            # Get actual sequence length (excluding special tokens)
            seq_len = len(seq)
            
            # Extract embeddings for the actual sequence (excluding [CLS] and [SEP] tokens)
            # [CLS] is at position 0, sequence starts at position 1
            seq_embedding = hidden_states[0, 1:seq_len+1].mean(dim=0).cpu().numpy()
            
            embeddings.append(seq_embedding)
    
    return np.array(embeddings)

def calculate_plddt(seq: str, fold_model, fold_tokenizer, device="cuda") -> float:
    inputs = fold_tokenizer(seq, return_tensors="pt", add_special_tokens=False)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        plddt = fold_model(**inputs).plddt[0].mul_(100).mean()
    
    return float(plddt)



def calculate_perplexity(
    seq: str,
    ppl_model: EsmForMaskedLM,
    ppl_tokenizer: EsmTokenizer,
    device: Union[torch.device, str] = "cuda"
) -> float:
    # Tokenize sequence
    inputs = ppl_tokenizer(seq, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Get sequence length (excluding special tokens)
    seq_len = len(seq)
    
    # Store original tokens for positions we'll mask
    original_input_ids = inputs['input_ids'].clone()
    
    # Create batch with all positions masked (one sample per position)
    batch_size = seq_len
    batch_input_ids = original_input_ids.repeat(batch_size, 1)
    batch_attention_mask = inputs['attention_mask'].repeat(batch_size, 1)
    
    # Mask each position in its corresponding batch sample
    for i in range(batch_size):
        batch_input_ids[i, i + 1] = ppl_tokenizer.mask_token_id  # +1 to skip [CLS]
    
    batch_inputs = {
        'input_ids': batch_input_ids,
        'attention_mask': batch_attention_mask
    }
    
    with torch.no_grad():
        # Single forward pass for all masked positions
        outputs = ppl_model(**batch_inputs)
        logits = outputs.logits  # Shape: [batch_size, seq_len, vocab_size]
    
    total_loss = 0.0
    
    # Calculate loss for each position
    for i in range(batch_size):
        # Get the original token at position i+1 (skip [CLS])
        original_token = original_input_ids[0, i + 1].item()
        
        # Get logits for the masked position in sample i
        position_logits = logits[i, i + 1]  # +1 to skip [CLS]
        
        # Calculate log probability
        log_probs = torch.nn.functional.log_softmax(position_logits, dim=-1)
        token_loss = -log_probs[original_token].item()
        total_loss += token_loss
    
    # Calculate perplexity
    avg_loss = total_loss / seq_len
    perplexity = math.exp(avg_loss)
    
    return perplexity

def calculate_perplexity_simple(
    seq: str,
    ppl_model: EsmForMaskedLM,
    ppl_tokenizer: EsmTokenizer,
    device: Union[torch.device, str] = "cuda"
) -> float:
    inputs = ppl_tokenizer(seq, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = ppl_model(**inputs, labels=inputs['input_ids'])
        # La loss est déjà moyennée par token
        perplexity = math.exp(outputs.loss.item())
    
    return perplexity

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Test des fonctions calculate_plddt et calculate_perplexity
import time

# Séquences de test de tailles variables
df_train = pd.read_csv("/home/arthur/projets/protein-generation/experiments/models/noised_dplm_simple/training_sequences.csv")
train = df_train["sequence"].tolist()

sequences = train[:10]  # Utiliser les 10 premières séquences pour le test

# Charger les modèles
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
ppl_model, ppl_tokenizer = load_perplexity_model("facebook/esm2_t6_8M_UR50D", device)
fold_model, fold_tokenizer = load_folding_model("facebook/esmfold_v1", device)

# Tester les fonctions
for i, seq in enumerate(sequences):
    print(f"\nSéquence {i+1} ({len(seq)} AA):")
    
    # Test pLDDT
    start = time.time()
    plddt_mean = calculate_plddt(seq, fold_model, fold_tokenizer, device)
    plddt_time = time.time() - start
    print(f"  pLDDT: {plddt_mean:.2f} ({plddt_time:.3f}s)")
    
    # Test Perplexity
    start = time.time()
    perplexity = calculate_perplexity(seq, ppl_model, ppl_tokenizer, device)
    ppl_time = time.time() - start
    print(f"  Perplexity: {perplexity:.2f} ({ppl_time:.3f}s)")
    
    # Test Perplexity
    start = time.time()
    simple_perplexity = calculate_perplexity_simple(seq, ppl_model, ppl_tokenizer, device)
    ppl_time = time.time() - start
    print(f"  Simple Perplexity: {simple_perplexity:.2f} ({ppl_time:.3f}s)")

Using device: cuda:1
Loading perplexity model...


Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmForMaskedLM: ['esm.embeddings.position_embeddings.weight']
- This IS expected if you are initializing EsmForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


✓ Perplexity model loaded
Loading folding model...


Some weights of the model checkpoint at facebook/esmfold_v1 were not used when initializing EsmForProteinFolding: ['esm.embeddings.position_embeddings.weight']
- This IS expected if you are initializing EsmForProteinFolding from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForProteinFolding from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✓ Folding model loaded

Séquence 1 (267 AA):
  pLDDT: 79.65 (6.294s)
  Perplexity: 10.37 (0.265s)
  Simple Perplexity: 1.48 (0.010s)

Séquence 2 (224 AA):
  pLDDT: 69.86 (4.059s)
  Perplexity: 17.38 (0.172s)
  Simple Perplexity: 1.66 (0.006s)

Séquence 3 (299 AA):
  pLDDT: 75.57 (8.057s)
  Perplexity: 9.76 (0.340s)
  Simple Perplexity: 1.50 (0.007s)

Séquence 4 (216 AA):
  pLDDT: 87.33 (3.795s)
  Perplexity: 9.42 (0.160s)
  Simple Perplexity: 1.47 (0.007s)

Séquence 5 (298 AA):
  pLDDT: 83.23 (8.127s)
  Perplexity: 7.29 (0.340s)
  Simple Perplexity: 1.48 (0.007s)

Séquence 6 (281 AA):
  pLDDT: 52.47 (7.097s)
  Perplexity: 13.58 (0.300s)
  Simple Perplexity: 1.54 (0.007s)

Séquence 7 (166 AA):
  pLDDT: 67.33 (2.360s)
  Perplexity: 7.51 (0.093s)
  Simple Perplexity: 1.40 (0.006s)

Séquence 8 (213 AA):
  pLDDT: 84.05 (3.730s)
  Perplexity: 7.21 (0.156s)
  Simple Perplexity: 1.43 (0.007s)

Séquence 9 (247 AA):
  pLDDT: 80.83 (5.078s)
  Perplexity: 8.42 (0.215s)
  Simple Perplexity: 1.45 (0

In [8]:
# Test des fonctions calculate_plddt et calculate_perplexity
import time

# Séquences de test de tailles variables
df_train = pd.read_csv("/home/arthur/projets/protein-generation/experiments/models/noised_dplm_simple/training_sequences.csv")
train = df_train["sequence"].tolist()

df_gen_masked = pd.read_csv('/home/arthur/projets/protein-generation/experiments/models/masked_dplm_simple/generated_sequences.csv')
demasked = df_gen_masked["sequence"].tolist()


sequences = demasked[:10]  # Utiliser les 10 premières séquences pour le test

# Charger les modèles
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
ppl_model, ppl_tokenizer = load_perplexity_model("facebook/esm2_t6_8M_UR50D", device)
fold_model, fold_tokenizer = load_folding_model("facebook/esmfold_v1", device)

# Tester les fonctions
for i, seq in enumerate(sequences):
    print(f"\nSéquence {i+1} ({len(seq)} AA):")
    
    # Test pLDDT
    start = time.time()
    plddt_mean, _ = calculate_plddt(seq, fold_model, fold_tokenizer, device)
    plddt_time = time.time() - start
    print(f"  pLDDT: {plddt_mean:.2f} ({plddt_time:.3f}s)")
    
    # Test Perplexity
    start = time.time()
    perplexity = calculate_perplexity(seq, ppl_model, ppl_tokenizer, device)
    ppl_time = time.time() - start
    print(f"  Perplexity: {perplexity:.2f} ({ppl_time:.3f}s)")
    
    
    # Test Perplexity
    start = time.time()
    simple_perplexity = calculate_perplexity_simple(seq, ppl_model, ppl_tokenizer, device)
    ppl_time = time.time() - start
    print(f"  Simple Perplexity: {simple_perplexity:.2f} ({ppl_time:.3f}s)")

Using device: cuda:1
Loading perplexity model...


Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmForMaskedLM: ['esm.embeddings.position_embeddings.weight']
- This IS expected if you are initializing EsmForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


✓ Perplexity model loaded
Loading folding model...


Some weights of the model checkpoint at facebook/esmfold_v1 were not used when initializing EsmForProteinFolding: ['esm.embeddings.position_embeddings.weight']
- This IS expected if you are initializing EsmForProteinFolding from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForProteinFolding from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✓ Folding model loaded

Séquence 1 (100 AA):
  pLDDT: 59.40 (1.043s)
  Perplexity: 5.09 (0.033s)
  Simple Perplexity: 1.47 (0.007s)

Séquence 2 (100 AA):
  pLDDT: 40.14 (1.049s)
  Perplexity: 18.65 (0.033s)
  Simple Perplexity: 1.67 (0.007s)

Séquence 3 (100 AA):
  pLDDT: 43.49 (1.029s)
  Perplexity: 16.02 (0.033s)
  Simple Perplexity: 1.56 (0.007s)

Séquence 4 (100 AA):
  pLDDT: 65.54 (1.044s)
  Perplexity: 7.99 (0.033s)
  Simple Perplexity: 1.49 (0.007s)

Séquence 5 (100 AA):
  pLDDT: 49.44 (1.046s)
  Perplexity: 17.27 (0.033s)
  Simple Perplexity: 1.64 (0.007s)

Séquence 6 (100 AA):
  pLDDT: 60.34 (1.048s)
  Perplexity: 6.40 (0.033s)
  Simple Perplexity: 1.56 (0.007s)

Séquence 7 (100 AA):
  pLDDT: 38.84 (1.035s)
  Perplexity: 16.86 (0.033s)
  Simple Perplexity: 1.62 (0.007s)

Séquence 8 (100 AA):
  pLDDT: 43.82 (1.020s)
  Perplexity: 18.51 (0.033s)
  Simple Perplexity: 1.62 (0.007s)

Séquence 9 (100 AA):
  pLDDT: 38.38 (1.049s)
  Perplexity: 17.83 (0.033s)
  Simple Perplexity: 1.61