ver0: 多 chunk modules 独立权重

ver0.1: 加样本特征标签（演化坐标）

ver3: chunk-wise 稀疏激活

## Dependency

In [1]:
import os; os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 设置GPU
import gzip
import json
import shutil
from typing import Optional, Tuple, Dict, Set

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

from tqdm import tqdm

from sklearn.model_selection import train_test_split

from mamba_ssm import Mamba2
from mamba_ssm.modules.mamba2_simple import Mamba2Simple as Mamba2Block # 原Mamba2Block

## Data

In [2]:
class DataReader:
    def __init__(self, ref_vcf: str, ref_extra: Optional[str] = None):
        self.ref_vcf   = ref_vcf
        self.ref_extra = ref_extra

        # 1) 先扫一次 VCF：拿样本 ID 与「所有 GT 字符串」
        self.sample_ids, gt_strings = self._scan_all_gt()
        # 2) 建立映射
        self.hap_map   = self._make_hap_map(gt_strings)
        self.seq_depth = len(self.hap_map)        # 已含缺失位

        # 3) 正式读矩阵
        self.X_gt = self._read_body()
        self.total_sites = self.X_gt.shape[1]
        print(f'Loaded genotypes info: {self.X_gt.shape}')

        # 4) 读 extra
        self.X_extra = self._read_extra() if self.ref_extra else None
        if self.X_extra is not None:
            print(f'Loaded extra info: {self.X_extra.shape}')

        # 5) 打印
        self._print_summary()

    # ---- 工具 ----
    def _open(self): return gzip.open(self.ref_vcf, 'rt') if self.ref_vcf.endswith('.gz') else open(self.ref_vcf, 'rt')

    def _scan_all_gt(self) -> Tuple[list, Set[str]]:
        samples, gt_set = [], set()
        with self._open() as f:
            for line in f:
                if line.startswith('##'): continue
                if line.startswith('#CHROM'):
                    samples = line.strip().split('\t')[9:]
                    continue
                if not samples: continue
                parts = line.strip().split('\t')
                for fmt in parts[9:]:
                    gt = fmt.split(':')[0]
                    gt_set.add(gt)
        if not samples: raise RuntimeError('No sample IDs found')
        return samples, gt_set

    def _make_hap_map(self, gt_strings: Set[str]) -> Dict[str, int]:
        """给每个真实 GT 字符串一个唯一整数；缺失统一用 '.|.' 并强制占最后一档"""
        cleaned = set()
        for gt in gt_strings:
            if '.' in gt:
                cleaned.add('.|.')
            else:
                cleaned.add(gt)
        # 强制加入缺失项（即使当前没出现）
        cleaned.add('.|.')
        # 排序保证一致性：缺失放最后
        sorted_gts = sorted(cleaned, key=lambda x: (x == '.|.', x))
        hap_map = {gt: idx for idx, gt in enumerate(sorted_gts)}
        return hap_map

    def _read_body(self) -> np.ndarray:
        data = []
        with self._open() as f:
            for line in f:
                if line.startswith('#'): continue
                parts = line.strip().split('\t')
                row = []
                for fmt in parts[9:]:
                    gt = fmt.split(':')[0]
                    # 缺失处理
                    if '.' in gt:
                        row.append(self.hap_map['.|.'])
                    else:
                        row.append(self.hap_map[gt])
                data.append(row)
        return np.array(data, dtype=np.int32).T   # (n_samp, n_var)

    def _read_extra(self) -> Optional[np.ndarray]:
        try:
            df = pd.read_csv(self.ref_extra, sep='\t', index_col=0)
            df = df.loc[self.sample_ids]
            return df.values.astype(np.float32)
        except Exception as e:
            print(f'Extra features skipped: {e}')
            return None

    def _print_summary(self):
        uniq_codes = sorted(set(self.X_gt.flat))
        rev = {v: k for k, v in self.hap_map.items()}
        print('Unique genotypes in dataset:', [rev.get(c, f'unknown({c})') for c in uniq_codes])
        print('hap_map:')
        for k, v in list(self.hap_map.items())[:10]:
            print(f'  {k} -> {v}')
        print(f'self.seq_depth: {self.seq_depth}')

    # ---- 对外 ----
    def load(self) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        return self.X_gt, self.X_extra

In [47]:
dr = DataReader(
    ref_vcf='/home/qmtang/GitHub/STICI-HPC/data/training_sets/ALL.chr22.training.samples.100k.any.type.0.01.maf.variants.vcf.gz',
    ref_extra='/mnt/qmtang/EvoFill/data/251020_ver01_chr22/pop_wasserstein.tsv'
)
X_gt, X_extra = dr.load()

Loaded genotypes info: (2404, 99314)
Loaded extra info: (2404, 26)
Unique genotypes in dataset: ['0|0', '0|1', '1|0', '1|1']
hap_map:
  0|0 -> 0
  0|1 -> 1
  1|0 -> 2
  1|1 -> 3
  .|. -> 4
self.seq_depth: 5


In [29]:
X_extra

