## 1. Install Dependencies and device setup

In [1]:
!pip install torch
!pip install fair-esm
!pip install biopython

Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)
  Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)
  Downloading nvidia_nvjitlink_cu12-12.4.127-py3-n

In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.cuda.amp import autocast, GradScaler
from torch.nn.functional import cross_entropy, mse_loss, cosine_similarity
import esm
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 2. Data Loading

In [5]:
from Bio import SeqIO
seq_path='/kaggle/input/uniref50-sub/uniref50_subsample.fasta'
sequences=[]
for seq_record in SeqIO.parse(seq_path, "fasta"):
    sequences.append(str(seq_record.seq))
print(len(sequences))

1000000


In [None]:
import random

# 예시 리스트
items = sequences

# 1) 중복 없이 k개 샘플링
k = 14000
sequences= random.sample(items, k)

## 3. Model & Dataset Definitions

In [6]:
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn

# ── Dataset & Collate ──
class ProteinDataset(Dataset):
    def __init__(self, sequences, alphabet):
        self.sequences = [seq[:MAX_LEN] for seq in sequences]
        self.alphabet  = alphabet

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        idxs = [self.alphabet.get_idx(c) for c in self.sequences[idx]]
        return torch.tensor(idxs, dtype=torch.long)

def collate_fn(batch, pad_idx):
    padded = pad_sequence(batch, batch_first=True, padding_value=pad_idx)
    mask   = (padded != pad_idx)
    return padded, mask

# ── Teacher & Encoder ──
class SmallTransformer(nn.Module):
    def __init__(self, vocab_size, emb_dim, layers, heads, ffn_dim, max_len, pad_idx):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.pos = nn.Parameter(torch.zeros(1, max_len, emb_dim))
        layer   = nn.TransformerEncoderLayer(
            d_model=emb_dim, nhead=heads,
            dim_feedforward=ffn_dim, batch_first=True,
            activation='gelu', dropout=DROPOUT
        )
        self.enc = nn.TransformerEncoder(layer, layers)
        self.ln  = nn.LayerNorm(emb_dim)

    def forward(self, x):
        mask = x != self.emb.padding_idx
        h    = self.emb(x) + self.pos[:, :x.size(1), :]
        h    = self.enc(h, src_key_padding_mask=~mask)
        return self.ln(h), mask

class BigTransformer(SmallTransformer):
    pass  # identical API

# ── Single‐stage VAE ──
class VAETransformerDecoder(nn.Module):
    def __init__(self, encoder, vocab_size, latent_dim=LATENT_DIM,
                 emb_dim=EMB_DIM, num_layers=NUM_LAYERS, num_heads=NUM_HEADS,
                 ffn_dim=FFN_DIM, max_len=MAX_LEN, pad_token=0, bos_token=1):
        super().__init__()
        self.encoder   = encoder
        self.pad_token = pad_token
        self.bos_token = bos_token

        # latent heads
        self.to_mu     = nn.Linear(emb_dim, latent_dim)
        self.to_logvar = nn.Linear(emb_dim, latent_dim)
        self.latent2emb= nn.Linear(latent_dim, emb_dim)

        # decoder
        self.dec_emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_token)
        self.dec_pos = nn.Parameter(torch.zeros(1, max_len, emb_dim))
        layer = nn.TransformerDecoderLayer(
            d_model=emb_dim, nhead=num_heads,
            dim_feedforward=ffn_dim, dropout=DROPOUT,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(layer, num_layers)
        self.out     = nn.Linear(emb_dim, vocab_size)

    def forward(self, x, mask):
        # encode
        h_enc, enc_mask = self.encoder(x)
        pooled = (h_enc * enc_mask.unsqueeze(-1)).sum(1) / enc_mask.sum(1, True)
        mu, logvar = self.to_mu(pooled), self.to_logvar(pooled)
        z = mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)

        # prepare decoder input
        B, L = x.size()
        dec_in = torch.full((B, L), self.bos_token, device=x.device, dtype=torch.long)
        dec_in[:,1:] = x[:,:-1]
        emb = self.dec_emb(dec_in) + self.dec_pos[:, :L, :]
        z_emb = self.latent2emb(z).unsqueeze(1).expand(-1, L, -1)
        emb = emb + z_emb

        tgt_mask = nn.Transformer.generate_square_subsequent_mask(L).to(x.device)
        h_dec = self.decoder(
            tgt=emb,
            memory=h_enc,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=~mask,
            memory_key_padding_mask=~enc_mask
        )
        logits = self.out(h_dec)
        return logits, mu, logvar, h_enc, enc_mask


