ver0: 多 chunk modules 独立权重

ver0.1: 加样本特征标签（演化坐标）

ver3: chunk-wise 稀疏激活

## Dependency

In [25]:
import os; os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 设置GPU
from cyvcf2 import VCF
import scipy.sparse as sp
import json
import shutil
from typing import Dict, List, Optional, Tuple, Union
from itertools import combinations

import numpy as np
import random
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.checkpoint import checkpoint

from tqdm import tqdm

from sklearn.model_selection import train_test_split

from mamba_ssm import Mamba2
from mamba_ssm.modules.mamba2_simple import Mamba2Simple as Mamba2Block # 原Mamba2Block

## Data

In [None]:
class GenotypeEncoder:
    def __init__(self,
                 save_dir: str,
                 vcf_path: str,
                 ref_extra: Optional[str] = None,
                 phased: bool = True,
                 gts012: bool = False):
        self.save_dir = save_dir
        self.vcf_path    = vcf_path
        self.ref_extra   = ref_extra
        self.phased      = phased if ref_extra is None else False # 是否把样本拆成单倍型
        self.gts012      = gts012
        
        # 其余成员先占位
        self.hap_map = {}
        self.n_samples   = 0
        self.n_variants  = 0
        self.sample_ids  = []   # 后面读 VCF 时填充
        self.variant_ids = []

        self.X_gt        = None   # 最终返回的张量
        self.X_extra     = None   # extra 信息
        self.seq_depth   = None

        # 1) 读 VCF
        self.X_gt = self.load_gt()
        # 2) 读 extra
        self.X_extra = self.load_extra() if self.ref_extra else None
        # 3) 保存 meta
        self.save_meta()

    def add_hap_map(self, key, val):
        if key in self.hap_map:
            if self.hap_map[key] != int(val):
                raise(f"[DATA] hap_map[{key}] inconsistent")
        else:
            self.hap_map[key] = int(val)

    def encode_gt(self, rec, n_samples, phase=False, gts012=True):
        """
        return:
            phase=False  -> (n_samples,)          剂量或基因型
            phase=True   -> (2*n_samples,)        单倍型

        encoding rule:
            gts012=True  -> 0/1/2/3  （3=missing）
            gts012=False -> 0/1/2/3/…/-1  （0=REF, 1+=ALT, -1=missing）
        """
        n = n_samples

        # ---------- 1. 单倍型模式 ----------
        if phase:
            out = np.empty(2 * n, dtype=np.int8)
            for i, gt in enumerate(rec.genotypes):
                a1, a2, _phased = gt
                # 缺失
                if a1 is None:
                    out[2*i]   = 3 if gts012 else -1
                    a1 = '.'
                else:
                    if gts012:                      # 压缩成 0/1/2
                        out[2*i] = 0 if a1 == 0 else (2 if a1 >= 2 else 1)
                    else:                           # 原值保留
                        out[2*i] = a1
                self.add_hap_map(str(a1), out[2*i])
                if a2 is None:
                    out[2*i+1] = 3 if gts012 else -1
                    a2 = '.'
                else:
                    if gts012:
                        out[2*i+1] = 0 if a2 == 0 else (2 if a2 >= 2 else 1)
                    else:
                        out[2*i+1] = a2
                self.add_hap_map(str(a2),out[2*i+1])
            return out

        # ---------- 2. 剂量模式 ----------
        else:
            out = np.empty(n, dtype=np.int8)
            for i, gt in enumerate(rec.genotypes):
                a1, a2, _phased = gt
                phase = '|' if _phased else '/'
                # 缺失
                if a1 is None or a2 is None:
                    out[i] = 3 if gts012 else -1
                else:
                    if gts012:
                        # 0/1/2 剂量
                        out[i] = (1 if a1 > 0 else 0) + (1 if a2 > 0 else 0)
                    else:
                        # 多等位剂量：把 ALT 编号直接相加
                        out[i] = (0 if a1 == 0 else a1) + (0 if a2 == 0 else a2)
                a1 ='.' if a1 is None else str(a1)
                a2 ='.' if a2 is None else str(a2)
                a1, a2 = sorted([a1,a2])
                self.add_hap_map(a1+phase+a2, out[i])
            return out

    def load_extra(self) -> Optional[np.ndarray]:
        try:
            df = pd.read_csv(self.ref_extra, sep='\t', index_col=0)
            df = df.loc[self.sample_ids]          # 保证与 VCF 样本顺序一致
            print(f"[DATA] Extra dims: {df.shape}")
            return df.values.astype(np.float32)
        except Exception as e:
            print(f"[DATA] Extra features skipped: {e}")
            return None

    def load_gt(self):
        interval = 10000

        cols, data, indptr = [], [], [0]

        vcf = VCF(self.vcf_path, gts012 = self.gts012)
        self.sample_ids = vcf.samples
        self.n_samples = len(self.sample_ids)
        self.n_variants = 0

        for rec in vcf:
            vec = self.encode_gt(rec, self.n_samples, phase=self.phased, gts012=self.gts012)
            nz_idx = np.flatnonzero(vec)
            cols.extend(nz_idx)
            data.extend(vec[nz_idx])
            indptr.append(indptr[-1] + len(nz_idx))

            self.n_variants += 1
            self.variant_ids.append(f"{rec.CHROM}:{rec.POS}_{rec.REF}/{','.join(rec.ALT)}")
            if self.n_variants % interval == 0:
                print(f'\r[DATA] 已编码 {self.n_variants:,} 个位点', end='', flush=True)

        print(f'\r[DATA] 总计 {self.n_variants:,} 个位点  ', flush=True)
        vcf.close()

        # 根据 phase_mode 决定行数
        n_rows = 2 * self.n_samples if self.phased else self.n_samples
        M = sp.csc_matrix((data, cols, indptr),
                        shape=(n_rows,self.n_variants),
                        dtype=np.int8)

        print(f'[DATA] 位点矩阵 = {M.shape}，稀疏度 = {M.nnz / (M.shape[0] * M.shape[1]):.2%}')
        if self.gts012:
            self.seq_depth = M.data.max()+1
        else:
            self.seq_depth = M.data.max() + 2 
            M.data[M.data == -1] = M.data.max() + 1
            self.hap_map = {k: self.seq_depth-1 if '.' in str(k) else v for k, v in self.hap_map.items()}
        
        print("[DATA] Hap Map: ",self.hap_map)
        print(f'[DATA] gt alleles = [0 - {M.data.max()}], seq_depth = {self.seq_depth} ({self.seq_depth-1} 代表缺失)')

        os.makedirs(self.save_dir, exist_ok=True)          # 1. 不存在就创建

        # 2. 保存稀疏矩阵
        sp.save_npz(os.path.join(self.save_dir, "gt_matrix.npz"), M)

        # 3. 保存样本列表（顺序与矩阵行对应）
        with open(os.path.join(self.save_dir, "gt_samples.txt"), "w") as f:
            if self.phased:                      # 单倍型模式：写成 sample_A / sample_B
                for s in self.sample_ids:
                    f.write(f"{s}_A\n{s}_B\n")
            else:                               # 剂量模式
                for s in self.sample_ids:
                    f.write(f"{s}\n")

        # 4. 保存变异位点 ID（chr:pos/ref/alt）
        with open(os.path.join(self.save_dir, "gt_variants.txt"), "w") as f:
            for vid in self.variant_ids:
                f.write(vid + "\n")

        print(f"[DATA] 结果已写入 {self.save_dir}")
        return M

    def save_meta(self):
        def _make_json_safe(obj):
            """递归地把 numpy 数组、tuple、set、bytes 转成 list/str"""
            if isinstance(obj, dict):
                return {k: _make_json_safe(v) for k, v in obj.items()}
            if isinstance(obj, (list, tuple, set)):
                return [_make_json_safe(i) for i in obj]
            if isinstance(obj, np.ndarray):
                return _make_json_safe(obj.tolist())
            if isinstance(obj, (np.integer, np.floating)):
                return obj.item()
            if isinstance(obj, bytes):
                return obj.decode(errors='ignore')
            return obj
        meta = {
            "vcf_path"   : str(self.vcf_path),
            "ref_extra"  : str(self.ref_extra),
            "phased"     : str(self.phased),
            "gts012"     : str(self.gts012),
            "n_samples"  : str(self.n_samples),
            "n_variants" : str(self.n_variants),
            "seq_depth"  : str(self.seq_depth),
            "hap_map"    : _make_json_safe(self.hap_map),
        }
        with open(os.path.join(self.save_dir, "gt_enc_meta.json"), "w") as f:
            json.dump(meta, f, indent=2)

        # 如果 X_extra 不是 None，也可以落盘
        if self.X_extra is not None:
            np.save(os.path.join(self.save_dir, "gt_extra.npy"), self.X_extra)

    @classmethod
    def loadfromdisk(cls, work_dir: str):
        """
        反向构造 GenotypeEncoder，要求 work_dir 里必须有：
            gt_matrix.npz      -> X_gt  (scipy.sparse.csc_matrix)
            gt_samples.txt     -> sample_ids
            gt_variants.txt    -> variant_ids
            gt_enc_meta.json   -> 其余标量 / 布尔 / 路径信息
        """
        # 1. 读 meta（构造 __init__ 需要的几个“外部”参数）
        meta_path = os.path.join(work_dir, "gt_enc_meta.json")
        if not os.path.exists(meta_path):
            raise FileNotFoundError(f"{meta_path} 不存在，无法反序列化")
        with open(meta_path) as f:
            meta = json.load(f)

        # 2. 先“假”构造一个对象（不触发 VCF 扫描）
        #    把关键字段先填进去，避免 __init__ 里再去读 VCF
        obj = cls.__new__(cls)  # 不调用 __init__
        obj.vcf_path    = meta["vcf_path"]
        obj.ref_extra   = meta["ref_extra"]
        obj.phased      = bool(meta["phased"])
        obj.gts012      = bool(meta["gts012"])
        obj.n_samples   = int(meta["n_samples"])
        obj.n_variants  = int(meta["n_variants"])
        obj.seq_depth   = int(meta["seq_depth"])
        obj.hap_map     = meta["hap_map"]

        # 3. 读样本 & 位点 ID 列表
        obj.sample_ids = [
            l.rstrip("\n") for l in open(os.path.join(work_dir, "gt_samples.txt"))
        ]
        obj.variant_ids = [
            l.rstrip("\n") for l in open(os.path.join(work_dir, "gt_variants.txt"))
        ]

        # 4. 读稀疏矩阵
        obj.X_gt = sp.load_npz(os.path.join(work_dir, "gt_matrix.npz"))

        # 5. 读 extra（如果有）
        extra_path = os.path.join(work_dir, "gt_extra.npy")
        if os.path.exists(extra_path):
            obj.X_extra = np.load(extra_path)
        else:
            obj.X_extra = None

        return obj

