# EvoFill all-in-one

本Notebook打包了项目涉及的所有模块与函数，依赖环境: `/home/qmtang/miniconda3/envs/mamba`，可参考通过以下命令克隆：
```bash
conda create --name mamba --clone /home/qmtang/miniconda3/envs/mamba
```

设置工作目录

In [None]:
import os
os.chdir(r'/home/qmtang/mnt_qmtang/EvoFill') 

## 1. Dataloader

In [1]:
import json
from types import SimpleNamespace
def load_config(path: str) -> SimpleNamespace:
    def hook(d):
        return SimpleNamespace(**{k: hook(v) if isinstance(v, dict) else v
                                  for k, v in d.items()})
    with open(path) as f:
        return json.load(f, object_hook=hook)

注：在针对变异位点坐标的嵌入上，这里的处理方式和STICI原版稍有不同

In [2]:
import os
import math
import json
import numpy as np
import torch
from tqdm import tqdm
from cyvcf2 import VCF

from src.utils import load_config


def build_quaternion(chrom, pos, chrom_len_dict, chrom_start_dict, genome_len):
    """
    返回 list[float32] 长度 4
    """
    def _log4(x):
        return math.log(x) / math.log(4)

    chrom = str(chrom).strip('chr')
    pos = int(pos)
    c_len = chrom_len_dict[chrom]
    c_start = chrom_start_dict[chrom]
    abs_pos = c_start + pos
    return [
        _log4(pos),
        _log4(c_len),
        _log4(abs_pos),
        _log4(genome_len),
    ]


def read_vcf(path: str, phased: bool, genome_json: str):
    """
    返回
        gts: np.ndarray (n_samples, n_snps)  int32
        samples: list[str]
        var_index: torch.Tensor (n_snps,)  int8
        depth: int
        pos_tensor: torch.Tensor (n_snps, 2)  str  // 染色体+坐标
        quat_tensor: torch.Tensor (n_snps, 4)  float32
    同时保存 var_index.pt
    """
    # ---- 0. 读基因组元信息 ----
    with open(genome_json) as f:
        gmeta = json.load(f)
    chrom_len = gmeta["chrom_len"]        # dict[str, int]
    chrom_start = gmeta["chrom_start"]    # dict[str, int]
    genome_len = gmeta["genome_len"]      # int

    vcf = VCF(path)
    samples = list(vcf.samples)

    gts_list = []
    var_depth_list = []
    quat_list = []

    total = sum(1 for _ in VCF(path))
    for var in tqdm(vcf, total=total, desc="Parsing VCF"):
        alleles = [var.REF] + var.ALT
        allele2idx = {a: i for i, a in enumerate(alleles)}

        row = []
        for gt_str in var.gt_bases:
            if gt_str in ['.|.', './.']:
                row.append([-1,-1])
            else:
                sep = '|' if phased else '/'
                for a in gt_str.split(sep):
                    row.append(allele2idx[a])
        row = np.array(row, dtype=np.int32)
        gts_list.append(row)
        
        var_depth_list.append(int(len(alleles)))

        # 变异位点位置坐标
        quat = build_quaternion(var.CHROM, var.POS, chrom_len, chrom_start, genome_len)
        quat_list.append(quat)

    gts = np.vstack(gts_list).T.astype(np.int32)
    flat = gts[gts >= 0]
    global_depth = int(flat.max())

    gts = torch.tensor(gts, dtype=torch.int8)
    var_depth_index = torch.tensor(var_depth_list, dtype=torch.int8)
    quat_tensor = torch.tensor(quat_list, dtype=torch.float32)

    return gts, samples, var_depth_index, global_depth, quat_tensor

In [3]:
cfg = load_config("/home/qmtang/mnt_qmtang/EvoFill/config/config.json")
os.makedirs(cfg.data.path, exist_ok=True)

phased = bool(cfg.data.tihp)
genome_json = cfg.data.genome_json

In [4]:
import gzip
filename = cfg.data.train_vcf
opener = gzip.open if filename.endswith('.gz') else open

n_lines = 10
with opener(filename, 'rt') as f:
    lines = 0
    for line in f:
        print(line.rstrip('\n'))
        lines += 1
        if lines >= n_lines:
            break

