In [1]:
!pip install fair-esm
!pip install biopython

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m921.6 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0
Collecting biopython
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.85


In [2]:
from Bio import SeqIO
seq_path='/kaggle/input/uniref50-sub/uniref50_subsample.fasta'
sequences=[]
for seq_record in SeqIO.parse(seq_path, "fasta"):
    sequences.append(str(seq_record.seq))
print(len(sequences))

1000000


In [3]:
# ============================================================
# VAE + Fresh Small Surrogate (Identity adapter) — LEAKY TF MODE
#  - CE: teacher forcing with memory='encoder'  (정보 누설, CE≈0)
#  - Align: surrogate memory vs encoder memory (MSEᵐ + COS)
# ============================================================

import os, math, time, json, random, datetime
from typing import List, Tuple
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import esm

# -------------------
# Config
# -------------------
SEED   = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(SEED); random.seed(SEED)
if DEVICE.type == "cuda": torch.cuda.manual_seed_all(SEED)

MAX_LEN     = 512
BATCH_SIZE  = 128
TRAIN_RATIO = 0.9

EMB_DIM     = 256
LATENT_DIM  = 256
NUM_LAYERS  = 4      # encoder/decoder layers (set to match your ckpt)
NUM_HEADS   = 4
FFN_DIM     = 512
DROPOUT     = 0.10

# Loss weights
W_CE  = 1.0
W_MSE = 5.0
W_COS = 5.0

# -------------------
# Data (EOS-less TF)
# -------------------
class ProteinDataset(Dataset):
    def __init__(self, sequences: List[str], alphabet, max_len: int = MAX_LEN):
        self.alphabet = alphabet
        self.max_len  = max_len
        self.tokens, self.lengths = [], []
        for s in sequences:
            ids = [alphabet.get_idx(c) for c in s][:max_len]
            if len(ids) == 0: continue
            self.tokens.append(torch.tensor(ids, dtype=torch.long))
            self.lengths.append(len(ids))
    def __len__(self): return len(self.tokens)
    def __getitem__(self, idx):
        return self.tokens[idx], self.lengths[idx]

def _collate_eosless(batch, pad_idx: int):
    seqs, lens = zip(*batch)
    x = pad_sequence(seqs, batch_first=True, padding_value=pad_idx)
    Lmax = x.size(1)
    m = torch.zeros((len(seqs), Lmax), dtype=torch.bool)
    for i, L in enumerate(lens): m[i, :L] = True
    return x, m

def make_loaders(sequences, alphabet, batch_size=BATCH_SIZE, train_ratio=TRAIN_RATIO, max_len=MAX_LEN, seed=SEED):
    PAD = getattr(alphabet, "pad_idx", alphabet.get_idx("<pad>"))
    ds = ProteinDataset(sequences, alphabet, max_len=max_len)
    n_tr = int(len(ds)*train_ratio); n_va = len(ds)-n_tr
    g = torch.Generator().manual_seed(seed)
    ds_tr, ds_va = random_split(ds, [n_tr, n_va], generator=g)
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True,  num_workers=2,
                       pin_memory=(DEVICE.type=="cuda"),
                       collate_fn=lambda b: _collate_eosless(b, PAD), drop_last=True)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, num_workers=2,
                       pin_memory=(DEVICE.type=="cuda"),
                       collate_fn=lambda b: _collate_eosless(b, PAD), drop_last=False)
    return dl_tr, dl_va, PAD

# -------------------
# Models
# -------------------
class SmallTransformer(nn.Module):
    def __init__(self, vocab_size: int, emb_dim: int, layers: int, heads: int,
                 ffn_dim: int, max_len: int, pad_idx: int):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.pos = nn.Parameter(torch.zeros(1, max_len, emb_dim))
        enc_layer = nn.TransformerEncoderLayer(
            d_model=emb_dim, nhead=heads, dim_feedforward=ffn_dim,
            batch_first=True, activation="gelu", dropout=DROPOUT
        )
        self.enc = nn.TransformerEncoder(enc_layer, layers)
        self.ln  = nn.LayerNorm(emb_dim)
    def forward(self, x: torch.Tensor):
        mask = x != self.emb.padding_idx
        h = self.emb(x) + self.pos[:, :x.size(1), :]
        h = self.enc(h, src_key_padding_mask=~mask)
        return self.ln(h), mask