## 3.1 Load Teacher and alphabet

In [7]:
# load alphabet & teacher
import esm

_, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
PAD_IDX     = alphabet.get_idx('<pad>')
BOS_IDX     = alphabet.get_idx('<cls>')
ckpt        = torch.load('/kaggle/input/esms/transformers/default/1/distilled_embeddings_two_stage.pt', map_location='cpu')
teacher     = SmallTransformer(
        len(alphabet.all_toks), EMB_DIM, NUM_LAYERS, NUM_HEADS, FFN_DIM, MAX_LEN, PAD_IDX
    ).to(device)
teacher.load_state_dict(ckpt['student_state_dict'], strict=False)
teacher.eval()
if torch.cuda.device_count()>1:
    teacher = nn.DataParallel(teacher)

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt


FileNotFoundError: [Errno 2] No such file or directory: '/kaggle/input/esms/transformers/default/1/distilled_embeddings_two_stage.pt'

## 3.2 Define model

In [2]:
# ── Config ──
MAX_LEN       = 512
BATCH_SIZE    = 64
LATENT_DIM    = 256
EMB_DIM       = 256
NUM_LAYERS    = 4
NUM_HEADS     = 4
FFN_DIM       = 512
DROPOUT       = 0.3

LR_PHASE1     = 5e-4    # higher LR for CE‐only warmup
LR_PHASE2     = 1e-4    # later LR

EPOCHS_PHASE1 = 100      # CE‐only for first 20 epochs
TOTAL_EPOCHS  = 500

CE_WEIGHT1    = 100.0    # CE weight during phase1
CE_WEIGHT2    = 1.0     # CE weight afterwards
KL_WEIGHT     = 0.1
COS_WEIGHT    = 5.0
MSE_WEIGHT    = 5.0

In [8]:
import torch.nn as nn
import torch
enc = BigTransformer(
        len(alphabet.all_toks), EMB_DIM, NUM_LAYERS, NUM_HEADS, FFN_DIM, MAX_LEN, PAD_IDX
    ).to(device)
vae = VAETransformerDecoder(
        encoder=enc,
        vocab_size=len(alphabet.all_toks),
        latent_dim=LATENT_DIM,
        emb_dim=EMB_DIM, num_layers=NUM_LAYERS, num_heads=NUM_HEADS,
        ffn_dim=FFN_DIM, max_len=MAX_LEN,
        pad_token=PAD_IDX, bos_token=BOS_IDX
    ).to(device)
vae = nn.DataParallel(vae)

scaler = GradScaler()

  scaler = GradScaler()


## 4. Data Preparation

In [9]:
from torch.utils.data import DataLoader
ds        = ProteinDataset(sequences, alphabet)
t,v,s     = int(0.8*len(ds)), int(0.1*len(ds)), len(ds) - int(0.9*len(ds))
train_ds, val_ds, test_ds = random_split(ds, [t,v,s])
train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True,
                              collate_fn=lambda b: collate_fn(b, PAD_IDX))
val_loader   = DataLoader(val_ds,   BATCH_SIZE, shuffle=False,
                              collate_fn=lambda b: collate_fn(b, PAD_IDX))


# 5. Training & Validation

## 5.1 Train and checkpoint

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.cuda.amp import autocast, GradScaler
from torch.nn.functional import cross_entropy, mse_loss, cosine_similarity
import esm
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence


for ep in range(1, TOTAL_EPOCHS+1):
        # adjust LR and weights
        if ep <= EPOCHS_PHASE1:
            lr       = LR_PHASE1
            w_ce, w_cos, w_mse, w_kl = CE_WEIGHT1, 0.0, 0.0, 0.0
        else:
            lr       = LR_PHASE2
            w_ce, w_cos, w_mse, w_kl = CE_WEIGHT2, COS_WEIGHT, MSE_WEIGHT, KL_WEIGHT
        for pg in vae.parameters():
            pg.requires_grad = True
        optimizer = optim.AdamW(vae.parameters(), lr=lr)

        # train
        stats = {'ce':0,'cos':0,'mse':0,'kl':0}
        vae.train()
        for x,mask in tqdm(train_loader, desc=f"Train{ep}"):
            x,mask = x.to(device), mask.to(device)
            optimizer.zero_grad()
            with autocast():
                logits, mu, logvar, h_enc, enc_mask = vae(x,mask)
                # CE
                ce = cross_entropy(
                    logits.view(-1,logits.size(-1)),
                    x.view(-1),
                    ignore_index=PAD_IDX,
                    label_smoothing=0.0
                )
                # teacher losses
                with torch.no_grad():
                    orig_h,_  = teacher(x)
                recon_tokens = logits.argmax(-1)
                recon_h,_  = teacher(recon_tokens)
                cos_res = 1 - cosine_similarity(orig_h, recon_h, dim=-1)
                cos = cos_res.masked_select(mask).mean()
                mse_feat = (orig_h - recon_h).pow(2).mean(-1)
                mse = mse_feat.masked_select(mask).mean()
                # KL
                kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

                loss = w_ce*torch.exp(ce) + w_cos*cos + w_mse*mse + w_kl*kl

            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()

            stats['ce']  += ce.item()
            stats['cos'] += cos.item()
            stats['mse'] += mse.item()
            stats['kl']  += kl.item()
        for k in stats: stats[k] /= len(train_loader)

        # validate
        vstats = {'ce':0,'cos':0,'mse':0,'kl':0}
        vae.eval()
        with torch.no_grad():
            for x,mask in val_loader:
                x,mask = x.to(device), mask.to(device)
                logits, mu, logvar, h_enc, enc_mask = vae(x,mask)
                ce = cross_entropy(logits.view(-1,logits.size(-1)), x.view(-1), ignore_index=PAD_IDX, label_smoothing=0.0)
                orig_h,_  = teacher(x)
                recon_h,_ = teacher(logits.argmax(-1))
                cos = (1 - cosine_similarity(orig_h, recon_h, dim=-1)).masked_select(mask).mean()
                mse = ((orig_h - recon_h).pow(2).mean(-1)).masked_select(mask).mean()
                kl  = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

                vstats['ce']  += ce.item()
                vstats['cos'] += cos.item()
                vstats['mse'] += mse.item()
                vstats['kl']  += kl.item()
        for k in vstats: vstats[k] /= len(val_loader)

        print(
            f"Epoch {ep:2d} | "
            f"Train CE={stats['ce']:.3f} COS={stats['cos']:.3f} MSE={stats['mse']:.3f} KL={stats['kl']:.3f} | "
            f" Val CE={vstats['ce']:.3f} COS={vstats['cos']:.3f} MSE={vstats['mse']:.3f} KL={vstats['kl']:.3f}"
        )
        if ep%10==0:
            # ─── Checkpoint save ───
            SAVE_PATH = f"/kaggle/working/vae_epoch{ep:03d}.pt"
            model_to_save = vae.module if hasattr(vae, "module") else vae
            torch.save({
                "epoch":    ep,
                "model_sd": model_to_save.state_dict(),
                "opt_sd":   optimizer.state_dict(),
                "scaler_sd": scaler.state_dict(),
            }, SAVE_PATH)
            print(f"Saved checkpoint to {SAVE_PATH}")

# 6. Load Saved VAE

## 6.1 Model definition

In [None]:
# ── Config ──
MAX_LEN       = 512
BATCH_SIZE    = 64
LATENT_DIM    = 256
EMB_DIM       = 256
NUM_LAYERS    = 4
NUM_HEADS     = 4
FFN_DIM       = 512
DROPOUT       = 0.3

LR_PHASE1     = 5e-4    # higher LR for CE‐only warmup
LR_PHASE2     = 1e-4    # later LR

EPOCHS_PHASE1 = 100      # CE‐only for first 20 epochs
TOTAL_EPOCHS  = 500

CE_WEIGHT1    = 100.0    # CE weight during phase1
CE_WEIGHT2    = 1.0     # CE weight afterwards
KL_WEIGHT     = 0.1
COS_WEIGHT    = 5.0
MSE_WEIGHT    = 5.0

In [None]:
class ProteinDataset(Dataset):
    def __init__(self, sequences, alphabet):
        self.sequences = [seq[:MAX_LEN] for seq in sequences]
        self.alphabet  = alphabet

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        idxs = [self.alphabet.get_idx(c) for c in self.sequences[idx]]
        return torch.tensor(idxs, dtype=torch.long)