##fileformat=VCFv4.2
##source=tskit 0.6.4
##FILTER=<ID=PASS,Description="All filters passed">
##contig=<ID=chr22,length=50818468>
##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">
#CHROM	POS	ID	REF	ALT	QUAL	FILTER	INFO	FORMAT	tsk_0	tsk_1	tsk_2	tsk_3	tsk_4	tsk_5	tsk_6	tsk_7	tsk_8	tsk_9	tsk_10	tsk_11	tsk_12	tsk_13	tsk_14	tsk_15	tsk_16	tsk_17	tsk_18	tsk_19	tsk_20	tsk_21	tsk_22	tsk_23	tsk_24	tsk_25	tsk_26	tsk_27	tsk_28	tsk_29	tsk_30	tsk_31	tsk_32	tsk_33	tsk_34	tsk_35	tsk_36	tsk_37	tsk_38	tsk_39	tsk_40	tsk_41	tsk_42	tsk_43	tsk_44	tsk_45	tsk_46	tsk_47	tsk_48	tsk_49	tsk_50	tsk_51	tsk_52	tsk_53	tsk_54	tsk_55	tsk_56	tsk_57	tsk_58	tsk_59	tsk_60	tsk_61	tsk_62	tsk_63	tsk_64	tsk_65	tsk_66	tsk_67	tsk_68	tsk_69	tsk_70	tsk_71	tsk_72	tsk_73	tsk_74	tsk_75	tsk_76	tsk_77	tsk_78	tsk_79	tsk_80	tsk_81	tsk_82	tsk_83	tsk_84	tsk_85	tsk_86	tsk_87	tsk_88	tsk_89	tsk_90	tsk_91	tsk_92	tsk_93	tsk_94	tsk_95	tsk_96	tsk_97	tsk_98	tsk_99	tsk_100	tsk_101	tsk_102	tsk_103	tsk_104	tsk_105	tsk_106	tsk_107	tsk_108	t

In [5]:
# ---------- 训练集 ----------
train_gts, train_samples, var_depth_index, global_depth, quat_train = read_vcf(
    cfg.data.train_vcf, phased, genome_json)
print(f"Inferred unified depth = {global_depth}")

torch.save({'gts': train_gts, 'coords':quat_train, 'var_depths':var_depth_index},
            os.path.join(cfg.data.path, "train.pt"))

print(f"Saved train.pt | gts={tuple(train_gts.shape)} | coords={tuple(quat_train.shape)}")

Parsing VCF:   3%|▎         | 1926/56145 [00:02<01:09, 775.65it/s]


KeyboardInterrupt: 

train_gts 张量形状 = (样本，变异位点)

In [None]:
print(train_gts)

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.int8)


quat_train 张量形状 = (样本，4)

变异位点4维坐标 = 位点所在染色体上的位置，位点所在染色体长度，位点在全基因组上的位置，全基因组长度

记录值为原碱基数取log4

In [None]:
print(quat_train)

tensor([[ 2.6610, 12.7994, 15.6976, 15.7621],
        [ 3.7047, 12.7994, 15.6976, 15.7621],
        [ 4.0900, 12.7994, 15.6976, 15.7621],
        ...,
        [11.1267, 12.7994, 15.6989, 15.7621],
        [11.1267, 12.7994, 15.6989, 15.7621],
        [11.1267, 12.7994, 15.6989, 15.7621]])


In [None]:
# ---------- 验证集 ----------
val_gts, val_samples, _, _, quat_val = read_vcf(
    cfg.data.val_vcf, phased, genome_json)

torch.save({'gts': val_gts, 'coords':quat_val, 'var_depths':var_depth_index},
            os.path.join(cfg.data.path, "val.pt"))

print(f"Saved val.pt   | gts={tuple(val_gts.shape)} | coords={tuple(quat_val.shape)}")

Parsing VCF: 100%|██████████| 56145/56145 [00:37<00:00, 1503.54it/s]


Saved val.pt   | gts=(2000, 56145) | coords=(56145, 4)


## 2. Model

将STICI中原有的 `chunk_module` 内的注意力、全连接层等模块统一替换为 `BiMamba2Block`

以下几个地方需要注意和STICI的差别：