array([[1.3819261 , 0.6733273 , 2.54835   , ..., 1.8984295 , 1.9717863 , 3.2552247 ],
       [0.79339385, 0.6408869 , 1.1961467 , ..., 1.4899036 , 2.3885481 , 2.7226706 ],
       [1.1014444 , 2.3403306 , 1.5370792 , ..., 2.2961452 , 2.28726   , 1.3032436 ],
       ...,
       [2.9782143 , 3.4962935 , 2.4113977 , ..., 1.9876934 , 1.383686  , 1.4305575 ],
       [1.2182752 , 1.1490815 , 1.7263185 , ..., 1.3928081 , 3.5453155 , 3.036527  ],
       [2.8082004 , 2.8113415 , 2.583803  , ..., 2.4467747 , 1.1540115 , 1.2323742 ]], dtype=float32)

In [28]:
X_gt

array([[0, 0, 1, ..., 2, 0, 0],
       [2, 0, 0, ..., 0, 0, 3],
       [0, 0, 0, ..., 0, 0, 3],
       ...,
       [1, 1, 0, ..., 0, 0, 2],
       [3, 2, 0, ..., 0, 0, 1],
       [0, 0, 0, ..., 0, 0, 3]], dtype=int32)

In [28]:
class GenomicDataset(Dataset):
    """Dataset class for genomic data with masking for training"""

    def __init__(self, x_gts, x_extra, seq_depth,
                 mask=True, masking_rates=(0.5, 0.99)):
        self.gts = x_gts
        self.x_extra = x_extra
        self.seq_depth = seq_depth
        self.mask = mask
        self.masking_rates = masking_rates

    def __len__(self):
        return len(self.gts)

    def __getitem__(self, idx):
        x       = self.gts[idx].copy()
        x_extra = self.x_extra[idx]
        y       = self.gts[idx]

        if self.mask:
            # Apply masking
            seq_len = len(x)
            masking_rate = np.random.uniform(*self.masking_rates)
            mask_size = int(seq_len * masking_rate)
            mask_indices = np.random.choice(seq_len, mask_size, replace=False)
            x[mask_indices] = self.seq_depth - 1  # Missing value token

        # Convert to one-hot
        x_onehot = np.eye(self.seq_depth)[x]
        y_onehot = np.eye(self.seq_depth - 1)[y]

        return torch.FloatTensor(x_onehot),torch.FloatTensor(x_extra), torch.FloatTensor(y_onehot)

class ImputationDataset(Dataset):
    """Dataset for imputation (no masking needed)"""

    def __init__(self, data, seq_depth):
        self.data = data
        self.seq_depth = seq_depth

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        # Convert to one-hot without masking
        x_onehot = np.eye(self.seq_depth)[x]
        return torch.FloatTensor(x_onehot)

## Model

In [4]:
class GenoEmbedding(nn.Module):
    """Genomic embedding layer with positional encoding"""

    def __init__(self, n_alleles, n_snps, d_model):
        super().__init__()
        self.d_model = d_model
        self.n_alleles = n_alleles
        self.n_snps = n_snps

        # Allele embedding
        self.allele_embedding = nn.Parameter(torch.randn(n_alleles, d_model))

        # Positional embedding
        self.position_embedding = nn.Embedding(n_snps, d_model)

        # Initialize parameters
        nn.init.xavier_uniform_(self.allele_embedding)

    def forward(self, x):
        # x shape: (batch, seq_len, n_alleles) - one-hot encoded
        batch_size, seq_len, _ = x.shape

        # Allele embedding
        embedded = torch.einsum('bsn,nd->bsd', x, self.allele_embedding)

        # Positional embedding
        positions = torch.arange(seq_len, device=x.device)
        pos_emb = self.position_embedding(positions).unsqueeze(0)

        return embedded + pos_emb

class BiMambaBlock(nn.Module):
    """Bidirectional Mamba block for genomic sequence processing"""

    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.d_model = d_model

        # Forward and backward Mamba blocks
        self.mamba_forward = Mamba2(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )

        self.mamba_backward = Mamba2(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )

        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model * 2, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.GELU()
        )

        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        # x shape: (batch, seq_len, d_model)
        residual = x

        # Bidirectional processing
        x_norm = self.norm1(x)

        # Forward direction
        forward_out = self.mamba_forward(x_norm)

        # Backward direction (flip sequence)
        x_backward = torch.flip(x_norm, dims=[1])
        backward_out = self.mamba_backward(x_backward)
        backward_out = torch.flip(backward_out, dims=[1])

        # Concatenate bidirectional outputs
        bi_out = torch.cat([forward_out, backward_out], dim=-1)

        # FFN
        ffn_out = self.ffn(bi_out)
        ffn_out = self.dropout(ffn_out)

        # Residual connection
        out = self.norm2(residual + ffn_out)

        return out

class ConvBlock(nn.Module):
    """Convolutional block for local pattern extraction"""

    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

        self.conv1 = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=5, padding=2)
        self.conv3 = nn.Conv1d(d_model, d_model, kernel_size=7, padding=3)

        self.conv_large1 = nn.Conv1d(d_model, d_model, kernel_size=7, padding=3)
        self.conv_large2 = nn.Conv1d(d_model, d_model, kernel_size=15, padding=7)

        self.conv_final = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.conv_reduce = nn.Conv1d(d_model, d_model, kernel_size=1)

        self.bn1 = nn.BatchNorm1d(d_model)
        self.bn2 = nn.BatchNorm1d(d_model)

        self.gelu = nn.GELU()

    def forward(self, x):
        # x shape: (batch, seq_len, d_model)
        x = x.transpose(1, 2)  # (batch, d_model, seq_len)

        xa = self.gelu(self.conv1(x))

        xb = self.gelu(self.conv2(xa))
        xb = self.gelu(self.conv3(xb))

        xc = self.gelu(self.conv_large1(xa))
        xc = self.gelu(self.conv_large2(xc))

        xa = xb + xc
        xa = self.gelu(self.conv_final(xa))
        xa = self.bn1(xa)
        xa = self.gelu(self.conv_reduce(xa))
        xa = self.bn2(xa)
        xa = self.gelu(xa)

        return xa.transpose(1, 2)  # (batch, seq_len, d_model)