def collate_fn(batch, pad_idx):
    padded = pad_sequence(batch, batch_first=True, padding_value=pad_idx)
    mask   = (padded != pad_idx)
    return padded, mask
    
class SmallTransformer(nn.Module):
    def __init__(self, vocab_size, emb_dim, layers, heads, ffn_dim, max_len, pad_idx):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.pos = nn.Parameter(torch.zeros(1, max_len, emb_dim))
        layer   = nn.TransformerEncoderLayer(
            d_model=emb_dim, nhead=heads,
            dim_feedforward=ffn_dim, batch_first=True,
            activation='gelu', dropout=DROPOUT
        )
        self.enc = nn.TransformerEncoder(layer, layers)
        self.ln  = nn.LayerNorm(emb_dim)

    def forward(self, x):
        mask = x != self.emb.padding_idx
        h    = self.emb(x) + self.pos[:, :x.size(1), :]
        h    = self.enc(h, src_key_padding_mask=~mask)
        return self.ln(h), mask

class BigTransformer(SmallTransformer):
    pass  # identical API

# ── Single‐stage VAE ──
class VAETransformerDecoder(nn.Module):
    def __init__(self, encoder, vocab_size, latent_dim=LATENT_DIM,
                 emb_dim=EMB_DIM, num_layers=NUM_LAYERS, num_heads=NUM_HEADS,
                 ffn_dim=FFN_DIM, max_len=MAX_LEN, pad_token=0, bos_token=1):
        super().__init__()
        self.encoder   = encoder
        self.pad_token = pad_token
        self.bos_token = bos_token

        # latent heads
        self.to_mu     = nn.Linear(emb_dim, latent_dim)
        self.to_logvar = nn.Linear(emb_dim, latent_dim)
        self.latent2emb= nn.Linear(latent_dim, emb_dim)

        # decoder
        self.dec_emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_token)
        self.dec_pos = nn.Parameter(torch.zeros(1, max_len, emb_dim))
        layer = nn.TransformerDecoderLayer(
            d_model=emb_dim, nhead=num_heads,
            dim_feedforward=ffn_dim, dropout=DROPOUT,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(layer, num_layers)
        self.out     = nn.Linear(emb_dim, vocab_size)

    def forward(self, x, mask):
        # encode
        h_enc, enc_mask = self.encoder(x)
        pooled = (h_enc * enc_mask.unsqueeze(-1)).sum(1) / enc_mask.sum(1, True)
        mu, logvar = self.to_mu(pooled), self.to_logvar(pooled)
        z = mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)

        # prepare decoder input
        B, L = x.size()
        dec_in = torch.full((B, L), self.bos_token, device=x.device, dtype=torch.long)
        dec_in[:,1:] = x[:,:-1]
        emb = self.dec_emb(dec_in) + self.dec_pos[:, :L, :]
        z_emb = self.latent2emb(z).unsqueeze(1).expand(-1, L, -1)
        emb = emb + z_emb

        tgt_mask = nn.Transformer.generate_square_subsequent_mask(L).to(x.device)
        h_dec = self.decoder(
            tgt=emb,
            memory=h_enc,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=~mask,
            memory_key_padding_mask=~enc_mask
        )
        logits = self.out(h_dec)
        return logits, mu, logvar, h_enc, enc_mask

## 6.2 load model

In [None]:
# load alphabet & teacher
import esm

_, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
PAD_IDX     = alphabet.get_idx('<pad>')
BOS_IDX     = alphabet.get_idx('<cls>')
enc = BigTransformer(
        len(alphabet.all_toks), EMB_DIM, NUM_LAYERS, NUM_HEADS, FFN_DIM, MAX_LEN, PAD_IDX
    ).to(device)
vae = VAETransformerDecoder(
        encoder=enc,
        vocab_size=len(alphabet.all_toks),
        latent_dim=LATENT_DIM,
        emb_dim=EMB_DIM, num_layers=NUM_LAYERS, num_heads=NUM_HEADS,
        ffn_dim=FFN_DIM, max_len=MAX_LEN,
        pad_token=PAD_IDX, bos_token=BOS_IDX
    ).to(device)
#vae = nn.DataParallel(vae)
# ── 0) Config ──
device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT  = "/kaggle/input/esms-vae/pytorch/default2/1/vae_epoch380.pt"
noise_scale = 0.2   # tweak as needed

# load alphabet (must match training)
_, alphabet = esm.pretrained.esm2_t33_650M_UR50D()

