In [None]:
!git clone https://github.com/Ahnd6474/ZART

In [None]:
%cd /kaggle/working/ZART

In [None]:
!pip install biopython

In [None]:
from Bio import SeqIO
seq_path='/kaggle/input/uniref50-sub/uniref50_subsample.fasta'
sequences=[]
for seq_record in SeqIO.parse(seq_path, "fasta"):
    sequences.append(str(seq_record.seq))
print(len(sequences))

In [None]:
from typing import Tuple, Optional

import os
import math
import torch
import torch.nn as nn

DROPOUT   = 0.1
LATENT_DIM = 256
EMB_DIM    = 256
NUM_LAYERS = 4
NUM_HEADS  = 8
FFN_DIM    = 512
MAX_LEN    = 512

# ===============================
# 1) 가우시안 정렬 바이어스 모듈 (업데이트)
# - 정규화 좌표(t∈[0,1], i∈[0,1])에서 중심 m_hat_norm = a*t + δ 학습
# - 실제 메모리 길이 M에 동적으로 대응
# - 밴드 하드마스크(선택) 포함
# ===============================
class CrossDiagBias(nn.Module):
    """
    정렬 중심 i ≈ a*t + δ 를 학습하고, cross-attn 로짓에 -α*(i - m_hat)^2 가산.
    - a = 1 + a_span * tanh(u)   (기울기 주변)
    - δ = d_span * tanh(v)       (정규화 오프셋)
    """
    def __init__(
        self,
        init_alpha: float = 0.05,
        a_span: float = 0.5,
        d_span: float = 0.15,
        band_W_tokens: int = 8,
    ):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(float(init_alpha)))  # > 0 권장
        self.a_raw = nn.Parameter(torch.zeros(1))                   # tanh → [-1,1]
        self.d_raw = nn.Parameter(torch.zeros(1))
        self.a_span = float(a_span)
        self.d_span = float(d_span)
        self.band_W_tokens = int(band_W_tokens)

    @torch.no_grad()
    def clamp_alpha_(self, min_alpha: float = 1e-6, max_alpha: float = 1.0):
        self.alpha.clamp_(min_alpha, max_alpha)

    def forward(self, T: int, M: int, device: torch.device) -> torch.Tensor:
        """
        반환: (T,M) 가산형 로짓 마스크(=바이어스 포함).
        - 밴드 밖은 -inf로 하드 컷, 밴드 안은 가우시안 바이어스 추가.
        """
        t = torch.linspace(0.0, 1.0, T, device=device)[:, None]  # (T,1)
        i = torch.linspace(0.0, 1.0, M, device=device)[None, :]  # (1,M)

        a = 1.0 + self.a_span * torch.tanh(self.a_raw)           # 스칼라
        d = self.d_span * torch.tanh(self.d_raw)                 # 스칼라

        m_hat_norm = (a * t + d).clamp(0.0, 1.0)                 # (T,1)
        m_hat_idx  = m_hat_norm * (M - 1)                        # (T,1) in [0,M-1]

        alpha = self.alpha.clamp_min(1e-6)
        # 가우시안 바이어스
        i_idx = torch.arange(M, device=device, dtype=m_hat_idx.dtype)[None, :]  # (1,M)
        bias  = - alpha * (i_idx - m_hat_idx).pow(2)                             # (T,M)

        # 하드 밴드 마스크(밴드폭을 "토큰 수" 기준으로; M≠T에서도 자연스케일)
        if self.band_W_tokens > 0:
            W = max(1, int(self.band_W_tokens * M / max(T, 1)))
            band = (i_idx - m_hat_idx).abs() <= W                                # (T,M) bool
            bias = bias.masked_fill(~band, float('-inf'))                        # 밴드 밖 완전 차단
        return bias  # (T,M)


# ===============================
# 인코더
# ===============================
class SmallTransformer(nn.Module):
    """Simple Transformer encoder used in the notebook."""

    def __init__(self, vocab_size: int, emb_dim: int, layers: int, heads: int,
                 ffn_dim: int, max_len: int, pad_idx: int):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.pos = nn.Parameter(torch.zeros(1, max_len, emb_dim))
        layer = nn.TransformerEncoderLayer(
            d_model=emb_dim,
            nhead=heads,
            dim_feedforward=ffn_dim,
            batch_first=True,
            activation="gelu",
            dropout=DROPOUT,
        )
        self.enc = nn.TransformerEncoder(layer, layers)
        self.ln = nn.LayerNorm(emb_dim)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        pad_idx = self.emb.padding_idx if self.emb.padding_idx is not None else 0
        mask = x != pad_idx  # True == valid
        h = self.emb(x) + self.pos[:, : x.size(1), :]
        h = self.enc(h, src_key_padding_mask=~mask)
        return self.ln(h), mask