class EvoEmbedding(nn.Module):
    """
    输入:  (B, L)        L == extra_dim
    输出: (B, L, d_model)
    """
    def __init__(
        self,
        d_model: int,
        d_state: int = 64,
        d_conv: int  = 4,
        expand: int  = 2,
        headdim: int = 128,
        ngroups: int = 1,
        dropout: float = 0.1,
        **mamba_kwargs,
    ):
        super().__init__()
        self.d_model   = d_model

        # 1. 把 (B, L) 的 1-d 标量升到 d_model
        self.in_proj = nn.Linear(1, d_model, bias=False)

        # 2. 官方 Mamba2Simple：把 L 当序列长度，建模 L↔L
        self.mamba = Mamba2Block(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
            headdim=headdim,
            ngroups=ngroups,
            **mamba_kwargs
        )

        # 3. 残差 + Norm
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)


    def forward(self, x: torch.Tensor):
        """
        x: (B, L)  连续值或离散索引
        """
        # (B, L) -> (B, L, 1) -> (B, L, d_model)
        h = self.in_proj(x.unsqueeze(-1).float())   # 1-d 投影

        # Mamba2Simple 要求输入 (B, L, d_model) 即可
        h = self.mamba(h)                           # SSD 全局建模

        # 残差 + Norm
        out = self.norm(h + self.dropout(h))
        return out

class Mamba2CrossBlock(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=64,
        d_conv=4,
        expand=2,
        headdim=128,
        ngroups=1,
        chunk_size=256,
        dropout=0.0,
        d_embed_dropout=0.0,
        device=None,
        dtype=None,
    ):
        super().__init__()
        self.d_model = d_model

        # 距离矩阵嵌入
        self.extra_embed = EvoEmbedding(d_model=d_model, dropout=d_embed_dropout)

        # 原归一化
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        # SSD 核心
        self.ssd = Mamba2Block(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
            headdim=headdim,
            ngroups=ngroups,
            chunk_size=chunk_size,
            use_mem_eff_path=True,
            device=device,
            dtype=dtype,
        )

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, d_model),
        )

    def forward(self, local_repr, global_repr, x_extra=None,
                start_offset=0, end_offset=0):
        """
        local_repr: (B, L, D)
        global_repr: (B, G, D)
        x_extra: 可选，(B,E) 
        """
        local_norm  = self.norm1(local_repr)
        global_norm = self.norm2(global_repr)

        # 1. 构造输入序列
        tokens = []
        if x_extra is not None:
            extra_token = self.extra_embed(x_extra)        # (B,E,D)
            tokens.append(extra_token)
        tokens.append(global_norm)
        tokens.append(local_norm)
        x = torch.cat(tokens, dim=1)               # [B, (E)+G+L, D]

        # 2. SSD 扫描
        x = self.ssd(x)                            # [B, (E)+G+L, D]

        # 3. 只取 local 部分
        local_len = local_norm.shape[1]
        x = x[:, -local_len:, :]                   # [B, L, D]

        # 4. pad 回原始长度
        if start_offset or end_offset:
            x = F.pad(x, (0, 0, start_offset, end_offset))

        # 5. 残差 + FFN
        x = x + local_norm
        x = self.norm3(x)
        x = self.ffn(x) + x
        return x
        return x

class ChunkModule(nn.Module):
    """Single chunk processing module with BiMamba"""

    def __init__(self, d_model, dropout_rate=0.1):
        super().__init__()
        self.d_model = d_model

        # BiMamba block
        self.bimamba_block = BiMambaBlock(d_model)

        # Convolutional blocks
        self.conv_block1 = ConvBlock(d_model)
        self.conv_block2 = ConvBlock(d_model)
        self.conv_block3 = ConvBlock(d_model)

        # Cross attention
        # self.cross_attention = CrossAttentionLayer(d_model, n_heads)
        self.cross_attention = Mamba2CrossBlock(
            d_model=d_model,
            d_state=64,
            d_conv=4,
            expand=2,
            headdim=128,
            ngroups=1,
            chunk_size=256,
        )

        # Additional layers
        self.dense = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout_rate)
        self.gelu = nn.GELU()

    def forward(self, x, x_extra=None):
        # BiMamba processing
        xa0 = self.bimamba_block(x)

        # First conv block
        xa = self.conv_block1(xa0)
        xa_skip = self.conv_block2(xa)

        # Dense layer
        xa = self.gelu(self.dense(xa))
        xa = self.conv_block2(xa)

        # Cross attention
        xa = self.cross_attention(xa, xa0, x_extra)
        xa = self.dropout(xa)

        # Final conv block
        xa = self.conv_block3(xa)

        # Concatenate with skip connection
        xa = torch.cat([xa_skip, xa], dim=-1)

        return xa