# your BOS (“begin sequence”) and PAD token IDs:
BOS_IDX     = alphabet.get_idx('<cls>')
PAD_IDX     = alphabet.get_idx('<pad>')

# ── 1) Load model + checkpoint ──
# assume `vae` is already defined (your VAETransformerDecoder wrapped in DataParallel)
ckpt = torch.load(CHECKPOINT, map_location=device)
model = vae.module if hasattr(vae, "module") else vae
model.load_state_dict(ckpt["model_sd"])
model.to(device).eval()

## 6.3 helper function

In [None]:
# ── Helpers ──
def reconstruction_accuracy(orig: str, recon: str) -> float:
    assert len(orig) == len(recon), "Lengths must match"
    return sum(o == r for o, r in zip(orig, recon)) / len(orig) * 100.0

def decode_batch(id_seqs, alphabet, pad_idx):
    strs = []
    for seq in id_seqs.cpu().tolist():
        chars = []
        for idx in seq:
            if idx == pad_idx:
                break
            chars.append(alphabet.get_tok(idx))
        strs.append("".join(chars))
    return strs

## 6.4 load data

In [None]:

# ── 2) Prepare a batch x, mask ──
# Example using a test DataLoader of token-ID tensors:
def collate_fn(batch, pad_idx):
    padded = pad_sequence(batch, batch_first=True, padding_value=pad_idx)
    mask   = (padded != pad_idx)
    return padded, mask
ds        = ProteinDataset(sequences, alphabet)
t,v,s     = int(0.8*len(ds)), int(0.1*len(ds)), len(ds) - int(0.9*len(ds))
train_ds, val_ds, test_ds = random_split(ds, [t,v,s])
# Suppose you have a `test_dataset` yielding LongTensor sequences of token‐IDs:
test_loader = DataLoader(
    test_ds,
    batch_size=64,
    shuffle=False,
    collate_fn=lambda b: collate_fn(b, PAD_IDX)
)
# grab one batch
x, mask = next(iter(test_loader))
x, mask = x.to(device), mask.to(device)

## 6.4 Latent vector space test

In [None]:
# ── 3) Forward through encoder ──
with torch.no_grad():
    logits_clean, mu, logvar, h_enc, enc_mask = model(x, mask)

    # compute z and noisy z
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z   = mu + std * eps

    eps2    = torch.randn_like(std)
    z_noisy = z + noise_scale * eps2

    # diagnostics
    delta = (z_noisy - z).view(z.size(0), -1).norm(dim=1)
    print(f"│z_noisy − z│ mean: {delta.mean().item():.4f}, std: {delta.std().item():.4f}")

    z_emb_clean = model.latent2emb(z)
    emb_norm    = z_emb_clean.view(z_emb_clean.size(0), -1).norm(dim=1)
    print(f"‖latent2emb(z)‖ mean: {emb_norm.mean().item():.4f}")

# ── 4) Decode clean + latent‐noise versions ──
def decode_from(z_latent):
    B, L = x.size()
    dec_in = torch.full((B, L), BOS_IDX, device=device, dtype=torch.long)
    dec_in[:,1:] = x[:,:-1]
    emb       = model.dec_emb(dec_in) + model.dec_pos[:, :L, :]
    z_emb     = model.latent2emb(z_latent).unsqueeze(1).expand(-1, L, -1)
    dec_input = emb + z_emb

    tgt_mask  = nn.Transformer.generate_square_subsequent_mask(L).to(device)
    h_dec = model.decoder(
        tgt=dec_input,
        memory=h_enc,
        tgt_mask=tgt_mask,
        tgt_key_padding_mask=~mask,
        memory_key_padding_mask=~enc_mask
    )
    logits = model.out(h_dec)
    return logits.argmax(-1)

recon_clean = decode_from(z)
recon_noisy = decode_from(z_noisy)

In [None]:
# 한 배치의 mu, logvar를 뽑아서 분포를 시각화
import matplotlib.pyplot as plt
import numpy as np
mus = mu.detach().cpu().numpy().flatten()
logvars = logvar.detach().cpu().numpy().flatten()
plt.hist(mus, bins=50); plt.title("mu distribution"); plt.show()
plt.hist(np.exp(0.5*logvars), bins=50); plt.title("sigma distribution"); plt.show()


## 6.5 Posterior Collapse Check

