In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Dict, Union
import warnings
warnings.filterwarnings('ignore')

# Hugging Face transformers
from transformers import EsmTokenizer, EsmForMaskedLM
import sys, pathlib, os

project_root = pathlib.Path.home() / "projets" / "protein-generation"
sys.path.append(str(project_root))
from scripts.utils import *
from scripts.models.DPLM.noising.training_perp_optimized import *

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

def load_perplexity_model(
    ppl_model_name: str,
    device: str = "cuda"
) -> Tuple[EsmForMaskedLM, EsmTokenizer]:
    """Load ESM model and tokenizer for perplexity-based noising."""
    print(f"[INFO] Loading perplexity model: {ppl_model_name}")
    # Load tokenizer and model from Hugging Face
    ppl_tokenizer = EsmTokenizer.from_pretrained(ppl_model_name)
    ppl_model = EsmForMaskedLM.from_pretrained(ppl_model_name)
    # Set model to evaluation mode and move to specified device
    ppl_model.eval()
    ppl_model.to(device)
    print(f"[SUCCESS] Perplexity model loaded successfully")
    return ppl_model, ppl_tokenizer

# Model configuration
PPL_MODEL_NAME = "facebook/esm2_t6_8M_UR50D"

# Load perplexity model
print("=" * 60)
print("LOADING MODELS")
print("=" * 60)
ppl_model, ppl_tokenizer = load_perplexity_model(ppl_model_name=PPL_MODEL_NAME, device=device)
print(f"[SUCCESS] All models loaded successfully\n")

def test_masked_vs_unmasked_prediction(sequence: str, test_position: int, model, tokenizer, device):
    """
    Compare predictions when a position is masked vs unmasked
    
    Args:
        sequence: protein sequence string
        test_position: position to test (0-indexed)
        model: ESM model
        tokenizer: ESM tokenizer
        device: torch device
    
    Returns:
        dict with results
    """
    
    # Get the original amino acid at test position
    original_aa = sequence[test_position]
    
    # Test 1: Original sequence (unmasked)
    inputs_unmasked = tokenizer(sequence, return_tensors="pt")
    inputs_unmasked = {k: v.to(device) for k, v in inputs_unmasked.items()}
    
    with torch.no_grad():
        outputs_unmasked = model(**inputs_unmasked)
        logits_unmasked = outputs_unmasked.logits
        probs_unmasked = torch.softmax(logits_unmasked, dim=-1)
    
    # Test 2: Masked sequence
    masked_sequence = sequence[:test_position] + tokenizer.mask_token + sequence[test_position+1:]
    inputs_masked = tokenizer(masked_sequence, return_tensors="pt")
    inputs_masked = {k: v.to(device) for k, v in inputs_masked.items()}
    
    with torch.no_grad():
        outputs_masked = model(**inputs_masked)
        logits_masked = outputs_masked.logits
        probs_masked = torch.softmax(logits_masked, dim=-1)
    
    # Account for [CLS] token (+1 offset)
    esm_pos = test_position + 1
    
    # Get probabilities for all amino acids
    amino_acids = ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 
                   'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
    
    results = {
        'original_aa': original_aa,
        'position': test_position,
        'sequence': sequence,
        'masked_sequence': masked_sequence,
        'unmasked_probs': {},
        'masked_probs': {},
        'prob_differences': {}
    }
    
    for aa in amino_acids:
        aa_token_id = tokenizer.convert_tokens_to_ids(aa)
        
        prob_unmasked = probs_unmasked[0, esm_pos, aa_token_id].item()
        prob_masked = probs_masked[0, esm_pos, aa_token_id].item()
        
        results['unmasked_probs'][aa] = prob_unmasked
        results['masked_probs'][aa] = prob_masked
        results['prob_differences'][aa] = prob_masked - prob_unmasked
    
    return results