1. chunk 划分标准：按 chunk_size 分割？按 n_chunks 分割？尾部 chunk 位点数不足的部分如何处理？ 
2. 分 chunk 后的 concat 部分，如何处理 chunk 与 chunk 之间 overlap 的位点？
3. STICI 在 concat chunk 后，又经过了两层 Conv1D，文章示意图上未标明。
 

In [1]:
import math
import torch
import torch.nn as nn
from typing import Optional
import torch.nn.functional as F
from mamba_ssm import Mamba2  # 官方实现



class CatEmbeddings(nn.Module):
    """
    等位基因 + 坐标 嵌入
    -1 -> padding_idx (n_cats) -> 零向量
    """
    def __init__(self, n_cats: int, d_model: int, coord_dim: int = 4):
        super().__init__()
        self.allele_embed = nn.Embedding(n_cats + 1, d_model, padding_idx=n_cats)
        self.coord_proj = nn.Linear(coord_dim, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor, x_coord: torch.Tensor):
        """
        x:       (B, L)        long，-1 会被当成 padding_idx n_cats
        x_coord: (L, 4)        float
        return:  (B, L, d_model)
        """
        x = x.masked_fill(x == -1, self.allele_embed.padding_idx)
        e1 = self.allele_embed(x)                       # (B,L,d)
        e2 = self.coord_proj(x_coord).unsqueeze(0)      # (1,L,d)
        return self.norm(e1 + e2)

class BiMamba2Block(nn.Module):
    def __init__(
        self,
        d_model: int,
        bidirectional: bool = True,
        bidirectional_strategy: str = "add",      # "add" | "ew_multiply"
        bidirectional_weight_tie: bool = True,
        # ---- 以下透传给 Mamba2 ----
        d_state: int = 128,
        expand: int = 2,
        d_conv: int = 4,
        conv_bias: bool = True,
        bias: bool = False,
        headdim: int = 64,
        ngroups: int = 1,
        **mamba2_kwargs,
    ):
        super().__init__()
        if bidirectional and bidirectional_strategy not in {"add", "ew_multiply"}:
            raise NotImplementedError(bidirectional_strategy)

        self.bidirectional = bidirectional
        self.strategy = bidirectional_strategy

        # 前向 SSM
        self.mamba_fwd = Mamba2(
            d_model=d_model,
            d_state=d_state,
            expand=expand,
            d_conv=d_conv,
            conv_bias=conv_bias,
            bias=bias,
            headdim=headdim,
            ngroups=ngroups,
            **mamba2_kwargs,
        )

        if bidirectional:
            self.mamba_rev = Mamba2(
                d_model=d_model,
                d_state=d_state,
                expand=expand,
                d_conv=d_conv,
                conv_bias=conv_bias,
                bias=bias,
                headdim=headdim,
                ngroups=ngroups,
                **mamba2_kwargs,
            )
            if bidirectional_weight_tie:
                self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
                self.mamba_rev.in_proj.bias   = self.mamba_fwd.in_proj.bias
                self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
                self.mamba_rev.out_proj.bias   = self.mamba_fwd.out_proj.bias
        else:
            self.mamba_rev = None

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        out = self.mamba_fwd(x)
        if self.bidirectional:
            x_rev = x.flip(dims=[1])
            out_rev = self.mamba_rev(x_rev).flip(dims=[1])
            if self.strategy == "add":
                out = out + out_rev
            elif self.strategy == "ew_multiply":
                out = out * out_rev
            else:
                raise RuntimeError(self.strategy)
        return out


class ChunkModule(nn.Module):
    def __init__(self, d_model: int, n_layers: int, **mamba_kwargs):
        super().__init__()
        self.blocks = nn.ModuleList([
            BiMamba2Block(d_model=d_model, **mamba_kwargs)
            for _ in range(n_layers)
        ])

    def forward(self, x_chunk):
        for blk in self.blocks:
            x_chunk = blk(x_chunk)
        return x_chunk