In [None]:
work_dir = '/mnt/qmtang/EvoFill/data/251027_ver3_chr22_trim'
gt_enc = GenotypeEncoder(
    save_dir=work_dir,
    vcf_path='/home/qmtang/GitHub/STICI-HPC/data/training_sets/ALL.chr22.training.samples.100k.any.type.0.01.maf.variants.vcf.gz',
    ref_extra='/mnt/qmtang/EvoFill/data/251020_ver01_chr22/pop_wasserstein.tsv',
    phased= True,
    gts012= False)

print(f"[DATA] {gt_enc.n_samples:,} Samples")
print(f"[DATA] {gt_enc.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc.seq_depth} seq_depth")

[DATA] 总计 99,314 个位点  
[DATA] 位点矩阵 = (2404, 99314)，稀疏度 = 28.10%
[DATA] {'0|0': 0, '0|1': 1, '1|1': 2}
[DATA] gt alleles = [0 - 2], seq_depth = 4 (3 代表缺失)
[DATA] 结果已写入 /mnt/qmtang/EvoFill/data/251027_ver3_chr22_trim
[DATA] Extra dims: (2404, 26)
[DATA] 2,404 Samples
[DATA] 99,314 Variants Sites
[DATA] 4 seq_depth
[DATA] Hap Map: {'0|0': 0, '0|1': 1, '1|1': 2}


In [8]:
work_dir = '/mnt/qmtang/EvoFill/data/251027_ver3_chr22'
gt_enc = GenotypeEncoder(
    save_dir=work_dir,
    vcf_path='/mnt/NAS/Omics/DNA/1kGP/vcf/ALL.chr22.phase3_shapeit2_mvncall_integrated_v5b.20130502.genotypes.vcf.gz',
    ref_extra=None,
    phased= True,
    gts012= False)

print(f"[DATA] {gt_enc.n_samples:,} Samples")
print(f"[DATA] {gt_enc.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc.seq_depth} seq_depth")

[DATA] 总计 1,103,547 个位点  
[DATA] 位点矩阵 = (5008, 1103547)，稀疏度 = 3.71%
[DATA] gt alleles = [0 - 8], seq_depth = 10 (9 代表缺失)
[DATA] 结果已写入 /mnt/qmtang/EvoFill/data/251023_chr22
[DATA] 2,504 Samples
[DATA] 1,103,547 Variants Sites
[DATA] 10 seq_depth


In [18]:
class GenomicDataset(Dataset):
    """Dataset class for genomic data with masking for training"""
    def __init__(self, x_gts_sparse, x_extra=None, seq_depth=4,
                 mask=True, masking_rates=(0.5, 0.99), indices=None):
        """
        x_gts_sparse: scipy.sparse.csr_matrix or similar
        x_extra: numpy array or None
        indices: 可选，指定要使用的样本索引（如 train/valid 索引）
        """
        self.gts_sparse = x_gts_sparse
        self.x_extra = x_extra
        self.seq_depth = seq_depth
        self.mask = mask
        self.masking_rates = masking_rates
        self.indices = indices if indices is not None else np.arange(x_gts_sparse.shape[0])

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        x = self.gts_sparse[real_idx].toarray().squeeze().astype(np.int8)
        y = x.copy()

        if self.mask:
            seq_len = len(x)
            masking_rate = np.random.uniform(*self.masking_rates)
            mask_size = int(seq_len * masking_rate)
            mask_indices = np.random.choice(seq_len, mask_size, replace=False)
            x[mask_indices] = self.seq_depth - 1  # missing token

        x_onehot = torch.FloatTensor(np.eye(self.seq_depth)[x])
        y_onehot = torch.FloatTensor(np.eye(self.seq_depth - 1)[y])

        if self.x_extra is not None:
            x_extra = torch.FloatTensor(self.x_extra[real_idx])
        else:
            x_extra = torch.empty(0)

        return x_onehot, x_extra, y_onehot

class ImputationDataset(Dataset):
    """Dataset for imputation (no masking needed)"""

    def __init__(self, data, seq_depth):
        self.data = data
        self.seq_depth = seq_depth

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        # Convert to one-hot without masking
        x_onehot = np.eye(self.seq_depth)[x]
        return torch.FloatTensor(x_onehot)

## Model

In [26]:
class GenoEmbedding(nn.Module):
    """Genomic embedding layer with positional encoding"""

    def __init__(self, n_alleles, n_snps, d_model):
        super().__init__()
        self.d_model = d_model
        self.n_alleles = n_alleles
        self.n_snps = n_snps

        # Allele embedding
        self.allele_embedding = nn.Parameter(torch.randn(n_alleles, d_model))

        # Positional embedding
        self.position_embedding = nn.Embedding(n_snps, d_model)

        # Initialize parameters
        nn.init.xavier_uniform_(self.allele_embedding)

    def forward(self, x):
        # x shape: (batch, seq_len, n_alleles) - one-hot encoded
        _, seq_len, _ = x.shape

        # Allele embedding
        embedded = torch.einsum('bsn,nd->bsd', x, self.allele_embedding)

        # Positional embedding
        positions = torch.arange(seq_len, device=x.device)
        pos_emb = self.position_embedding(positions).unsqueeze(0)

        return embedded + pos_emb

