## SETUP

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Dict, Union
import warnings
warnings.filterwarnings('ignore')

# Hugging Face transformers
from transformers import EsmTokenizer, EsmForMaskedLM #type: ignore

import sys, pathlib, os
project_root = pathlib.Path.home() / "projets" / "protein-generation"
sys.path.append(str(project_root))

from scripts.utils import *

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
def load_perplexity_model(
    ppl_model_name: str, 
    device: str = "cuda"
) -> Tuple[EsmForMaskedLM, EsmTokenizer]:
    print("Loading perplexity model...")
    
    # Load tokenizer and model from Hugging Face
    ppl_tokenizer = EsmTokenizer.from_pretrained(ppl_model_name)
    ppl_model = EsmForMaskedLM.from_pretrained(ppl_model_name)
    
    # Set model to evaluation mode and move to specified device
    ppl_model.eval()
    ppl_model.to(device)
    
    print("✓ Perplexity model loaded")
    return ppl_model, ppl_tokenizer

# Model names
PPL_MODEL_NAME = "facebook/esm2_t6_8M_UR50D"  # For perplexity calculation

# Global variables to store models (avoiding reloading)
ppl_model = None
ppl_tokenizer = None

# Load both models
ppl_model, ppl_tokenizer = load_perplexity_model(ppl_model_name=PPL_MODEL_NAME, device=device)

print("Models loaded successfully!")

## FUNCTIONS

In [None]:

### VOCABULARY

class ProteinVocabularyMask:
    def __init__(self):
        self.ALPHABET = "ACDEFGHIKLMNPQRSTVWY"
        self.MASK_TOKEN = 20
        self.VOCAB_SIZE = 21  # 20 AA + 1 MASK
        self.vocab_to_id = {aa: i for i, aa in enumerate(self.ALPHABET)}
        self.id_to_vocab = {i: aa for i, aa in enumerate(self.ALPHABET)}
        self.id_to_vocab[self.MASK_TOKEN] = 'X'
    
    def encode_sequence(self, seq):
        """Encode une séquence d'acides aminés en tensor d'IDs."""
        return torch.tensor([self.vocab_to_id[aa] for aa in seq], dtype=torch.long)
    
    def decode_sequence(self, tokens):
        """Décode un tensor d'IDs en séquence d'acides aminés."""
        return ''.join([self.id_to_vocab[int(tok)] for tok in tokens])
    
    def encode_batch(self, sequences):
        """Encode un batch de séquences."""
        return torch.stack([self.encode_sequence(seq) for seq in sequences])


### NOISE

class NoiseSchedule:
    def __init__(self, schedule_type='cosine'):
        self.schedule_type = schedule_type
    
    def get_noise_level(self, t):
        if self.schedule_type == 'linear':
            return t
        elif self.schedule_type == 'cosine':
            return 1 - np.cos(t * np.pi / 2)
        elif self.schedule_type == 'sqrt':
            return np.sqrt(t)


