### SETUP AND IMPORTS

In [1]:
# Core imports
import pandas as pd
import numpy as np
import torch
import random
import time
import math
import warnings
from typing import List, Tuple, Dict, Union
import matplotlib.pyplot as plt
import seaborn as sns
import statistics as stats
import faiss
from tqdm import tqdm

# Scientific computing
from scipy.spatial import ConvexHull
from scipy.linalg import sqrtm
from sklearn.decomposition import PCA
from tqdm.auto import tqdm

# Hugging Face transformers
from transformers import EsmTokenizer, EsmForMaskedLM, EsmForProteinFolding

# External soft alignment
import sys, pathlib, os
project_root = pathlib.Path.home() / "projets" / "protein-generation"
sys.path.append(str(project_root))
from external.protein_embed_softalign.soft_align import soft_align
from scripts.protein_metrics.calculate_metrics import *

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("✓ All imports successful")

PPL_MODEL_NAME = "facebook/esm2_t6_8M_UR50D"    # For perplexity calculation
FOLD_MODEL_NAME = "facebook/esmfold_v1"         # For structure prediction

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Global variables to store models (avoiding reloading)
ppl_model = None
ppl_tokenizer = None
fold_model = None
fold_tokenizer = None

# Load both models
ppl_model, ppl_tokenizer = load_perplexity_model(ppl_model_name=PPL_MODEL_NAME, device=device)
fold_model, fold_tokenizer = load_folding_model(fold_model_name=FOLD_MODEL_NAME, device=device)

  from .autonotebook import tqdm as notebook_tqdm


✓ All imports successful
Loading perplexity model...


Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmForMaskedLM: ['esm.embeddings.position_embeddings.weight']
- This IS expected if you are initializing EsmForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


✓ Perplexity model loaded
Loading folding model...


Some weights of the model checkpoint at facebook/esmfold_v1 were not used when initializing EsmForProteinFolding: ['esm.embeddings.position_embeddings.weight']
- This IS expected if you are initializing EsmForProteinFolding from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForProteinFolding from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✓ Folding model loaded


### TEST QUALITY

In [3]:
raw_df = pd.read_csv("/home/arthur/projets/protein-generation/data/train.tsv", sep=',')

AA_SET = set("ACDEFGHIKLMNPQRSTVWY")

def is_clean(seq: str) -> bool:
    return isinstance(seq, str) and len(seq) >= 100 and all(c in AA_SET for c in seq)

# 2.  Sélectionne les 10 premières séquences propres
clean_seqs = [s for s in raw_df["sequence"] if is_clean(s)][:10]
assert len(clean_seqs) == 10, "Moins de 10 séquences valides trouvées !"


In [None]:
# 3.  Fonctions métriques (réutilisent tes helpers)
def get_perplexity(seq: str) -> float:
    return calculate_perplexity(
        seq,
        ppl_model=ppl_model,
        ppl_tokenizer=ppl_tokenizer,
        device=device,
        batch_size=32,
    )

def get_plddt(seq: str) -> float:
    mean_plddt, _ = calculate_plddt(
        seq,
        fold_model=fold_model,
        fold_tokenizer=fold_tokenizer,
        device=device,
    )
    return mean_plddt

# 4.  Boucle principale
results = []
for idx, full_seq in enumerate(clean_seqs):
    short_seq = full_seq[:100]

    # — metrics full length
    ppl_full   = get_perplexity(full_seq)
    plddt_full = get_plddt(full_seq)

    # — metrics truncated to 100 aa
    ppl_short   = get_perplexity(short_seq)
    plddt_short = get_plddt(short_seq)

    results.append({
        "id": idx,
        "len_full": len(full_seq),
        "ppl_full": ppl_full,
        "ppl_100":  ppl_short,
        "Δppl":     ppl_short - ppl_full,
        "plddt_full": plddt_full,
        "plddt_100":  plddt_short,
        "Δplddt":    plddt_short - plddt_full,
    })

# 5.  Résumé sous forme de DataFrame + stats
df_res = pd.DataFrame(results)
display(df_res)

print("\n=== Moyennes sur 10 séquences ===")
print(df_res[["ppl_full", "ppl_100", "Δppl",
              "plddt_full", "plddt_100", "Δplddt"]].mean())

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Unnamed: 0,id,len_full,ppl_full,ppl_100,Δppl,plddt_full,plddt_100,Δplddt
0,0,267,10.36629,10.648608,0.282318,79.651863,54.84063,-24.811234
1,1,224,17.380379,18.315645,0.935266,69.855713,69.21936,-0.636353
2,2,378,8.890858,17.070048,8.17919,86.330635,68.484192,-17.846443
3,3,299,9.760826,13.936968,4.176141,75.565346,51.252251,-24.313095
4,4,418,7.324702,11.89431,4.569609,83.163017,46.313927,-36.849091
5,5,565,11.640944,13.218821,1.577877,69.142036,59.937927,-9.204109
6,6,524,6.383735,15.574412,9.190677,77.200768,39.604534,-37.596233
7,7,216,9.423461,14.556885,5.133424,87.327797,72.876869,-14.450928
8,8,298,7.288765,14.857831,7.569067,83.230148,42.113281,-41.116867
9,9,434,13.763902,13.62217,-0.141731,68.357468,44.123409,-24.234058



=== Moyennes sur 10 séquences ===
ppl_full      10.222386
ppl_100       14.369570
Δppl           4.147184
plddt_full    77.982479
plddt_100     54.876638
Δplddt       -23.105841
dtype: float64


: 