# ===============================
# VAE 디코더
# ===============================
class VAETransformerDecoder(nn.Module):
    """VAE model from the notebook."""

    def __init__(self, encoder: SmallTransformer, vocab_size: int,
                 latent_dim: int = LATENT_DIM, emb_dim: int = EMB_DIM,
                 num_layers: int = NUM_LAYERS, num_heads: int = NUM_HEADS,
                 ffn_dim: int = FFN_DIM, max_len: int = MAX_LEN,
                 pad_token: int = 0, bos_token: int = 1):
        super().__init__()
        self.encoder = encoder
        self.pad_token = pad_token
        self.bos_token = bos_token

        self.to_mu     = nn.Linear(emb_dim, latent_dim)
        self.to_logvar = nn.Linear(emb_dim, latent_dim)
        self.latent2emb = nn.Linear(latent_dim, emb_dim)

        self.dec_emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_token)
        self.dec_pos = nn.Parameter(torch.zeros(1, max_len, emb_dim))
        layer = nn.TransformerDecoderLayer(
            d_model=emb_dim,
            nhead=num_heads,
            dim_feedforward=ffn_dim,
            dropout=DROPOUT,
            batch_first=True,
        )
        self.decoder = nn.TransformerDecoder(layer, num_layers)
        self.out = nn.Linear(emb_dim, vocab_size)

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        h_enc, enc_mask = self.encoder(x)
        denom = enc_mask.sum(1, keepdim=True).clamp_min(1)
        pooled = (h_enc * enc_mask.unsqueeze(-1)).sum(1) / denom
        mu, logvar = self.to_mu(pooled), self.to_logvar(pooled)
        z = mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)

        B, L = x.size()
        dec_in = torch.full((B, L), self.bos_token, device=x.device, dtype=torch.long)
        dec_in[:, 1:] = x[:, :-1]
        emb = self.dec_emb(dec_in) + self.dec_pos[:, :L, :]
        z_emb = self.latent2emb(z).unsqueeze(1).expand(-1, L, -1)
        emb = emb + z_emb

        tgt_mask = nn.Transformer.generate_square_subsequent_mask(L).to(x.device)
        h_dec = self.decoder(
            tgt=emb,
            memory=h_enc,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=~mask,
            memory_key_padding_mask=~enc_mask,
        )
        logits = self.out(h_dec)
        return logits, mu, logvar, h_enc, enc_mask


