
# ProT‑Diff Debug‑First Notebook (ipynb 工作流)

这本 **ipynb** 设计成“先能跑‑再做强”的调试流程：先最小闭环（能采样出序列并存表），再逐步切换为论文设置。  
> 两种模式：  
> - **Quick Debug Mode**：不下载模型，几秒完成端到端（用伪嵌入器和小 U‑Net），确保环境 OK；  
> - **Full Mode**：使用 **ProtT5‑XL‑UniRef50** 编码/解码 + 扩散模型，贴合论文（可用 CPU/GPU）。

你可以先跑 Debug 模式通关，再切到 Full 模式进行正式训练/采样/筛选。  
最后会输出 **`generated_seqs.csv`**（待筛选）与 **`logs/`、`checkpoints/`** 等。


## 1) 环境准备（可重复执行）

In [None]:

# !pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# !pip install -U transformers einops numpy pandas scikit-learn tqdm matplotlib

import os, math, json, random, time
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

try:
    from transformers import T5EncoderModel, T5ForConditionalGeneration, AutoTokenizer
    TRANSFORMERS_OK = True
except Exception:
    TRANSFORMERS_OK = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE, "Transformers:", TRANSFORMERS_OK)
os.makedirs("logs", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)


## 2) 配置（Debug/Full 两套参数一键切换）

In [None]:

from dataclasses import dataclass

@dataclass
class DiffConfig:
    max_len: int = 48
    dim: int = 1024
    train_steps: int = 2000
    sample_steps: int = 200
    predict_x0: bool = True
    lr: float = 2e-4
    batch_size: int = 16
    epochs: int = 2
    seed: int = 42
    uniform_noise_sampling: bool = False
    save_every: int = 1

    debug_mode: bool = True
    debug_dim: int = 128
    debug_epochs: int = 1
    debug_sample_n: int = 32

    sample_n: int = 200
    ckpt_path: str = "checkpoints/diffusion.pt"

cfg = DiffConfig()
print(cfg)
random.seed(cfg.seed); np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(cfg.seed)


## 3) 数据：占位与加载（把你的训练序列替换进来）

In [None]:

toy_train = ["GIGKFLKKAKKFGKAFVKILKK", "LLKKLLKKLLKKLL", "VQGLLKALKKAL", "KKLLKLLKLLKLLK"] * 64

def load_txt_lines(path):
    if not os.path.exists(path):
        return None
    with open(path, "r") as f:
        return [ln.strip() for ln in f if ln.strip()]

user_train = load_txt_lines("data/train.txt")
if user_train:
    train_seqs = user_train
    print(f"Loaded {len(train_seqs)} sequences from data/train.txt")
else:
    train_seqs = toy_train
    print(f"Using toy set: {len(train_seqs)} sequences")

AA = set(list("ACDEFGHIKLMNPQRSTVWY"))
def clean_seq(s):
    s = s.strip().upper()
    return ''.join([c for c in s if c in AA])[:cfg.max_len]

train_seqs = [clean_seq(s) for s in train_seqs if 5 <= len(clean_seq(s)) <= cfg.max_len]
print("After clean:", len(train_seqs))


## 4) 编码器/解码器：Debug 伪嵌入 vs ProtT5‑XL（Full）

In [None]:

class FakeEmbedder:
    def __init__(self, dim=cfg.dim, max_len=cfg.max_len):
        self.dim = dim; self.max_len = max_len
    def encode(self, seqs):
        embs = []
        for s in seqs:
            L = min(len(s), self.max_len)
            rng = np.random.RandomState(abs(hash(s)) % (2**32))
            e = rng.randn(L, self.dim).astype(np.float32) * 0.2
            pos = np.linspace(0, 1, L, dtype=np.float32)[:, None]
            e += pos * 0.1
            pad = np.zeros((self.max_len - L, self.dim), dtype=np.float32)
            embs.append(np.vstack([e, pad]))
        return torch.tensor(np.stack(embs), dtype=torch.float32)

    def decode(self, embs):
        B, L, D = embs.shape
        thr = 0.05
        proto = torch.randn(20, D, device=embs.device)
        idx2aa = list("ACDEFGHIKLMNPQRSTVWY")
        outs = []
        for b in range(B):
            x = embs[b]
            mask = (x.norm(dim=-1) > thr)
            x = x[mask]
            if x.numel() == 0:
                outs.append("G")
                continue
            sim = torch.matmul(x, proto.T)  # L x 20
            ids = sim.argmax(dim=-1).tolist()
            outs.append(''.join(idx2aa[i] for i in ids))
        return outs