class EvoFill(nn.Module):
    def __init__(
        self,
        n_cats: int,
        chunk_size: int,
        d_model: int = 256,
        n_layers: int = 4,
        chunk_overlap: int = 64,          # 直接指定 overlap 长度
        **mamba_kwargs,
    ):
        super().__init__()
        self.n_cats = n_cats
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

        self.embed = CatEmbeddings(n_cats, d_model)
        # 只实例化一个 ChunkModule，所有 chunk 共享权重
        self.chunk_module = ChunkModule(d_model, n_layers, **mamba_kwargs)

        self.length_proj = nn.Sequential(
            nn.Conv1d(d_model, d_model, kernel_size=3, padding=1),
            nn.GELU(),
        )
        self.out_conv = nn.Conv1d(d_model, n_cats, kernel_size=1)

    def forward(self, x: torch.Tensor, x_coord: torch.Tensor):
        """
        x:       (B, L)        long，-1 表示 padding
        x_coord: (L, 4)        float
        return:  (B, L, n_cats)
        """
        B, L_orig = x.shape
        device = x.device

        # 1. 嵌入
        h = self.embed(x, x_coord)                      # (B, L_orig, d)

        # 2. 滑窗切分
        chunk_size = self.chunk_size
        overlap = self.chunk_overlap
        step = chunk_size - overlap
        n_chunks = math.ceil((L_orig - overlap) / step)

        # 3. 需要 pad 到能整除 step
        pad_len = n_chunks * step + overlap - L_orig
        if pad_len > 0:
            h = F.pad(h, (0, 0, 0, pad_len))            # (B, L_pad, d)
        L_pad = h.shape[1]

        # 4. 收集每个 chunk 的输出，同时记录每个 token 被哪些 chunk 覆盖
        out_buf = torch.zeros(B, L_pad, h.shape[-1], device=device)
        count_buf = torch.zeros(B, L_pad, dtype=torch.long, device=device)

        for i in range(n_chunks):
            start = i * step
            end = start + chunk_size
            chunk = h[:, start:end, :]                  # (B, chunk_size, d)
            chunk_out = self.chunk_module(chunk)        # (B, chunk_size, d)

            # 累加重叠区域
            out_buf[:, start:end, :] += chunk_out
            count_buf[:, start:end] += 1

        # 5. 重叠区域取平均
        out_buf = out_buf / count_buf.unsqueeze(-1).clamp_min(1)

        # 6. 1D 卷积 + 插值回原始长度
        out = self.length_proj(out_buf.transpose(1, 2))  # (B, d, L_pad)
        out = F.interpolate(out, size=L_orig, mode='linear', align_corners=False)

        # 7. 输出 logits
        logits = self.out_conv(out).transpose(1, 2)      # (B, L_orig, n_cats)
        return logits

模型单测，检查输出张量形状

In [None]:
# unit test
model = EvoFill(
    n_cats=3,
    chunk_size=512,
    chunk_overlap=64,
    d_model=256,
    n_layers=4,
    d_state=128,
    expand=2,
).cuda()

x = torch.randint(-1, 2, (2, 1800)).cuda() # -1（missing）, 0, 1
x_coord = torch.randn(1800, 4).cuda()
logits = model(x, x_coord)
print(logits.shape)

torch.Size([2, 1800, 3])


## 3. Loss

STICI 的 `ImputationLoss` 计算本身以及Pytorch复现存在问题：

`tf.keras.losses.KLDivergence` 计算要求输入logits，会在其内部做一遍`softmax`。

然而 STICI 最后一层（`STICI_V1.1.py`, L387）：

`self.last_conv = layers.Conv1D(self.in_channel - 1, 5, padding='same', activation=tf.nn.softmax)`

输出值已经做过一次`softmax`，如此一来会导致计算的KL散度偏大。

同时，Pytorch 中的`nn.KLDivLoss`输入要求为`softmax`处理后的`log-probabilities`，其余差异见下表：

| 特性                | `tf.keras.losses.KLDivergence` | `nn.KLDivLoss`                        |
| ----------------- | ------------------------------ | ------------------------------------- |
| 输入格式              | 概率                             | 输入：log-probabilities，目标：probabilities |
| 是否需手动取 log        | ❌                              | ✅                                     |
| 是否自动裁剪输入          | ✅                              | ❌                                     |
| 默认归约方式            | `sum_over_batch_size`          | `batchmean`                           |
| 是否支持 `log_target` | ❌                              | ✅（可选）                                 |

在 `y_true` 为 one-hot 编码时，真实分布 `y_true` 的熵为0，交叉熵和KL散度应该相等，两者累加无意义。

在以下代码中采取了和STICI不同的处理方法：保留MCE，删除KL，R2取10/log（而非负数），用GradNorm平衡MCE和R2损失。实际性能表现待评估。