In [None]:
# ── 2) Posterior Collapse Check (KL per dimension) ──
print("\n## Posterior Collapse Check")
with torch.no_grad():
    all_mu, all_logvar = [], []
    for x, mask in test_loader:
        x, mask = x.to(device), mask.to(device)
        _, mu, logvar, _, _ = model(x, mask)
        all_mu.append(mu); all_logvar.append(logvar)
    mu = torch.cat(all_mu, dim=0)
    logvar = torch.cat(all_logvar, dim=0)
    kl_per_dim = (0.5 * (mu.pow(2) + logvar.exp() - 1 - logvar)).mean(0)
    print("KL per dimension:", kl_per_dim.cpu().numpy())

## 6.6 Novel generation

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm, trange

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# DataParallel 해제한 VAE 모듈
vae_module = vae 
dec_emb     = vae_module.dec_emb
dec_pos     = vae_module.dec_pos
latent2emb  = vae_module.latent2emb
decoder     = vae_module.decoder
out_proj    = vae_module.out

# 상수
MAX_LEN    = 512
LATENT_DIM = 256
PAD_IDX    = PAD_IDX
BOS_IDX    = BOS_IDX
EMB_DIM    = dec_emb.embedding_dim

def generate_from_z(z, max_len=MAX_LEN):
    batch = z.size(0)
    # 1) 초기 디코더 입력 (BOS + PAD)
    generated = torch.full((batch, max_len), PAD_IDX, device=device, dtype=torch.long)
    generated[:, 0] = BOS_IDX

    # 2) causal mask
    tgt_mask = torch.triu(torch.full((max_len, max_len), float('-inf')),
                          diagonal=1).to(device)
    # 3) dummy memory
    memory = torch.zeros(batch, max_len, EMB_DIM, device=device)

    with torch.no_grad():
        for t in trange(1, max_len):
            # 토큰+포지셔널+latent 임베딩
            tok_emb = dec_emb(generated[:, :t])                          # (B, t, E)
            pos_emb = dec_pos[:, :t, :]                                  # (1, t, E)
            z_emb   = latent2emb(z).unsqueeze(1).expand(-1, t, -1)        # (B, t, E)
            tgt     = tok_emb + pos_emb + z_emb                          # (B, t, E)

            # 디코더 호출
            dec_out = decoder(
                tgt=tgt,
                memory=memory,
                tgt_mask=tgt_mask[:t, :t],
                memory_key_padding_mask=None,
                tgt_key_padding_mask=None
            )  # (B, t, E)

            # 다음 토큰 예측
            logits     = out_proj(dec_out)                               # (B, t, V)
            next_token = logits[:, -1].argmax(-1)                        # (B,)
            generated[:, t] = next_token

            # 모두 PAD면 종료
            if (next_token == PAD_IDX).all():
                break

    # 4) 인덱스를 시퀀스로 변환
    out_seqs = []
    for seq in generated.cpu().tolist():
        toks = [alphabet.all_toks[i] for i in seq if i not in (PAD_IDX, BOS_IDX)]
        out_seqs.append("".join(toks))

    return out_seqs

# -- 사용 예시 및 Novelty 계산 --
n = 1000
z_rand      = torch.randn(n, LATENT_DIM, device=device)
gen_seqs    = generate_from_z(z_rand, max_len=MAX_LEN)

def seq_identity(a, b):
    L = max(len(a), len(b))
    a2, b2 = a.ljust(L, '-'), b.ljust(L, '-')
    return sum(x==y for x,y in zip(a2, b2)) / L

max_id = []
for s in tqdm(gen_seqs, desc="Compute novelty"):
    ids = [seq_identity(s, t) for t in sequences]
    max_id.append(max(ids))

max_id = np.array(max_id)
print("Percent ≤30% identity:", np.mean(max_id <= 0.30)*100)
print("Median identity:       ", np.median(max_id))
print("90th percentile:       ", np.percentile(max_id, 90))
print("Max identity:          ", max_id.max())

## 6.7 Reconstruction test

In [10]:
import torch

# 1) 체크포인트 로드
BOS_IDX     = alphabet.get_idx('<cls>')
PAD_IDX     = alphabet.get_idx('<pad>')
CHECKPOINT_PATH = "/kaggle/input/esms-vae/pytorch/default2/1/vae_epoch380.pt"  # 가장 최신 파일 경로로 설정
ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
model_to_save = vae.module if hasattr(vae, "module") else vae
model_to_save.load_state_dict(ckpt["model_sd"])
model_to_save.eval()
ds        = ProteinDataset(sequences, alphabet)
train_ds, test_ds = random_split(ds, [1000000-100000,100000])