def analyze_results(results):
    """Analyze and display the results"""
    
    print(f"\nAnalyse pour la position {results['position']} (AA original: {results['original_aa']})")
    print(f"Séquence originale: {results['sequence'][:20]}...{results['sequence'][-10:]}")
    print(f"Séquence masquée:   {results['masked_sequence'][:20]}...{results['masked_sequence'][-10:]}")
    print("\n" + "="*80)
    
    # Top 5 predictions for each case
    unmasked_sorted = sorted(results['unmasked_probs'].items(), key=lambda x: x[1], reverse=True)
    masked_sorted = sorted(results['masked_probs'].items(), key=lambda x: x[1], reverse=True)
    
    print("\nTop 5 prédictions - Séquence NON masquée:")
    for i, (aa, prob) in enumerate(unmasked_sorted[:5]):
        marker = " ← ORIGINAL" if aa == results['original_aa'] else ""
        print(f"{i+1}. {aa}: {prob:.4f}{marker}")
    
    print("\nTop 5 prédictions - Séquence masquée:")
    for i, (aa, prob) in enumerate(masked_sorted[:5]):
        marker = " ← ORIGINAL" if aa == results['original_aa'] else ""
        print(f"{i+1}. {aa}: {prob:.4f}{marker}")
    
    # Compare probabilities for the original AA
    orig_prob_unmasked = results['unmasked_probs'][results['original_aa']]
    orig_prob_masked = results['masked_probs'][results['original_aa']]
    
    print(f"\nProbabilité de l'AA original ({results['original_aa']}):")
    print(f"  Non masqué: {orig_prob_unmasked:.4f}")
    print(f"  Masqué:     {orig_prob_masked:.4f}")
    print(f"  Différence: {orig_prob_masked - orig_prob_unmasked:.4f}")
    
    # Statistical analysis
    differences = list(results['prob_differences'].values())
    mean_diff = np.mean(differences)
    std_diff = np.std(differences)
    max_diff = max(differences)
    min_diff = min(differences)
    
    print(f"\nStatistiques des différences (masqué - non masqué):")
    print(f"  Moyenne: {mean_diff:.4f}")
    print(f"  Écart-type: {std_diff:.4f}")
    print(f"  Min: {min_diff:.4f}")
    print(f"  Max: {max_diff:.4f}")
    
    return results

# Main experiment
print("=" * 60)
print("TEST: MASKED vs NON-MASKED PREDICTIONS")
print("=" * 60)

# Test sequences (you can modify these)
test_sequences = [
    "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG",
    "ARNDCEQGHILKMFPSTWYV",  # All amino acids
    "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKAHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"  # Hemoglobin alpha
]

# Test multiple positions for each sequence
for seq_idx, sequence in enumerate(test_sequences):
    print(f"\n{'='*60}")
    print(f"SÉQUENCE {seq_idx + 1}: {len(sequence)} AA")
    print(f"{'='*60}")
    
    # Test positions at different locations (beginning, middle, end)
    test_positions = [
        len(sequence) // 4,      # 25%
        len(sequence) // 2,      # 50% 
        3 * len(sequence) // 4   # 75%
    ]
    
    for pos in test_positions:
        if pos < len(sequence):
            try:
                results = test_masked_vs_unmasked_prediction(
                    sequence, pos, ppl_model, ppl_tokenizer, device
                )
                analyze_results(results)
                print("\n" + "-"*60)
                
            except Exception as e:
                print(f"Erreur pour la position {pos}: {e}")

print(f"\n{'='*60}")
print("TEST TERMINÉ")
print(f"{'='*60}")

[INFO] Using device: cuda
LOADING MODELS
[INFO] Loading perplexity model: facebook/esm2_t6_8M_UR50D


Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmForMaskedLM: ['esm.embeddings.position_embeddings.weight']
- This IS expected if you are initializing EsmForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[SUCCESS] Perplexity model loaded successfully
[SUCCESS] All models loaded successfully

TEST: MASKED vs NON-MASKED PREDICTIONS

SÉQUENCE 1: 65 AA

Analyse pour la position 16 (AA original: E)
Séquence originale: MKTVRQERLKSIVRILERSK...TPRGYVLAGG
Séquence masquée:   MKTVRQERLKSIVRIL<mas...TPRGYVLAGG


Top 5 prédictions - Séquence NON masquée:
1. E: 0.8151 ← ORIGINAL
2. L: 0.0313
3. R: 0.0228
4. A: 0.0222
5. S: 0.0183

Top 5 prédictions - Séquence masquée:
1. L: 0.1894
2. E: 0.1223 ← ORIGINAL
3. A: 0.1011
4. R: 0.0998
5. S: 0.0756

Probabilité de l'AA original (E):
  Non masqué: 0.8151
  Masqué:     0.1223
  Différence: -0.6927

Statistiques des différences (masqué - non masqué):
  Moyenne: -0.0000
  Écart-type: 0.1627
  Min: -0.6927
  Max: 0.1581

------------------------------------------------------------

Analyse pour la position 32 (AA original: S)
Séquence originale: MKTVRQERLKSIVRILERSK...TPRGYVLAGG
Séquence masquée:   MKTVRQERLKSIVRILERSK...TPRGYVLAGG


Top 5 prédictions - Séque