In [3]:
from torch.autograd import grad

class GradNormLoss(nn.Module):
    """
    GradNorm: Gradient Normalization for Adaptive Loss Balancing
    参考原始论文 Chen et al. 2018 实现，适配 2 任务（CE + R²）
    """
    def __init__(self, num_tasks=2, alpha=1.5, lr_w=1e-3, eps=1e-8):
        super().__init__()
        self.num_tasks = num_tasks
        self.alpha   = alpha          # 恢复速度偏好，论文默认 1.5
        self.lr_w    = lr_w           # 权重学习率，比模型 lr 小 1~2 量级
        self.eps     = eps
        self.w       = nn.Parameter(torch.ones(num_tasks))   # 可训练权重
        self.register_buffer('L0', torch.zeros(num_tasks))   # 初始损失
        self.initialized = False

    def forward(self, losses: torch.Tensor):
        # losses: [ce, r2]  已经 detach-free
        if not self.initialized:
            self.L0 = losses.detach().clone()
            self.initialized = True

        self.L_t = losses
        weighted = self.w * losses          # w_i * L_i
        return weighted.sum()               # 返回给主优化器

    def gradnorm_step(self, shared_params, retain_graph=False):
        """
        在 model.loss_backward() 之后、optimizer.step() 之前调用一次
        shared_params:  ***共享部分*** 的参数（例如 encoder 最后一层）
        """
        if not self.initialized:
            return

        # 1. 清零 w 的 grad
        if self.w.grad is not None:
            self.w.grad.zero_()

        # 2. 计算每个任务对 shared 的梯度范数  G_i(t)
        G_t = []
        for i in range(self.num_tasks):
            g = grad(self.L_t[i], shared_params, retain_graph=True,
                     create_graph=True)[0]          # 返回 tuple
            G_t.append(torch.norm(g * self.w[i]) + self.eps)
        G_t = torch.stack(G_t)                      # [T]

        # 3. 相对逆训练速率  r_i(t)
        tilde_L_t = (self.L_t / self.L0).detach()
        r_t       = tilde_L_t / tilde_L_t.mean()

        # 4. 期望梯度范数
        bar_G_t = G_t.mean()

        # 5. GradNorm 损失：L_grad = sum|G_i(t) - bar_G_t * r_i(t)^α|
        l_grad = F.l1_loss(G_t, bar_G_t * (r_t ** self.alpha))

        # 6. 只更新 w
        self.w.grad = torch.autograd.grad(l_grad, self.w)[0]
        with torch.no_grad():
            new_w = self.w - self.lr_w * self.w.grad
            new_w = new_w * (self.num_tasks / new_w.sum())
            self.w.data = new_w  # ✅ 替换 copy_