test_loader   = DataLoader(test_ds,   BATCH_SIZE, shuffle=False,
                              collate_fn=lambda b: collate_fn(b, PAD_IDX))
def reconstruction_accuracy(orig: str, recon: str) -> float:
    """
    두 시퀀스(orig, recon)의 길이는 동일하다고 가정.
    정확히 일치하는 토큰의 비율을 반환.
    """
    assert len(orig) == len(recon), "Original and reconstructed must have same length"
    matches = sum(o == r for o, r in zip(orig, recon))
    return matches / len(orig) * 100.0  # percentage
# 2) 검증 세트에서 재구성율 계산
correct, total = 0, 0
with torch.no_grad():
    for x, mask in test_loader:
        x, mask = x.to(device), mask.to(device)
        logits, *_ = vae(x, mask)  # logits만 필요
        preds = logits.argmax(dim=-1)
        correct += ((preds == x) & mask).sum().item()
        total   += mask.sum().item()

recon_acc = correct / total
print(f"Final Reconstruction Accuracy: {recon_acc*100:.5f}%")

# 2) 검증 세트에서 5개 시퀀스 재구성 출력
from torch.nn.functional import pad

n_show = 5
shown = 0

with torch.no_grad():
    for x, mask in test_loader:
        x, mask = x.to(device), mask.to(device)
        logits, *_ = vae(x, mask)
        preds = logits.argmax(dim=-1)

        # 배치 내 각 시퀀스에 대해
        for orig_ids, pred_ids, m in zip(x, preds, mask):
            # mask를 이용해 실제 토큰 길이만 추출
            length = m.sum().item()
            orig_seq = [alphabet.all_toks[i] for i in orig_ids[:length].tolist()]
            pred_seq = [alphabet.all_toks[i] for i in pred_ids[:length].tolist()]
            
            acc=reconstruction_accuracy(orig_seq, pred_seq)

            print(f"=== Example {shown+1} ===")
            print(f'reconstruction accuracy: {acc}')

            print("Original:       ", "".join(orig_seq))
            print("Reconstructed:  ", "".join(pred_seq))
            print()

            shown += 1
            if shown >= n_show:
                break
        if shown >= n_show:
            break

  output = torch._nested_tensor_from_mask(


Final Reconstruction Accuracy: 97.17987%
=== Example 1 ===
reconstruction accuracy: 100.0
Original:        MNLSMTDRDNATATSDSSRTACSVSRAAGPAVQLRIGRLRRTIGHRDHVRPGRDPVGEQLREHRAGDQAAFPGR
Reconstructed:   MNLSMTDRDNATATSDSSRTACSVSRAAGPAVQLRIGRLRRTIGHRDHVRPGRDPVGEQLREHRAGDQAAFPGR

=== Example 2 ===
reconstruction accuracy: 98.82352941176471
Original:        MIGYGNEEFGYKLWDPEKQKIVRSRDIVFHEHETIKDMEKNVVSTKLTYEGNLDEEIFMEQLEGFKVKGKENMVCKLKKSMYGLK
Reconstructed:   MIGYGNEEFGYKLWDPEKQKIVRSRDIVFHEHETIKDMEKNVVSTKLTYEGNLDEEIFMEQLEGFKVKGKENMVCKLKSSMYGLK

=== Example 3 ===
reconstruction accuracy: 96.98795180722891
Original:        MRVVRWLDTGLNSLNFLLHQISNLILMLIMFLTTFDVIGRALFNHSITGAYELTELGSAIVIFFTLAVTHKYKEHVAVGFLVDKLSAKKKAMIEGLVDLFIFVLILIMSFQLINEAMRLMERGTTTTDLGLPIYTFILIVSIGSFIFAFVALANGIKSMIEAVKKS
Reconstructed:   MRVVRWLDTGLNSLNFLLHQISNLILMLIMFLTTFDVIGRALFNHSITGAYELTELGSAIVIFTTLAVTHKYKEHVAVGFLVDKLSAKKAAMIEGLVDLFIFVLILMMSFQLINEAMRLMERGTTTDDLGLPIYTFILIVSIGSFIFAVVALANGIKSMIEAVKKS

=== Example 4 ===
reconstru

## 7. Evaluate Loaded VAE