class VAETransformerDecoder(nn.Module):
    def __init__(self, encoder: SmallTransformer, vocab_size: int,
                 latent_dim: int, emb_dim: int,
                 num_layers: int, num_heads: int, ffn_dim: int,
                 max_len: int, pad_token: int, bos_token: int):
        super().__init__()
        self.encoder   = encoder
        self.pad_token = pad_token
        self.bos_token = bos_token
        self.to_mu      = nn.Linear(emb_dim, latent_dim)
        self.to_logvar  = nn.Linear(emb_dim, latent_dim)
        self.latent2emb = nn.Linear(latent_dim, emb_dim)
        self.dec_emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_token)
        self.dec_pos = nn.Parameter(torch.zeros(1, max_len, emb_dim))
        dec_layer = nn.TransformerDecoderLayer(
            d_model=emb_dim, nhead=num_heads, dim_feedforward=ffn_dim,
            dropout=DROPOUT, batch_first=True
        )
        self.decoder = nn.TransformerDecoder(dec_layer, num_layers)
        self.out     = nn.Linear(emb_dim, vocab_size)

class Z2MemorySurrogate(nn.Module):
    """Small surrogate: d_model=256, layers=2, heads=4."""
    def __init__(self, d_model: int, latent_dim: int, max_len: int,
                 layers: int = 2, heads: int = 4, ffn_dim: int = None, dropout: float = DROPOUT):
        super().__init__()
        if ffn_dim is None: ffn_dim = 3 * d_model
        self.pos   = nn.Parameter(torch.zeros(1, max_len, d_model))
        self.token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.z_proj= nn.Linear(latent_dim, d_model)
        self.z_ln  = nn.LayerNorm(d_model)
        enc_layer  = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=heads, dim_feedforward=ffn_dim,
            batch_first=True, activation="gelu", dropout=dropout
        )
        self.enc   = nn.TransformerEncoder(enc_layer, num_layers=layers)
        self.out_ln= nn.LayerNorm(d_model)
    def forward(self, z, mask_bool, causal_self: bool = False):
        B, L = mask_bool.shape
        base = self.token.expand(B, L, -1) + self.pos[:, :L, :]
        zemb = self.z_ln(self.z_proj(z)).unsqueeze(1).expand(-1, L, -1)
        h = base + zemb
        src_mask = None
        if causal_self:
            src_mask = torch.triu(torch.full((L, L), float('-inf'), device=h.device), diagonal=1)
        h = self.enc(h, mask=src_mask, src_key_padding_mask=~mask_bool)
        return self.out_ln(h), mask_bool