# ===============================
# surrogate (z→메모리) — Pre-LN 유지
# ===============================
class Z2MemorySurrogate(nn.Module):
    """Small transformer that predicts decoder memory from latent ``z``."""

    def __init__(
        self,
        d_model: int,
        latent_dim: int,
        max_len: int,
        layers: int = 2,
        heads: int = 4,
        ffn_dim: Optional[int] = None,
        dropout: float = DROPOUT,
    ) -> None:
        super().__init__()
        if ffn_dim is None:
            ffn_dim = 3 * d_model
        self.pos   = nn.Parameter(torch.zeros(1, max_len, d_model))
        self.token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.z_proj = nn.Linear(latent_dim, d_model)
        self.z_ln   = nn.LayerNorm(d_model)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=heads,
            dim_feedforward=ffn_dim,
            batch_first=True,
            activation="gelu",
            dropout=dropout,
        )
        self.enc   = nn.TransformerEncoder(enc_layer, num_layers=layers)
        self.out_ln = nn.LayerNorm(d_model)

    def forward(
        self, z: torch.Tensor, mask_bool: torch.Tensor, causal_self: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        z: [B,D], mask_bool: [B,L] (True==valid)
        """
        B, L = mask_bool.shape
        base = self.token.expand(B, L, -1) + self.pos[:, :L, :]
        zemb = self.z_ln(self.z_proj(z)).unsqueeze(1).expand(-1, L, -1)
        h = base + zemb
        src_mask = None
        if causal_self:
            src_mask = torch.triu(
                torch.full((L, L), float("-inf"), device=h.device), diagonal=1
            )
        h = self.enc(h, mask=src_mask, src_key_padding_mask=~mask_bool)
        return self.out_ln(h), mask_bool


# ============================================
# 2) VAEWithSurrogate 확장: 바이어스/LN/게이트/체크포인트
# ============================================
class VAEWithSurrogate(nn.Module):
    """Wrapper bundling a VAE and a surrogate network."""

    def __init__(
        self,
        vae: VAETransformerDecoder,
        surrogate: Optional[Z2MemorySurrogate] = None,
        use_sur_ln: bool = True,
        use_sur_gate: bool = True,           # ★ 추가: 채널 게이트
        use_diag_bias: bool = True,
        diag_init_alpha: float = 0.05,
        diag_a_span: float = 0.5,
        diag_d_span: float = 0.15,
        diag_band_W: int = 8,
    ) -> None:
        super().__init__()
        self.vae = vae
        self.surrogate = surrogate

        # 편의 alias (원 코드 유지)
        for name in [
            "encoder",
            "decoder",
            "dec_emb",
            "dec_pos",
            "latent2emb",
            "pad_token",
            "bos_token",
            "out",
        ]:
            setattr(self, name, getattr(vae, name))

        d_model = self.dec_emb.embedding_dim

        # (선택) surrogate 출력 뒤 LayerNorm + 채널 게이트
        self.sur_ln   = nn.LayerNorm(d_model, elementwise_affine=True) if use_sur_ln else None
        self.use_sur_gate = bool(use_sur_gate)
        self.sur_gate = nn.Parameter(torch.full((d_model,), math.log(0.3/0.7))) if use_sur_gate else None  # init ~0.3

        # (선택) 가우시안 정렬 바이어스 (정규화 좌표/동적 길이 대응)
        self.diag_bias = CrossDiagBias(
            init_alpha=diag_init_alpha,
            a_span=diag_a_span,
            d_span=diag_d_span,
            band_W_tokens=diag_band_W,
        ) if use_diag_bias else None

        # 내부 캐시
        self._z_cached: Optional[torch.Tensor] = None

    # -------------------------
    # Surrogate 메모리 구성 헬퍼
    # -------------------------
    def build_surrogate_memory(self, z: torch.Tensor, x_gt_ids: Optional[torch.Tensor] = None):
        """
        반환: memory (B,M,D), mem_pad_mask (B,M)  (True==PAD)
        """
        device = next(self.parameters()).device
        z = z.to(device)  # [B,D] 가정
        B = z.size(0)

        if self.surrogate is None:
            M = MAX_LEN
            memory = torch.zeros(B, M, self.dec_emb.embedding_dim, device=device)
            mem_valid = torch.ones(B, M, dtype=torch.bool, device=device)
        else:
            if x_gt_ids is not None:
                if x_gt_ids.dim() == 1:
                    x_gt_ids = x_gt_ids.unsqueeze(0)
                x_gt_ids = x_gt_ids.to(device)
                _, enc_mask = self.encoder(x_gt_ids)  # True==valid
                mem_valid = enc_mask.to(torch.bool)
            else:
                M = int(self.surrogate.pos.size(1))
                mem_valid = torch.ones(B, M, dtype=torch.bool, device=device)
            memory, _ = self.surrogate(z, mem_valid, causal_self=False)

        # (선택) surrogate 출력 정규화 + 게이트 스케일
        if self.sur_ln is not None:
            memory = self.sur_ln(memory)
        if self.use_sur_gate and self.sur_gate is not None:
            g = torch.sigmoid(self.sur_gate)  # [D]
            memory = memory * g               # K/V의 스케일 정합

        mem_pad_mask = ~mem_valid  # True==PAD(무시)
        return memory, mem_pad_mask

    # -------------------------
    # 디코딩 한 스텝 로짓(+가우시안 바이어스)
    # -------------------------
    def decode_step(self, prefix_ids: torch.Tensor,
                    memory: torch.Tensor, mem_pad_mask: torch.Tensor,
                    tokenizer=None, use_bias: bool = True) -> torch.Tensor:
        """
        입력: prefix_ids (B,T), memory (B,M,D), mem_pad_mask (B,M; True==PAD)
        출력: 마지막 스텝 로짓 (B,V)
        """
        device = prefix_ids.device
        B, T = prefix_ids.size()
        M = memory.size(1)

        tok = self.dec_emb(prefix_ids)                 # (B,T,D)
        pos = self.dec_pos[:, :T, :]                   # (1,T,D)
        if self._z_cached is None:
            raise RuntimeError("self._z_cached가 없습니다. 학습/추론 루프에서 self._z_cached = z (B,D)로 세팅하세요.")
        z_emb = self.latent2emb(self._z_cached).unsqueeze(1).expand(-1, T, -1)
        tgt = tok + pos + z_emb

        tgt_mask = torch.triu(torch.full((T, T), float("-inf"), device=device), diagonal=1)

        memory_mask = None
        if use_bias and (self.diag_bias is not None):
            # (T,M) 가산형 로짓 마스크 (가우시안 + 밴드 하드컷)
            memory_mask = self.diag_bias(T, M, device=device)

        h = self.decoder(
            tgt=tgt,
            memory=memory,
            tgt_mask=tgt_mask,                         # (T,T) float
            memory_mask=memory_mask,                   # (T,M) float (가산형)
            memory_key_padding_mask=mem_pad_mask,      # (B,M) bool
        )
        if h.dim() == 3 and h.size(0) != B:  # (T,B,D) → (B,T,D)
            h = h.transpose(0, 1)
        logits = self.out(h)[:, -1, :]       # (B,V)

        # PAD 샘플링 절대 금지
        if tokenizer is not None:
            pad_idx = getattr(tokenizer, "pad_idx", getattr(tokenizer, "pad_token_id", self.pad_token))
            if pad_idx is not None:
                logits[:, int(pad_idx)] = -float("inf")
        else:
            logits[:, int(self.pad_token)] = -float("inf")

        return logits

    # -------------------------
    # 체크포인트 저장/복구
    # -------------------------
    def save_checkpoint(self, path: str, optimizer: Optional[torch.optim.Optimizer] = None,
                        epoch: Optional[int] = None, step: Optional[int] = None,
                        extra: Optional[dict] = None):
        os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
        ckpt = {
            "model_state": self.state_dict(),      # ★ diag_bias/sur_ln/sur_gate 포함
            "epoch": epoch,
            "step": step,
            "extra": extra or {},
        }
        if optimizer is not None:
            ckpt["optimizer_state"] = optimizer.state_dict()
        torch.save(ckpt, path)

    @staticmethod
    def load_checkpoint(path: str, model: "VAEWithSurrogate",
                        optimizer: Optional[torch.optim.Optimizer] = None,
                        map_location: Optional[str] = None, strict: bool = True) -> dict:
        ckpt = torch.load(path, map_location=map_location or "cpu")
        model.load_state_dict(ckpt["model_state"], strict=strict)
        if optimizer is not None and "optimizer_state" in ckpt:
            optimizer.load_state_dict(ckpt["optimizer_state"])
        return ckpt


# ===============================
# 4) 정렬 보조 손실 (분리 실행용; 지금은 호출하지 말 것)
# ===============================
def guided_alignment_kl(attn_probs: torch.Tensor, m_hat_idx: torch.Tensor, sigma: float) -> torch.Tensor:
    """
    TF 경로에서만 사용:
    - attn_probs: [B,H,T,M] (softmax 이후)
    - m_hat_idx : [B,H,T]   (정규화 중심을 인덱스 공간으로 변환한 것; 구현 편의상
                             마지막 블록/헤드 중심을 추정해 넣어도 됨)
    - sigma: float (초반 크게→점감)
    반환: scalar KL(attn || Gaussian(m_hat, sigma))
    """
    B, H, T, M = attn_probs.shape
    device = attn_probs.device
    i_idx = torch.arange(M, device=device).float()[None, None, None, :]  # [1,1,1,M]
    m = m_hat_idx[..., None]                                             # [B,H,T,1]
    gauss = torch.exp(- (i_idx - m).pow(2) / (2.0 * (sigma ** 2)))
    gauss = gauss / (gauss.sum(-1, keepdim=True) + 1e-9)
    kl = (attn_probs.clamp_min(1e-9).log() - gauss.clamp_min(1e-9).log()) * attn_probs
    return kl.sum(dim=-1).mean()  # scalar


In [None]:
# === TRAINING WITH TQDM ===
import math, torch
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
AMP_DEV = "cuda" if torch.cuda.is_available() else "cpu"
scaler  = GradScaler(AMP_DEV, enabled=(AMP_DEV == "cuda"))
from tqdm.auto import tqdm
from typing import Dict, Iterable
from contextlib import nullcontext
# sdpa_kernel 컨텍스트 래퍼 (새 API)
# sdpa 컨텍스트 (새 API)
from torch.amp import autocast, GradScaler
from contextlib import nullcontext
AMP_DEV = "cuda" if torch.cuda.is_available() else "cpu"
SCALER  = GradScaler(AMP_DEV, enabled=(AMP_DEV=="cuda"))

try:
    from torch.nn.attention import sdpa_kernel, SDPBackend
    def SDPA_CTX(which="math"):
        w = str(which).lower()
        if   w == "flash": backend = SDPBackend.FLASH_ATTENTION
        elif w in ("mem","efficient","mem_efficient"): backend = SDPBackend.EFFICIENT_ATTENTION
        elif w == "cudnn" and hasattr(SDPBackend, "CUDNN_ATTENTION"): backend = SDPBackend.CUDNN_ATTENTION
        else: backend = SDPBackend.MATH
        return sdpa_kernel(backend)
except Exception:
    def SDPA_CTX(which="math"): return nullcontext()

def kl_loss(mu, logvar, free_bits=0.0):
    kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(-1)  # [B]
    if free_bits > 0:
        fb = torch.full_like(kl, free_bits)
        kl = torch.maximum(kl, fb)
    return kl.mean()

def decode_full_with_bias(model: VAEWithSurrogate, x_in, z, memory, mem_pad_mask):
    """교사강제 전체 토큰 한 번에(가우시안 대각선 바이어스 포함)."""
    B, T = x_in.size()
    tok = model.dec_emb(x_in)
    pos = model.dec_pos[:, :T, :]
    zemb = model.latent2emb(z).unsqueeze(1).expand(-1, T, -1)
    tgt = tok + pos + zemb

    tgt_mask = torch.triu(torch.full((T, T), float('-inf'), device=x_in.device), diagonal=1)
    M = memory.size(1)
    mem_mask = model.diag_bias(T, M, device=x_in.device) if model.diag_bias is not None else None

    with SDPA_CTX("math"):
        h = model.decoder(
            tgt=tgt, memory=memory,
            tgt_mask=tgt_mask,
            memory_mask=mem_mask,
            memory_key_padding_mask=mem_pad_mask,
        )
    if h.dim() == 3 and h.size(0) != B:  # (T,B,D)->(B,T,D)
        h = h.transpose(0,1)
    return model.out(h)  # (B,T,V)

def make_param_groups(model, base_lr=2e-4, wd=0.01):
    """AdamW 파라미터 그룹(노멀 weight_decay / LN·Embedding·bias no_decay)."""
    no_decay_types = (nn.LayerNorm, nn.Embedding)
    decays, nodecays = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        modname = n.split('.')[0]
        mod = getattr(model, modname, None)
        if n.endswith('bias') or isinstance(mod, no_decay_types) or ('ln' in n.lower()) or ('LayerNorm' in n):
            nodecays.append(p)
        else:
            decays.append(p)
    return [
        {"params": decays,   "lr": base_lr, "weight_decay": wd},
        {"params": nodecays, "lr": base_lr, "weight_decay": 0.0},
    ]

def set_requires_grad(model: VAEWithSurrogate, *, ln_only: bool):
    """Phase-2에서 LN/게이트만 학습하고 싶을 때 사용."""
    for p in model.parameters():
        p.requires_grad = not ln_only
    if ln_only:
        def mark_on(m):
            for n, p in m.named_parameters(recurse=False):
                p.requires_grad = True
        # surrogate LN/게이트
        if model.sur_ln is not None: mark_on(model.sur_ln)
        if getattr(model, 'sur_gate', None) is not None:
            model.sur_gate.requires_grad = True
        # 디코더/인코더의 내부 LN(Pre-LN)까지 켜고 싶으면 아래 주석 해제
        # for mod in model.modules():
        #     if isinstance(mod, nn.LayerNorm):
        #         for p in mod.parameters(): p.requires_grad = True
        # δ,a,α는 기본 off (엄밀히 LN만 학습)
        if model.diag_bias is not None:
            for p in model.diag_bias.parameters():
                p.requires_grad = False

# -------------------------
# Phase-0: 인코더 메모리 + TF (δ,a,α 빠른 수렴)
# -------------------------
def run_phase0_tf_encmem(model: VAEWithSurrogate, batch, optimizer, scaler,
                         beta: float, free_bits: float = 0.0, pad_idx: int = 0) -> Dict[str,float]:
    model.train()
    x, mask_bool = _unpack_batch(batch, model, pad_idx)
    device = next(model.parameters()).device
    x = x.to(device); mask_bool = mask_bool.to(device)
    if not any(p.requires_grad for p in model.parameters()):
        unfreeze_all_(model)

    optimizer.zero_grad(set_to_none=True)
    with autocast(AMP_DEV):
        h_enc, enc_mask = model.encoder(x)
        denom = enc_mask.sum(1, keepdim=True).clamp_min(1)
        pooled = (h_enc * enc_mask.unsqueeze(-1)).sum(1) / denom
        mu, logvar = model.vae.to_mu(pooled), model.vae.to_logvar(pooled)
        z = mu + 0.5 * torch.randn_like(mu) * torch.exp(0.5 * logvar)
        model._z_cached = z

        memory = h_enc
        mem_pad_mask = ~enc_mask

        B, T = x.size()
        dec_in = torch.full((B, T), model.bos_token, device=device, dtype=torch.long)
        dec_in[:, 1:] = x[:, :-1]

        logits = decode_full_with_bias(model, dec_in, z, memory, mem_pad_mask)
        loss_nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)),
                                   x.reshape(-1), ignore_index=pad_idx)
        loss_kl  = kl_loss(mu, logvar, free_bits)
        loss = loss_nll + beta * loss_kl
        
    assert loss.requires_grad, "loss has no grad_fn — parameters may be frozen or graph got detached."
    scaler.scale(loss).backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer); scaler.update()
    return {"loss": float(loss.item()), "nll": float(loss_nll.item()), "kl": float(loss_kl.item())}

# -------------------------
# Phase-2: 서러게이트 메모리 + TF (K 없음, LN/게이트 학습)
# -------------------------
def run_phase2_tf_surmem(model: VAEWithSurrogate, batch, optimizer, scaler,
                         beta: float, free_bits: float = 0.0, pad_idx: int = 0) -> Dict[str,float]:
    model.train()
    x, mask_bool = _unpack_batch(batch, model, pad_idx)
    device = next(model.parameters()).device
    x = x.to(device); mask_bool = mask_bool.to(device)

    optimizer.zero_grad(set_to_none=True)
    with autocast(AMP_DEV):
        # 인코더 + z
        h_enc, enc_mask = model.encoder(x)
        denom = enc_mask.sum(1, keepdim=True).clamp_min(1)
        pooled = (h_enc * enc_mask.unsqueeze(-1)).sum(1) / denom
        mu, logvar = model.vae.to_mu(pooled), model.vae.to_logvar(pooled)
        z = mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)
        model._z_cached = z

        # surrogate 메모리 경유 (LN/게이트 경로 포함)
        memory, mem_pad_mask = model.build_surrogate_memory(z, x_gt_ids=x)

        B, T = x.size()
        dec_in = torch.full((B, T), model.bos_token, device=device, dtype=torch.long)
        dec_in[:, 1:] = x[:, :-1]

        logits = decode_full_with_bias(model, dec_in, z, memory, mem_pad_mask)
        loss_nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)),
                                   x.reshape(-1), ignore_index=pad_idx)
        loss_kl  = kl_loss(mu, logvar, free_bits)
        loss = loss_nll + beta * loss_kl

    scaler.scale(loss).backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer); scaler.update()
    return {"loss": float(loss.item()), "nll": float(loss_nll.item()), "kl": float(loss_kl.item())}


# ---------- 배치 언팩 유틸 (어디든 공용으로 추가) ----------
def _pad_index_from(model, pad_idx_arg: int | None):
    if pad_idx_arg is not None:
        return int(pad_idx_arg)
    if getattr(model, "dec_emb", None) is not None and model.dec_emb.padding_idx is not None:
        return int(model.dec_emb.padding_idx)
    return int(getattr(model, "pad_token", 0))

def _unpack_batch(batch, model, pad_idx_arg: int | None):
    """
    반환: x(LongTensor [B,T]), mask_bool(BoolTensor [B,T], True=valid)
    - dict: HuggingFace 스타일 지원 (input_ids, attention_mask 등)
    - tuple/list: (x, mask) 또는 (x,)
    - tensor: x만
    """
    pad = _pad_index_from(model, pad_idx_arg)

    # 1) dict(HF) 처리
    if isinstance(batch, Mapping):
        # x 후보 키 우선순위
        for k in ("input_ids", "x", "ids", "tokens"):
            if k in batch:
                x = batch[k]
                break
        else:
            # 첫 LongTensor를 x로 사용
            x = next(v for v in batch.values()
                     if torch.is_tensor(v) and v.dtype in (torch.long, torch.int64))

        # mask 후보 키
        mask_bool = None
        for k in ("attention_mask", "mask", "padding_mask", "valid_mask"):
            if k in batch:
                m = batch[k]
                # attention_mask가 float(0/1)일 수 있음
                mask_bool = (m > 0.5) if torch.is_floating_point(m) else m.bool()
                break
        if mask_bool is None:
            mask_bool = (x != pad)
        return x, mask_bool

    # 2) tuple/list
    if isinstance(batch, (list, tuple)):
        if len(batch) >= 2:
            x, mask_bool = batch[0], batch[1]
        else:
            x = batch[0]
            mask_bool = (x != pad)
        return x, mask_bool

    # 3) 단일 텐서
    if torch.is_tensor(batch):
        x = batch
        mask_bool = (x != pad)
        return x, mask_bool

    raise TypeError(f"Unsupported batch type: {type(batch)}")

from typing import Iterable, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
from tqdm.auto import tqdm

AMP_DEV = "cuda" if torch.cuda.is_available() else "cpu"

def unfreeze_all_(model: nn.Module):
    for p in model.parameters():
        p.requires_grad = True

def freeze_all_(model: nn.Module):
    for p in model.parameters():
        p.requires_grad = False

def enable_ln_gate_only_(model):  # Phase-2: surrogate LN/게이트만 학습
    freeze_all_(model)
    if getattr(model, "sur_ln", None) is not None:
        for p in model.sur_ln.parameters(): p.requires_grad = True
    if getattr(model, "sur_gate", None) is not None:
        model.sur_gate.requires_grad = True
    # 대각선 바이어스(δ,a,α)는 고정
    if getattr(model, "diag_bias", None) is not None:
        for p in model.diag_bias.parameters(): p.requires_grad = False

def count_trainables_(model) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

from typing import Iterable, Dict, Optional
import math, os, glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
from tqdm.auto import tqdm

AMP_DEV = "cuda" if torch.cuda.is_available() else "cpu"

def count_trainables_(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def train(
    model: "VAEWithSurrogate",
    train_loader: Iterable,
    val_loader: Iterable,
    pad_idx: int,
    epochs: int = 5,
    lr: float = 2e-4,
    wd: float = 0.01,
    beta_max: float = 1.0,
    fb: float = 0.5,
    phase0_epochs: int = 1,
    freeze_ln_in_phase2: bool = True,
    alpha_warm_steps: int = 3000,
    ckpt_dir: str = "./ckpts",
    save_every_epoch: bool = True,
    save_best: bool = True,
    keep_last_k: Optional[int] = 3,
) -> list[Dict[str, float]]:
    """
    Phase-0: enc-mem + TF (δ,a,α 포함 전체 학습)
    Phase-2: sur-mem + TF (기본: surrogate LN/게이트만 학습)
    - 필요 헬퍼: make_param_groups, run_phase0_tf_encmem, run_phase2_tf_surmem,
                 decode_full_with_bias, _unpack_batch,
                 unfreeze_all_, enable_ln_gate_only_
    """
    device = next(model.parameters()).device
    scaler = GradScaler(AMP_DEV, enabled=(AMP_DEV == "cuda"))

    # Phase-0: 전체 학습으로 시작
    unfreeze_all_(model)
    optim = torch.optim.AdamW(make_param_groups(model, lr, wd), betas=(0.9, 0.95))

    # α 워밍업 준비
    alpha_base = float(model.diag_bias.alpha.item() if getattr(model, "diag_bias", None) is not None else 0.0)
    total_steps = epochs * len(train_loader)
    alpha_warm_steps = min(alpha_warm_steps, total_steps) if getattr(model, "diag_bias", None) is not None else 0

    os.makedirs(ckpt_dir, exist_ok=True)
    best_val = math.inf
    history: list[Dict[str, float]] = []
    global_step = 0

    for ep in range(epochs):
        # Phase-2 진입 시점 처리(동결/해제 + 옵티마이저 재생성)
        if ep == phase0_epochs:
            if freeze_ln_in_phase2:
                enable_ln_gate_only_(model)
            else:
                unfreeze_all_(model)
            optim = torch.optim.AdamW(make_param_groups(model, lr, wd), betas=(0.9, 0.95))
            ntr = count_trainables_(model)
            if ntr == 0:
                raise RuntimeError("No trainable parameters after phase switch.")
            print(f"[Phase-2 start] trainable params: {ntr:,}")

        pbar = tqdm(enumerate(train_loader), total=len(train_loader),
                    desc=f"Epoch {ep+1}/{epochs}", leave=False)
        last_stats: Dict[str, float] = {"loss": 0.0, "nll": 0.0, "kl": 0.0}

        # 밴드폭 단계 감소(선택)
        if getattr(model, "diag_bias", None) is not None and hasattr(model.diag_bias, "band_W_tokens"):
            if ep == 1:
                model.diag_bias.band_W_tokens = max(4, int(model.diag_bias.band_W_tokens) - 2)
            if ep == 3:
                model.diag_bias.band_W_tokens = max(4, int(model.diag_bias.band_W_tokens) - 2)

        for it, batch in pbar:
            # α 워밍업
            if alpha_warm_steps > 0 and getattr(model, "diag_bias", None) is not None:
                with torch.no_grad():
                    s = min(1.0, global_step / max(1, alpha_warm_steps))
                    model.diag_bias.alpha.copy_(
                        torch.tensor(alpha_base * (0.5 + 0.5 * s), device=device)
                    )

            # KL β 워밍업(총 스텝의 20%)
            beta = beta_max * min(1.0, global_step / max(1, int(0.2 * total_steps)))

            if ep < phase0_epochs:
                stats = run_phase0_tf_encmem(model, batch, optim, scaler, beta, free_bits=fb, pad_idx=pad_idx)
            else:
                stats = run_phase2_tf_surmem(model, batch, optim, scaler, beta, free_bits=fb, pad_idx=pad_idx)

            last_stats = stats
            pbar.set_postfix(loss=f"{stats['loss']:.3f}",
                             nll=f"{stats['nll']:.3f}",
                             kl=f"{stats['kl']:.3f}")
            global_step += 1

        # ---- 검증(NLL) ----
        with torch.no_grad():
            batch_val = next(iter(val_loader))
            x_val, m_val = _unpack_batch(batch_val, model, pad_idx)

            if ep < phase0_epochs:
                # enc-mem 검증
                h_enc, enc_mask = model.encoder(x_val.to(device))
                denom = enc_mask.sum(1, keepdim=True).clamp_min(1)
                pooled = (h_enc * enc_mask.unsqueeze(-1)).sum(1) / denom
                mu, logvar = model.vae.to_mu(pooled), model.vae.to_logvar(pooled)
                z = mu
                model._z_cached = z
                memory, mem_mask = h_enc, ~enc_mask
            else:
                # sur-mem 검증
                xv = x_val.to(device)
                hv, mv = model.encoder(xv)
                denom = mv.sum(1, keepdim=True).clamp_min(1)
                pooled = (hv * mv.unsqueeze(-1)).sum(1) / denom
                mu, logvar = model.vae.to_mu(pooled), model.vae.to_logvar(pooled)
                z = mu
                model._z_cached = z
                memory, mem_mask = model.build_surrogate_memory(z, x_gt_ids=xv)

            B, T = x_val.size()
            dec_in = torch.full((B, T), model.bos_token, device=device, dtype=torch.long)
            dec_in[:, 1:] = x_val[:, :-1].to(device)

            logits = decode_full_with_bias(model, dec_in, z, memory, mem_mask)
            val_nll = F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),
                x_val.to(device).reshape(-1),
                ignore_index=pad_idx
            ).item()

        print(f"[ep {ep}] train_loss={last_stats['loss']:.3f} "
              f"nll={last_stats['nll']:.3f} kl={last_stats['kl']:.3f}  "
              f"val_nll={val_nll:.3f}")

        # 기록
        rec = {"epoch": ep, "train_loss": last_stats["loss"], "train_nll": last_stats["nll"],
               "train_kl": last_stats["kl"], "val_nll": float(val_nll)}
        history.append(rec)

        # ---- 저장 ----
        if save_every_epoch:
            ep_path = os.path.join(ckpt_dir, f"ep{ep:03d}_val{val_nll:.3f}.pt")
            model.save_checkpoint(
                ep_path,
                optimizer=optim,
                epoch=ep,
                step=global_step,
                extra={"val_nll": float(val_nll)}
            )
            if keep_last_k is not None and keep_last_k > 0:
                ckpts = sorted(glob.glob(os.path.join(ckpt_dir, "ep*.pt")))
                for p in ckpts[:-keep_last_k]:
                    try: os.remove(p)
                    except OSError: pass

        if save_best and val_nll < best_val:
            best_val = val_nll
            best_path = os.path.join(ckpt_dir, "best.pt")
            model.save_checkpoint(
                best_path,
                optimizer=optim,
                epoch=ep,
                step=global_step,
                extra={"val_nll": float(val_nll)}
            )

    # 마지막 저장(보너스)
    final_path = os.path.join(ckpt_dir, "final.pt")
    model.save_checkpoint(
        final_path,
        optimizer=optim,
        epoch=epochs - 1,
        step=global_step,
        extra={"val_nll": float(val_nll)}
    )
    return history


In [None]:
from vae_module import Tokenizer, Config, load_vae, encode, decode, SequenceDataset
import torch
cfg = Config(model_path="models/vae_sur.pt")
tok = Tokenizer.from_esm()
tokenizer=tok
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DROPOUT = 0.1
LATENT_DIM = 256
EMB_DIM = 256
NUM_LAYERS = 4
NUM_HEADS = 8
FFN_DIM = 512
MAX_LEN = 512
vocab_size=len(tokenizer.vocab)
pad_idx=tokenizer.pad_idx
bos_idx=tokenizer.bos_idx

enc = SmallTransformer(
        vocab_size,
        EMB_DIM,
        NUM_LAYERS,
        NUM_HEADS,
        FFN_DIM,
        MAX_LEN,
        pad_idx,
    ).to(device)

vae = VAETransformerDecoder(
        encoder=enc,
        vocab_size=vocab_size,
        pad_token=pad_idx,
        bos_token=bos_idx,
    ).to(device)

checkpoint = torch.load(cfg.model_path, map_location=device)

sur = Z2MemorySurrogate(
            d_model=EMB_DIM,
            latent_dim=LATENT_DIM,
            max_len=MAX_LEN,
            layers=2,
            heads=4,
            ffn_dim=3 * EMB_DIM,
            dropout=DROPOUT,
        ).to(device)
model = VAEWithSurrogate(vae, sur, use_sur_ln=True, use_diag_bias=True).to(device)
model.load_state_dict(checkpoint,strict=False)
model.to(device)
print('loaded to gpu')

In [None]:
from torch.utils.data import DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence

full_ds= SequenceDataset(sequences, tok, MAX_LEN)
PAD_IDX=tok.get_idx("<pad>")
VAL_RATIO = 0.2   # 20%를 검증으로
SEED      = 42  # 재현성
n_total = len(full_ds)
n_val   = max(1, int(n_total * VAL_RATIO))
n_train = n_total - n_val
g_split = torch.Generator().manual_seed(SEED)
train_ds, val_ds = random_split(full_ds, [n_train, n_val], generator=g_split)
def collate_truncate(batch, pad_idx: int, max_len: int = 512):
    # batch: List[Tensor] or List[dict]
    xs = []
    for item in batch:
        x = item["input_ids"] if isinstance(item, dict) else (item[0] if isinstance(item, (tuple, list)) else item)
        x = x[:max_len]                          # 오른쪽 truncate
        xs.append(x)
    # pad to max_len
    L = max(x.size(0) for x in xs)
    L = min(L, max_len)
    out = torch.full((len(xs), L), pad_idx, dtype=torch.long)
    for i, x in enumerate(xs):
        l = min(x.size(0), L)
        out[i, :l] = x[:l]
    return {"input_ids": out}

# DataLoader도 재현 가능하게(선택)
def seed_worker(worker_id):
    worker_seed = SEED + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    torch.manual_seed(worker_seed)

g_loader = torch.Generator().manual_seed(SEED)

train_loader = DataLoader(
    train_ds, batch_size=32, shuffle=True, num_workers=2, pin_memory=True,
    worker_init_fn=seed_worker, generator=g_loader,collate_fn=lambda b: collate_truncate(b, PAD_IDX,MAX_LEN)
)
val_loader = DataLoader(
    val_ds, batch_size=32, shuffle=False, num_workers=2, pin_memory=True,
    worker_init_fn=seed_worker, generator=g_loader,collate_fn=lambda b: collate_truncate(b, PAD_IDX,MAX_LEN)
)

In [None]:
# ==== 0) 준비: tokenizer / dataset / dataloader ====
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import numpy as np
import random
from collections.abc import Mapping

# ==== 2) 학습 호출 (tqdm 내장, Phase-2는 TF만; LN/게이트만 학습 옵션) ====
history = train(
    model,
    train_loader,
    val_loader,
    pad_idx=pad_idx,
    epochs=3,                # 총 에폭
    lr=2e-4, wd=0.01,
    beta_max=0.1, fb=0.5,    # KL β 최대, free-bits
    phase0_epochs=3,         # ep<1: 인코더 메모리 + TF, ep>=1: 서러게이트 메모리 + TF
    freeze_ln_in_phase2=True,# Phase-2: LN/게이트만 학습
    alpha_warm_steps=3000,
    ckpt_dir="/ckpts",
    save_every_epoch=True,
    save_best=True,
    keep_last_k=3,
)

# ==== 3) 체크포인트 저장/로드 ====
model.save_checkpoint("/ckpts/final.pt")
# 다시 불러올 때:
# ckpt = VAEWithSurrogate.load_checkpoint("./ckpts/final.pt", model, map_location="cpu", strict=True)

# ==== 4) 간단 생성(greedy) – surrogate 메모리 + 대각선 바이어스로 프리런 ====
@torch.no_grad()
def generate_greedy(model: VAEWithSurrogate, x_ref_ids, out_len: int):
    """
    x_ref_ids: (1,T_ref) 참조 시퀀스(길이/분포 정합용)
    out_len: 생성 길이(접두 BOS부터 out_len개 생성)
    """
    model.eval()
    device = next(model.parameters()).device
    x_ref_ids = x_ref_ids.to(device)

    # z 추출
    h_enc, enc_mask = model.encoder(x_ref_ids)
    denom = enc_mask.sum(1, keepdim=True).clamp_min(1)
    pooled = (h_enc * enc_mask.unsqueeze(-1)).sum(1) / denom
    mu, logvar = model.vae.to_mu(pooled), model.vae.to_logvar(pooled)
    z = mu
    model._z_cached = z

    # surrogate 메모리 구성
    memory, mem_pad_mask = model.build_surrogate_memory(z, x_gt_ids=x_ref_ids)

    # 프리런
    dec_in = torch.full((1,1), model.bos_token, device=device, dtype=torch.long)
    toks = []
    for _ in range(out_len):
        logits = model.decode_step(dec_in, memory, mem_pad_mask, tokenizer=None, use_bias=True)
        y = logits.argmax(-1)  # (1,)
        toks.append(int(y.item()))
        dec_in = torch.cat([dec_in, y.unsqueeze(0)], dim=1)
    return toks  # 생성된 토큰 ID 리스트

# 예시:
x0, _ = next(iter(val_loader))
sample_ids = generate_greedy(model, x0[:1], out_len=50)
print(sample_ids[:20])