class BandConv1d(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias, mask, kernel, pad):
        # x: (B, C, L)   mask: (L,)  0/1
        ctx.save_for_backward(x, weight, bias, mask)
        ctx.kernel, ctx.pad = kernel, pad
        L = x.shape[-1]
        # 1. 把有效区抽出来（连续内存）
        idx = torch.where(mask)[0]
        x_band = x[..., idx]                      # (B, C, len_band)
        # 2. 标准 conv1d
        y_band = F.conv1d(x_band, weight, bias, padding=pad)
        # 3. 建全长度空张量，再把结果写回去
        y = torch.full((x.shape[0], weight.shape[0], L),
                       float('nan'), device=x.device)
        y[..., idx] = y_band
        return y

    @staticmethod
    def backward(ctx, grad_output):
        x, weight, bias, mask = ctx.saved_tensors
        kernel, pad = ctx.kernel, ctx.pad
        idx = torch.where(mask)[0]
        # 只把带状区梯度拿出来
        grad_band = grad_output[..., idx]
        x_band = x[..., idx]
        # 计算输入/权重梯度
        grad_x_band = torch.nn.grad.conv1d_input(
            x_band.shape, weight, grad_band, padding=pad)
        grad_weight = torch.nn.grad.conv1d_weight(
            x_band, weight.shape, grad_band, padding=pad)
        grad_bias   = grad_band.sum(dim=[0,2]) if bias is not None else None
        # 把输入梯度写回原位，其余区域无梯度（=None）
        grad_x = torch.full_like(x, float('nan'))
        grad_x[..., idx] = grad_x_band
        return grad_x, grad_weight, grad_bias, None, None, None

class SparseGlobalConv(nn.Module):
    def __init__(self, c_in, c_out, kernel=5, pad=2):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(c_out, c_in, kernel))
        self.bias   = nn.Parameter(torch.zeros(c_out))
        self.kernel, self.pad = kernel, pad
        nn.init.kaiming_normal_(self.weight)

    def forward(self, x, mask):
        return BandConv1d.apply(x, self.weight, self.bias, mask,
                                self.kernel, self.pad)