class VAEWithSurrogateBundle(nn.Module):
    def __init__(self, vae, surrogate, sur_adapter):
        super().__init__()
        self.vae, self.surrogate, self.sur_adapter = vae, surrogate, sur_adapter

    def _teacher_logits(self, x: torch.Tensor, x_mask: torch.Tensor,
                        memory: torch.Tensor, memory_mask: torch.Tensor,
                        z: torch.Tensor, inject_z: bool = True):
        B, L = x.size()
        dec_in = torch.full((B, L), self.vae.bos_token, device=x.device, dtype=torch.long)
        dec_in[:, 1:] = x[:, :-1]
        tgt = self.vae.dec_emb(dec_in) + self.vae.dec_pos[:, :L, :]
        if inject_z:
            z_emb = self.vae.latent2emb(z).unsqueeze(1).expand(-1, L, -1)
            tgt = tgt + z_emb
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(L).to(x.device)
        h_dec = self.vae.decoder(
            tgt=tgt, memory=memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=~x_mask,
            memory_key_padding_mask=(~memory_mask) if (memory_mask is not None) else None
        )
        return self.vae.out(h_dec)

    def forward(self, x: torch.Tensor,
                use_surrogate: bool = False,
                deterministic_z: bool = False,
                inject_z: bool = True):
        x_mask = (x != self.vae.pad_token)
        h_enc, enc_mask = self.vae.encoder(x)
        pooled = (h_enc * enc_mask.unsqueeze(-1)).sum(1) / enc_mask.sum(1, True).clamp_min(1)
        mu, logvar = self.vae.to_mu(pooled), self.vae.to_logvar(pooled)
        z = mu if deterministic_z else mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)
        if not use_surrogate:
            logits = self._teacher_logits(x, x_mask, h_enc, enc_mask, z, inject_z=inject_z)
            return logits, mu, logvar, (h_enc, enc_mask, z, None, None)
        sur_mem, sur_mask = self.surrogate(z, enc_mask, causal_self=False)
        mem = sur_mem if (self.sur_adapter is None) else self.sur_adapter(sur_mem)
        logits = self._teacher_logits(x, x_mask, mem, sur_mask, z, inject_z=inject_z)
        return logits, mu, logvar, (h_enc, enc_mask, z, sur_mem, sur_mask)

    def save(self, path: str, extra_meta: dict = None):
        payload = {
            "bundle_version": 2,
            "vae":         self.vae.state_dict(),
            "surrogate":   self.surrogate.state_dict(),
            "sur_adapter": None,  # Identity
            "meta": {"saved_at": datetime.datetime.now().isoformat()}
        }
        if extra_meta: payload["meta"].update(extra_meta)
        torch.save(payload, path)
        print(f"[save] VAEWithSurrogateBundle → {path}")

# -------------------
# Loss helpers
# -------------------
def masked_ce(logits: torch.Tensor, tgt: torch.Tensor, mask: torch.Tensor, ignore_index: int, label_smoothing: float = 0.0):
    B,L,V = logits.shape
    flat_logits = logits.view(B*L, V)
    flat_tgt    = tgt.view(B*L)
    valid       = mask.view(B*L)
    if valid.sum() == 0: return flat_logits.new_zeros(())
    return F.cross_entropy(flat_logits[valid], flat_tgt[valid],
                           ignore_index=ignore_index, label_smoothing=label_smoothing)

def masked_mse_mean(a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor):
    diff2 = (a - b).pow(2).mean(-1)   # mean over channels
    denom = mask.float().sum().clamp_min(1)
    return (diff2 * mask.float()).sum() / denom

def masked_cosine_loss(a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor):
    cs = F.cosine_similarity(a, b, dim=-1, eps=1e-8)
    loss = 1.0 - cs
    denom = mask.float().sum().clamp_min(1)
    return (loss * mask.float()).sum() / denom

# -------------------
# Build bundle (load VAE ckpt['model_sd']; surrogate fresh; adapter Identity)
# -------------------
def build_bundle_from_vae_ckpt(vae_ckpt_path: str, alphabet) -> VAEWithSurrogateBundle:
    PAD = getattr(alphabet, "pad_idx", alphabet.get_idx("<pad>"))
    BOS = getattr(alphabet, "bos_idx", alphabet.get_idx("<cls>"))
    vocab = len(alphabet.all_toks)

    enc = SmallTransformer(vocab_size=vocab, emb_dim=EMB_DIM,
                           layers=NUM_LAYERS, heads=NUM_HEADS, ffn_dim=FFN_DIM,
                           max_len=MAX_LEN, pad_idx=PAD).to(DEVICE)
    vae = VAETransformerDecoder(encoder=enc, vocab_size=vocab,
                                latent_dim=LATENT_DIM, emb_dim=EMB_DIM,
                                num_layers=NUM_LAYERS, num_heads=NUM_HEADS, ffn_dim=FFN_DIM,
                                max_len=MAX_LEN, pad_token=PAD, bos_token=BOS).to(DEVICE)

    ckpt = torch.load(vae_ckpt_path, map_location="cpu")
    assert "model_sd" in ckpt, "VAE ckpt must contain 'model_sd'."
    miss = vae.load_state_dict(ckpt["model_sd"], strict=False)
    print(f"[VAE] loaded from 'model_sd'  missing={len(miss.missing_keys)}  unexpected={len(miss.unexpected_keys)}")

    sur = Z2MemorySurrogate(d_model=256, latent_dim=LATENT_DIM, max_len=MAX_LEN,
                            layers=2, heads=4, ffn_dim=3*256, dropout=DROPOUT).to(DEVICE)
    adapter = nn.Identity().to(DEVICE)  # ← 완전 제거(Identity)
    print("[SUR] initialized small surrogate (2L,4H,256D)")
    print("[ADAPTER] Identity (no params)")

    bundle = VAEWithSurrogateBundle(vae, sur, adapter).to(DEVICE)
    return bundle