class BiMambaBlock(nn.Module):
    """Bidirectional Mamba block for genomic sequence processing"""

    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.d_model = d_model

        # Forward and backward Mamba blocks
        self.mamba_forward = Mamba2(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )

        self.mamba_backward = Mamba2(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )

        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model * 2, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.GELU()
        )

        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        # x shape: (batch, seq_len, d_model)
        residual = x

        # Bidirectional processing
        x_norm = self.norm1(x)

        # Forward direction
        forward_out = self.mamba_forward(x_norm)

        # Backward direction (flip sequence)
        x_backward = torch.flip(x_norm, dims=[1])
        backward_out = self.mamba_backward(x_backward)
        backward_out = torch.flip(backward_out, dims=[1])

        # Concatenate bidirectional outputs
        bi_out = torch.cat([forward_out, backward_out], dim=-1)

        # FFN
        ffn_out = self.ffn(bi_out)
        ffn_out = self.dropout(ffn_out)

        # Residual connection
        out = self.norm2(residual + ffn_out)

        return out

class ConvBlock(nn.Module):
    """Convolutional block for local pattern extraction"""

    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

        self.conv1 = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=5, padding=2)
        self.conv3 = nn.Conv1d(d_model, d_model, kernel_size=7, padding=3)

        self.conv_large1 = nn.Conv1d(d_model, d_model, kernel_size=7, padding=3)
        self.conv_large2 = nn.Conv1d(d_model, d_model, kernel_size=15, padding=7)

        self.conv_final = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.conv_reduce = nn.Conv1d(d_model, d_model, kernel_size=1)

        self.bn1 = nn.BatchNorm1d(d_model)
        self.bn2 = nn.BatchNorm1d(d_model)

        self.gelu = nn.GELU()

    def forward(self, x):
        # x shape: (batch, seq_len, d_model)
        x = x.transpose(1, 2)  # (batch, d_model, seq_len)

        xa = self.gelu(self.conv1(x))

        xb = self.gelu(self.conv2(xa))
        xb = self.gelu(self.conv3(xb))

        xc = self.gelu(self.conv_large1(xa))
        xc = self.gelu(self.conv_large2(xc))

        xa = xb + xc
        xa = self.gelu(self.conv_final(xa))
        xa = self.bn1(xa)
        xa = self.gelu(self.conv_reduce(xa))
        xa = self.bn2(xa)
        xa = self.gelu(xa)

        return xa.transpose(1, 2)  # (batch, seq_len, d_model)

class ExtraEmbedding(nn.Module):
    """
    输入:  (B, L)        L == extra_dim
    输出: (B, L, d_model)
    """
    def __init__(
        self,
        d_model: int,
        d_state: int = 64,
        d_conv: int  = 4,
        expand: int  = 2,
        headdim: int = 128,
        ngroups: int = 1,
        dropout: float = 0.1,
        **mamba_kwargs,
    ):
        super().__init__()
        self.d_model   = d_model

        # 1. 把 (B, L) 的 1-d 标量升到 d_model
        self.in_proj = nn.Linear(1, d_model, bias=False)

        # 2. 官方 Mamba2Simple：把 L 当序列长度，建模 L↔L
        self.mamba = Mamba2Block(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
            headdim=headdim,
            ngroups=ngroups,
            **mamba_kwargs
        )

        # 3. Norm
        self.norm = nn.LayerNorm(d_model)


    def forward(self, x: torch.Tensor):
        """
        x: (B, L)  连续值或离散索引
        """
        # (B, L) -> (B, L, 1) -> (B, L, d_model)
        h = self.in_proj(x.unsqueeze(-1).float())   # 1-d 投影

        h = self.norm(h)

        # Mamba2Simple 要求输入 (B, L, d_model) 即可
        out = self.mamba(h)                           # SSD 全局建模
        return out

class StackMambaBlock(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=64,
        d_conv=4,
        expand=2,
        headdim=128,
        ngroups=1,
        chunk_size=256,
        dropout=0.0,
        d_embed_dropout=0.0,
        device=None,
        dtype=None,
    ):
        super().__init__()
        self.d_model = d_model

        # 距离矩阵嵌入
        self.extra_embed = ExtraEmbedding(d_model=d_model, dropout=d_embed_dropout)

        # 原归一化
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        # SSD 核心
        self.ssd = Mamba2Block(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
            headdim=headdim,
            ngroups=ngroups,
            chunk_size=chunk_size,
            use_mem_eff_path=True,
            device=device,
            dtype=dtype,
        )

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, d_model),
        )

    def forward(self, local_repr, global_repr, x_extra=None,
                start_offset=0, end_offset=0):
        """
        local_repr: (B, L, D)
        global_repr: (B, G, D)
        x_extra: 可选，(B,E) 
        """
        local_norm  = self.norm1(local_repr)
        global_norm = self.norm2(global_repr)

        # 1. 构造输入序列
        tokens = []
        if x_extra is not None:
            extra_token = self.extra_embed(x_extra)        # (B,E,D)
            tokens.append(extra_token)
        tokens.append(global_norm)
        tokens.append(local_norm)
        x = torch.cat(tokens, dim=1)               # [B, (E)+G+L, D]

        # 2. SSD 扫描
        x = self.ssd(x)                            # [B, (E)+G+L, D]

        # 3. 只取 local 部分
        local_len = local_norm.shape[1]
        x = x[:, -local_len:, :]                   # [B, L, D]

        # 4. pad 回原始长度
        if start_offset or end_offset:
            x = F.pad(x, (0, 0, start_offset, end_offset))

        # 5. 残差 + FFN
        x = x + local_norm
        x = self.norm3(x)
        x = self.ffn(x) + x
        return x

class ChunkModule(nn.Module):
    """Single chunk processing module with BiMamba"""

    def __init__(self, d_model, dropout_rate=0.2):
        super().__init__()
        self.d_model = d_model

        # BiMamba block
        self.bimamba_block = BiMambaBlock(d_model)

        # Convolutional blocks
        self.conv_block1 = ConvBlock(d_model)
        self.conv_block2 = ConvBlock(d_model)
        self.conv_block3 = ConvBlock(d_model)
        self.conv_block4 = ConvBlock(d_model)

        # Cross attention
        # self.cross_attention = CrossAttentionLayer(d_model, n_heads)
        self.cross_attention = StackMambaBlock(
            d_model=d_model,
            d_state=64,
            d_conv=4,
            expand=2,
            headdim=128,
            ngroups=1,
            chunk_size=256,
        )

        # Additional layers
        self.dense = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout_rate)
        self.gelu = nn.GELU()

    def forward(self, x, x_extra=None):
        # BiMamba processing
        xa0 = self.bimamba_block(x)

        # First conv block
        xa = self.conv_block1(xa0)
        xa_skip = self.conv_block2(xa)

        # Dense layer
        xa = self.gelu(self.dense(xa))
        xa = self.conv_block3(xa)

        # Cross attention
        xa = self.cross_attention(xa, xa0, x_extra)
        xa = self.dropout(xa)

        # Final conv block
        xa = self.conv_block4(xa)

        # Concatenate with skip connection
        xa = torch.cat([xa_skip, xa], dim=-1)

        return xa