class ImputationLoss(nn.Module):
    def __init__(self, use_r2_loss=True, group_size=4, eps=1e-8,
                 use_grad_norm=False, gn_alpha=1.5, gn_lr_w=1e-3):
        super().__init__()
        self.use_r2_loss = use_r2_loss
        self.group_size  = group_size
        self.eps         = eps
        self.use_gn      = use_grad_norm
        if self.use_gn:
            self.gn_loss = GradNormLoss(num_tasks=2, alpha=gn_alpha, lr_w=gn_lr_w)

    # ---------- 工具函数 ---------- #
    def _calc_r2(self, pred_alt_prob: torch.Tensor, gt_alt_af: torch.Tensor):
        mask = ((gt_alt_af == 0.0) | (gt_alt_af == 1.0))
        gt_alt_af = torch.where(mask, 0.5, gt_alt_af)
        denom = gt_alt_af * (1.0 - gt_alt_af)
        denom = torch.clamp(denom, min=0.01)
        r2 = ((pred_alt_prob - gt_alt_af) ** 2) / denom
        r2 = torch.where(mask, 0.0, r2)
        return r2

    # ---------- R2 loss（改为 10/log(r2)） ---------- #
    def _r2_loss(self, y_pred: torch.Tensor, y_true: torch.Tensor, mask_valid: torch.Tensor):
        B, V, C = y_pred.shape
        G = self.group_size
        num_full = B // G
        rem = B % G

        prob = F.softmax(y_pred, dim=-1)
        alt_prob = prob[..., 1] + 2.0 * prob[..., 2]

        r2_penalty = 0.0

        def one_group(sl):
            gt_sl   = y_true[sl]                     # (g_size, V)
            mask_sl = mask_valid[sl]                 # (g_size, V)
            alt_sl  = alt_prob[sl]                   # (g_size, V)

            gt_alt_cnt = (gt_sl * mask_sl).sum(dim=0)
            gt_alt_af  = gt_alt_cnt / (mask_sl.sum(dim=0) + self.eps)

            pred_alt_af = (alt_sl * mask_sl).sum(dim=0) / (mask_sl.sum(dim=0) + self.eps)

            r2 = self._calc_r2(pred_alt_af, gt_alt_af)          # (V,)
            return r2.sum() * (sl.stop - sl.start)      # 保持与原来相同的加权方式

        # 完整组
        for g in range(num_full):
            r2_penalty += one_group(slice(g * G, (g + 1) * G))

        # 剩余样本
        if rem:
            r2_penalty += one_group(slice(num_full * G, B))

        return 10.0 / torch.log(r2_penalty + self.eps)

    # ---------- 前向 ---------- #
    def forward(self, y_pred, y_true):
        mask_valid = (y_true != -1)
        y_true_m   = y_true.clone()
        y_true_m[~mask_valid] = 0

        # 1. MCE 改为 mean
        log_p = F.log_softmax(y_pred, dim=-1)
        ce = -log_p.gather(dim=-1, index=y_true_m.long().unsqueeze(-1)).squeeze(-1)
        ce = (ce * mask_valid).sum() / (mask_valid.sum() + self.eps)
        # 2. R²
        r2 = 0.
        if self.use_r2_loss:
            r2 = self._r2_loss(y_pred, y_true, mask_valid)

        # 3. GradNorm 或固定系数
        if self.use_gn:
            losses = torch.stack([ce, r2])
            gn_loss = self.gn_loss(losses)
            # print('ce:',ce,'r2:',r2, 'gn_loss:', gn_loss)
            return gn_loss
        else:
            return ce + r2



## 4. Training

单卡，早停，AdamW优化器

In [4]:
import torch
from pathlib import Path

from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

class GenotypeDataset(Dataset):
    def __init__(self, gts, coords, mask_ratio=0.0):
        self.gt_true = gts.long()          # 原始完整标签
        self.coords = coords.float()
        self.mask_ratio = mask_ratio

    def __len__(self):
        return self.gt_true.shape[0]

    def __getitem__(self, idx):
        gt_true = self.gt_true[idx]        # 完整标签
        coords = self.coords                # (L, 4)

        # 训练时额外随机遮掩
        gt_mask = gt_true.clone()
        if self.mask_ratio > 0:
            mask = torch.rand_like(gt_mask.float()) < self.mask_ratio
            gt_mask[mask] = -1             # 仅输入被遮掩

        # 返回：输入（含缺失）、原始标签、坐标
        return gt_mask, gt_true, coords 

def collate_fn(batch):
    """
    batch: List[(gt_mask, gt_true, coords)] 每个 coords 形状相同
    返回：gt_mask, gt_true, coords（二维，直接取第 0 个即可
    """
    gt_mask  = torch.stack([b[0] for b in batch], 0)
    gt_true  = torch.stack([b[1] for b in batch], 0)
    coords   = batch[0][2]          # 全局共享
    return gt_mask, gt_true, coords

def build_loader(pt_path, batch_size, shuffle, mask_ratio):
    data = torch.load(pt_path)
    dataset = GenotypeDataset(data['gts'], data['coords'], mask_ratio=mask_ratio)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=shuffle,
        collate_fn=collate_fn,
    )