# -------------------
# Teacher-forced logits with selectable memory (LEAKY TF)
# -------------------
def tf_logits_with_memory(bundle, x, xmask, z, mem, mem_mask, inject_z=True):
    return bundle._teacher_logits(x, xmask, mem, mem_mask, z, inject_z=inject_z)

# -------------------
# Training (LEAKY TF): CE uses encoder memory; Align mem_sur ↔ h_enc
# -------------------
def finetune_surrogate_adapter_leaky(
    bundle: VAEWithSurrogateBundle,
    sequences,
    alphabet,
    out_best="./vae_sur_leaky_best.pt",
    out_last="./vae_sur_leaky_last.pt",
    epochs=3,
    batch_size=128,
    lr=1e-4,
    weight_decay=0.01,
    label_smoothing=0.0,
    grad_clip=1.0,
    log_every=20,
):
    dl_tr, dl_va, PAD = make_loaders(sequences, alphabet, batch_size=batch_size)

    # Freeze VAE, train surrogate (+ adapter if any; Identity → 0 params)
    for p in bundle.vae.parameters(): p.requires_grad = False
    for p in bundle.surrogate.parameters(): p.requires_grad = True
    adp_params = list(getattr(bundle.sur_adapter, "parameters", lambda: [])())
    for p in adp_params: p.requires_grad = True

    groups = [{"params":[p for p in bundle.surrogate.parameters() if p.requires_grad]}]
    if any(p.requires_grad for p in adp_params):
        groups.append({"params":[p for p in adp_params if p.requires_grad]})

    opt = torch.optim.AdamW(groups, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.95))
    scaler = torch.amp.GradScaler('cuda', enabled=(DEVICE.type=="cuda"))

    print(f"[info] trainable params = {sum(p.numel() for g in groups for p in g['params']):,}")
    print(f"[info] frozen    params = {sum(p.numel() for p in bundle.vae.parameters()):,}")

    best = float("inf")

    for ep in range(1, epochs+1):
        bundle.train()
        pbar = tqdm(dl_tr, desc=f"[Train(leaky enc-mem) ep={ep}/{epochs}]", dynamic_ncols=True)
        run = defaultdict(float)

        for i, (x, xmask) in enumerate(pbar, start=1):
            x, xmask = x.to(DEVICE), xmask.to(DEVICE)
            opt.zero_grad(set_to_none=True)

            # Encode → z
            h_enc, enc_mask = bundle.vae.encoder(x)
            pooled = (h_enc * enc_mask.unsqueeze(-1)).sum(1) / enc_mask.sum(1, True).clamp_min(1)
            mu, logvar = bundle.vae.to_mu(pooled), bundle.vae.to_logvar(pooled)
            z = mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)

            # Surrogate mem (to be aligned)
            sur_mem, sur_mask = bundle.surrogate(z, enc_mask, causal_self=False)
            mem_sur = sur_mem if (bundle.sur_adapter is None) else bundle.sur_adapter(sur_mem)

            with torch.amp.autocast('cuda', enabled=(DEVICE.type=="cuda")):
                # ★ CE with LEAKY memory = encoder h_enc
                logits = tf_logits_with_memory(bundle, x, xmask, z, h_enc, enc_mask, inject_z=True)
                vm_ce  = (xmask & enc_mask)
                ce = masked_ce(logits, x, vm_ce, ignore_index=PAD, label_smoothing=label_smoothing)

                # Align surrogate memory ↔ encoder memory
                vm_align = (xmask & enc_mask & sur_mask)
                mse = masked_mse_mean(mem_sur, h_enc, vm_align)
                cos = masked_cosine_loss(mem_sur, h_enc, vm_align)

                loss = W_CE*ce + W_MSE*mse + W_COS*cos

            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            if grad_clip and grad_clip>0:
                torch.nn.utils.clip_grad_norm_([p for g in groups for p in g["params"]], grad_clip)
            scaler.step(opt); scaler.update()

            # logs
            run["loss"] += float(loss); run["ce"] += float(ce); run["mse"] += float(mse); run["cos"] += float(cos); run["n"] += 1
            if i % log_every == 0 or i == 1:
                pbar.set_postfix({"loss": f"{float(loss):.3f}", "ce": f"{float(ce):.4f}", "mse": f"{float(mse):.3f}", "cos": f"{float(cos):.3f}"})

        # ---- epoch summary
        n = max(1, int(run["n"]))
        tr = {k: run[k]/n for k in ["loss","ce","mse","cos"]}

        # ---- validation (deterministic z=μ)
        bundle.eval()
        val = defaultdict(float)
        with torch.no_grad():
            for x, xmask in tqdm(dl_va, desc="[Val(leaky)]", dynamic_ncols=True):
                x, xmask = x.to(DEVICE), xmask.to(DEVICE)
                h_enc, enc_mask = bundle.vae.encoder(x)
                pooled = (h_enc * enc_mask.unsqueeze(-1)).sum(1) / enc_mask.sum(1, True).clamp_min(1)
                mu, logvar = bundle.vae.to_mu(pooled), bundle.vae.to_logvar(pooled)
                z = mu  # deterministic for eval

                sur_mem, sur_mask = bundle.surrogate(z, enc_mask, causal_self=False)
                mem_sur = sur_mem if (bundle.sur_adapter is None) else bundle.sur_adapter(sur_mem)

                logits = tf_logits_with_memory(bundle, x, xmask, z, h_enc, enc_mask, inject_z=True)
                vm_ce  = (xmask & enc_mask)
                ce = masked_ce(logits, x, vm_ce, ignore_index=PAD, label_smoothing=0.0)
                vm_align = (xmask & enc_mask & sur_mask)
                mse = masked_mse_mean(mem_sur, h_enc, vm_align)
                cos = masked_cosine_loss(mem_sur, h_enc, vm_align)

                total = W_CE*ce + W_MSE*mse + W_COS*cos
                val["loss"] += float(total); val["ce"] += float(ce); val["mse"] += float(mse); val["cos"] += float(cos); val["n"] += 1

        n = max(1, int(val["n"]))
        va = {k: val[k]/n for k in ["loss","ce","mse","cos"]}

        print(f"[Ep {ep}] train: loss={tr['loss']:.3f} ce={tr['ce']:.4f} mse={tr['mse']:.3f} cos={tr['cos']:.3f} | "
              f"valid: loss={va['loss']:.3f} ce={va['ce']:.4f} mse={va['mse']:.3f} cos={va['cos']:.3f}")

        if va["loss"] < best:
            best = va["loss"]
            bundle.save(out_best, extra_meta={"best_val_total": best, "leaky_tf": "encoder"})
            print(f"[SAVE] best → {out_best}")

    bundle.save(out_last, extra_meta={"best_val_total": best, "leaky_tf": "encoder"})
    print(f"[SAVE] last → {out_last}")
    return {"best_total": best}

