# SMC-guided D3PM generation with a simple prefix reward
In this notebook, we guide sampling using Sequential Monte Carlo (SMC) to encourage the first four residues to match the target prefix MSTQ.

In [2]:
# Imports
import os
import torch
import numpy as np
from pprint import pprint

from evodiff.pretrained import D3PM_UNIFORM_38M
from evodiff.smc_generate import generate_d3pm, generate_d3pm_smc, prefix_reward_mstq, batch_prefix_rewards

# Select device
if torch.cuda.is_available():
    device = torch.device('cuda:0')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

print(f"Using device: {device}")

  import pkg_resources


Using device: cpu


In [3]:
# Load D3PM model (uniform 38M) and tokenizer
model, collater, tokenizer, scheme, dt, Q_bar, Q = D3PM_UNIFORM_38M(return_all=True)
model = model.eval().to(device)
Q_bar = Q_bar.to(device)
Q = Q.to(device)
print("Scheme:", scheme, "Timesteps:", dt, "Tokenizer.K:", tokenizer.K)

sohl-dickstein
Scheme: d3pm Timesteps: 500 Tokenizer.K: 26


In [4]:
# Baseline sampling without SMC (ancestral D3PM)
seq_len = 64
batch_size = 30
with torch.no_grad():
    sample_base, strings_base = generate_d3pm(model, tokenizer, Q, Q_bar, dt, seq_len, batch_size=batch_size, device=str(device))

# Compute rewards and quick stats
rewards_base = [prefix_reward_mstq(s) for s in strings_base]
match_base = sum(1 for s in strings_base if s[:4] == 'MSTQ')
print(f"Baseline: exact MSTQ matches: {match_base}/{batch_size}; avg reward: {np.mean(rewards_base):.2f}")
print("Sample baseline sequences (first 5):")
pprint(strings_base[:5])

  0%|          | 1/499 [00:05<48:21,  5.83s/it]


KeyboardInterrupt: 

In [None]:
# SMC-guided sampling with different configurations
seq_len = 10
batch_size = 20

configs = [
    {"name": "SMC alpha=1.0 every=1", "reward_scale": 1.0, "smc_every": 1},
    {"name": "SMC alpha=3.0 every=1", "reward_scale": 3.0, "smc_every": 1},
    {"name": "SMC alpha=1.0 every=10", "reward_scale": 1.0, "smc_every": 10},
]

results = []
for cfg in configs:
    with torch.no_grad():
        sample_smc, strings_smc, rewards_smc = generate_d3pm_smc(
            model, tokenizer, Q, Q_bar, dt, seq_len, batch_size=batch_size, device=str(device),
            reward_scale=cfg["reward_scale"], smc_every=cfg["smc_every"]
)
    match = sum(1 for s in strings_smc if s[:4] == 'MSTQ')
    avg_reward = float(torch.mean(rewards_smc).cpu().item())
    results.append((cfg["name"], match, avg_reward))
    print(f"{cfg['name']}: exact MSTQ matches: {match}/{batch_size}; avg reward: {avg_reward:.2f}")
    print("Sample sequences (first 3):")
    pprint(strings_smc[:3])

NameError: name 'torch' is not defined

## Interprétation des effets de SMC et du scaling de la récompense
- Lorsque le scaling (alpha) augmente, la pondération exp(alpha * reward) privilégie davantage les particules dont le préfixe se rapproche de MSTQ. On observe donc une hausse de la fréquence de préfixes exacts MSTQ et une augmentation de la récompense moyenne, au prix d’une diversité potentiellement moindre (plus de duplications lors du rééchantillonnage).
- À l’inverse, diminuer alpha (jusqu’à 0) revient à annuler l’effet de la récompense et on récupère la distribution non guidée du modèle.
- La fréquence de SMC (smc_every) contrôle à quel rythme on « corrige » la population. SMC à chaque pas (every=1) pousse fortement vers MSTQ rapidement; des SMC plus espacés (par ex. every=10) laissent plus de liberté au modèle entre deux corrections, ce qui maintient davantage de diversité mais ralentit la convergence vers le préfixe cible.

En pratique, on choisit alpha et smc_every pour équilibrer ciblage du préfixe et diversité. Un bon point de départ est alpha ∈ [1, 2] avec every=1 pour une contrainte forte, ou every∈{5,10} pour un compromis.