In [None]:
def forward_diffusion(x, t, noise_schedule, mask_token, ppl_model, ppl_tokenizer, vocab, device):
    """
    Forward diffusion qui masque les positions qui augmentent le plus la perplexité.
    """
    B, L = x.shape
    
    # TODO déterministe à changer
    # Calculer le nombre de positions à masquer pour chaque séquence
    num_masks = []
    for ti in t:
        mask_prob = noise_schedule.get_noise_level(float(ti))
        # Binomiale parce que somme de bernouilli (vaiid mask avec prob mask_prob)
        num_to_mask = torch.binomial(torch.tensor(L, dtype=torch.float), torch.tensor(mask_prob)).int().item()
        num_masks.append(num_to_mask)
    
    # Initialiser les masques
    mask = torch.zeros(B, L, dtype=torch.bool, device=x.device)
    
    # Pour chaque séquence, trouver les positions avec la plus haute perplexité
    for b in range(B):
        if num_masks[b] == 0:
            continue
            
        # Convertir la séquence en string
        sequence_str = vocab.decode_sequence(x[b].cpu().numpy())
        
        # Calculer la perplexité pour chaque position
        perplexities = []
        for pos in range(L):
            # Masquer cette position
            masked_sequence = sequence_str[:pos] + ppl_tokenizer.mask_token + sequence_str[pos+1:]
            
            # Évaluer avec le modèle ESM
            inputs = ppl_tokenizer(masked_sequence, return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = ppl_model(**inputs)
                probs = torch.softmax(outputs.logits, dim=-1)
                esm_pos = pos + 1  # Position après [CLS]
                
                # Probabilité du token original
                original_aa = vocab.id_to_vocab[x[b, pos].item()]
                original_token_id = ppl_tokenizer.convert_tokens_to_ids(original_aa)
                original_prob = probs[0, esm_pos, original_token_id].item()
                
                # Perplexité = -log(prob)
                perplexity = -torch.log(torch.tensor(original_prob + 1e-8)).item()
                perplexities.append(perplexity)
        
        # Sélectionner les positions avec la plus haute perplexité
        _, top_indices = torch.topk(torch.tensor(perplexities), num_masks[b])
        mask[b, top_indices] = True
    
    # Appliquer les masques
    xt = x.clone()
    xt[mask] = mask_token
    
    return xt, mask

In [None]:

### DENOISING TRANSFORMER

class DenoisingTransformer(nn.Module):
    def __init__(self, vocab_size, seq_length, d_model=256, n_heads=8, n_layers=6, dropout=0.1):
        super().__init__()
        
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(seq_length, d_model)
        self.time_emb = nn.Sequential(
            nn.Linear(1, d_model),
            nn.SiLU(),
            nn.Linear(d_model, d_model)
        )
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model, n_heads, d_model * 4, 
            dropout=dropout, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
        
        self.output_head = nn.Linear(d_model, 20)  # Seulement les 20 AA
        
    def forward(self, x, t):
        B, L = x.shape
        
        # Embeddings
        h = self.token_emb(x)
        h += self.pos_emb(torch.arange(L, device=x.device)).unsqueeze(0)
        h += self.time_emb(t.unsqueeze(1)).expand(-1, L, -1)
        
        # Transformer
        h = self.transformer(h)
        
        # Sortie
        return self.output_head(h)




In [None]:

### INFERENCE

@torch.no_grad()
def denoise_step(model, x, noise_schedule, t_current, t_next, mask_token):
    B, L = x.shape
    
    noise_curr = noise_schedule.get_noise_level(t_current)
    noise_next = noise_schedule.get_noise_level(t_next)
    
    if noise_curr > 1e-6:
        reveal_prob = (noise_curr - noise_next) / noise_curr
        # j'ai bien vérifié avec ma feuille de calculs prise en photo
    else:
        reveal_prob = 1.0
    
    # Prédictions du modèle
    t_tensor = torch.full((B, 1), t_current, device=x.device)
    logits = model(x, t_tensor)
    probs = F.softmax(logits, dim=-1)
    
    # Positions actuellement masquées
    mask_pos = (x == mask_token)
    
    # Décider quelles positions révéler
    reveal_mask = (torch.rand(B, L, device=x.device) < reveal_prob) & mask_pos
    
    # Échantillonner de nouveaux tokens
    x_new = x.clone()
    if reveal_mask.any():
        samples = torch.multinomial(probs[reveal_mask], 1).squeeze(-1)
        x_new[reveal_mask] = samples
    
    return x_new


@torch.no_grad()
def generate_sequences(model, n_samples, seq_length, noise_schedule, dt, mask_token):
    """Génère des séquences par débruitage itératif."""
    model.eval()

    device = next(model.parameters()).device
    x = torch.full((n_samples, seq_length), mask_token, dtype=torch.long, device=device)
    
    # Débruitage itératif
    t = 1.0
    while t > 0:
        t_next = max(t - dt, 0.0)
        x = denoise_step(model, x, noise_schedule, t, t_next, mask_token)
        t = t_next
    
    # Nettoyer les masques restants
    if (x == mask_token).any():
        t_tensor = torch.zeros((n_samples, 1), device=x.device)
        logits = model(x, t_tensor)
        probs = F.softmax(logits, dim=-1)
        mask_pos = (x == mask_token)
        if mask_pos.any():
            samples = torch.multinomial(probs[mask_pos], 1).squeeze(-1)
            x[mask_pos] = samples
    
    return x




In [None]:

### LOSS AND TRAINING

def compute_loss(model, x0, noise_schedule, mask_token, ppl_model, ppl_tokenizer, vocab, device):
    """Calcule la loss pour un batch."""
    B, L = x0.shape
    
    # Timesteps aléatoires
    t = torch.rand(B, device=x0.device)
    
    # Forward diffusion
    xt, mask = forward_diffusion(x0, t, noise_schedule, mask_token, ppl_model, ppl_tokenizer, vocab, device)
    
    # Prédictions du modèle
    logits = model(xt, t.unsqueeze(1))
    
    if mask.sum() == 0:
        return torch.tensor(0.0, device=x0.device, requires_grad=True), 0.0
    
    # Loss seulement sur les positions masquées
    loss = F.cross_entropy(logits[mask], x0[mask], reduction='mean')
    mask_ratio = mask.sum().item() / mask.numel()
    
    return loss, mask_ratio


def train_step(model, batch, optimizer, noise_schedule, mask_token, ppl_model, ppl_tokenizer, vocab, device):
    """Un pas d'entraînement."""
    model.train()
    
    loss, mask_ratio = compute_loss(model, batch, noise_schedule, mask_token, ppl_model, ppl_tokenizer, vocab, device)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item(), mask_ratio


def train_model(model, dataloader, optimizer, noise_schedule_fn, n_epochs, mask_token, ppl_model, ppl_tokenizer, vocab, device):
    """Boucle d'entraînement complète."""
    losses = []
    
    for epoch in tqdm(range(n_epochs), desc="Training"):
        epoch_losses = []
        
        for batch_data in dataloader:
            batch = batch_data[0].to(next(model.parameters()).device)
            loss, _ = train_step(model, batch, optimizer, noise_schedule_fn, mask_token, ppl_model, ppl_tokenizer, vocab, device)
            epoch_losses.append(loss)
        
        avg_loss = np.mean(epoch_losses)
        losses.append(avg_loss)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
    
    return losses



## RUN

In [None]:
# %% Main experiment workflow
print("=== STARTING EXPERIMENT ===")
config = load_experiment_config("/home/arthur/projets/protein-generation/configs/base_config.yaml")
config = setup_experiment_directory(config)

print(f"Experiment name: {config['experiment']['name']}")
print(f"Experiment directory: {config['exp_dir']}")

# Immediately save the configuration
save_experiment_config(config, config['exp_dir'])

In [None]:

# Device and data setup
device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.manual_seed(config['training']['seed'])
np.random.seed(config['training']['seed'])

protein_data = pd.read_csv(config['data']['input_file'])
sequences = protein_data['sequence'].tolist()
sequences = sequences[:config['training']['n_samples']]
print(f'{len(sequences)} séquences')
print(set([len(seq) for seq in sequences]))


# Model creation
vocabulary = ProteinVocabularyMask()
encoded = vocabulary.encode_batch(sequences)
dataset = torch.utils.data.TensorDataset(encoded)
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=config['training']['batch_size'], 
                                         shuffle=True)