# -------------------
# Smoke check (eval + deterministic μ): ENC vs SUR TF metrics
# -------------------
@torch.no_grad()
def smoke_check(bundle, sequences, alphabet, batch_size=8, nbatches=2):
    dl, _, PAD = make_loaders(sequences, alphabet, batch_size=batch_size, train_ratio=1.0)

    def masked_acc(logits, tgt, mask):
        if mask.sum() == 0: return 0.0
        pred = logits.argmax(dim=-1)
        correct = ((pred == tgt) & mask).float().sum()
        total   = mask.float().sum().clamp_min(1)
        return float(correct / total)

    was_training = bundle.training
    bundle.eval()

    enc_ce = enc_acc = 0.0
    sur_ce = sur_acc = 0.0
    first_done = False

    it = iter(dl)
    for _ in tqdm(range(nbatches), desc="[Smoke|eval] ENC vs SUR (TF)", dynamic_ncols=True):
        try: x, mlen = next(it)
        except StopIteration: it = iter(dl); x, mlen = next(it)
        x, mlen = x.to(DEVICE), mlen.to(DEVICE)

        # deterministic z
        h_enc, enc_mask = bundle.vae.encoder(x)
        pooled = (h_enc * enc_mask.unsqueeze(-1)).sum(1) / enc_mask.sum(1, True).clamp_min(1)
        mu = bundle.vae.to_mu(pooled); z = mu

        # ENC TF logits (leaky)
        logits_e = bundle._teacher_logits(x, mlen, h_enc, enc_mask, z, inject_z=True)
        vm_e = (mlen & enc_mask)
        ce_e = masked_ce(logits_e, x, vm_e, ignore_index=PAD); acc_e = masked_acc(logits_e, x, vm_e)

        # SUR TF logits
        sur_mem, sur_mask = bundle.surrogate(z, enc_mask, causal_self=False)
        mem_s = sur_mem if (bundle.sur_adapter is None) else bundle.sur_adapter(sur_mem)
        logits_s = bundle._teacher_logits(x, mlen, mem_s, sur_mask, z, inject_z=True)
        vm_s = (mlen & sur_mask)
        ce_s = masked_ce(logits_s, x, vm_s, ignore_index=PAD); acc_s = masked_acc(logits_s, x, vm_s)

        enc_ce += float(ce_e); enc_acc += float(acc_e)
        sur_ce += float(ce_s); sur_acc += float(acc_s)

        if not first_done:
            B, L = x.shape
            print(f"[shapes] x={(B,L)}  h_enc={tuple(h_enc.shape)}  sur_mem={tuple(sur_mem.shape)}")
            first_done = True

    enc_ce /= max(1, nbatches); sur_ce /= max(1, nbatches)
    enc_acc/= max(1, nbatches); sur_acc/= max(1, nbatches)
    print(f"[ENC-TF]  CE={enc_ce:.4f}  PPL={math.exp(enc_ce):.2f}  ACC@1={enc_acc*100:.2f}%")
    print(f"[SUR-TF]  CE={sur_ce:.4f}  PPL={math.exp(sur_ce):.2f}  ACC@1={sur_acc*100:.2f}%")
    print(f"[Δ SUR-ENC] ΔCE={sur_ce-enc_ce:+.4f}  ΔACC={(sur_acc-enc_acc)*100:+.2f}%")

    if was_training: bundle.train()