class ProtT5Wrapper:
    def __init__(self, device=DEVICE):
        self.device = device
        from transformers import T5EncoderModel, T5ForConditionalGeneration, AutoTokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50")
        self.encoder = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50").to(device)
        self.decoder = T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_uniref50").to(device)
        self.encoder.eval(); self.decoder.eval()

    @torch.no_grad()
    def encode(self, seqs):
        toks = self.tokenizer(seqs, return_tensors="pt", padding=True, truncation=True)
        toks = {k: v.to(self.device) for k, v in toks.items()}
        out = self.encoder(**toks).last_hidden_state  # (B, L, 1024)
        B, L, D = out.shape
        if L > cfg.max_len:
            out = out[:, :cfg.max_len, :]
            L = cfg.max_len
        if L < cfg.max_len:
            pad = torch.zeros(B, cfg.max_len - L, D, device=out.device)
            out = torch.cat([out, pad], dim=1)
        return out.float().cpu()

    @torch.no_grad()
    def decode(self, embs):
        from transformers.modeling_outputs import BaseModelOutput
        B, L, D = embs.shape
        x = embs.to(self.device)
        mask = (x.norm(dim=-1) > 1e-4).float()
        enc_out = BaseModelOutput(last_hidden_state=x)
        ids = self.decoder.generate(encoder_outputs=enc_out, max_length=cfg.max_len+2)
        return self.tokenizer.batch_decode(ids, skip_special_tokens=True)

def build_embedder():
    if cfg.debug_mode or (not TRANSFORMERS_OK):
        print("Using FakeEmbedder (Debug Mode ON or transformers not available).")
        return FakeEmbedder()
    try:
        print("Loading ProtT5‑XL…")
        return ProtT5Wrapper()
    except Exception as e:
        print("ProtT5 load failed, fallback to FakeEmbedder:", e)
        return FakeEmbedder()

embedder = build_embedder()


## 5) 扩散模型：轻量 1D U‑Net（预测 x0）

In [None]:

class ResidualBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.GroupNorm(8, ch),
            nn.SiLU(),
            nn.Conv1d(ch, ch, 3, padding=1),
            nn.GroupNorm(8, ch),
            nn.SiLU(),
            nn.Conv1d(ch, ch, 3, padding=1),
        )
    def forward(self, x):
        return x + self.net(x)

class TinyUNet1D(nn.Module):
    # Input/Output: (B, L, D) -> (B, L, D); internally transpose to (B, D, L)
    def __init__(self, dim_in, dim_model, predict_x0=True):
        super().__init__()
        self.predict_x0 = predict_x0
        C = dim_model
        self.inp = nn.Conv1d(dim_in, C, 1)
        self.down = nn.Sequential(
            ResidualBlock(C),
            nn.Conv1d(C, C, 2, stride=2),  # L//2
            ResidualBlock(C),
        )
        self.bot = ResidualBlock(C)
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv1d(C, C, 3, padding=1),
            ResidualBlock(C),
        )
        self.out = nn.Conv1d(C, dim_in, 1)

    def forward(self, x):
        x = x.transpose(1, 2)  # (B, D, L)
        x = self.inp(x)
        d = self.down(x)
        b = self.bot(d)
        u = self.up(b)
        if u.size(-1) != x.size(-1):
            u = F.interpolate(u, size=x.size(-1), mode="nearest")
        y = self.out(u).transpose(1, 2)  # (B, L, D)
        return y

