In [1]:
import torch
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Dict, Union
import warnings
warnings.filterwarnings('ignore')

import sys, pathlib, os
project_root = pathlib.Path.home() / "projets" / "protein-generation"
sys.path.append(str(project_root))

from scripts.evaluation.evaluation_metrics import *
from scripts.models.noised_dplm.vocabulary import ProteinVocabulary


In [2]:
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
import random                        

class MutationExperiment:
    def __init__(self, csv_path, device="cuda"):
        self.device = device
        self.vocab = ProteinVocabulary()
        
        df = pd.read_csv(csv_path)
        # Séquences exactement de longueur 100, limitées aux 10 premières
        self.sequences = [seq for seq in df['sequence'].tolist() if len(seq) == 100][:10]
        
        self.ppl_model, self.ppl_tokenizer   = load_perplexity_model("facebook/esm2_t6_8M_UR50D", device)
        self.fold_model, self.fold_tokenizer = load_folding_model   ("facebook/esmfold_v1",   device)
    
    def mutate_sequence(self, sequence, positions, use_mask=True):
        """
        use_mask = True           → ESM avec tokens [MASK] (context_masked)
        use_mask = False          → ESM sans masking (no_mask)
        use_mask = "random"       → mutation aléatoire (uniforme) sans ESM
        """
        if len(positions) == 0:
            return sequence
        
        if use_mask == "random":
            mutated = list(sequence)
            for pos in positions:
                original = mutated[pos]
                choices = [aa for aa in self.vocab.ALPHABET if aa != original]
                mutated[pos] = random.choice(choices)
            return ''.join(mutated)
        
        if use_mask:                              # context_masked
            masked_seq = list(sequence)
            for pos in positions:
                masked_seq[pos] = self.ppl_tokenizer.mask_token
            input_seq = ''.join(masked_seq)
        else:                                     # no_mask
            input_seq = sequence
        
        # Prédictions ESM
        inputs = self.ppl_tokenizer(input_seq, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self.ppl_model(**inputs)
            probs = torch.softmax(outputs.logits, dim=-1)
        
        # Application des mutations (AA le moins probable)
        mutated_seq = list(sequence)
        for pos in positions:
            esm_pos = pos + 1  # Décalage pour le token [CLS] en 0
            aa_probs = {
                aa: probs[0, esm_pos, self.ppl_tokenizer.convert_tokens_to_ids(aa)].item()
                for aa in self.vocab.ALPHABET
            }
            mutated_seq[pos] = min(aa_probs, key=aa_probs.get)
        
        return ''.join(mutated_seq)
    
    def run_experiment(self, output_csv="results.csv"):
        results = []
        mutation_counts = list(range(0, 101, 10))
        method_configs = [
            ("context_masked", True),
            ("no_mask",       False),
            ("random",        "random")         
        ]
        
        for seq_idx, seq in enumerate(tqdm(self.sequences)):
            for num_mut in mutation_counts:
                positions = torch.randperm(len(seq))[:num_mut].cpu().tolist()
                
                for method, use_mask in method_configs:
                    mutated_seq = self.mutate_sequence(seq, positions, use_mask)
                    
                    plddt, _ = calculate_plddt(
                        mutated_seq, self.fold_model, self.fold_tokenizer, self.device
                    )
                    perplexity = calculate_perplexity(
                        mutated_seq, self.ppl_model, self.ppl_tokenizer, self.device
                    )
                    
                    results.append({
                        'sequence_id':    seq_idx,
                        'num_mutations':  num_mut,
                        'method':         method,
                        'plddt':          plddt,
                        'perplexity':     perplexity,
                        'mutated_sequence': mutated_seq
                    })
        
        pd.DataFrame(results).to_csv(output_csv, index=False)
        
torch.manual_seed(42)
random.seed(42)                        

experiment = MutationExperiment(
    '/home/arthur/projets/protein-generation/data/seq_clean_L100.csv'
)
experiment.run_experiment('mutation_comparison_results.csv')


Loading perplexity model...


Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmForMaskedLM: ['esm.embeddings.position_embeddings.weight']
- This IS expected if you are initializing EsmForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


✓ Perplexity model loaded
Loading folding model...


Some weights of the model checkpoint at facebook/esmfold_v1 were not used when initializing EsmForProteinFolding: ['esm.embeddings.position_embeddings.weight']
- This IS expected if you are initializing EsmForProteinFolding from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForProteinFolding from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✓ Folding model loaded


100%|██████████| 10/10 [08:36<00:00, 51.64s/it]


In [3]:
# ============================================
#  ANALYSE DES RÉSULTATS DE MUTATION
# ============================================
import pandas as pd
import numpy as np

# Charger les résultats
df = pd.read_csv('mutation_comparison_results.csv')

print("=== ANALYSE DES RÉSULTATS DE MUTATION ===\n")

# ------------------------------------------------------------------
# 1. Tableau par séquence (moyenne sur tous les nombres de mutations)
# ------------------------------------------------------------------
print("1. MOYENNES PAR SÉQUENCE (moyennées sur tous les nombres de mutations)")
print("=" * 70)
seq_analysis = (
    df.groupby(['sequence_id', 'method'])
      .agg(plddt=('plddt', 'mean'),
           perplexity=('perplexity', 'mean'))
      .round(3)
)

seq_pivot = seq_analysis.unstack('method')
seq_pivot.columns = [f'{metric}_{method}' for metric, method in seq_pivot.columns]
print(seq_pivot, "\n")

# ------------------------------------------------------------------
# 2. Tableau par nombre de mutations (moyenne sur toutes les séquences)
# ------------------------------------------------------------------
print("2. MOYENNES PAR NOMBRE DE MUTATIONS (moyennées sur toutes les séquences)")
print("=" * 75)
mut_analysis = (
    df.groupby(['num_mutations', 'method'])
      .agg(plddt=('plddt', 'mean'),
           perplexity=('perplexity', 'mean'))
      .round(3)
)

mut_pivot = mut_analysis.unstack('method')
mut_pivot.columns = [f'{metric}_{method}' for metric, method in mut_pivot.columns]
print(mut_pivot, "\n")

# ------------------------------------------------------------------
# 3. Tableau global (moyenné sur séquences ET mutations)
# ------------------------------------------------------------------
print("3. MOYENNES GLOBALES (moyennées sur séquences ET mutations)")
print("=" * 60)
global_analysis = (
    df.groupby('method')
      .agg(plddt=('plddt', 'mean'),
           perplexity=('perplexity', 'mean'))
      .round(3)
)
print(global_analysis, "\n")

# ------------------------------------------------------------------
# 4. Différences entre méthodes
# ------------------------------------------------------------------
print("4. DIFFÉRENCES ENTRE MÉTHODES")
print("=" * 35)

def diff(a, b, metric):
    return global_analysis.loc[a, metric] - global_analysis.loc[b, metric]

pairs = [
    ('context_masked', 'no_mask'),
    ('context_masked', 'random'),
    ('no_mask', 'random')
]

for a, b in pairs:
    print(f"Différences ({a} - {b}) :")
    print(f"  pLDDT     : {diff(a, b, 'plddt'):.3f}")
    print(f"  Perplexity: {diff(a, b, 'perplexity'):.3f}\n")


=== ANALYSE DES RÉSULTATS DE MUTATION ===

1. MOYENNES PAR SÉQUENCE (moyennées sur tous les nombres de mutations)
             plddt_context_masked  plddt_no_mask  plddt_random  \
sequence_id                                                      
0                          37.262         34.888        39.135   
1                          36.526         35.668        38.371   
2                          36.547         35.946        40.228   
3                          37.544         37.853        42.719   
4                          36.625         35.565        41.170   
5                          34.413         34.204        40.650   
6                          37.483         35.078        43.142   
7                          36.595         34.842        42.099   
8                          38.969         34.982        40.522   
9                          35.292         32.666        39.453   

             perplexity_context_masked  perplexity_no_mask  perplexity_random  
sequence_id  

conclusion : c'ets mieux no_mask car plddt plus bas (donc confiance plus basse) et perplexité plus haute (surprise plus grande)