# ============================================================
# Usage
# ============================================================
# 0) alphabet & sequences 준비
_, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
# sequences = [...]  # List[str]  ← 사용자 데이터
assert 'sequences' in globals(), "Provide your `sequences: List[str]`."

# 1) VAE ckpt 로딩 + surrogate fresh (adapter=Identity)
VAE_CKPT = "/kaggle/input/esms-vae/pytorch/default2/1/vae_epoch380.pt"
bundle = build_bundle_from_vae_ckpt(VAE_CKPT, alphabet)
bundle.save("/kaggle/working/vae_sur_leaky_init.pt")

# 2) 스모크 체크 (eval + μ)
smoke_check(bundle, sequences, alphabet, batch_size=8, nbatches=2)

# 3) 파인튜닝 (LEAKY: CE with encoder memory)
stats = finetune_surrogate_adapter_leaky(
    bundle, sequences, alphabet,
    out_best="/kaggle/working/vae_sur_leaky_best.pt",
    out_last="/kaggle/working/vae_sur_leaky_last.pt",
    epochs=4, batch_size=128, lr=1e-4, log_every=20
)
print(stats)


Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt


[VAE] loaded from 'model_sd'  missing=0  unexpected=0
[SUR] initialized small surrogate (2L,4H,256D)
[ADAPTER] Identity (no params)
[save] VAEWithSurrogateBundle → /kaggle/working/vae_sur_leaky_init.pt


  output = torch._nested_tensor_from_mask(
[Smoke|eval] ENC vs SUR (TF): 100%|██████████| 2/2 [00:03<00:00,  1.79s/it]

[shapes] x=(8, 512)  h_enc=(8, 512, 256)  sur_mem=(8, 512, 256)
[ENC-TF]  CE=0.0875  PPL=1.09  ACC@1=96.75%
[SUR-TF]  CE=6.6109  PPL=743.15  ACC@1=7.94%
[Δ SUR-ENC] ΔCE=+6.5234  ΔACC=-88.81%





[info] trainable params = 1,515,008
[info] frozen    params = 5,756,961


[Train(leaky enc-mem) ep=1/4]: 100%|██████████| 7031/7031 [1:08:10<00:00,  1.72it/s, loss=0.161, ce=0.1021, mse=0.009, cos=0.003]
[Val(leaky)]: 100%|██████████| 782/782 [02:53<00:00,  4.50it/s]


[Ep 1] train: loss=0.989 ce=0.0982 mse=0.130 cos=0.048 | valid: loss=0.334 ce=0.0843 mse=0.038 cos=0.012
[save] VAEWithSurrogateBundle → /kaggle/working/vae_sur_leaky_best.pt
[SAVE] best → /kaggle/working/vae_sur_leaky_best.pt


[Train(leaky enc-mem) ep=2/4]: 100%|██████████| 7031/7031 [1:08:28<00:00,  1.71it/s, loss=0.149, ce=0.0997, mse=0.007, cos=0.002]
[Val(leaky)]: 100%|██████████| 782/782 [02:54<00:00,  4.49it/s]


[Ep 2] train: loss=0.151 ce=0.0982 mse=0.008 cos=0.003 | valid: loss=0.130 ce=0.0843 mse=0.007 cos=0.002
[save] VAEWithSurrogateBundle → /kaggle/working/vae_sur_leaky_best.pt
[SAVE] best → /kaggle/working/vae_sur_leaky_best.pt


[Train(leaky enc-mem) ep=3/4]: 100%|██████████| 7031/7031 [1:08:26<00:00,  1.71it/s, loss=0.143, ce=0.0963, mse=0.007, cos=0.002]
[Val(leaky)]: 100%|██████████| 782/782 [02:53<00:00,  4.50it/s]


[Ep 3] train: loss=0.146 ce=0.0982 mse=0.007 cos=0.002 | valid: loss=0.129 ce=0.0843 mse=0.007 cos=0.002
[save] VAEWithSurrogateBundle → /kaggle/working/vae_sur_leaky_best.pt
[SAVE] best → /kaggle/working/vae_sur_leaky_best.pt


[Train(leaky enc-mem) ep=4/4]: 100%|██████████| 7031/7031 [1:08:23<00:00,  1.71it/s, loss=0.153, ce=0.1056, mse=0.007, cos=0.002]
[Val(leaky)]: 100%|██████████| 782/782 [02:53<00:00,  4.50it/s]


[Ep 4] train: loss=0.145 ce=0.0982 mse=0.007 cos=0.002 | valid: loss=0.128 ce=0.0843 mse=0.007 cos=0.002
[save] VAEWithSurrogateBundle → /kaggle/working/vae_sur_leaky_best.pt
[SAVE] best → /kaggle/working/vae_sur_leaky_best.pt
[save] VAEWithSurrogateBundle → /kaggle/working/vae_sur_leaky_last.pt
[SAVE] last → /kaggle/working/vae_sur_leaky_last.pt
{'best_total': 0.12818251283424895}