def imputation_accuracy(logits, gts, mask):
    """仅在被 mask 位点计算 accuracy"""
    preds = torch.argmax(logits, dim=-1)  # (B, L)
    correct = (preds == gts) & mask
    return correct.sum().float() / mask.sum().float()

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss, total_acc, total_mask = 0.0, 0.0, 0

    shared_params = list(model.length_proj.parameters()) + list(model.out_conv.parameters())
    assert len(shared_params) > 0
    assert all(p.requires_grad for p in shared_params)

    pbar = tqdm(loader, leave=False)
    for gt_mask, gt_true, coords in pbar:
        gt_mask, gt_true, coords = gt_mask.to(device), gt_true.to(device), \
                                               coords.to(device)
        
        logits = model(gt_mask, coords)  # (B, L, n_cats)
        loss = criterion(logits, gt_true) 

        if criterion.use_gn:
            criterion.gn_loss.gradnorm_step(shared_params, retain_graph=False)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # accuracy：只算被 mask 的位点
        mask = gt_mask == -1
        acc = imputation_accuracy(logits, gt_true, mask)
        total_loss += loss.item()
        total_acc += acc.item() * mask.sum().item()
        total_mask += mask.sum().item()
        pbar.set_postfix(loss=f"{loss.item():.4f}", acc=f"{acc.item():.4f}")

    return total_loss / len(loader), total_acc / total_mask


@torch.no_grad()
def validate(model, loader, criterion, device):
    model.eval()
    total_loss, total_acc, total_mask = 0.0, 0.0, 0

    pbar = tqdm(loader, leave=False, desc='validate')
    for gt_mask, gt_true, coords in pbar:
        gt_mask, gt_true, coords = gt_mask.to(device), \
                                     gt_true.to(device), \
                                     coords.to(device)

        logits = model(gt_mask, coords)          # (B, L, n_cats)
        loss   = criterion(logits, gt_true)      # 计算与真值差异

        # 只统计被 mask 的位点
        mask = gt_mask == -1
        acc  = imputation_accuracy(logits, gt_true, mask)

        total_loss += loss.item()
        total_acc  += acc.item() * mask.sum().item()
        total_mask += mask.sum().item()

        pbar.set_postfix(loss=f"{loss.item():.4f}", acc=f"{acc.item():.4f}")

    return total_loss / len(loader), total_acc / (total_mask + 1e-8)


class EarlyStopper:
    def __init__(self, patience=10, min_delta=0.0, mode='min'):
        assert mode in {'min', 'max'}
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best = None
        self.best_state = None

    def __call__(self, metric, model):
        if self.best is None:
            self.best = metric
            self.best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            return False

        better = (metric < self.best - self.min_delta) if self.mode == 'min' else \
                 (metric > self.best + self.min_delta)

        if better:
            self.best = metric
            self.best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            self.counter = 0
        else:
            self.counter += 1

        return self.counter >= self.patience

In [None]:
import os
import torch
from pathlib import Path
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm import tqdm

cfg = load_config("/home/qmtang/mnt_qmtang/EvoFill/config/config.json")
device = torch.device(cfg.train.device)
torch.manual_seed(42)

# 数据
train_loader = build_loader(
    Path(cfg.data.path) / "train.pt",
    batch_size=cfg.train.batch_size,
    shuffle=True,
    mask_ratio=cfg.train.mask_ratio,
)
val_loader = build_loader(
    Path(cfg.data.path) / "val.pt",
    batch_size=cfg.train.batch_size,
    shuffle=False,
    mask_ratio=cfg.train.mask_ratio,
)

# 模型 & 优化器
model = EvoFill(**vars(cfg.model)).to(device)

criterion = ImputationLoss(use_r2_loss=True,
                        use_grad_norm=True,
                        gn_alpha=0.8,
                        gn_lr_w=cfg.train.lr/10).to(device) #权重学习率，比模型 lr 小 1~2 量级

optimizer = AdamW(model.parameters(), lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)

early_stopper = EarlyStopper(patience=cfg.train.patience,
                                min_delta=cfg.train.min_delta,
                                mode='min')

# 训练循环
for epoch in range(1, cfg.train.num_epochs + 1):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    print(f"Epoch {epoch:03d} | train loss {train_loss:.4f} acc {train_acc:.4f} | "
            f"val loss {val_loss:.4f} acc {val_acc:.4f}")
    if early_stopper(val_loss, model):
        print(f"Early stopping triggered at epoch {epoch}")
        break

# 保存最优模型
save_dir = Path(cfg.train.save)
save_dir.mkdir(parents=True, exist_ok=True)
torch.save(early_stopper.best_state, save_dir / "evofill_best.pt")
print(f"Best model saved to {save_dir / 'evofill_best.pt'} (epoch {epoch - early_stopper.counter})")

## 5. Imuptation

## 6. Evaulation