def make_model():
    dim_in = cfg.dim
    dim_model = cfg.debug_dim if cfg.debug_mode else 256
    return TinyUNet1D(dim_in, dim_model, predict_x0=cfg.predict_x0).to(DEVICE)

model = make_model()
sum_p = sum(p.numel() for p in model.parameters())
print("Model params:", sum_p/1e6, "M")


## 6) 扩散过程工具（sqrt schedule / DDPM 采样）

In [None]:

def sqrt_schedule(T):
    t = torch.linspace(1e-4, 1.0, T)
    betas = t**0.5 * 1e-2
    betas = betas.clamp(1e-6, 0.02)
    return betas

class Diffusion:
    def __init__(self, T, predict_x0=True, uniform_noise=False):
        self.T = T
        self.predict_x0 = predict_x0
        self.betas = sqrt_schedule(T).to(DEVICE)
        self.alphas = 1. - self.betas
        self.alphas_bar = torch.cumprod(self.alphas, dim=0)
        self.uniform = uniform_noise

    def q_sample(self, x0, t):
        a_bar = self.alphas_bar[t].view(-1, 1, 1)
        if self.uniform:
            eps = torch.empty_like(x0).uniform_(-1, 1)
        else:
            eps = torch.randn_like(x0)
        return a_bar.sqrt()*x0 + (1 - a_bar).sqrt()*eps, eps

    def p_losses(self, model, x0, t):
        xt, eps = self.q_sample(x0, t)
        pred = model(xt)
        target = x0 if self.predict_x0 else eps
        return F.mse_loss(pred, target)

    @torch.no_grad()
    def p_sample_loop(self, model, shape, steps=200):
        B, L, D = shape
        if self.uniform:
            x = torch.empty(B, L, D, device=DEVICE).uniform_(-1, 1)
        else:
            x = torch.randn(B, L, D, device=DEVICE)
        ts = torch.linspace(self.T-1, 0, steps).long()
        for ti in tqdm(ts, desc="Sampling", leave=False):
            beta = self.betas[ti]
            alpha = self.alphas[ti]
            a_bar = self.alphas_bar[ti]
            eps = model(x)
            if self.predict_x0:
                x0 = (x - (1 - a_bar).sqrt()*eps) / (a_bar.sqrt() + 1e-8)
                mean = alpha.sqrt()*x0 + (1 - alpha).sqrt()*eps
            else:
                mean = (1/alpha.sqrt())*(x - beta/((1 - a_bar).sqrt()+1e-8)*eps)
            if ti > 0:
                z = torch.randn_like(x) if not self.uniform else torch.empty_like(x).uniform_(-1,1)
                x = mean + self.betas[ti].sqrt()*z
            else:
                x = mean
        return x

diff = Diffusion(T=cfg.train_steps, predict_x0=cfg.predict_x0,
                 uniform_noise=cfg.uniform_noise_sampling)


## 7) 数据集封装：把序列 → (48, 1024) 嵌入

In [None]:

class EmbedDataset(Dataset):
    def __init__(self, seqs, embedder):
        self.seqs = seqs
        self.embedder = embedder
    def __len__(self):
        return len(self.seqs)
    def __getitem__(self, idx):
        return self.seqs[idx]

def collate_embed(batch):
    embs = embedder.encode(batch)  # (B, 48, 1024)
    return embs

ds = EmbedDataset(train_seqs, embedder)
dl = DataLoader(ds, batch_size=(4 if cfg.debug_mode else cfg.batch_size),
                shuffle=True, collate_fn=collate_embed, drop_last=True)
print("Batches:", len(dl))


## 8) 训练（先 Debug 快跑一遍）

In [None]:

def train(model, dl, epochs, ckpt):
    model.train()
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
    step = 0
    for ep in range(1, epochs+1):
        pbar = tqdm(dl, desc=f"Epoch {ep}/{epochs}")
        for x0 in pbar:
            x0 = x0.to(DEVICE)
            t = torch.randint(0, diff.T, (x0.size(0),), device=DEVICE)
            loss = diff.p_losses(model, x0, t)
            opt.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            step += 1
            pbar.set_postfix(loss=float(loss))
        if ep % cfg.save_every == 0:
            torch.save({"model": model.state_dict(), "cfg": cfg.__dict__}, ckpt)
            print("Saved:", ckpt)

if cfg.debug_mode:
    print(">>> Debug training …")
    train(model, dl, epochs=cfg.debug_epochs, ckpt="checkpoints/debug.pt")
else:
    print(">>> Full training …")
    train(model, dl, epochs=cfg.epochs, ckpt=cfg.ckpt_path)


## 9) 采样 + 解码 → 待筛选序列 `generated_seqs.csv`

In [None]:

@torch.no_grad()
def sample_and_decode(n=100):
    model.eval()
    B = n
    shape = (B, cfg.max_len, cfg.dim)
    x = diff.p_sample_loop(model, shape, steps=(50 if cfg.debug_mode else cfg.sample_steps))
    if DEVICE != "cpu":
        x = x.float().cpu()
    seqs = embedder.decode(x)
    return seqs

N = (cfg.debug_sample_n if cfg.debug_mode else cfg.sample_n)
gen = sample_and_decode(N)
df = pd.DataFrame({"seq": gen})
df.to_csv("generated_seqs.csv", index=False)
print("Saved generated_seqs.csv with", len(df), "rows")
df.head()


## 10) 快速体检过滤（长度/重复/净电荷/R+K占比等）

In [None]:

pka = {"K":10.5, "R":12.5, "H":6.0, "D":3.9, "E":4.1}
def net_charge(seq, ph=7.4):
    pos = sum(1/(1+10**(ph-pka.get(aa, 0))) for aa in seq if aa in "KRH")
    neg = sum(1/(1+10**(pka.get(aa, 14)-ph)) for aa in seq if aa in "DE")
    return pos - neg

def has_long_repeat(seq, k=6):
    for aa in set(seq):
        if aa*k in seq:
            return True
    return False

def rk_ratio(seq):
    if not seq: return 0.0
    rk = sum(aa in "RK" for aa in seq)
    return rk/len(seq)

df = pd.read_csv("generated_seqs.csv")
df["len"] = df.seq.str.len()
df["charge"] = df.seq.apply(net_charge)
df["rk_ratio"] = df.seq.apply(rk_ratio)
df["long_repeat"] = df.seq.apply(has_long_repeat)
filtered = df[
    (df["len"].between(5, 48)) &
    (~df["long_repeat"]) &
    (df["charge"] > 0) &
    (df["rk_ratio"] <= 0.40)
].copy()

filtered.to_csv("generated_seqs.filtered.csv", index=False)
print("After quick filters:", len(filtered), "kept; saved -> generated_seqs.filtered.csv")
filtered.head()



## 11) 切到 Full 模式（贴论文设置）

1. 在 **配置** 单元格把 `debug_mode=False`，并调整 `epochs`、`batch_size`、`sample_n`；  
2. 确保安装 `transformers` 并可以拉取 **Rostlab/prot_t5_xl_uniref50**；  
3. 把你的训练集写到 `data/train.txt`（每行一个序列，长度 5–48）；  
4. 从 **第 4 节** 开始依次重新运行到 **第 10 节**，得到 `generated_seqs.filtered.csv`；  
5. 后续把该 CSV 喂给判别器 + 完整理化筛选（另附 notebook）。

> 论文关键点（用于参数对齐）：“固定到 `(48,1024)` 的 latent；训练 2000 步；采样下采样为 200 步；
> 解码前根据行范数阈值去掉接近 0 的 padding 行”。