class UltraLongRangeMamba(nn.Module):
    """
    线性复杂度 O(L) 全局建模，只激活 mask=1 的位点
    """
    def __init__(self, d_model, d_state=64, d_conv=4, expand=2,
                 n_layers=2, dropout=0.1):
        super().__init__()
        self.d_inner = int(d_model * expand)
        # 可选：多层 Mamba2
        self.layers = nn.ModuleList([
            BiMambaBlock(d_model, d_state=d_state, d_conv=d_conv, expand=expand)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, idx):
        """
        x: (B, L_all, d_model)  全局张量，其余位置为 nan
        idx: (M,)  当前 mask=1 的坐标
        返回: (B, M, d_model//2)
        """
        # 只取有效 token
        x_in = x[:, idx] 
        x_in = self.dropout(x_in)                                   # (B, M, D)
        for layer in self.layers:
            x_in = layer(x_in)                              # BiMamba2
        return self.norm(x_in)

class GlobalOut(nn.Module):
    def __init__(self, d_model, n_alleles, total_sites, chunk_size,
                 kernel=5, pad=2, stripe=4096,
                 d_state=64, d_conv=4, expand=2, n_mamba_layers=2):
        super().__init__()
        self.k, self.p = kernel, pad
        self.stripe = stripe
        self.total_sites = total_sites
        self.n_alleles = n_alleles

        # -------------- 1) 局部卷积权重 --------------
        # Conv1: 2*d_model -> d_model//2
        self.w1 = nn.Parameter(torch.empty(d_model // 2, 2 * d_model, kernel))
        self.b1 = nn.Parameter(torch.zeros(d_model // 2))
        # Conv2: d_model//2 -> n_alleles-1
        self.w2 = nn.Parameter(torch.empty(n_alleles - 1, d_model // 2, kernel))
        self.b2 = nn.Parameter(torch.zeros(n_alleles - 1))
        nn.init.kaiming_normal_(self.w1)
        nn.init.kaiming_normal_(self.w2)

        # -------------- 2) ulr 中间件（Mamba2） --------------
        self.ulr_mamba = UltraLongRangeMamba(
            d_model=d_model//2,          # 与 Conv1 输出同维
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
            n_layers=n_mamba_layers,
        )
        self.gate = nn.Linear(d_model, 2)   # [local; global] -> 2
        self.norm = nn.LayerNorm(d_model // 2)

        # -------------- 3) 开关 --------------
        self.skip_ulr = True
        self.set_ulr_enabled(False)

    # ============ 两阶段切换 ============
    def set_ulr_enabled(self, enabled: bool):
        self.skip_ulr = not enabled
        for p in self.ulr_mamba.parameters():
            p.requires_grad = enabled
        for p in self.gate.parameters():
            p.requires_grad = enabled

    # ============ 前向：ulr 是可插拔中间件 ============
    def forward(self, x, mask):
        """
        x:   (B, L,  2*d_model)
        mask:(L,) 0/1
        return: (B, L, n_alleles-1)
        """
        x = x.transpose(1, 2)  # (B, 2*d_model, L)
        device = x.device
        idx = torch.where(mask)[0]                # 有效坐标 M
        n = idx.shape[0]
        out = torch.full((x.shape[0], self.w2.shape[0], x.shape[2]), -float('inf'),
                         device=device, dtype=x.dtype)

        # ---- 1) 统一走 Conv1：2*d_model -> d_model//2 ----
        h_local = []                              # (B, d_model//2, M)
        for i in range(0, n, self.stripe):
            sl = slice(i, i + self.stripe)
            idx_i = idx[sl]
            x_i = x[..., idx_i].contiguous()      # (B, 2*d_model, stripe)

            y1 = checkpoint(self._band_conv1, x_i, self.w1, self.b1, use_reentrant=False)
            h_local.append(y1)
        h_local = torch.cat(h_local, dim=2).transpose(1, 2)  # (B, M, d_model//2)
        # ---- 2) ulr 中间件（可选） ----
        if self.skip_ulr:
            # 第一阶段：不做任何全局事，h_local 保持原样
            fused = h_local
        else:
            # 第二阶段：Mamba2 全局建模并融合
            h_global = self.ulr_mamba(h_local, idx)           # (B, M, d_model//2)
            gate_in = torch.cat([h_local, h_global], dim=-1)  # (B, M, d_model)
            w = torch.softmax(self.gate(gate_in), dim=-1)     # (B, M, 2)
            fused = w[..., 0:1] * h_local + w[..., 1:2] * h_global
            fused = self.norm(fused)                          # (B, M, d_model//2)

        # ---- 3) 统一走 Conv2：d_model//2 -> n_alleles-1 ----
        y_final = F.conv1d(fused.transpose(1, 2), self.w2, self.b2, padding=self.p)
        out[..., idx] = y_final
        return F.softmax(out.transpose(1, 2), dim=-1)

    # ---------- 辅助 ----------
    def _band_conv1(self, x, w, b):
        return F.gelu(F.conv1d(x, w, b, padding=self.p))

class EvoFill(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_alleles: int,
        total_sites: int,
        chunk_size: int = 8192,
        chunk_overlap: int = 64,
        dropout_rate: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_alleles = n_alleles
        self.total_sites = total_sites
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

        # 1. chunk 边界
        stride = chunk_size - chunk_overlap
        starts = [i * stride for i in range((total_sites - 1) // stride + 1)]
        ends = [min(s + chunk_size, total_sites) for s in starts]
        self.register_buffer("starts", torch.tensor(starts, dtype=torch.long))
        self.register_buffer("ends", torch.tensor(ends, dtype=torch.long))
        self.n_chunks = len(starts)

        # 2. 每 chunk 一份嵌入 & 处理模块（常驻 GPU，但训练时只激活一个）
        self.chunk_embeds = nn.ModuleList(
            GenoEmbedding(n_alleles, e - s, d_model) for s, e in zip(starts, ends)
        )
        self.chunk_modules = nn.ModuleList(
            ChunkModule(d_model, dropout_rate) for s, e in zip(starts, ends)
        )

        # 3. 全局输出层
        self.global_out = GlobalOut(d_model, n_alleles, total_sites, chunk_size)

        # 4. chunk 掩码表  (n_chunks, L)
        masks = torch.stack(
            [torch.arange(total_sites).ge(s) & torch.arange(total_sites).lt(e)
             for s, e in zip(starts, ends)]
        ).float()
        self.register_buffer("chunk_masks", masks)

    def forward(self,
            x: torch.Tensor,                 # (B, L, n_alleles) one-hot
            chunk_id: Union[int, List[int]],
            x_extra: Optional[torch.Tensor] = None
            ):

        batch_size = x.shape[0]
        device = x.device
        if x_extra is not None and x_extra.shape[0] != batch_size:
            x_extra = None

        # 统一成 list
        if isinstance(chunk_id, int):
            mask = self.chunk_masks[chunk_id].bool()          # 单 chunk
            chunk_id = [chunk_id]
        else:
            mask = self.chunk_masks[chunk_id].sum(dim=0).bool()  # 多 chunk 并集

        z_acc   = torch.zeros(batch_size, self.total_sites, 2 * self.d_model, device=device)
        cnt_acc = torch.zeros(self.total_sites, device=device)

        # 1. 依次处理每个cid
        for cid in chunk_id:
            s, e = self.starts[cid].item(), self.ends[cid].item()
            x_chunk = x[:, s:e]
            z = self.chunk_embeds[cid](x_chunk)                    # (B, len, d_model)
            z = self.chunk_modules[cid](z, x_extra)                # (B, len, 2*d_model)
            z_acc[:, s:e] += z
            cnt_acc[s:e]  += 1

        # 2. 重叠平均
        cnt_acc = cnt_acc.clamp(min=1)
        z_full  = z_acc / cnt_acc.unsqueeze(0).unsqueeze(-1)     # (B, L, 2*d_model)

        # 3. 全局输出
        out  = self.global_out(z_full, mask)                     # (B, L, n_alleles-1)

        # 4. 返回并集区域
        # return out[:, torch.where(mask)[0]]
        return out, torch.where(mask)[0]



假数据测试

In [None]:
B, L, A = 8, 12345, 3
d_model = 64
chunk_size, overlap = 4096, 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'

x_train = torch.zeros(B, L, A, device=device)
allele = torch.randint(0, A, (B, L), device=device)
x_train.scatter_(2, allele.unsqueeze(-1), 1)
x_extra = torch.randn(B, 10, device=device)
y_train = torch.randn(B, L, A-1, device=device)

print(f"x_train: {x_train.shape}")
print(f"x_extra: {x_extra.shape}")
print(f"y_train: {y_train.shape}")
print("")
# ---------- 模型 &损失 ----------
model = EvoFill(d_model, A, L, chunk_size, overlap).to(device)
print(f"model chunks: {model.n_chunks}")

print("单 chunk 测试")
cid = 0
model.global_out.set_ulr_enabled(False)
pred, mask_idx = model(x_train, cid, x_extra)
print(pred.shape)

print("多 chunk 测试")
cids= [0,2]
model.global_out.set_ulr_enabled(True)
pred, mask_idx = model(x_train, cid, x_extra)
print(pred.shape)


x_train: torch.Size([8, 12345, 3])
x_extra: torch.Size([8, 10])
y_train: torch.Size([8, 12345, 2])

model chunks: 4
torch.Size([8, 12345, 2])
torch.Size([8, 12345, 2])


## Loss

In [20]:
class ImputationLoss(nn.Module):
    """Custom loss function for genomic imputation"""

    def __init__(self, use_r2=True):
        super().__init__()
        self.use_r2_loss = use_r2
        self.ce_loss = nn.CrossEntropyLoss(reduction='sum')
        self.kl_loss = nn.KLDivLoss(reduction='sum')

    def calculate_minimac_r2(self, pred_alt_allele_probs, gt_alt_af):
        """Calculate Minimac-style RÂ² metric"""
        mask = torch.logical_or(torch.eq(gt_alt_af, 0.0), torch.eq(gt_alt_af, 1.0))
        gt_alt_af = torch.where(mask, 0.5, gt_alt_af)
        denom = gt_alt_af * (1.0 - gt_alt_af)
        denom = torch.where(denom < 0.01, 0.01, denom)
        r2 = torch.mean(torch.square(pred_alt_allele_probs - gt_alt_af), dim=0) / denom
        r2 = torch.where(mask, torch.zeros_like(r2), r2)
        return r2

    def forward(self, y_pred, y_true):
        y_true = y_true.float()

        # Convert to proper format for losses
        y_true_ce = torch.argmax(y_true, dim=-1)  # For CrossEntropy
        y_pred_log = torch.log(y_pred + 1e-8)  # For KL divergence

        # Basic losses
        ce_loss = self.ce_loss(y_pred.view(-1, y_pred.size(-1)), y_true_ce.view(-1))
        kl_loss = self.kl_loss(y_pred_log.view(-1, y_pred.size(-1)),
                               y_true.view(-1, y_true.size(-1)))

        total_loss = ce_loss + kl_loss

        if self.use_r2_loss:
            batch_size = y_true.size(0)
            group_size = 4
            num_full_groups = batch_size // group_size

            if num_full_groups > 0:
                y_true_grouped = y_true[:num_full_groups * group_size].view(
                    num_full_groups, group_size, *y_true.shape[1:])
                y_pred_grouped = y_pred[:num_full_groups * group_size].view(
                    num_full_groups, group_size, *y_pred.shape[1:])

                r2_loss = 0.0
                for i in range(num_full_groups):
                    gt_alt_af = torch.count_nonzero(
                        torch.argmax(y_true_grouped[i], dim=-1), dim=0
                    ).float() / group_size

                    pred_alt_allele_probs = torch.sum(y_pred_grouped[i][:, :, 1:], dim=-1)
                    r2_loss += -torch.sum(self.calculate_minimac_r2(
                        pred_alt_allele_probs, gt_alt_af)) * group_size

                total_loss += r2_loss

        return total_loss, None

## Train

工具函数

In [None]:
def set_seed(seed=42):
    random.seed(seed)                # Python 内置 random 模块
    np.random.seed(seed)             # NumPy
    torch.manual_seed(seed)          # PyTorch 的 CPU 和 CUDA 的通用随机种子
    torch.cuda.manual_seed(seed)     # 当前 GPU
    torch.cuda.manual_seed_all(seed) # 所有 GPU（多卡训练时）
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def create_directories(save_dir, models_dir="models", outputs="out") -> None:
    """Create necessary directories"""
    for dd in [save_dir, f"{save_dir}/{models_dir}", f"{save_dir}/{outputs}"]:
        if not os.path.exists(dd):
            os.makedirs(dd)

def clear_dir(path) -> None:
    """Clear directory if it exists"""
    if os.path.exists(path):
        shutil.rmtree(path)

def precompute_maf(gts_np, mask_int=-1):
    """
    gts_np: (N, L)  int64
    return:
        maf: (L,) float32
        bin_cnt: list[int] 长度 6，对应 6 个 bin 的位点数量
    """
    L = gts_np.shape[1]
    maf = np.zeros(L, dtype=np.float32)
    bin_cnt = [0] * 6

    for l in range(L):
        alleles = gts_np[:, l]
        alleles = alleles[alleles != mask_int]   # 去掉缺失
        if alleles.size == 0:
            maf[l] = 0.0
            continue

        uniq, cnt = np.unique(alleles, return_counts=True)
        total = cnt.sum()
        freq = cnt / total
        freq[::-1].sort()
        maf_val = freq[1] if len(freq) > 1 else 0.0
        maf[l] = maf_val

        # 统计 bin
        for i, (lo, hi) in enumerate(MAF_BINS):
            if lo <= maf_val < hi:
                bin_cnt[i] += 1
                break

    return maf, bin_cnt

def build_geno3_map_from_hapmap(hap_map: dict) -> np.ndarray:
    sorted_items = sorted(hap_map.items(), key=lambda kv: kv[1])
    three_class = []
    for gt, idx in sorted_items:
        if gt in ('.|.', './.'):
            continue
        sep = '|' if '|' in gt else ('/' if '/' in gt else None)
        a, b = (gt.split(sep) if sep else (gt, gt))
        try:
            ai, bi = int(a), int(b)
        except Exception:
            three_class.append(1); continue
        if ai == bi == 0:
            three_class.append(0)
        elif ai != bi:
            three_class.append(1)
        else:
            three_class.append(2)
    return np.array(three_class, dtype=np.int64)

# ---------- 2. 线程安全缓存 ----------
MAF_BINS = [(0.00, 0.05), (0.05, 0.10), (0.10, 0.20),
            (0.20, 0.30), (0.30, 0.40), (0.40, 0.50)]
_GENO3_CACHE: Dict[int, torch.Tensor] = {}
_GENO3_LOCK = torch.multiprocessing.Lock()

def get_geno3_map_tensor(C_orig: int, hap_map, device: torch.device) -> torch.Tensor:
    key = int(C_orig)
    with _GENO3_LOCK:
        t = _GENO3_CACHE.get(key)
        if t is None:
            arr = build_geno3_map_from_hapmap(hap_map)  # 假设 gt_enc 已全局可见
            if arr.shape[0] != C_orig:
                raise RuntimeError(f"三分类映射长度{arr.shape[0]}与类别数{C_orig}不符")
            t = torch.from_numpy(arr)
            _GENO3_CACHE[key] = t
    return t.to(device)

# ---------- 3. 三分类聚合 ----------
def aggregate_three_classes(prob: torch.Tensor, y_true: torch.Tensor, hap_map) -> Tuple[torch.Tensor, torch.Tensor]:
    N, L, C = prob.shape
    device = prob.device
    gmap = get_geno3_map_tensor(C,hap_map, device)
    W = torch.zeros(C, 3, device=device)
    W[torch.arange(C, device=device), gmap.long()] = 1.0
    prob3 = torch.einsum('nlc,ck->nlk', prob, W)
    y3    = torch.einsum('nlc,ck->nlk', y_true, W)
    prob3 = prob3 / prob3.sum(-1, keepdim=True).clamp(min=1e-8)
    return prob3, y3

# ---------- 4. 向量化计算 3 个指标 ----------
def _compute_site_metrics(prob3: torch.Tensor,
                          y3: torch.Tensor,
                          mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    一次性返回 (INFO, MaCH-Rsq, IQS) 三个 (L,) 向量
    prob3/y3: (N,L,3)  mask: (N,L)
    """
    # dosage / W / p_alt
    p_ref, p_het, p_hom = prob3.unbind(-1)
    dosage = p_het + 2*p_hom
    W_score = p_het + 4*p_hom

    # 按位点求平均
    n_valid = mask.sum(0)                        # (L,)
    AF = 0.5 * (dosage * mask).sum(0) / n_valid.clamp(min=1)
    denom_info = AF * (1 - AF)

    # INFO
    var_want = ((W_score - dosage.square()) * mask).sum(0) / n_valid.clamp(min=1)
    info = 1 - 0.5 * var_want / denom_info.clamp(min=1e-8)
    info = info.clamp(0, 1)

    # MaCH-Rsq
    # 真实剂量
    true_dosage = (y3[..., 1] + 2*y3[..., 2]).float()        # (N,L)
    # 预测剂量
    pred_dosage = dosage                                       # (N,L) 前面已算好
    # 有效样本均值
    mean_true = (true_dosage * mask).sum(0) / n_valid.clamp(min=1)
    mean_pred = (pred_dosage * mask).sum(0) / n_valid.clamp(min=1)

    # 分子：协方差（= 预测对真实解释的方差）
    num = ((pred_dosage - mean_pred.unsqueeze(0)) *
           (true_dosage - mean_true.unsqueeze(0)) * mask).sum(0) / n_valid.clamp(min=1)

    # 分母：由真实剂量得到的 AF*(1-AF)
    AF = mean_true / 2.0
    denom = AF * (1 - AF)
    mach = num / denom.clamp(min=1e-8)
    mach = mach.clamp(0, 1)

    # IQS (Cohen's kappa)
    pred_cls = prob3.argmax(-1)                  # (N,L)
    true_cls = y3.argmax(-1)
    agree = (pred_cls == true_cls) & mask        # (N,L)
    Po = (agree.sum(0)).float() / n_valid.clamp(min=1)
    Pe = torch.zeros_like(Po)
    for c in range(3):
        p_c = ((pred_cls == c) & mask).sum(0).float() / n_valid.clamp(min=1)
        t_c = ((true_cls == c) & mask).sum(0).float() / n_valid.clamp(min=1)
        Pe += p_c * t_c
    iqs = (Po - Pe) / (1 - Pe).clamp(min=1e-8)
    iqs = iqs.clamp(-1, 1)

    # 无效位点填 0
    invalid = n_valid == 0
    info[invalid] = 0
    mach[invalid] = 0
    iqs[invalid]  = 0
    return info, mach, iqs

# ---------- 5. 唯一对外接口 ----------
def metrics_by_maf(prob: torch.Tensor,
                   y_true: torch.Tensor,
                   hap_map,
                   maf_vec: torch.Tensor,
                   bins: List[Tuple[float, float]] = MAF_BINS,
                   mask: Optional[torch.Tensor] = None
                   ) -> Dict[str, List[float]]:
    """
    返回 dict: {'Acc':[...], 'INFO':[...], 'MaCH':[...], 'IQS':[...]}
    顺序与 bins 一致
    """
    N, L, _ = prob.shape
    device = prob.device
    if mask is None:
        mask = torch.ones((N, L), dtype=torch.bool, device=device)

    # 三分类
    prob3, y3 = aggregate_three_classes(prob, y_true, hap_map)

    # --- 5.1 accuracy 向量化 ---
    preds = prob3.argmax(-1)
    gts   = y3.argmax(-1)
    correct = (preds == gts) & mask                      # (N,L)
    maf_b = maf_vec.unsqueeze(0)                         # (1,L)
    acc_bins = []
    for lo, hi in bins:
        mbin = mask & (maf_b >= lo) & (maf_b < hi)
        n_cor = (correct & mbin).sum()
        n_tot = mbin.sum()
        acc_bins.append((n_cor / n_tot).item() if n_tot > 0 else 0.)

    # --- 5.2 其余 3 个指标 ---
    info_all, mach_all, iqs_all = _compute_site_metrics(prob3, y3, mask)
    info_bins, mach_bins, iqs_bins = [], [], []
    for lo, hi in bins:
        idx = (maf_vec >= lo) & (maf_vec < hi)
        if idx.sum() == 0:
            info_bins.append(0.); mach_bins.append(0.); iqs_bins.append(0.)
        else:
            info_bins.append(info_all[idx].mean().item())
            mach_bins.append(mach_all[idx].mean().item())
            iqs_bins.append(iqs_all[idx].mean().item())

    return {'Acc': acc_bins, 'INFO': info_bins,
            'MaCH': mach_bins, 'IQS': iqs_bins}

# ---------- 6. 打印 ----------
def print_maf_stat_df(chunk_bin_cnt: List[int],
                      train_bins_metrics: Dict[str, List[float]],
                      val_bins_metrics: Dict[str, List[float]]):
    maf_df = pd.DataFrame({
        'MAF_bin': ['(0.00, 0.05)', '(0.05, 0.10)', '(0.10, 0.20)',
                    '(0.20, 0.30)', '(0.30, 0.40)', '(0.40, 0.50)'],
        'Counts':  [f"{c}" for c in chunk_bin_cnt],
        'Train_Acc':   [f"{v:.3f}" for v in train_bins_metrics['Acc']],
        'Val_Acc':     [f"{v:.3f}" for v in val_bins_metrics['Acc']],
        'Train_INFO':  [f"{v:.3f}" for v in train_bins_metrics['INFO']],
        'Val_INFO':    [f"{v:.3f}" for v in val_bins_metrics['INFO']],
        'Train_MaCH':  [f"{v:.3f}" for v in train_bins_metrics['MaCH']],
        'Val_MaCH':    [f"{v:.3f}" for v in val_bins_metrics['MaCH']],
        'Train_IQS':   [f"{v:.3f}" for v in train_bins_metrics['IQS']],
        'Val_IQS':     [f"{v:.3f}" for v in val_bins_metrics['IQS']],
    })
    print(maf_df.to_string(index=False))

load data

In [28]:
# work_dir = "/mnt/qmtang/EvoFill/data/251027_ver3_chr22/"
work_dir = '/mnt/qmtang/EvoFill/data/251027_ver3_chr22_trim'
print(f"Work Dir: {work_dir}")
create_directories(work_dir)

val_n_samples = 128

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()
print(f"Using device: {device}")

gt_enc = GenotypeEncoder.loadfromdisk(work_dir)
print(f'{gt_enc.n_samples:,} samples, {gt_enc.n_variants:,} variants, {gt_enc.seq_depth} seq-depth.')

x_train_indices, x_valid_indices = train_test_split(
    range(gt_enc.n_samples),
    test_size=val_n_samples,
    random_state=3047,
    shuffle=True
)
print(f"{len(x_train_indices):,} samples in train")
print(f"{len(x_valid_indices):,} samples in val")

Work Dir: /mnt/qmtang/EvoFill/data/251027_ver3_chr22_trim
Using device: cuda
2,404 samples, 99,314 variants, 4 seq-depth.
2,276 samples in train
128 samples in val


init model

In [29]:
model_name  = 'hg19_chr22_trim'
total_sites = gt_enc.n_variants
alleles     = gt_enc.seq_depth
chunk_size  = 32768
overlap     = 1024
d_model     = 64

set_seed(42)

model = EvoFill(d_model, alleles, total_sites, chunk_size, overlap).to(device)
print(f"model[{model_name}] would have {model.n_chunks} chunks.")

criterion = ImputationLoss(use_r2=True)

model[hg19_chr22_trim] would have 4 chunks.


### STAGE 1: Chunk Module Training

In [None]:
model.global_out.set_ulr_enabled(False)

batch_size         = 8
max_epochs         = 100
lr                 = 0.001
weight_decay       = 1e-5
earlystop_patience = 13
max_mr             = 0.7
min_mr             = 0.3
verbose            = False

train_dataset = GenomicDataset(
    gt_enc.X_gt,
    x_extra=gt_enc.X_extra,
    seq_depth=gt_enc.seq_depth,
    mask=True,
    masking_rates=(min_mr, max_mr),
    indices=x_train_indices
)

val_dataset = GenomicDataset(
    gt_enc.X_gt,
    x_extra=gt_enc.X_extra,
    seq_depth=gt_enc.seq_depth,
    mask=True,
    masking_rates=(min_mr, max_mr),
    indices=x_valid_indices
)

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        shuffle=False, num_workers=4, pin_memory=True)

for cid in range(model.n_chunks):
    chunk_mask = model.chunk_masks[cid].cpu()
    chunk_maf, chunk_bin_cnt = precompute_maf(gt_enc.X_gt[:,chunk_mask.bool().cpu().numpy()].toarray(),  mask_int=gt_enc.seq_depth)
    chunk_maf = torch.from_numpy(chunk_maf).to(device)
    if verbose:
        print(f"=== Chunk {cid + 1} STAT ===")
        maf_df = pd.DataFrame({
            'MAF_bin': ['(0.00, 0.05)', '(0.05, 0.10)', '(0.10, 0.20)',
                        '(0.20, 0.30)', '(0.30, 0.40)', '(0.40, 0.50)'],
            'Counts':  [f"{c}" for c in chunk_bin_cnt],
        })
        print(maf_df.to_string(index=False))

    # 2. 只给当前chunk专家+GlobalOut局部卷积上优化器
    trainable = (list(model.chunk_embeds[cid].parameters()) +
                list(model.chunk_modules[cid].parameters()) +
                [model.global_out.w1, model.global_out.b1,
                model.global_out.w2, model.global_out.b2])
    optimizer = AdamW(trainable, lr=lr, weight_decay=weight_decay)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-8)
    best_loss = float('inf')
    patience = earlystop_patience
    patience_counter = 0
    is_early_stopped = False
    for epoch in range(max_epochs):
        model.train()
        train_loss = 0.0
        train_logits, train_gts, train_mask = [], [], []

        train_pbar = tqdm(train_loader, desc=f'Chunk {cid + 1}/{model.n_chunks}, Epoch {epoch + 1}/{max_epochs}', leave=False)
        for batch_idx, (x, x_extra, target) in enumerate(train_pbar):
            x,  target = x.to(device), target.to(device)
            if x_extra.numel() == 0:
                x_extra = None
            else:
                x_extra = x_extra.to(device)

            optimizer.zero_grad()
            pred, mask_idx = model(x, cid, x_extra)
            loss, logs = criterion(pred[:, mask_idx], target[:,mask_idx]) 
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_pbar.set_postfix({'loss': loss.item()})

            # === 收集训练结果 ===
            miss_mask = x[:, mask_idx][..., -1].bool()         # 只关心被 mask 的位点
            train_logits.append(pred[:, mask_idx].detach())
            train_gts.append(target[:,mask_idx].detach())
            train_mask.append(miss_mask)

        # 训练集 MAF-acc
        train_logits = torch.cat(train_logits, dim=0)
        train_gts    = torch.cat(train_gts,    dim=0)
        train_mask   = torch.cat(train_mask,   dim=0)

        # ----------- 验证循环同理 ------------
        model.eval()
        val_loss = 0.0
        val_logits, val_gts = [], []
        with torch.no_grad():
            for x, x_extra, target in val_loader:
                x,  target = x.to(device), target.to(device)
                if x_extra.numel() == 0:
                    x_extra = None
                else:
                    x_extra = x_extra.to(device)
                pred, mask_idx = model(x, cid, x_extra)
                loss, logs = criterion(pred[:, mask_idx], target[:,mask_idx]) 
                val_loss += loss.item()
                val_logits.append(pred[:, mask_idx].detach())
                val_gts.append(target[:,mask_idx].detach())

        val_logits = torch.cat(val_logits, dim=0)
        val_gts    = torch.cat(val_gts,    dim=0)
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss   = val_loss   / len(val_loader)

        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]['lr']

        print(f'Chunk {cid + 1}/{model.n_chunks}, '
            f'Epoch {epoch + 1}/{max_epochs}, '
            f'Train Loss: {avg_train_loss:.1f}, '
            f'Val Loss: {avg_val_loss:.1f}, '
            f'LR: {current_lr:.2e}')
        
        # Early stopping
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            # 只存当前 chunk 专家 + 全局层
            ckpt = {
                'chunk_id': cid,
                'chunk_embed_state': model.chunk_embeds[cid].state_dict(),
                'chunk_module_state': model.chunk_modules[cid].state_dict(),
                'global_out_state': model.global_out.state_dict(),
                'best_val_loss': best_loss,
            }
            torch.save(ckpt, f'{work_dir}/models/{model_name}_chunk{cid}.pth')
            predres_with_bestloss = (train_logits, train_gts, val_logits, val_gts)
            if verbose:
                train_bins_metrics = metrics_by_maf(train_logits, train_gts, gt_enc.hap_map, chunk_maf, mask=train_mask)
                val_bins_metrics   = metrics_by_maf(val_logits,   val_gts, gt_enc.hap_map, chunk_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,train_bins_metrics,val_bins_metrics)
                print(f'  --> updated {model_name}_chunk{cid}.pth')
        else:
            patience_counter += 1
            if patience_counter >= earlystop_patience:
                is_early_stopped = True
                print(f'Chunk {cid + 1}/{model.n_chunks}, Early stopping triggered')
                train_logits, train_gts, val_logits, val_gts = predres_with_bestloss
                train_bins_metrics = metrics_by_maf(train_logits, train_gts, gt_enc.hap_map, chunk_maf, mask=train_mask)
                val_bins_metrics   = metrics_by_maf(val_logits,   val_gts, gt_enc.hap_map, chunk_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,train_bins_metrics,val_bins_metrics)
                break

    if not is_early_stopped:
        predres_with_bestloss = (train_logits, train_gts, val_logits, val_gts)
        train_bins_metrics = metrics_by_maf(train_logits, train_gts, gt_enc.hap_map, chunk_maf, mask=train_mask)
        val_bins_metrics   = metrics_by_maf(val_logits,   val_gts, gt_enc.hap_map, chunk_maf, mask=None)
        print_maf_stat_df(chunk_bin_cnt,train_bins_metrics,val_bins_metrics)
    del optimizer, scheduler
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

# ---------------- 全部 chunk 训练完成 -> 保存完整模型 ----------------
final_ckpt = {
    'model_state': model.state_dict(),
    'n_chunks': model.n_chunks,
    'chunk_size': model.chunk_size,
    'chunk_overlap': model.chunk_overlap,
}
torch.save(final_ckpt, f'{work_dir}/models/{model_name}_stage1.pth')
print(f'==> STAGE1 (Chunk Module) training finished: {work_dir}/models/{model_name}_stage1.pth')

                                                                                       

Chunk 1/4, Epoch 1/100, Train Loss: 72410.7, Val Loss: 71993.1, LR: 1.00e-03


                                                                                       

Chunk 1/4, Epoch 2/100, Train Loss: 69847.0, Val Loss: 72826.6, LR: 1.00e-03


                                                                                       

Chunk 1/4, Epoch 3/100, Train Loss: 71025.5, Val Loss: 69634.1, LR: 1.00e-03


                                                                                       

Chunk 1/4, Epoch 4/100, Train Loss: 70020.6, Val Loss: 74399.9, LR: 1.00e-03


                                                                                       

Chunk 1/4, Epoch 5/100, Train Loss: 68704.4, Val Loss: 69484.2, LR: 1.00e-03


                                                                                       

Chunk 1/4, Epoch 6/100, Train Loss: 68544.6, Val Loss: 67193.0, LR: 1.00e-03


                                                                                       

Chunk 1/4, Epoch 7/100, Train Loss: 68802.1, Val Loss: 68758.5, LR: 1.00e-03


                                                                                       

Chunk 1/4, Epoch 8/100, Train Loss: 68900.9, Val Loss: 69451.4, LR: 1.00e-03


                                                                                       

Chunk 1/4, Epoch 9/100, Train Loss: 67918.7, Val Loss: 67256.9, LR: 1.00e-03


                                                                                        

Chunk 1/4, Epoch 10/100, Train Loss: 67859.7, Val Loss: 68598.2, LR: 5.00e-04


                                                                                        

Chunk 1/4, Epoch 11/100, Train Loss: 65436.2, Val Loss: 64980.8, LR: 5.00e-04


                                                                                        

Chunk 1/4, Epoch 12/100, Train Loss: 64971.0, Val Loss: 64206.6, LR: 5.00e-04


                                                                                        

Chunk 1/4, Epoch 13/100, Train Loss: 64833.6, Val Loss: 64920.5, LR: 5.00e-04


                                                                                        

Chunk 1/4, Epoch 14/100, Train Loss: 64495.2, Val Loss: 66000.7, LR: 5.00e-04


                                                                                        

Chunk 1/4, Epoch 15/100, Train Loss: 64730.5, Val Loss: 64204.8, LR: 5.00e-04


                                                                                        

Chunk 1/4, Epoch 16/100, Train Loss: 64565.5, Val Loss: 62841.1, LR: 5.00e-04


                                                                                        

Chunk 1/4, Epoch 17/100, Train Loss: 64014.0, Val Loss: 64050.9, LR: 5.00e-04


                                                                                        

Chunk 1/4, Epoch 18/100, Train Loss: 64851.4, Val Loss: 63889.1, LR: 5.00e-04


                                                                                        

Chunk 1/4, Epoch 19/100, Train Loss: 64235.0, Val Loss: 65281.4, LR: 5.00e-04


Chunk 1/4, Epoch 20/100:  39%|███▊      | 110/285 [00:16<00:25,  6.93it/s, loss=6.17e+4]

### STAGE 2: Ultra-Long-Range LD Module Training

In [None]:
# ============ 超参 ============
max_epochs_per_pair = 10
lr                 = 5e-4
weight_decay       = 1e-5
earlystop_patience = 15
batch_size         = 4
min_mr, max_mr     = 0.4, 0.6
verbose            = True
# ==============================

train_dataset = GenomicDataset(
    gt_enc.X_gt,
    x_extra=gt_enc.X_extra,
    seq_depth=gt_enc.seq_depth,
    mask=True,
    masking_rates=(min_mr, max_mr),
    indices=x_train_indices
)

val_dataset = GenomicDataset(
    gt_enc.X_gt,
    x_extra=gt_enc.X_extra,
    seq_depth=gt_enc.seq_depth,
    mask=True,
    masking_rates=(min_mr, max_mr),
    indices=x_valid_indices
)

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        shuffle=False, num_workers=4, pin_memory=True)


# ----------- 加载已完成的第一阶段模型 -----------
ckpt = torch.load(f'{work_dir}/models/{model_name}_stage1.pth', map_location='cpu')
model.load_state_dict(ckpt['model_state'])
model.global_out.set_ulr_enabled(True)          # 只开 ulr 分支
model.eval()        # chunk 专家冻结（requires_grad=False）

# 分离优化器
ulr_params = list(model.global_out.ulr_mamba.parameters()) + \
             list(model.global_out.gate.parameters()) + \
             list(model.global_out.norm.parameters())          # 如有
optimizer = AdamW(ulr_params, lr=lr, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5,
                              patience=5, min_lr=1e-9)


pair_list = list(combinations(range(model.n_chunks), 2))
np.random.shuffle(pair_list)          # 打乱
total_pairs = len(pair_list)

for pair_idx, (cid1, cid2) in enumerate(pair_list, 1):
    # ====== 构造并集 mask ======
    union_mask = (model.chunk_masks[cid1] + model.chunk_masks[cid2]).clamp(max=1).bool()

    # 并集 MAF
    union_maf, union_bin_cnt = precompute_maf(
        gt_enc.X_gt[:, union_mask.cpu().numpy()].toarray(),
        mask_int=gt_enc.seq_depth
    )
    union_maf = torch.from_numpy(union_maf).to(device)

    # ====== 早停变量 ======
    best_loss = float('inf')
    patience_counter = 0
    is_early_stopped = False

    # ====== 训练循环 ======
    for epoch in range(max_epochs_per_pair):
        model.train()
        train_loss = 0.0
        train_logits, train_gts, train_mask = [], [], []

        pbar = tqdm(train_loader,
                    desc=f'Pair {pair_idx}/{total_pairs}  '
                         f'{cid1}-{cid2}  Epoch {epoch+1}/{max_epochs_per_pair}',
                    leave=False)
        for x, x_extra, target in pbar:
            x = x.to(device)
            target = target.to(device)
            if x_extra.numel() == 0:
                x_extra = None
            else:
                x_extra = x_extra.to(device)

            optimizer.zero_grad()

            pred, mask_idx = model(x, [cid1, cid2], x_extra)
            loss, _ = criterion(pred[:,mask_idx], target[:, mask_idx])
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})

            # 收集指标
            miss_mask = x[:,union_mask][..., -1].bool()
            train_logits.append(pred[:, mask_idx].detach())
            train_gts.append(target[:,mask_idx].detach())
            train_mask.append(miss_mask)

        # 训练集 MAF
        train_logits = torch.cat(train_logits, dim=0)
        train_gts    = torch.cat(train_gts,    dim=0)
        train_mask   = torch.cat(train_mask,   dim=0)

        # ----------- 验证 -----------
        model.eval()
        val_loss = 0.0
        val_logits, val_gts = [], []
        with torch.no_grad():
            for x, x_extra, target in val_loader:
                x = x.to(device)
                target = target.to(device)
                x_extra = x_extra.to(device) if x_extra.numel() else None
                pred, mask_idx = model(x, [cid1, cid2], x_extra)
                loss, _ = criterion(pred[:,mask_idx], target[:,mask_idx])
                val_loss += loss.item()
                val_logits.append(pred[:,mask_idx])
                val_gts.append(target[:,mask_idx])

        val_logits = torch.cat(val_logits, dim=0)
        val_gts    = torch.cat(val_gts,    dim=0)

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss   = val_loss   / len(val_loader)
        scheduler.step(avg_val_loss)

        # 早停
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            torch.save({
                'pair': (cid1, cid2),
                'ulr_state': {
                    'ulr_mamba': model.global_out.ulr_mamba.state_dict(),
                    'gate':      model.global_out.gate.state_dict(),
                },
                'best_val_loss': best_loss,
                'epoch': epoch,
            }, f'{work_dir}/models/{model_name}_ulr_ld_{cid1}_{cid2}.pth')
            # MAF 表格
            predres_with_bestloss = (train_logits, train_gts, val_logits, val_gts)
            if verbose:
                train_bins_metrics = metrics_by_maf(train_logits, train_gts, gt_enc.hap_map, chunk_maf, mask=train_mask)
                val_bins_metrics   = metrics_by_maf(val_logits,   val_gts, gt_enc.hap_map, chunk_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,train_bins_metrics,val_bins_metrics)
                print(f'  --> updated ulr_ld_{cid1}_{cid2}.pth')
        else:
            patience_counter += 1
            if patience_counter >= earlystop_patience:
                is_early_stopped = True
                print(f'Pair {cid1}-{cid2} early stopping')
                train_logits, train_gts, val_logits, val_gts = predres_with_bestloss
                train_bins_metrics = metrics_by_maf(train_logits, train_gts, gt_enc.hap_map, chunk_maf, mask=train_mask)
                val_bins_metrics   = metrics_by_maf(val_logits,   val_gts, gt_enc.hap_map, chunk_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,train_bins_metrics,val_bins_metrics)
                break
            
    if not is_early_stopped:
        predres_with_bestloss = (train_logits, train_gts, val_logits, val_gts)
        train_bins_metrics = metrics_by_maf(train_logits, train_gts, gt_enc.hap_map, chunk_maf, mask=train_mask)
        val_bins_metrics   = metrics_by_maf(val_logits,   val_gts, gt_enc.hap_map, chunk_maf, mask=None)
        print_maf_stat_df(chunk_bin_cnt,train_bins_metrics,val_bins_metrics)

    # del optimizer, scheduler
    torch.cuda.empty_cache()

# ----------- 全部 pair 结束 -> 保存最终模型 -----------
torch.save({
    'model_state': model.state_dict(),
    'ulr_enabled': True,
}, f'{work_dir}/models/{model_name}_stage2_final.pth')
print(f'==> STAGE2 (Ultra_LR-LD) training finished: {work_dir}/models/{model_name}_stage2_final.pth')

                                                                                           

NameError: name 'imputation_maf_accuracy_epoch' is not defined

## Inferring

In [None]:
# ---------- 1. 必须与训练时完全一致 ----------
d_model       = 64
n_alleles     = 3
total_sites   = 12345
chunk_size    = 4096
chunk_overlap = 64
device        = 'cuda' if torch.cuda.is_available() else 'cpu'

# 重建模型
model = EvoFill(
    d_model=d_model,
    n_alleles=n_alleles,
    total_sites=total_sites,
    chunk_size=chunk_size,
    chunk_overlap=chunk_overlap
).to(device)

# ---------- 2. 加载最终权重 ----------
ckpt = torch.load('exp1/models/final_model.pth', map_location=device)
model.load_state_dict(ckpt['model_state'])

# ---------- 3. 切换推理模式 ----------
model.eval()
with torch.no_grad():
    pred = model(x, chunk_id=0, x_extra=x_extra)