class GlobalOut(nn.Module):
    def __init__(self, d_model, n_alleles):
        super().__init__()
        self.final_conv  = SparseGlobalConv(2*d_model, d_model//2)
        self.output_conv = SparseGlobalConv(d_model//2, n_alleles-1)  # no missing
        self.gelu = nn.GELU()

    def forward(self, x, mask):
        # mask: (L,)  0/1  当前 chunk 要激活的位点
        x = self.gelu(self.final_conv(x, mask))
        x = self.output_conv(x, mask)        # (B, n_alleles-1, L)
        # 无效区填 -inf，后面 CrossEntropy 会自动忽略
        x = torch.where(mask.unsqueeze(0).unsqueeze(0).bool(),
                        x, torch.tensor(-float('inf'), device=x.device))
        return F.softmax(x.transpose(1,2), dim=-1)

class EvoFill(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_alleles: int,
        total_sites: int,
        chunk_size: int = 8192,
        chunk_overlap: int = 64,
        dropout_rate: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_alleles = n_alleles
        self.total_sites = total_sites
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

        # 1. chunk 边界
        stride = chunk_size - chunk_overlap
        starts = [i * stride for i in range((total_sites - 1) // stride + 1)]
        ends = [min(s + chunk_size, total_sites) for s in starts]
        self.register_buffer("starts", torch.tensor(starts, dtype=torch.long))
        self.register_buffer("ends", torch.tensor(ends, dtype=torch.long))
        self.n_chunks = len(starts)

        # 2. 每 chunk 一份嵌入 & 处理模块（常驻 GPU，但训练时只激活一个）
        self.chunk_embeds = nn.ModuleList(
            GenoEmbedding(n_alleles, e - s, d_model) for s, e in zip(starts, ends)
        )
        self.chunk_modules = nn.ModuleList(
            ChunkModule(d_model, dropout_rate) for s, e in zip(starts, ends)
        )

        # 3. 全局输出层（始终 GPU）
        self.global_out = GlobalOut(d_model, n_alleles)

        # 4. chunk 掩码表  (n_chunks, L)
        masks = torch.stack(
            [torch.arange(total_sites).ge(s) & torch.arange(total_sites).lt(e)
             for s, e in zip(starts, ends)]
        ).float()
        self.register_buffer("chunk_masks", masks)

    def forward(self, x: torch.Tensor, chunk_id: int,
                x_extra: Optional[torch.Tensor] = None):
        """
        x:       (B, L, n_alleles)  完整序列 one-hot
        chunk_id: 0..n_chunks-1
        x_extra:  (B, extra_dim)
        return:   (B, L, n_alleles-1)  softmax 概率
        """
        B, L, _ = x.shape
        device = x.device
        s, e = self.starts[chunk_id].item(), self.ends[chunk_id].item()
        mask = self.chunk_masks[chunk_id]          # (L,)  当前 chunk 覆盖区

        # 1. 只取当前 chunk 输入切片
        x_slice = x[:, s:e]                        # (B, len, n_alleles)

        # 2. 当前 chunk 嵌入 & 处理
        z = self.chunk_embeds[chunk_id](x_slice)   # (B, len, d_model)
        z = self.chunk_modules[chunk_id](z, x_extra)  # (B, len, 2*d_model)

        # 3. 拼回全长度，其余 nan
        z_full = torch.full((B, L, 2 * self.d_model), float('nan'), device=device)
        z_full[:, s:e] = z
        z_full = z_full.transpose(1, 2)            # (B, 2*d_model, L)

        # 4. 全局卷积只激活带状区
        return self.global_out(z_full, mask)       # (B, L, n_alleles-1)

In [None]:
# ---------- 假数据 ----------
B, L, A = 4, 12345, 3
d_model = 64
chunk_size, overlap = 4096, 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'

x_train = torch.zeros(B, L, A, device=device)
allele = torch.randint(0, A, (B, L), device=device)
x_train.scatter_(2, allele.unsqueeze(-1), 1)
x_extra = torch.randn(B, 10, device=device)
y_train = torch.randn(B, L, A-1, device=device)

print(f"x_train: {x_train.shape}")
print(f"x_extra: {x_extra.shape}")
print(f"y_train: {y_train.shape}")
print("")
# ---------- 模型 &损失 ----------
model = EvoFill(d_model, A, L, chunk_size, overlap).to(device)
print(f"model chunks: {model.n_chunks}")
criterion = nn.MSELoss(reduction='mean')

# ---------- 训练循环 ----------
epochs_per_chunk = 3
for cid in range(model.n_chunks):
    mask = model.chunk_masks[cid]
    opt = torch.optim.AdamW(
          list(model.chunk_embeds[cid].parameters()) +
          list(model.chunk_modules[cid].parameters()) +
          list(model.global_out.parameters()), lr=3e-4)
    for epoch in range(epochs_per_chunk):
        opt.zero_grad()
        pred = model(x_train, cid, x_extra)

        idx = torch.where(mask)[0]                 # (n_active,)
        pred_band  = pred[:, idx]                  # (B, n_active, n_alleles-1)
        y_band     = y_train[:, idx]               # (B, n_active, n_alleles-1)

        loss = criterion(pred_band, y_band)        # 默认 reduction='mean'
        loss.backward()
        opt.step()
        print("")
        print(f"x_train: {x_train.shape}")
        print(f"x_extra: {x_extra.shape}")
        print(f"pred: {pred.shape}")

        print(f"pred_band: {pred_band.shape}")
        print(f"y_band: {y_band.shape}")

        print(f'chunk {cid}/{model.n_chunks-1} | epoch {epoch+1}/{epochs_per_chunk} | loss {loss.item():.4f}')

x_train: torch.Size([4, 12345, 3])
x_extra: torch.Size([4, 10])
y_train: torch.Size([4, 12345, 2])

model chunks: 4


x_train: torch.Size([4, 12345, 3])
x_extra: torch.Size([4, 10])
pred: torch.Size([4, 12345, 2])
pred_band: torch.Size([4, 4096, 2])
y_band: torch.Size([4, 4096, 2])
chunk 0/3 | epoch 1/3 | loss 1.3536

x_train: torch.Size([4, 12345, 3])
x_extra: torch.Size([4, 10])
pred: torch.Size([4, 12345, 2])
pred_band: torch.Size([4, 4096, 2])
y_band: torch.Size([4, 4096, 2])
chunk 0/3 | epoch 2/3 | loss 1.2774

x_train: torch.Size([4, 12345, 3])
x_extra: torch.Size([4, 10])
pred: torch.Size([4, 12345, 2])
pred_band: torch.Size([4, 4096, 2])
y_band: torch.Size([4, 4096, 2])
chunk 0/3 | epoch 3/3 | loss 1.2199

x_train: torch.Size([4, 12345, 3])
x_extra: torch.Size([4, 10])
pred: torch.Size([4, 12345, 2])
pred_band: torch.Size([4, 4096, 2])
y_band: torch.Size([4, 4096, 2])
chunk 1/3 | epoch 1/3 | loss 1.3241

x_train: torch.Size([4, 12345, 3])
x_extra: torch.Size([4, 10])
pred: torc

## Loss

In [5]:
class ImputationLoss(nn.Module):
    """Custom loss function for genomic imputation"""

    def __init__(self, use_r2=True, 
                 use_focal=False, #  all dummy 
                 group_size=None,
                 gamma=None,
                 alpha=None,
                 eps=None,
                 use_gradnorm=None,
                 gn_alpha=None,
                 gn_lr_w=None,):
        super().__init__()
        self.use_r2_loss = use_r2
        self.ce_loss = nn.CrossEntropyLoss(reduction='sum')
        self.kl_loss = nn.KLDivLoss(reduction='sum')

    def calculate_minimac_r2(self, pred_alt_allele_probs, gt_alt_af):
        """Calculate Minimac-style RÂ² metric"""
        mask = torch.logical_or(torch.eq(gt_alt_af, 0.0), torch.eq(gt_alt_af, 1.0))
        gt_alt_af = torch.where(mask, 0.5, gt_alt_af)
        denom = gt_alt_af * (1.0 - gt_alt_af)
        denom = torch.where(denom < 0.01, 0.01, denom)
        r2 = torch.mean(torch.square(pred_alt_allele_probs - gt_alt_af), dim=0) / denom
        r2 = torch.where(mask, torch.zeros_like(r2), r2)
        return r2

    def forward(self, y_pred, y_true):
        y_true = y_true.float()

        # Convert to proper format for losses
        y_true_ce = torch.argmax(y_true, dim=-1)  # For CrossEntropy
        y_pred_log = torch.log(y_pred + 1e-8)  # For KL divergence

        # Basic losses
        ce_loss = self.ce_loss(y_pred.view(-1, y_pred.size(-1)), y_true_ce.view(-1))
        kl_loss = self.kl_loss(y_pred_log.view(-1, y_pred.size(-1)),
                               y_true.view(-1, y_true.size(-1)))

        total_loss = ce_loss + kl_loss

        if self.use_r2_loss:
            batch_size = y_true.size(0)
            group_size = 4
            num_full_groups = batch_size // group_size

            if num_full_groups > 0:
                y_true_grouped = y_true[:num_full_groups * group_size].view(
                    num_full_groups, group_size, *y_true.shape[1:])
                y_pred_grouped = y_pred[:num_full_groups * group_size].view(
                    num_full_groups, group_size, *y_pred.shape[1:])

                r2_loss = 0.0
                for i in range(num_full_groups):
                    gt_alt_af = torch.count_nonzero(
                        torch.argmax(y_true_grouped[i], dim=-1), dim=0
                    ).float() / group_size

                    pred_alt_allele_probs = torch.sum(y_pred_grouped[i][:, :, 1:], dim=-1)
                    r2_loss += -torch.sum(self.calculate_minimac_r2(
                        pred_alt_allele_probs, gt_alt_af)) * group_size

                total_loss += r2_loss

        return total_loss, None

## Train

In [6]:
def create_directories(save_dir, models_dir="models", outputs="out") -> None:
    """Create necessary directories"""
    for dd in [save_dir, f"{save_dir}/{models_dir}", f"{save_dir}/{outputs}"]:
        if not os.path.exists(dd):
            os.makedirs(dd)

def clear_dir(path) -> None:
    """Clear directory if it exists"""
    if os.path.exists(path):
        shutil.rmtree(path)


In [7]:
MAF_BINS = [(0.00, 0.05), (0.05, 0.10), (0.10, 0.20),
            (0.20, 0.30), (0.30, 0.40), (0.40, 0.50)]

def precompute_maf(gts_np, mask_int=-1):
    """
    gts_np: (N, L)  int64
    return:
        maf: (L,) float32
        bin_cnt: list[int] 长度 6，对应 6 个 bin 的位点数量
    """
    L = gts_np.shape[1]
    maf = np.zeros(L, dtype=np.float32)
    bin_cnt = [0] * 6

    for l in range(L):
        alleles = gts_np[:, l]
        alleles = alleles[alleles != mask_int]   # 去掉缺失
        if alleles.size == 0:
            maf[l] = 0.0
            continue

        uniq, cnt = np.unique(alleles, return_counts=True)
        total = cnt.sum()
        freq = cnt / total
        freq[::-1].sort()
        maf_val = freq[1] if len(freq) > 1 else 0.0
        maf[l] = maf_val

        # 统计 bin
        for i, (lo, hi) in enumerate(MAF_BINS):
            if lo <= maf_val < hi:
                bin_cnt[i] += 1
                break

    return maf, bin_cnt

def imputation_maf_accuracy_epoch(all_logits, all_gts, global_maf, mask=None):
    """
    all_logits: (N, L, C)
    all_gts:    (N, L, C) one-hot
    global_maf: (L,)
    mask:       (N, L) 或 None
    return:     list[float] 长度 6
    """
    # 1. 预测 vs 真实
    all_gts = all_gts.argmax(dim=-1)      # (N, L)
    preds   = all_logits.argmax(dim=-1)   # (N, L)

    # 2. 如果没有外部 mask，就默认全 1
    if mask is None:
        mask = torch.ones_like(all_gts, dtype=torch.bool)   # (N, L)
    correct = (preds == all_gts) & mask                   # (N, L)

    # 3. MAF 条件 -> (1, L) 再广播到 (N, L)
    maf = global_maf.unsqueeze(0)                         # (1, L)

    # 4. 分 bin 计算
    accs = []
    for lo, hi in MAF_BINS:
        maf_bin = mask & (maf >= lo) & (maf < hi)                # (1, L)
        n_cor = (correct & maf_bin).sum()
        n_tot = maf_bin.sum()
        accs.append(100*(n_cor / n_tot).item() if n_tot > 0 else 0.0)
    return accs

In [29]:
work_dir      = '/home/qmtang/mnt_qmtang/EvoFill/data/251021_ver01_chr22'
ref_vcf       = "/home/qmtang/GitHub/STICI-HPC/data/training_sets/ALL.chr22.training.samples.100k.any.type.0.01.maf.variants.vcf.gz"
ref_extra     = "/mnt/qmtang/EvoFill/data/251020_ver01_chr22/pop_wasserstein.tsv"

val_n_samples = 64
max_mr             = 0.7
min_mr             = 0.3
batch_size_per_gpu = 8

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create directories
create_directories(work_dir)

# Load data
dr = DataReader(
    ref_vcf=ref_vcf,
    ref_extra=ref_extra
)
X_gt, X_extra = dr.load()

# # Split data for validation
n_samples, total_sites = X_gt.shape
x_train_indices, x_valid_indices = train_test_split(
    range(n_samples),
    test_size=val_n_samples,
    random_state=3047,
    shuffle=True
)

train_gt = X_gt[x_train_indices,:]
train_extra = X_extra[x_train_indices,:]

val_gt = X_gt[x_valid_indices,:]
val_extra = X_extra[x_valid_indices,:]

# Create datasets
train_dataset = GenomicDataset(
    train_gt, train_extra, dr.seq_depth,
    mask=True, masking_rates=(min_mr, max_mr)
)
val_dataset = GenomicDataset(
    val_gt, val_extra, dr.seq_depth,
    mask=True, masking_rates=(min_mr, max_mr)
)
train_loader = DataLoader(train_dataset, batch_size=batch_size_per_gpu,
                        shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size_per_gpu,
                        shuffle=False, num_workers=4, pin_memory=True)

print(f"{len(x_train_indices)} samples in train")
print(f"{len(x_valid_indices)} samples in val")

Using device: cuda
Loaded genotypes info: (2404, 99314)
Loaded extra info: (2404, 26)
Unique genotypes in dataset: ['0|0', '0|1', '1|0', '1|1']
hap_map:
  0|0 -> 0
  0|1 -> 1
  1|0 -> 2
  1|1 -> 3
  .|. -> 4
self.seq_depth: 5
2340 samples in train
64 samples in val


In [30]:
model_name = 'hg19_chr22'
chunk_size = 8192
overlap    = 64
d_model    = 64
alleles    = dr.seq_depth

model = EvoFill(d_model, alleles, total_sites, chunk_size, overlap).to(device)
print(f"model[{model_name}] chunks={model.n_chunks}")

criterion = ImputationLoss(use_r2=True)

model[hg19_chr22] chunks=13


In [None]:

epochs_per_chunk   = 10
lr                 = 0.001
weight_decay       = 1e-5
earlystop_patience = 9

for cid in range(model.n_chunks):
    chunk_mask = model.chunk_masks[cid].cpu()

    chunk_maf, chunk_bin_cnt = precompute_maf(X_gt[:, torch.where(chunk_mask)[0]],  mask_int=dr.seq_depth)
    chunk_maf = torch.from_numpy(chunk_maf).to(device)
    maf_df = pd.DataFrame({
        'MAF_bin': ['(0.00, 0.05)', '(0.05, 0.10)', '(0.10, 0.20)',
                    '(0.20, 0.30)', '(0.30, 0.40)', '(0.40, 0.50)'],
        'Counts':  [f"{c}" for c in chunk_bin_cnt],
    })
    print(maf_df.to_string(index=False))

    optimizer = AdamW(list(model.chunk_embeds[cid].parameters()) +
                      list(model.chunk_modules[cid].parameters()) +
                      list(model.global_out.parameters()), 
                      lr=lr, weight_decay=weight_decay)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-7)
    best_loss = float('inf')
    patience = earlystop_patience
    patience_counter = 0
    
    for epoch in range(epochs_per_chunk):
        model.train()
        train_loss = 0.0
        train_logits, train_gts, train_mask = [], [], []

        train_pbar = tqdm(train_loader, desc=f'Chunk {cid}/{model.n_chunks}, Epoch {epoch + 1}/{epochs_per_chunk}', leave=False)
        for batch_idx, (x, x_extra, target) in enumerate(train_pbar):
            x, x_extra, target = x.to(device), x_extra.to(device), target.to(device)

            optimizer.zero_grad()
            pred = model(x, cid, x_extra)

            idx       = torch.where(chunk_mask)[0]          # (n_active,)
            pred_band = pred[:, idx]                  # (B, n_active, n_alleles-1)
            y_band    = target[:, idx]               # (B, n_active, n_alleles-1)

            loss, logs = criterion(pred_band, y_band)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_pbar.set_postfix({'loss': loss.item()})

            # === 收集训练结果 ===
            miss_mask = x[:, idx][..., -1].bool()         # 只关心被 mask 的位点
            train_logits.append(pred_band.detach())
            train_gts.append(y_band.detach())
            train_mask.append(miss_mask)

        # 训练集 MAF-acc
        train_logits = torch.cat(train_logits, dim=0)
        train_gts    = torch.cat(train_gts,    dim=0)
        train_mask   = torch.cat(train_mask,   dim=0)
        train_maf_accs = imputation_maf_accuracy_epoch(train_logits, train_gts, chunk_maf, mask=train_mask)

        # ----------- 验证循环同理 ------------
        model.eval()
        val_loss = 0.0
        val_logits, val_gts = [], []
        with torch.no_grad():
            for x, x_extra, target in val_loader:
                x, x_extra, target = x.to(device), x_extra.to(device), target.to(device)
                pred = model(x, cid, x_extra)

                idx       = torch.where(chunk_mask)[0]          # (n_active,)
                pred_band = pred[:, idx]                  # (B, n_active, n_alleles-1)
                y_band    = target[:, idx]               # (B, n_active, n_alleles-1)

                loss, logs = criterion(pred_band, y_band)

                val_loss += loss.item()

                val_logits.append(pred_band)
                val_gts.append(y_band)

        val_logits = torch.cat(val_logits, dim=0)
        val_gts    = torch.cat(val_gts,    dim=0)
        val_maf_accs = imputation_maf_accuracy_epoch(
            val_logits, val_gts, chunk_maf,  mask=None,) # 计算所有位点

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss   = val_loss   / len(val_loader)

        scheduler.step(avg_val_loss)

        print(f'Chunk {cid}/{model.n_chunks}, Epoch {epoch + 1}/{epochs_per_chunk}, '
                f'Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

        # 用 DataFrame 打印 MAF-bin 结果
        maf_df = pd.DataFrame({
            'MAF_bin': ['(0.00, 0.05)', '(0.05, 0.10)', '(0.10, 0.20)',
                        '(0.20, 0.30)', '(0.30, 0.40)', '(0.40, 0.50)'],
            'Counts':  [f"{c}" for c in chunk_bin_cnt],
            'Train':   [f"{acc:.2f}" for acc in train_maf_accs],
            'Val':     [f"{acc:.2f}" for acc in val_maf_accs]
        })
        print(maf_df.to_string(index=False))

        # Early stopping
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            # Save best model
            import os
            os.makedirs(f'{work_dir}/models', exist_ok=True)
            # torch.save(model.state_dict(), f'{work_dir}/models/weights.pth')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print('Early stopping triggered')
                break

    del optimizer, scheduler
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

     MAF_bin Counts
(0.00, 0.05)   3145
(0.05, 0.10)   1552
(0.10, 0.20)   1942
(0.20, 0.30)   1446
(0.30, 0.40)     85
(0.40, 0.50)     22


                                                                           

Chunk 0, Epoch 1/10, Train Loss: 69292.5245, Val Loss: 66277.8574
     MAF_bin Counts Train   Val
(0.00, 0.05)   3145 93.78 95.63
(0.05, 0.10)   1552 83.49 90.09
(0.10, 0.20)   1942 69.64 81.87
(0.20, 0.30)   1446 58.54 76.44
(0.30, 0.40)     85 43.98 69.36
(0.40, 0.50)     22 44.68 68.39


                                                                           

Chunk 0, Epoch 2/10, Train Loss: 49928.0296, Val Loss: 65222.2310
     MAF_bin Counts Train   Val
(0.00, 0.05)   3145 94.95 96.78
(0.05, 0.10)   1552 86.59 91.92
(0.10, 0.20)   1942 76.90 84.44
(0.20, 0.30)   1446 70.70 78.97
(0.30, 0.40)     85 51.11 71.67
(0.40, 0.50)     22 51.75 69.89


                                                                           

Chunk 0, Epoch 3/10, Train Loss: 46159.0682, Val Loss: 70372.1309
     MAF_bin Counts Train   Val
(0.00, 0.05)   3145 95.18 96.41
(0.05, 0.10)   1552 87.93 90.93
(0.10, 0.20)   1942 79.53 82.92
(0.20, 0.30)   1446 74.14 77.15
(0.30, 0.40)     85 53.12 70.17
(0.40, 0.50)     22 52.96 67.33


                                                                           

Chunk 0, Epoch 4/10, Train Loss: 44259.7210, Val Loss: 65322.7339
     MAF_bin Counts Train   Val
(0.00, 0.05)   3145 95.35 96.87
(0.05, 0.10)   1552 88.73 92.66
(0.10, 0.20)   1942 81.03 85.40
(0.20, 0.30)   1446 76.16 79.95
(0.30, 0.40)     85 54.22 72.89
(0.40, 0.50)     22 53.58 67.83


                                                                           

Chunk 0, Epoch 5/10, Train Loss: 42616.2359, Val Loss: 67731.4341
     MAF_bin Counts Train   Val
(0.00, 0.05)   3145 95.45 96.73
(0.05, 0.10)   1552 89.32 91.81
(0.10, 0.20)   1942 82.24 84.13
(0.20, 0.30)   1446 77.71 78.22
(0.30, 0.40)     85 55.34 72.72
(0.40, 0.50)     22 54.88 66.90


                                                                           

Chunk 0, Epoch 6/10, Train Loss: 41701.9468, Val Loss: 74443.4268
     MAF_bin Counts Train   Val
(0.00, 0.05)   3145 95.56 96.34
(0.05, 0.10)   1552 89.75 91.25
(0.10, 0.20)   1942 82.98 82.87
(0.20, 0.30)   1446 78.76 76.13
(0.30, 0.40)     85 56.15 69.52
(0.40, 0.50)     22 55.61 67.12


                                                                           

Chunk 0, Epoch 7/10, Train Loss: 40288.1210, Val Loss: 74371.6104
     MAF_bin Counts Train   Val
(0.00, 0.05)   3145 95.73 96.34
(0.05, 0.10)   1552 90.51 91.27
(0.10, 0.20)   1942 84.38 82.91
(0.20, 0.30)   1446 80.32 76.03
(0.30, 0.40)     85 57.38 69.65
(0.40, 0.50)     22 55.97 67.61


                                                                           

Chunk 0, Epoch 8/10, Train Loss: 39683.8185, Val Loss: 67533.2114
     MAF_bin Counts Train   Val
(0.00, 0.05)   3145 95.77 96.94
(0.05, 0.10)   1552 90.71 92.63
(0.10, 0.20)   1942 84.87 84.80
(0.20, 0.30)   1446 80.97 79.17
(0.30, 0.40)     85 57.71 72.17
(0.40, 0.50)     22 55.86 68.82


                                                                           

In [24]:
chunk_maf.shape

torch.Size([8192])