model = DenoisingTransformer(
    vocab_size=vocabulary.VOCAB_SIZE,
    seq_length=config['model']['seq_length'],
    d_model=config['model']['d_model'],
    n_heads=config['model']['n_heads'],
    n_layers=config['model']['n_layers'],
    dropout=config['model']['dropout']
    ).to(device)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config['training']['learning_rate'])
    
# Noise schedule function
noise_schedule = NoiseSchedule(config['diffusion']['noise_schedule'])
    

In [None]:
print("\n=== TRAINING ===")
losses = train_model(
    model=model,
    dataloader=dataloader,
    optimizer=optimizer,                                                                
    noise_schedule_fn=noise_schedule,
    mask_token=vocabulary.MASK_TOKEN,
    n_epochs=config['training']['n_epochs']
)

# Plot losses
plot_and_save_losses(config, losses)

In [None]:
# Sequence generation
print("\n=== GENERATION ===")
generated_tokens = generate_sequences(
    model=model,
    n_samples=config['generation']['n_samples'],
    seq_length=config['model']['seq_length'],
    noise_schedule=noise_schedule,
    dt=config['generation']['dt'],
    mask_token=vocabulary.MASK_TOKEN
)

# Decode sequences
generated_sequences = [vocabulary.decode_sequence(seq) for seq in generated_tokens]

# Display and save results
display_sample_sequences(config, generated_sequences)

In [None]:

print("\n=== SAVING RESULTS ===")
save_results(config, model, losses, generated_sequences)

# Final summary
print(f"\n=== EXPERIMENT COMPLETE ===")
print(f"Name: {config['experiment']['name']}")
print(f"Directory: {config['exp_dir']}")
print(f"Training sequences: {len(sequences)}")
print(f"Epochs: {config['training']['n_epochs']}")
print(f"Final loss: {losses[-1]:.4f}")
print(f"Generated sequences: {len(generated_sequences)}")

print("\nSaved files:")
for fname in os.listdir(config['exp_dir']):
    print(f"  - {fname}")

print("\nExperiment finished successfully!")