ver0: 多 chunk modules 独立权重

ver0.1: 加样本特征标签（演化坐标）

ver3: chunk-wise 稀疏激活

## Dependency

In [1]:
import os; os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 设置GPU
from cyvcf2 import VCF
import scipy.sparse as sp
import shutil
from typing import Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

from tqdm import tqdm

from sklearn.model_selection import train_test_split

from mamba_ssm import Mamba2
from mamba_ssm.modules.mamba2_simple import Mamba2Simple as Mamba2Block # 原Mamba2Block

## Data

In [2]:
class GenotypeEncoder:
    def __init__(self,
                 save_dir: str,
                 vcf_path: str,
                 ref_extra: Optional[str] = None,
                 phased: bool = True,
                 gts012: bool = False):
        self.vcf_path    = vcf_path
        self.ref_extra   = ref_extra
        self.phased      = phased if ref_extra is None else False # 是否把样本拆成单倍型
        self.gts012      = gts012
        # 其余成员先占位
        self.n_samples   = 0
        self.n_variants  = 0
        self.sample_ids  = []   # 后面读 VCF 时填充
        self.variant_ids = []

        self.X_gt        = None   # 最终返回的张量
        self.X_extra     = None   # extra 信息
        self.seq_depth   = None

        # 1) 读 VCF
        self.X_gt = self.load_gt()
        # 2) 读 extra
        self.X_extra = self.load_extra() if self.ref_extra else None

    def encode_gt(self, rec, n_samples, phase=False, gts012=True):
        """
        return:
            phase=False  -> (n_samples,)          剂量或基因型
            phase=True   -> (2*n_samples,)        单倍型

        encoding rule:
            gts012=True  -> 0/1/2/3  （3=missing）
            gts012=False -> 0/1/2/3/…/-1  （0=REF, 1+=ALT, -1=missing）
        """
        n = n_samples

        # ---------- 1. 单倍型模式 ----------
        if phase:
            out = np.empty(2 * n, dtype=np.int8)
            for i, gt in enumerate(rec.genotypes):
                a1, a2, _phased = gt
                # 缺失
                if a1 is None:
                    out[2*i]   = 3 if gts012 else -1
                else:
                    if gts012:                      # 压缩成 0/1/2
                        out[2*i] = 0 if a1 == 0 else (2 if a1 >= 2 else 1)
                    else:                           # 原值保留
                        out[2*i] = a1

                if a2 is None:
                    out[2*i+1] = 3 if gts012 else -1
                else:
                    if gts012:
                        out[2*i+1] = 0 if a2 == 0 else (2 if a2 >= 2 else 1)
                    else:
                        out[2*i+1] = a2
            return out

        # ---------- 2. 剂量模式 ----------
        else:
            out = np.empty(n, dtype=np.int8)
            for i, gt in enumerate(rec.genotypes):
                a1, a2, _phased = gt
                # 缺失
                if a1 is None or a2 is None:
                    out[i] = 3 if gts012 else -1
                else:
                    if gts012:
                        # 0/1/2 剂量
                        out[i] = (1 if a1 > 0 else 0) + (1 if a2 > 0 else 0)
                    else:
                        # 多等位剂量：把 ALT 编号直接相加
                        out[i] = (0 if a1 == 0 else a1) + (0 if a2 == 0 else a2)
            return out

    def load_extra(self) -> Optional[np.ndarray]:
        try:
            df = pd.read_csv(self.ref_extra, sep='\t', index_col=0)
            df = df.loc[self.sample_ids]          # 保证与 VCF 样本顺序一致
            print(f"[DATA] Extra dims: {df.shape}")
            return df.values.astype(np.float32)
        except Exception as e:
            print(f"[DATA] Extra features skipped: {e}")
            return None

    def load_gt(self):
        self.save_dir = "/mnt/qmtang/EvoFill/data/251023_chr22"
        interval = 10000

        cols, data, indptr = [], [], [0]

        vcf = VCF(self.vcf_path, gts012 = self.gts012)
        self.sample_ids = vcf.samples
        self.n_samples = len(self.sample_ids)
        self.n_variants = 0

        for rec in vcf:
            vec = self.encode_gt(rec, self.n_samples, phase=self.phased, gts012=self.gts012)
            nz_idx = np.flatnonzero(vec)
            cols.extend(nz_idx)
            data.extend(vec[nz_idx])
            indptr.append(indptr[-1] + len(nz_idx))

            self.n_variants += 1
            self.variant_ids.append(f"{rec.CHROM}:{rec.POS}_{rec.REF}/{','.join(rec.ALT)}")
            if self.n_variants % interval == 0:
                print(f'\r[DATA] 已编码 {self.n_variants:,} 个位点', end='', flush=True)

        print(f'\r[DATA] 总计 {self.n_variants:,} 个位点  ', flush=True)
        vcf.close()

        # 根据 phase_mode 决定行数
        n_rows = 2 * self.n_samples if self.phased else self.n_samples
        M = sp.csc_matrix((data, cols, indptr),
                        shape=(n_rows,self.n_variants),
                        dtype=np.int8)

        print(f'[DATA] 位点矩阵 = {M.shape}，稀疏度 = {M.nnz / (M.shape[0] * M.shape[1]):.2%}')
        if self.gts012:
            self.seq_depth = M.data.max()+1
        else:
            M.data[M.data == -1] = M.data.max() + 1
            self.seq_depth = M.data.max() + 1
        print(f'[DATA] gt alleles = [0 - {M.data.max()}], seq_depth = {self.seq_depth} ({M.data.max()} = 缺失)')

        os.makedirs(self.save_dir, exist_ok=True)          # 1. 不存在就创建

        # 2. 保存稀疏矩阵
        sp.save_npz(os.path.join(self.save_dir, "gt_matrix.npz"), M)

        # 3. 保存样本列表（顺序与矩阵行对应）
        with open(os.path.join(self.save_dir, "gt_samples.txt"), "w") as f:
            if self.phased:                      # 单倍型模式：写成 sample_A / sample_B
                for s in self.sample_ids:
                    f.write(f"{s}_A\n{s}_B\n")
            else:                               # 剂量模式
                for s in self.sample_ids:
                    f.write(f"{s}\n")

        # 4. 保存变异位点 ID（chr:pos/ref/alt）
        with open(os.path.join(self.save_dir, "gt_variants.txt"), "w") as f:
            for vid in self.variant_ids:
                f.write(vid + "\n")

        print(f"[DATA] 结果已写入 {self.save_dir}")
        return M

In [None]:
gt_enc = GenotypeEncoder(
    save_dir='/mnt/qmtang/EvoFill/data/251024_ver3_chr22_mini',
    vcf_path='/home/qmtang/GitHub/STICI-HPC/data/training_sets/ALL.chr22.training.samples.100k.any.type.0.01.maf.variants.vcf.gz',
    ref_extra='/mnt/qmtang/EvoFill/data/251020_ver01_chr22/pop_wasserstein.tsv',
    phased= True,
    gts012= False)

print(gt_enc.n_samples, gt_enc.n_variants)

[DATA] 总计 99,314 个位点  
[DATA] gt matrix = (2404, 99314)，稀疏度 = 28.10%
[DATA] gt alleles = [0 - 2], seq_depth = 4 (including missing)
[DATA] 结果已写入 /mnt/qmtang/EvoFill/data/251023_chr22
[DATA] Extra dims: (2404, 26)


<Compressed Sparse Column sparse matrix of dtype 'int8'
	with 67095358 stored elements and shape (2404, 99314)>

In [3]:
gt_enc = GenotypeEncoder(
    save_dir='/mnt/qmtang/EvoFill/data/251023_chr22',
    vcf_path='/mnt/NAS/Omics/DNA/1kGP/vcf/ALL.chr22.phase3_shapeit2_mvncall_integrated_v5b.20130502.genotypes.vcf.gz',
    ref_extra=None,
    phased= True,
    gts012= False)

print(gt_enc.n_samples, gt_enc.n_variants)

[DATA] 总计 1,103,547 个位点  
[DATA] 位点矩阵 = (5008, 1103547)，稀疏度 = 3.71%
[DATA] gt alleles = [0 - 8], seq_depth = 9 (8 = 缺失)
[DATA] 结果已写入 /mnt/qmtang/EvoFill/data/251023_chr22
2504 1103547


In [None]:
class GenomicDataset(Dataset):
    """Dataset class for genomic data with masking for training"""

    def __init__(self, x_gts, x_extra=None, seq_depth=4,
                 mask=True, masking_rates=(0.5, 0.99)):
        self.gts = x_gts
        self.x_extra = x_extra
        self.seq_depth = seq_depth
        self.mask = mask
        self.masking_rates = masking_rates

    def __len__(self):
        return len(self.gts)

    def __getitem__(self, idx):
        x       = self.gts[idx].copy()
        y       = self.gts[idx]
        if self.x_extra is not None:
            x_extra = self.x_extra[idx]
        else:
            x_extra = None

        if self.mask:
            # Apply masking
            seq_len = len(x)
            masking_rate = np.random.uniform(*self.masking_rates)
            mask_size = int(seq_len * masking_rate)
            mask_indices = np.random.choice(seq_len, mask_size, replace=False)
            x[mask_indices] = self.seq_depth - 1  # Missing value token

        # Convert to one-hot
        x_onehot = np.eye(self.seq_depth)[x]
        y_onehot = np.eye(self.seq_depth - 1)[y]

        return torch.FloatTensor(x_onehot),torch.FloatTensor(x_extra), torch.FloatTensor(y_onehot)

class ImputationDataset(Dataset):
    """Dataset for imputation (no masking needed)"""

    def __init__(self, data, seq_depth):
        self.data = data
        self.seq_depth = seq_depth

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        # Convert to one-hot without masking
        x_onehot = np.eye(self.seq_depth)[x]
        return torch.FloatTensor(x_onehot)

## Model

In [None]:
class GenoEmbedding(nn.Module):
    """Genomic embedding layer with positional encoding"""

    def __init__(self, n_alleles, n_snps, d_model):
        super().__init__()
        self.d_model = d_model
        self.n_alleles = n_alleles
        self.n_snps = n_snps

        # Allele embedding
        self.allele_embedding = nn.Parameter(torch.randn(n_alleles, d_model))

        # Positional embedding
        self.position_embedding = nn.Embedding(n_snps, d_model)

        # Initialize parameters
        nn.init.xavier_uniform_(self.allele_embedding)

    def forward(self, x):
        # x shape: (batch, seq_len, n_alleles) - one-hot encoded
        _, seq_len, _ = x.shape

        # Allele embedding
        embedded = torch.einsum('bsn,nd->bsd', x, self.allele_embedding)

        # Positional embedding
        positions = torch.arange(seq_len, device=x.device)
        pos_emb = self.position_embedding(positions).unsqueeze(0)

        return embedded + pos_emb

class BiMambaBlock(nn.Module):
    """Bidirectional Mamba block for genomic sequence processing"""

    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.d_model = d_model

        # Forward and backward Mamba blocks
        self.mamba_forward = Mamba2(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )

        self.mamba_backward = Mamba2(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )

        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model * 2, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.GELU()
        )

        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        # x shape: (batch, seq_len, d_model)
        residual = x

        # Bidirectional processing
        x_norm = self.norm1(x)

        # Forward direction
        forward_out = self.mamba_forward(x_norm)

        # Backward direction (flip sequence)
        x_backward = torch.flip(x_norm, dims=[1])
        backward_out = self.mamba_backward(x_backward)
        backward_out = torch.flip(backward_out, dims=[1])

        # Concatenate bidirectional outputs
        bi_out = torch.cat([forward_out, backward_out], dim=-1)

        # FFN
        ffn_out = self.ffn(bi_out)
        ffn_out = self.dropout(ffn_out)

        # Residual connection
        out = self.norm2(residual + ffn_out)

        return out

class ConvBlock(nn.Module):
    """Convolutional block for local pattern extraction"""

    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

        self.conv1 = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=5, padding=2)
        self.conv3 = nn.Conv1d(d_model, d_model, kernel_size=7, padding=3)

        self.conv_large1 = nn.Conv1d(d_model, d_model, kernel_size=7, padding=3)
        self.conv_large2 = nn.Conv1d(d_model, d_model, kernel_size=15, padding=7)

        self.conv_final = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.conv_reduce = nn.Conv1d(d_model, d_model, kernel_size=1)

        self.bn1 = nn.BatchNorm1d(d_model)
        self.bn2 = nn.BatchNorm1d(d_model)

        self.gelu = nn.GELU()

    def forward(self, x):
        # x shape: (batch, seq_len, d_model)
        x = x.transpose(1, 2)  # (batch, d_model, seq_len)

        xa = self.gelu(self.conv1(x))

        xb = self.gelu(self.conv2(xa))
        xb = self.gelu(self.conv3(xb))

        xc = self.gelu(self.conv_large1(xa))
        xc = self.gelu(self.conv_large2(xc))

        xa = xb + xc
        xa = self.gelu(self.conv_final(xa))
        xa = self.bn1(xa)
        xa = self.gelu(self.conv_reduce(xa))
        xa = self.bn2(xa)
        xa = self.gelu(xa)

        return xa.transpose(1, 2)  # (batch, seq_len, d_model)

class ExtraEmbedding(nn.Module):
    """
    输入:  (B, L)        L == extra_dim
    输出: (B, L, d_model)
    """
    def __init__(
        self,
        d_model: int,
        d_state: int = 64,
        d_conv: int  = 4,
        expand: int  = 2,
        headdim: int = 128,
        ngroups: int = 1,
        dropout: float = 0.1,
        **mamba_kwargs,
    ):
        super().__init__()
        self.d_model   = d_model

        # 1. 把 (B, L) 的 1-d 标量升到 d_model
        self.in_proj = nn.Linear(1, d_model, bias=False)

        # 2. 官方 Mamba2Simple：把 L 当序列长度，建模 L↔L
        self.mamba = Mamba2Block(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
            headdim=headdim,
            ngroups=ngroups,
            **mamba_kwargs
        )

        # 3. Norm
        self.norm = nn.LayerNorm(d_model)


    def forward(self, x: torch.Tensor):
        """
        x: (B, L)  连续值或离散索引
        """
        # (B, L) -> (B, L, 1) -> (B, L, d_model)
        h = self.in_proj(x.unsqueeze(-1).float())   # 1-d 投影

        h = self.norm(h)

        # Mamba2Simple 要求输入 (B, L, d_model) 即可
        out = self.mamba(h)                           # SSD 全局建模
        return out

class StackMambaBlock(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=64,
        d_conv=4,
        expand=2,
        headdim=128,
        ngroups=1,
        chunk_size=256,
        dropout=0.0,
        d_embed_dropout=0.0,
        device=None,
        dtype=None,
    ):
        super().__init__()
        self.d_model = d_model

        # 距离矩阵嵌入
        self.extra_embed = ExtraEmbedding(d_model=d_model, dropout=d_embed_dropout)

        # 原归一化
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        # SSD 核心
        self.ssd = Mamba2Block(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
            headdim=headdim,
            ngroups=ngroups,
            chunk_size=chunk_size,
            use_mem_eff_path=True,
            device=device,
            dtype=dtype,
        )

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, d_model),
        )

    def forward(self, local_repr, global_repr, x_extra=None,
                start_offset=0, end_offset=0):
        """
        local_repr: (B, L, D)
        global_repr: (B, G, D)
        x_extra: 可选，(B,E) 
        """
        local_norm  = self.norm1(local_repr)
        global_norm = self.norm2(global_repr)

        # 1. 构造输入序列
        tokens = []
        if x_extra is not None:
            extra_token = self.extra_embed(x_extra)        # (B,E,D)
            tokens.append(extra_token)
        tokens.append(global_norm)
        tokens.append(local_norm)
        x = torch.cat(tokens, dim=1)               # [B, (E)+G+L, D]

        # 2. SSD 扫描
        x = self.ssd(x)                            # [B, (E)+G+L, D]

        # 3. 只取 local 部分
        local_len = local_norm.shape[1]
        x = x[:, -local_len:, :]                   # [B, L, D]

        # 4. pad 回原始长度
        if start_offset or end_offset:
            x = F.pad(x, (0, 0, start_offset, end_offset))

        # 5. 残差 + FFN
        x = x + local_norm
        x = self.norm3(x)
        x = self.ffn(x) + x
        return x

class ChunkModule(nn.Module):
    """Single chunk processing module with BiMamba"""

    def __init__(self, d_model, dropout_rate=0.2):
        super().__init__()
        self.d_model = d_model

        # BiMamba block
        self.bimamba_block = BiMambaBlock(d_model)

        # Convolutional blocks
        self.conv_block1 = ConvBlock(d_model)
        self.conv_block2 = ConvBlock(d_model)
        self.conv_block3 = ConvBlock(d_model)
        self.conv_block4 = ConvBlock(d_model)

        # Cross attention
        # self.cross_attention = CrossAttentionLayer(d_model, n_heads)
        self.cross_attention = StackMambaBlock(
            d_model=d_model,
            d_state=64,
            d_conv=4,
            expand=2,
            headdim=128,
            ngroups=1,
            chunk_size=256,
        )

        # Additional layers
        self.dense = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout_rate)
        self.gelu = nn.GELU()

    def forward(self, x, x_extra=None):
        # BiMamba processing
        xa0 = self.bimamba_block(x)

        # First conv block
        xa = self.conv_block1(xa0)
        xa_skip = self.conv_block2(xa)

        # Dense layer
        xa = self.gelu(self.dense(xa))
        xa = self.conv_block3(xa)

        # Cross attention
        xa = self.cross_attention(xa, xa0, x_extra)
        xa = self.dropout(xa)

        # Final conv block
        xa = self.conv_block4(xa)

        # Concatenate with skip connection
        xa = torch.cat([xa_skip, xa], dim=-1)

        return xa

class LongRangeModule(nn.Module):
    """
        total_sites   : 序列最大长度（决定 Embedding 词表大小）
        d_model   : 输入特征维
        chunk_size: 距离阈值
        cos_cutoff: 余弦相似度绝对值阈值
        d_emb     : embedding 维，默认 d_model//2
    """
    def __init__(self, total_sites, d_model, chunk_size=128, cos_cutoff=0.8, d_emb=None):
        super().__init__()
        self.total_sites = total_sites
        self.d_model = d_model
        self.chunk_size = chunk_size
        self.cos_cutoff = cos_cutoff
        self.d_emb = d_emb or (d_model // 2)

        # 两个稀疏梯度 的 Embedding
        self.emb_i = nn.Embedding(self.total_sites, self.d_emb, sparse=True)
        self.emb_j = nn.Embedding(self.total_sites, self.d_emb, sparse=True)


    def forward(self, x, mask):
        """
        x    : (B, L, d_model)
        mask : (L,)  0/1 或 True/False
        return: 同 shape 的 x_out
        """
        # 1. 有效位点
        idx = torch.where(mask == 1)[0]          # (N_valid,)
        N_valid = idx.size(0)
        if N_valid == 0:
            return x

        # 2. 距离矩阵 & 是否有 far j
        dist = torch.abs(idx[:, None] - idx[None, :])          # (N_valid, N_valid)
        far_mask = dist > self.chunk_size
        if far_mask.sum() == 0:           # 一个 far j 都没有直接返回
            return x 

        emb_i_w = self.emb_i(idx)
        emb_j_w = self.emb_j(idx)
        cos_sim = torch.abs(F.cosine_similarity(emb_i_w.unsqueeze(1),
                                                emb_j_w.unsqueeze(0), dim=-1))

        # 4. 过滤
        valid_j_mask = far_mask & (cos_sim > self.cos_cutoff)

        # 5. 加权更新
        x_out = x.clone()
        for row, i_global in enumerate(idx):
            j_local_mask = valid_j_mask[row]
            num_j = j_local_mask.sum()
            if num_j == 0:
                continue
            j_local = torch.where(j_local_mask)[0]
            j_global = idx[j_local]
            weights = cos_sim[row, j_local] / num_j
            xj_weighted = (x[:, j_global] * weights.view(1, -1, 1)).sum(dim=1)
            x_out[:, i_global] = (x[:, i_global] + xj_weighted) / 2
        return x_out

class GlobalOut(nn.Module):
    def __init__(self,
                 d_model: int,
                 n_alleles: int,
                 total_sites: int,
                 chunk_size: int,
                 cos_cutoff: float = 0.8):
        super().__init__()
        self.proj_in = nn.Linear(2 * d_model, d_model//2)
        self.long_range = LongRangeModule(
                total_sites=total_sites,
                d_model=d_model//2,
                chunk_size=chunk_size,
                cos_cutoff=cos_cutoff
            )
        self.proj_out = nn.Linear(d_model//2, n_alleles - 1)


    def forward(self, x, mask):
        x = self.proj_in(x)                       # (B, L, d_model)
        x = self.long_range(x, mask)              # + 稀疏长程信号
        x = self.proj_out(x)                      # (B, L, n_alleles-1)
        x = torch.where(mask.unsqueeze(0).unsqueeze(-1).bool(),
                        x, torch.tensor(-float('inf'), device=x.device))
        return F.softmax(x, dim=-1)

class EvoFill(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_alleles: int,
        total_sites: int,
        chunk_size: int = 8192,
        chunk_overlap: int = 64,
        dropout_rate: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_alleles = n_alleles
        self.total_sites = total_sites
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

        # 1. chunk 边界
        stride = chunk_size - chunk_overlap
        starts = [i * stride for i in range((total_sites - 1) // stride + 1)]
        ends = [min(s + chunk_size, total_sites) for s in starts]
        self.register_buffer("starts", torch.tensor(starts, dtype=torch.long))
        self.register_buffer("ends", torch.tensor(ends, dtype=torch.long))
        self.n_chunks = len(starts)

        # 2. 每 chunk 一份嵌入 & 处理模块（常驻 GPU，但训练时只激活一个）
        self.chunk_embeds = nn.ModuleList(
            GenoEmbedding(n_alleles, e - s, d_model) for s, e in zip(starts, ends)
        )
        self.chunk_modules = nn.ModuleList(
            ChunkModule(d_model, dropout_rate) for s, e in zip(starts, ends)
        )

        # 3. 全局输出层（始终 GPU）
        self.global_out = GlobalOut(d_model, n_alleles, total_sites, chunk_size)

        # 4. chunk 掩码表  (n_chunks, L)
        masks = torch.stack(
            [torch.arange(total_sites).ge(s) & torch.arange(total_sites).lt(e)
             for s, e in zip(starts, ends)]
        ).float()
        self.register_buffer("chunk_masks", masks)

    def forward(self, x: torch.Tensor, chunk_id: int,
                x_extra: Optional[torch.Tensor] = None):
        """
        x:       (B, len, n_alleles)  对应chunk部分的序列 one-hot
        chunk_id: 0..n_chunks-1
        x_extra:  (B, extra_dim) or None
        return:   (B, len, n_alleles-1)  对应chunk部分的softmax 概率
        """
        B, _ , _ = x.shape
        device = x.device
        s, e = self.starts[chunk_id].item(), self.ends[chunk_id].item()
        mask = self.chunk_masks[chunk_id]          # (L,)  当前 chunk 覆盖区

        # 1. 确保输入 x 的形状与对应 chunk 吻合
        assert x.shape[1] == e-s

        # 2. 当前 chunk 嵌入 & 处理
        z = self.chunk_embeds[chunk_id](x)   # (B, len, d_model)
        z = self.chunk_modules[chunk_id](z, x_extra)  # (B, len, 2*d_model)

        # 3. 拼回全长度，其余 nan
        z_full = torch.full((B, self.total_sites, 2 * self.d_model), float('nan'), device=device)
        z_full[:, s:e] = z                        # (B, L, 2*d_model)

        # 4. 全局卷积只激活带状区
        out = self.global_out(z_full, mask)       # (B, L, n_alleles-1)
        
        # 5. 只返回对应chunk的logits
        return out[:, torch.where(mask)[0]]    # (B, len, n_alleles-1)
    
    @torch.no_grad()
    def infer(self, x: torch.Tensor,
              x_extra: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        推理接口：遍历全部 chunk，重叠区特征取平均，再统一过全局卷积
        x:       (B, L, n_alleles)
        x_extra: (B, extra_dim)
        return:  (B, L, n_alleles-1)  softmax 概率
        """
        B, L, _ = x.shape
        device = x.device
        d_out = 2 * self.d_model          # z 通道数

        # 累加器
        z_sum = torch.zeros(B, d_out, L, device=device)   # (B, 2*d_model, L)
        z_cnt = torch.zeros(1, 1, L, device=device)        # (1, 1, L)

        # 1. 遍历所有 chunk，拼回全长度并累加
        for cid in range(self.n_chunks):
            mask = self.chunk_masks[cid]          # (L,)  0/1
            idx  = torch.where(mask)[0]           # 当前 chunk 位点
            s, e = self.starts[cid].item(), self.ends[cid].item()

            # 与 forward 完全相同：chunk -> z
            x_slice = x[:, s:e]
            z = self.chunk_embeds[cid](x_slice)
            z = self.chunk_modules[cid](z, x_extra)        # (B, len, 2*d_model)
            z = z.transpose(1, 2)                          # (B, 2*d_model, len)

            # 写回全长度 & 累加
            z_sum[..., idx] += z[..., idx - s]             # 局部→全局对齐
            z_cnt[..., idx] += 1

        # 2. 重叠区平均
        z_full = z_sum / z_cnt.clamp_min(1.0)              # (B, 2*d_model, L)

        # 3. 统一过全局卷积一次
        #    global_out 需要 mask：全 1 即可（所有位点都有效）
        full_mask = torch.ones(L, dtype=torch.float, device=device)
        return self.global_out(z_full, full_mask)          # (B, L, n_alleles-1)


In [None]:
# ---------- 假数据 ----------
B, L, A = 4, 12345, 3
d_model = 64
chunk_size, overlap = 4096, 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'

x_train = torch.zeros(B, L, A, device=device)
allele = torch.randint(0, A, (B, L), device=device)
x_train.scatter_(2, allele.unsqueeze(-1), 1)
x_extra = torch.randn(B, 10, device=device)
y_train = torch.randn(B, L, A-1, device=device)

print(f"x_train: {x_train.shape}")
print(f"x_extra: {x_extra.shape}")
print(f"y_train: {y_train.shape}")
print("")
# ---------- 模型 &损失 ----------
model = EvoFill(d_model, A, L, chunk_size, overlap).to(device)
print(f"model chunks: {model.n_chunks}")
criterion = nn.MSELoss(reduction='mean')

# ---------- 训练循环 ----------
epochs_per_chunk = 3
for cid in range(model.n_chunks):
    mask = model.chunk_masks[cid]
    idx = torch.where(mask)[0]
    x_band = x_train[:,idx]
    y_band = y_train[:,idx]
    print(f"x_band: {x_band.shape}")
    print(f"x_extra: {x_extra.shape}")
    print(f"y_band: {y_band.shape}")
    opt = torch.optim.AdamW(
        list(model.chunk_embeds[cid].parameters()) +
        list(model.chunk_modules[cid].parameters()) +
        list(model.global_out.parameters()), lr=3e-4)
    for epoch in range(epochs_per_chunk):
        opt.zero_grad()
        pred = model(x_band, cid, x_extra)

        loss = criterion(pred, y_band)        # 默认 reduction='mean'
        loss.backward()
        opt.step()
        print("")

        print(f"pred: {pred.shape}")
        print(f'chunk {cid}/{model.n_chunks-1} | epoch {epoch+1}/{epochs_per_chunk} | loss {loss.item():.4f}')

x_train: torch.Size([4, 12345, 3])
x_extra: torch.Size([4, 10])
y_train: torch.Size([4, 12345, 2])

model chunks: 4


x_train: torch.Size([4, 12345, 3])
x_extra: torch.Size([4, 10])
pred: torch.Size([4, 12345, 2])
pred_band: torch.Size([4, 4096, 2])
y_band: torch.Size([4, 4096, 2])
chunk 0/3 | epoch 1/3 | loss 1.3536

x_train: torch.Size([4, 12345, 3])
x_extra: torch.Size([4, 10])
pred: torch.Size([4, 12345, 2])
pred_band: torch.Size([4, 4096, 2])
y_band: torch.Size([4, 4096, 2])
chunk 0/3 | epoch 2/3 | loss 1.2774

x_train: torch.Size([4, 12345, 3])
x_extra: torch.Size([4, 10])
pred: torch.Size([4, 12345, 2])
pred_band: torch.Size([4, 4096, 2])
y_band: torch.Size([4, 4096, 2])
chunk 0/3 | epoch 3/3 | loss 1.2199

x_train: torch.Size([4, 12345, 3])
x_extra: torch.Size([4, 10])
pred: torch.Size([4, 12345, 2])
pred_band: torch.Size([4, 4096, 2])
y_band: torch.Size([4, 4096, 2])
chunk 1/3 | epoch 1/3 | loss 1.3241

x_train: torch.Size([4, 12345, 3])
x_extra: torch.Size([4, 10])
pred: torc

## Loss

In [None]:
class ImputationLoss(nn.Module):
    """Custom loss function for genomic imputation"""

    def __init__(self, use_r2=True):
        super().__init__()
        self.use_r2_loss = use_r2
        self.ce_loss = nn.CrossEntropyLoss(reduction='sum')
        self.kl_loss = nn.KLDivLoss(reduction='sum')

    def calculate_minimac_r2(self, pred_alt_allele_probs, gt_alt_af):
        """Calculate Minimac-style RÂ² metric"""
        mask = torch.logical_or(torch.eq(gt_alt_af, 0.0), torch.eq(gt_alt_af, 1.0))
        gt_alt_af = torch.where(mask, 0.5, gt_alt_af)
        denom = gt_alt_af * (1.0 - gt_alt_af)
        denom = torch.where(denom < 0.01, 0.01, denom)
        r2 = torch.mean(torch.square(pred_alt_allele_probs - gt_alt_af), dim=0) / denom
        r2 = torch.where(mask, torch.zeros_like(r2), r2)
        return r2

    def forward(self, y_pred, y_true):
        y_true = y_true.float()

        # Convert to proper format for losses
        y_true_ce = torch.argmax(y_true, dim=-1)  # For CrossEntropy
        y_pred_log = torch.log(y_pred + 1e-8)  # For KL divergence

        # Basic losses
        ce_loss = self.ce_loss(y_pred.view(-1, y_pred.size(-1)), y_true_ce.view(-1))
        kl_loss = self.kl_loss(y_pred_log.view(-1, y_pred.size(-1)),
                               y_true.view(-1, y_true.size(-1)))

        total_loss = ce_loss + kl_loss

        if self.use_r2_loss:
            batch_size = y_true.size(0)
            group_size = 4
            num_full_groups = batch_size // group_size

            if num_full_groups > 0:
                y_true_grouped = y_true[:num_full_groups * group_size].view(
                    num_full_groups, group_size, *y_true.shape[1:])
                y_pred_grouped = y_pred[:num_full_groups * group_size].view(
                    num_full_groups, group_size, *y_pred.shape[1:])

                r2_loss = 0.0
                for i in range(num_full_groups):
                    gt_alt_af = torch.count_nonzero(
                        torch.argmax(y_true_grouped[i], dim=-1), dim=0
                    ).float() / group_size

                    pred_alt_allele_probs = torch.sum(y_pred_grouped[i][:, :, 1:], dim=-1)
                    r2_loss += -torch.sum(self.calculate_minimac_r2(
                        pred_alt_allele_probs, gt_alt_af)) * group_size

                total_loss += r2_loss

        return total_loss, None

## Train

工具函数

In [None]:
def create_directories(save_dir, models_dir="models", outputs="out") -> None:
    """Create necessary directories"""
    for dd in [save_dir, f"{save_dir}/{models_dir}", f"{save_dir}/{outputs}"]:
        if not os.path.exists(dd):
            os.makedirs(dd)

def clear_dir(path) -> None:
    """Clear directory if it exists"""
    if os.path.exists(path):
        shutil.rmtree(path)

MAF_BINS = [(0.00, 0.05), (0.05, 0.10), (0.10, 0.20),
            (0.20, 0.30), (0.30, 0.40), (0.40, 0.50)]

def precompute_maf(gts_np, mask_int=-1):
    """
    gts_np: (N, L)  int64
    return:
        maf: (L,) float32
        bin_cnt: list[int] 长度 6，对应 6 个 bin 的位点数量
    """
    L = gts_np.shape[1]
    maf = np.zeros(L, dtype=np.float32)
    bin_cnt = [0] * 6

    for l in range(L):
        alleles = gts_np[:, l]
        alleles = alleles[alleles != mask_int]   # 去掉缺失
        if alleles.size == 0:
            maf[l] = 0.0
            continue

        uniq, cnt = np.unique(alleles, return_counts=True)
        total = cnt.sum()
        freq = cnt / total
        freq[::-1].sort()
        maf_val = freq[1] if len(freq) > 1 else 0.0
        maf[l] = maf_val

        # 统计 bin
        for i, (lo, hi) in enumerate(MAF_BINS):
            if lo <= maf_val < hi:
                bin_cnt[i] += 1
                break

    return maf, bin_cnt

def imputation_maf_accuracy_epoch(all_logits, all_gts, global_maf, mask=None):
    """
    all_logits: (N, L, C)
    all_gts:    (N, L, C) one-hot
    global_maf: (L,)
    mask:       (N, L) 或 None
    return:     list[float] 长度 6
    """
    # 1. 预测 vs 真实
    all_gts = all_gts.argmax(dim=-1)      # (N, L)
    preds   = all_logits.argmax(dim=-1)   # (N, L)

    # 2. 如果没有外部 mask，就默认全 1
    if mask is None:
        mask = torch.ones_like(all_gts, dtype=torch.bool)   # (N, L)
    correct = (preds == all_gts) & mask                   # (N, L)

    # 3. MAF 条件 -> (1, L) 再广播到 (N, L)
    maf = global_maf.unsqueeze(0)                         # (1, L)

    # 4. 分 bin 计算
    accs = []
    for lo, hi in MAF_BINS:
        maf_bin = mask & (maf >= lo) & (maf < hi)                # (1, L)
        n_cor = (correct & maf_bin).sum()
        n_tot = maf_bin.sum()
        accs.append(100*(n_cor / n_tot).item() if n_tot > 0 else 0.0)
    return accs

In [None]:
work_dir      = '/home/qmtang/mnt_qmtang/EvoFill/data/251023_ver3_chr22'
val_n_samples = 128
max_mr             = 0.7
min_mr             = 0.3
batch_size_per_gpu = 8

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create directories
create_directories(work_dir)

gt_enc = GenotypeEncoder(
    save_dir='/mnt/qmtang/EvoFill/data/251024_ver3_chr22_mini',
    vcf_path='/home/qmtang/GitHub/STICI-HPC/data/training_sets/ALL.chr22.training.samples.100k.any.type.0.01.maf.variants.vcf.gz',
    ref_extra='/mnt/qmtang/EvoFill/data/251020_ver01_chr22/pop_wasserstein.tsv',
    phased= True,
    gts012= False)

print(gt_enc.n_samples, 'samples,', gt_enc.n_variants, 'variants,', gt_enc.seq_depth, 'alleles.')

x_train_indices, x_valid_indices = train_test_split(
    range(gt_enc.n_samples),
    test_size=val_n_samples,
    random_state=3047,
    shuffle=True
)
print(f"{len(x_train_indices)} samples in train")
print(f"{len(x_valid_indices)} samples in val")

In [None]:
model_name  = 'hg19_chr22'
total_sites = gt_enc.n_variants
alleles     = gt_enc.seq_depth
chunk_size  = 65536
overlap     = 4096
d_model     = 128

model = EvoFill(d_model, alleles, total_sites, chunk_size, overlap).to(device)
print(f"model[{model_name}] chunks={model.n_chunks}")

criterion = ImputationLoss(use_r2=True)

model[hg19_chr22] chunks=7


In [None]:
max_epochs_per_chunk = 100
lr                 = 0.001
weight_decay       = 1e-5
earlystop_patience = 9
verbose            = True

for cid in range(model.n_chunks):
    chunk_mask = model.chunk_masks[cid].cpu()

    train_dataset = GenomicDataset(
        train_gt, train_extra, dr.seq_depth,
        mask=True, masking_rates=(min_mr, max_mr)
    )
    val_dataset = GenomicDataset(
        val_gt, val_extra, dr.seq_depth,
        mask=True, masking_rates=(min_mr, max_mr)
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size_per_gpu,
                        shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size_per_gpu,
                        shuffle=False, num_workers=4, pin_memory=True)

    chunk_maf, chunk_bin_cnt = precompute_maf(X_gt[:, torch.where(chunk_mask)[0]],  mask_int=dr.seq_depth)
    chunk_maf = torch.from_numpy(chunk_maf).to(device)

    if verbose:
        print(f"=== Chunk {cid + 1} STAT ===")
        maf_df = pd.DataFrame({
            'MAF_bin': ['(0.00, 0.05)', '(0.05, 0.10)', '(0.10, 0.20)',
                        '(0.20, 0.30)', '(0.30, 0.40)', '(0.40, 0.50)'],
            'Counts':  [f"{c}" for c in chunk_bin_cnt],
        })
        print(maf_df.to_string(index=False))

    optimizer = AdamW(list(model.chunk_embeds[cid].parameters()) +
                      list(model.chunk_modules[cid].parameters()) +
                      list(model.global_out.parameters()), 
                      lr=lr, weight_decay=weight_decay)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-7)
    best_loss = float('inf')
    patience = earlystop_patience
    patience_counter = 0
    
    for epoch in range(max_epochs_per_chunk):
        model.train()
        train_loss = 0.0
        train_logits, train_gts, train_mask = [], [], []

        train_pbar = tqdm(train_loader, desc=f'Chunk {cid + 1}/{model.n_chunks}, Epoch {epoch + 1}/{max_epochs_per_chunk}', leave=False)
        for batch_idx, (x, x_extra, target) in enumerate(train_pbar):
            x, x_extra, target = x.to(device), x_extra.to(device), target.to(device)

            optimizer.zero_grad()
            pred = model(x, cid, x_extra)
            # pred = model(x, cid, None)

            idx       = torch.where(chunk_mask)[0]          # (n_active,)
            pred_band = pred[:, idx]                  # (B, n_active, n_alleles-1)
            y_band    = target[:, idx]               # (B, n_active, n_alleles-1)

            loss, logs = criterion(pred_band, y_band)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_pbar.set_postfix({'loss': loss.item()})

            # === 收集训练结果 ===
            miss_mask = x[:, idx][..., -1].bool()         # 只关心被 mask 的位点
            train_logits.append(pred_band.detach())
            train_gts.append(y_band.detach())
            train_mask.append(miss_mask)

        # 训练集 MAF-acc
        train_logits = torch.cat(train_logits, dim=0)
        train_gts    = torch.cat(train_gts,    dim=0)
        train_mask   = torch.cat(train_mask,   dim=0)
        train_maf_accs = imputation_maf_accuracy_epoch(train_logits, train_gts, chunk_maf, mask=train_mask)

        # ----------- 验证循环同理 ------------
        model.eval()
        val_loss = 0.0
        val_logits, val_gts = [], []
        with torch.no_grad():
            for x, x_extra, target in val_loader:
                x, x_extra, target = x.to(device), x_extra.to(device), target.to(device)
                pred = model(x, cid, x_extra)
                # pred = model(x, cid, None)

                idx       = torch.where(chunk_mask)[0]          # (n_active,)
                pred_band = pred[:, idx]                  # (B, n_active, n_alleles-1)
                y_band    = target[:, idx]               # (B, n_active, n_alleles-1)

                loss, logs = criterion(pred_band, y_band)

                val_loss += loss.item()

                val_logits.append(pred_band)
                val_gts.append(y_band)

        val_logits = torch.cat(val_logits, dim=0)
        val_gts    = torch.cat(val_gts,    dim=0)
        val_maf_accs = imputation_maf_accuracy_epoch(
            val_logits, val_gts, chunk_maf,  mask=None,) # 计算所有位点

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss   = val_loss   / len(val_loader)

        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]['lr']

        print(f'Chunk {cid + 1}/{model.n_chunks}, '
            f'Epoch {epoch + 1}/{max_epochs_per_chunk}, '
            f'Train Loss: {avg_train_loss:.1f}, '
            f'Val Loss: {avg_val_loss:.1f}, '
            f'LR: {current_lr:.2e}')

        # 用 DataFrame 打印 MAF-bin 结果
        if verbose:
            maf_df = pd.DataFrame({
                'MAF_bin': ['(0.00, 0.05)', '(0.05, 0.10)', '(0.10, 0.20)',
                            '(0.20, 0.30)', '(0.30, 0.40)', '(0.40, 0.50)'],
                'Counts':  [f"{c}" for c in chunk_bin_cnt],
                'Train':   [f"{acc:.2f}" for acc in train_maf_accs],
                'Val':     [f"{acc:.2f}" for acc in val_maf_accs]
            })
            print(maf_df.to_string(index=False))

        # Early stopping
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            # 只存当前 chunk 专家 + 全局层
            ckpt = {
                'chunk_id': cid,
                'chunk_embed_state': model.chunk_embeds[cid].state_dict(),
                'chunk_module_state': model.chunk_modules[cid].state_dict(),
                'global_out_state': model.global_out.state_dict(),
                'best_val_loss': best_loss,
            }
            torch.save(ckpt, f'{work_dir}/models/{model_name}_chunk{cid}.pth')
            print(f'  --> updated {model_name}_chunk{cid}.pth')
        else:
            patience_counter += 1
            if patience_counter >= earlystop_patience:
                print('Early stopping triggered')
                break

    del optimizer, scheduler
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

# ---------------- 全部 chunk 训练完成 -> 保存最终完整模型 ----------------
final_ckpt = {
    'model_state': model.state_dict(),
    'n_chunks': model.n_chunks,
    'chunk_size': model.chunk_size,
    'chunk_overlap': model.chunk_overlap,
}
torch.save(final_ckpt, f'{work_dir}/models/{model_name}_final.pth')
print(f'==> Final model saved to {work_dir}/models/{model_name}_final.pth')

=== Chunk 1 STAT ===
     MAF_bin Counts
(0.00, 0.05)   4427
(0.05, 0.10)   2598
(0.10, 0.20)   3011
(0.20, 0.30)   2460
(0.30, 0.40)   3132
(0.40, 0.50)    756


                                                                                       

Chunk 1/7, Epoch 1/100, Train Loss: 67618.2556, Val Loss: 53165.0034
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 96.02 98.28
(0.05, 0.10)   2598 92.08 96.84
(0.10, 0.20)   3011 86.17 94.59
(0.20, 0.30)   2460 78.25 91.95
(0.30, 0.40)   3132 76.56 91.51
(0.40, 0.50)    756 71.95 89.33
  --> updated hg19_chr22_chunk0.pth


                                                                                       

Chunk 1/7, Epoch 2/100, Train Loss: 49651.8990, Val Loss: 48643.8154
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 96.78 98.32
(0.05, 0.10)   2598 93.99 97.11
(0.10, 0.20)   3011 90.27 95.45
(0.20, 0.30)   2460 85.51 93.09
(0.30, 0.40)   3132 84.92 92.67
(0.40, 0.50)    756 80.80 90.55
  --> updated hg19_chr22_chunk0.pth


                                                                                       

Chunk 1/7, Epoch 3/100, Train Loss: 45129.7220, Val Loss: 44591.2644
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 96.99 98.45
(0.05, 0.10)   2598 94.67 97.41
(0.10, 0.20)   3011 91.76 95.94
(0.20, 0.30)   2460 87.76 93.92
(0.30, 0.40)   3132 87.36 93.88
(0.40, 0.50)    756 83.60 91.70
  --> updated hg19_chr22_chunk0.pth


                                                                                       

Chunk 1/7, Epoch 4/100, Train Loss: 42414.9077, Val Loss: 41514.3794
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.14 98.48
(0.05, 0.10)   2598 95.16 97.58
(0.10, 0.20)   3011 92.63 96.26
(0.20, 0.30)   2460 89.01 94.49
(0.30, 0.40)   3132 88.70 94.38
(0.40, 0.50)    756 85.15 92.80
  --> updated hg19_chr22_chunk0.pth


                                                                                       

Chunk 1/7, Epoch 5/100, Train Loss: 40632.2416, Val Loss: 40514.8638
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.26 98.58
(0.05, 0.10)   2598 95.43 97.80
(0.10, 0.20)   3011 93.17 96.72
(0.20, 0.30)   2460 89.84 94.84
(0.30, 0.40)   3132 89.54 94.53
(0.40, 0.50)    756 86.22 92.88
  --> updated hg19_chr22_chunk0.pth


                                                                                       

Chunk 1/7, Epoch 6/100, Train Loss: 38844.3344, Val Loss: 41107.4380
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.38 98.42
(0.05, 0.10)   2598 95.70 97.55
(0.10, 0.20)   3011 93.64 96.46
(0.20, 0.30)   2460 90.52 94.58
(0.30, 0.40)   3132 90.30 94.57
(0.40, 0.50)    756 87.09 92.99


                                                                                       

Chunk 1/7, Epoch 7/100, Train Loss: 37977.2959, Val Loss: 36902.1831
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.41 98.69
(0.05, 0.10)   2598 95.81 97.96
(0.10, 0.20)   3011 93.93 97.04
(0.20, 0.30)   2460 90.98 95.62
(0.30, 0.40)   3132 90.72 95.48
(0.40, 0.50)    756 87.73 94.08
  --> updated hg19_chr22_chunk0.pth


                                                                                       

Chunk 1/7, Epoch 8/100, Train Loss: 37059.0523, Val Loss: 36709.6687
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.49 98.59
(0.05, 0.10)   2598 96.00 97.86
(0.10, 0.20)   3011 94.23 97.08
(0.20, 0.30)   2460 91.34 95.65
(0.30, 0.40)   3132 91.16 95.63
(0.40, 0.50)    756 88.31 94.30
  --> updated hg19_chr22_chunk0.pth


                                                                                       

Chunk 1/7, Epoch 9/100, Train Loss: 36451.4998, Val Loss: 36397.6926
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.53 98.61
(0.05, 0.10)   2598 96.12 97.85
(0.10, 0.20)   3011 94.46 97.09
(0.20, 0.30)   2460 91.68 95.66
(0.30, 0.40)   3132 91.42 95.60
(0.40, 0.50)    756 88.60 94.08
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 10/100, Train Loss: 35908.5264, Val Loss: 37215.7029
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.57 98.73
(0.05, 0.10)   2598 96.25 98.11
(0.10, 0.20)   3011 94.63 97.29
(0.20, 0.30)   2460 91.90 95.51
(0.30, 0.40)   3132 91.67 95.18
(0.40, 0.50)    756 88.87 93.62


                                                                                        

Chunk 1/7, Epoch 11/100, Train Loss: 35058.2274, Val Loss: 36138.2747
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.63 98.70
(0.05, 0.10)   2598 96.37 98.02
(0.10, 0.20)   3011 94.85 97.19
(0.20, 0.30)   2460 92.16 95.73
(0.30, 0.40)   3132 91.92 95.61
(0.40, 0.50)    756 89.22 94.16
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 12/100, Train Loss: 34097.8942, Val Loss: 35301.5845
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.70 98.74
(0.05, 0.10)   2598 96.53 98.18
(0.10, 0.20)   3011 95.07 97.46
(0.20, 0.30)   2460 92.45 96.02
(0.30, 0.40)   3132 92.24 95.85
(0.40, 0.50)    756 89.55 94.39
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 13/100, Train Loss: 34020.8935, Val Loss: 36738.9705
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.72 98.53
(0.05, 0.10)   2598 96.58 97.78
(0.10, 0.20)   3011 95.15 96.85
(0.20, 0.30)   2460 92.61 95.59
(0.30, 0.40)   3132 92.37 95.49
(0.40, 0.50)    756 89.73 94.15


                                                                                        

Chunk 1/7, Epoch 14/100, Train Loss: 33391.5159, Val Loss: 33542.2427
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.75 98.79
(0.05, 0.10)   2598 96.67 98.20
(0.10, 0.20)   3011 95.33 97.54
(0.20, 0.30)   2460 92.80 96.25
(0.30, 0.40)   3132 92.57 96.10
(0.40, 0.50)    756 90.00 94.92
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 15/100, Train Loss: 33555.1070, Val Loss: 34762.1387
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.78 98.69
(0.05, 0.10)   2598 96.69 98.13
(0.10, 0.20)   3011 95.38 97.48
(0.20, 0.30)   2460 92.82 96.03
(0.30, 0.40)   3132 92.59 95.88
(0.40, 0.50)    756 89.99 94.44


                                                                                        

Chunk 1/7, Epoch 16/100, Train Loss: 32965.8106, Val Loss: 32834.5178
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.79 98.84
(0.05, 0.10)   2598 96.79 98.38
(0.10, 0.20)   3011 95.49 97.71
(0.20, 0.30)   2460 92.99 96.45
(0.30, 0.40)   3132 92.79 96.38
(0.40, 0.50)    756 90.20 95.10
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 17/100, Train Loss: 32673.8635, Val Loss: 33551.9792
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.83 98.79
(0.05, 0.10)   2598 96.82 98.36
(0.10, 0.20)   3011 95.59 97.72
(0.20, 0.30)   2460 93.09 96.34
(0.30, 0.40)   3132 92.87 96.17
(0.40, 0.50)    756 90.23 94.77


                                                                                        

Chunk 1/7, Epoch 18/100, Train Loss: 32535.2306, Val Loss: 33134.3237
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.86 98.70
(0.05, 0.10)   2598 96.89 98.13
(0.10, 0.20)   3011 95.61 97.50
(0.20, 0.30)   2460 93.16 96.35
(0.30, 0.40)   3132 92.93 96.20
(0.40, 0.50)    756 90.39 94.92


                                                                                        

Chunk 1/7, Epoch 19/100, Train Loss: 32537.5521, Val Loss: 31918.0554
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.86 98.91
(0.05, 0.10)   2598 96.90 98.43
(0.10, 0.20)   3011 95.65 97.86
(0.20, 0.30)   2460 93.24 96.54
(0.30, 0.40)   3132 93.00 96.44
(0.40, 0.50)    756 90.47 95.14
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 20/100, Train Loss: 32051.4818, Val Loss: 32012.1780
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.88 98.89
(0.05, 0.10)   2598 96.95 98.42
(0.10, 0.20)   3011 95.73 97.79
(0.20, 0.30)   2460 93.31 96.52
(0.30, 0.40)   3132 93.05 96.46
(0.40, 0.50)    756 90.56 95.18


                                                                                        

Chunk 1/7, Epoch 21/100, Train Loss: 31346.3132, Val Loss: 32103.2646
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.93 98.84
(0.05, 0.10)   2598 97.04 98.35
(0.10, 0.20)   3011 95.85 97.74
(0.20, 0.30)   2460 93.45 96.54
(0.30, 0.40)   3132 93.23 96.48
(0.40, 0.50)    756 90.74 95.08


                                                                                        

Chunk 1/7, Epoch 22/100, Train Loss: 31322.2982, Val Loss: 32519.8120
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.96 98.89
(0.05, 0.10)   2598 97.10 98.41
(0.10, 0.20)   3011 95.90 97.80
(0.20, 0.30)   2460 93.54 96.43
(0.30, 0.40)   3132 93.36 96.32
(0.40, 0.50)    756 90.88 94.86


                                                                                        

Chunk 1/7, Epoch 23/100, Train Loss: 31340.3900, Val Loss: 31747.6533
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.99 98.93
(0.05, 0.10)   2598 97.15 98.50
(0.10, 0.20)   3011 95.97 97.88
(0.20, 0.30)   2460 93.58 96.69
(0.30, 0.40)   3132 93.42 96.44
(0.40, 0.50)    756 90.94 95.10
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 24/100, Train Loss: 31075.3055, Val Loss: 30945.6965
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.00 98.94
(0.05, 0.10)   2598 97.16 98.53
(0.10, 0.20)   3011 96.00 97.89
(0.20, 0.30)   2460 93.65 96.72
(0.30, 0.40)   3132 93.47 96.69
(0.40, 0.50)    756 90.94 95.40
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 25/100, Train Loss: 31065.5087, Val Loss: 31794.3696
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 97.99 98.90
(0.05, 0.10)   2598 97.16 98.40
(0.10, 0.20)   3011 96.04 97.79
(0.20, 0.30)   2460 93.67 96.45
(0.30, 0.40)   3132 93.48 96.39
(0.40, 0.50)    756 90.99 95.12


                                                                                        

Chunk 1/7, Epoch 26/100, Train Loss: 30666.7435, Val Loss: 31248.1863
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.04 98.88
(0.05, 0.10)   2598 97.23 98.48
(0.10, 0.20)   3011 96.12 97.87
(0.20, 0.30)   2460 93.80 96.67
(0.30, 0.40)   3132 93.59 96.54
(0.40, 0.50)    756 91.15 95.41


                                                                                        

Chunk 1/7, Epoch 27/100, Train Loss: 30646.9043, Val Loss: 31202.0081
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.05 98.93
(0.05, 0.10)   2598 97.27 98.55
(0.10, 0.20)   3011 96.16 97.95
(0.20, 0.30)   2460 93.84 96.74
(0.30, 0.40)   3132 93.65 96.66
(0.40, 0.50)    756 91.15 95.15


                                                                                        

Chunk 1/7, Epoch 28/100, Train Loss: 30686.0120, Val Loss: 30916.3350
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.03 98.91
(0.05, 0.10)   2598 97.27 98.47
(0.10, 0.20)   3011 96.14 97.96
(0.20, 0.30)   2460 93.85 96.70
(0.30, 0.40)   3132 93.63 96.60
(0.40, 0.50)    756 91.17 95.40
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 29/100, Train Loss: 30170.8684, Val Loss: 31611.9001
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.09 98.90
(0.05, 0.10)   2598 97.33 98.42
(0.10, 0.20)   3011 96.21 97.84
(0.20, 0.30)   2460 93.92 96.57
(0.30, 0.40)   3132 93.73 96.43
(0.40, 0.50)    756 91.32 95.18


                                                                                        

Chunk 1/7, Epoch 30/100, Train Loss: 30295.2628, Val Loss: 29615.0149
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.12 99.04
(0.05, 0.10)   2598 97.38 98.67
(0.10, 0.20)   3011 96.25 98.15
(0.20, 0.30)   2460 93.97 97.01
(0.30, 0.40)   3132 93.78 96.94
(0.40, 0.50)    756 91.31 95.68
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 31/100, Train Loss: 29995.3658, Val Loss: 29912.1619
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.15 98.98
(0.05, 0.10)   2598 97.39 98.63
(0.10, 0.20)   3011 96.29 98.09
(0.20, 0.30)   2460 94.01 96.86
(0.30, 0.40)   3132 93.84 96.83
(0.40, 0.50)    756 91.43 95.58


                                                                                        

Chunk 1/7, Epoch 32/100, Train Loss: 29834.5134, Val Loss: 30606.4775
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.15 99.04
(0.05, 0.10)   2598 97.44 98.65
(0.10, 0.20)   3011 96.31 98.11
(0.20, 0.30)   2460 94.07 96.88
(0.30, 0.40)   3132 93.90 96.67
(0.40, 0.50)    756 91.48 95.38


                                                                                        

Chunk 1/7, Epoch 33/100, Train Loss: 29532.9486, Val Loss: 29903.0730
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.16 99.01
(0.05, 0.10)   2598 97.48 98.63
(0.10, 0.20)   3011 96.35 98.11
(0.20, 0.30)   2460 94.12 96.92
(0.30, 0.40)   3132 93.94 96.87
(0.40, 0.50)    756 91.51 95.56


                                                                                        

Chunk 1/7, Epoch 34/100, Train Loss: 29513.2648, Val Loss: 30918.5757
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.20 98.93
(0.05, 0.10)   2598 97.50 98.54
(0.10, 0.20)   3011 96.40 97.95
(0.20, 0.30)   2460 94.17 96.68
(0.30, 0.40)   3132 93.98 96.63
(0.40, 0.50)    756 91.59 95.28


                                                                                        

Chunk 1/7, Epoch 35/100, Train Loss: 28499.0482, Val Loss: 28951.9624
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.31 99.05
(0.05, 0.10)   2598 97.70 98.77
(0.10, 0.20)   3011 96.64 98.22
(0.20, 0.30)   2460 94.58 97.11
(0.30, 0.40)   3132 94.41 97.00
(0.40, 0.50)    756 92.05 95.80
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 36/100, Train Loss: 28546.1047, Val Loss: 28876.0408
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.33 99.09
(0.05, 0.10)   2598 97.71 98.78
(0.10, 0.20)   3011 96.69 98.24
(0.20, 0.30)   2460 94.58 97.14
(0.30, 0.40)   3132 94.43 97.03
(0.40, 0.50)    756 92.10 95.68
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 37/100, Train Loss: 28163.5161, Val Loss: 28918.3027
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.37 99.06
(0.05, 0.10)   2598 97.78 98.76
(0.10, 0.20)   3011 96.76 98.26
(0.20, 0.30)   2460 94.69 97.05
(0.30, 0.40)   3132 94.56 97.01
(0.40, 0.50)    756 92.28 95.70


                                                                                        

Chunk 1/7, Epoch 38/100, Train Loss: 28294.2665, Val Loss: 28281.3203
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.38 99.12
(0.05, 0.10)   2598 97.80 98.81
(0.10, 0.20)   3011 96.74 98.29
(0.20, 0.30)   2460 94.66 97.26
(0.30, 0.40)   3132 94.54 97.16
(0.40, 0.50)    756 92.26 95.99
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 39/100, Train Loss: 28223.2701, Val Loss: 27591.3950
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.38 99.19
(0.05, 0.10)   2598 97.81 98.86
(0.10, 0.20)   3011 96.75 98.39
(0.20, 0.30)   2460 94.71 97.31
(0.30, 0.40)   3132 94.58 97.33
(0.40, 0.50)    756 92.24 96.13
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 40/100, Train Loss: 27960.1902, Val Loss: 28568.0654
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.40 99.09
(0.05, 0.10)   2598 97.82 98.76
(0.10, 0.20)   3011 96.77 98.23
(0.20, 0.30)   2460 94.72 97.17
(0.30, 0.40)   3132 94.60 97.12
(0.40, 0.50)    756 92.32 95.94


                                                                                        

Chunk 1/7, Epoch 41/100, Train Loss: 28126.1007, Val Loss: 27822.9087
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.39 99.14
(0.05, 0.10)   2598 97.82 98.86
(0.10, 0.20)   3011 96.78 98.36
(0.20, 0.30)   2460 94.72 97.31
(0.30, 0.40)   3132 94.58 97.32
(0.40, 0.50)    756 92.31 96.14


                                                                                        

Chunk 1/7, Epoch 42/100, Train Loss: 27857.1578, Val Loss: 28358.2510
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.41 99.12
(0.05, 0.10)   2598 97.85 98.82
(0.10, 0.20)   3011 96.80 98.30
(0.20, 0.30)   2460 94.81 97.21
(0.30, 0.40)   3132 94.64 97.17
(0.40, 0.50)    756 92.37 95.90


                                                                                        

Chunk 1/7, Epoch 43/100, Train Loss: 27831.2923, Val Loss: 27793.9089
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.43 99.14
(0.05, 0.10)   2598 97.86 98.84
(0.10, 0.20)   3011 96.82 98.38
(0.20, 0.30)   2460 94.79 97.35
(0.30, 0.40)   3132 94.65 97.30
(0.40, 0.50)    756 92.37 96.17


                                                                                        

Chunk 1/7, Epoch 44/100, Train Loss: 27537.3535, Val Loss: 28134.6240
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.49 99.11
(0.05, 0.10)   2598 97.96 98.85
(0.10, 0.20)   3011 96.93 98.32
(0.20, 0.30)   2460 94.96 97.27
(0.30, 0.40)   3132 94.84 97.20
(0.40, 0.50)    756 92.61 95.85


                                                                                        

Chunk 1/7, Epoch 45/100, Train Loss: 27235.0241, Val Loss: 27088.9617
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.51 99.19
(0.05, 0.10)   2598 98.01 98.95
(0.10, 0.20)   3011 97.01 98.45
(0.20, 0.30)   2460 95.05 97.53
(0.30, 0.40)   3132 94.93 97.43
(0.40, 0.50)    756 92.68 96.29
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 46/100, Train Loss: 27246.8969, Val Loss: 28086.5127
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.52 99.13
(0.05, 0.10)   2598 98.00 98.83
(0.10, 0.20)   3011 97.01 98.36
(0.20, 0.30)   2460 95.05 97.24
(0.30, 0.40)   3132 94.93 97.23
(0.40, 0.50)    756 92.69 96.00


                                                                                        

Chunk 1/7, Epoch 47/100, Train Loss: 27382.0169, Val Loss: 27174.5803
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.53 99.21
(0.05, 0.10)   2598 98.02 98.94
(0.10, 0.20)   3011 97.01 98.46
(0.20, 0.30)   2460 95.05 97.50
(0.30, 0.40)   3132 94.94 97.42
(0.40, 0.50)    756 92.71 96.30


                                                                                        

Chunk 1/7, Epoch 48/100, Train Loss: 27290.9461, Val Loss: 27644.9390
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.52 99.15
(0.05, 0.10)   2598 98.02 98.92
(0.10, 0.20)   3011 97.02 98.39
(0.20, 0.30)   2460 95.08 97.37
(0.30, 0.40)   3132 94.93 97.30
(0.40, 0.50)    756 92.74 96.15


                                                                                        

Chunk 1/7, Epoch 49/100, Train Loss: 27252.2110, Val Loss: 27291.0588
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.51 99.20
(0.05, 0.10)   2598 98.00 98.93
(0.10, 0.20)   3011 97.02 98.43
(0.20, 0.30)   2460 95.03 97.41
(0.30, 0.40)   3132 94.93 97.36
(0.40, 0.50)    756 92.75 96.14


                                                                                        

Chunk 1/7, Epoch 50/100, Train Loss: 27089.4823, Val Loss: 26658.6794
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.55 99.26
(0.05, 0.10)   2598 98.07 99.00
(0.10, 0.20)   3011 97.08 98.52
(0.20, 0.30)   2460 95.16 97.51
(0.30, 0.40)   3132 95.03 97.47
(0.40, 0.50)    756 92.83 96.33
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 51/100, Train Loss: 26738.8923, Val Loss: 27913.2075
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.60 99.15
(0.05, 0.10)   2598 98.11 98.89
(0.10, 0.20)   3011 97.12 98.33
(0.20, 0.30)   2460 95.24 97.24
(0.30, 0.40)   3132 95.12 97.22
(0.40, 0.50)    756 92.94 96.04


                                                                                        

Chunk 1/7, Epoch 52/100, Train Loss: 26873.1979, Val Loss: 26841.7009
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.57 99.22
(0.05, 0.10)   2598 98.11 98.99
(0.10, 0.20)   3011 97.13 98.50
(0.20, 0.30)   2460 95.25 97.48
(0.30, 0.40)   3132 95.12 97.49
(0.40, 0.50)    756 92.93 96.32


                                                                                        

Chunk 1/7, Epoch 53/100, Train Loss: 26738.0207, Val Loss: 26660.1152
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.61 99.20
(0.05, 0.10)   2598 98.14 98.97
(0.10, 0.20)   3011 97.14 98.53
(0.20, 0.30)   2460 95.26 97.55
(0.30, 0.40)   3132 95.13 97.50
(0.40, 0.50)    756 92.96 96.44


                                                                                        

Chunk 1/7, Epoch 54/100, Train Loss: 26888.9819, Val Loss: 26798.7788
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.61 99.22
(0.05, 0.10)   2598 98.12 98.98
(0.10, 0.20)   3011 97.14 98.49
(0.20, 0.30)   2460 95.27 97.56
(0.30, 0.40)   3132 95.14 97.48
(0.40, 0.50)    756 92.95 96.31


                                                                                        

Chunk 1/7, Epoch 55/100, Train Loss: 26785.7917, Val Loss: 26685.4629
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.61 99.24
(0.05, 0.10)   2598 98.14 99.01
(0.10, 0.20)   3011 97.15 98.55
(0.20, 0.30)   2460 95.25 97.54
(0.30, 0.40)   3132 95.15 97.48
(0.40, 0.50)    756 92.98 96.39


                                                                                        

Chunk 1/7, Epoch 56/100, Train Loss: 26673.3350, Val Loss: 26652.5403
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.62 99.26
(0.05, 0.10)   2598 98.14 99.01
(0.10, 0.20)   3011 97.18 98.49
(0.20, 0.30)   2460 95.30 97.50
(0.30, 0.40)   3132 95.18 97.51
(0.40, 0.50)    756 93.02 96.42
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 57/100, Train Loss: 26337.6363, Val Loss: 26866.5115
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.63 99.19
(0.05, 0.10)   2598 98.17 98.96
(0.10, 0.20)   3011 97.19 98.47
(0.20, 0.30)   2460 95.32 97.55
(0.30, 0.40)   3132 95.21 97.50
(0.40, 0.50)    756 93.06 96.31


                                                                                        

Chunk 1/7, Epoch 58/100, Train Loss: 26395.4459, Val Loss: 26881.9001
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.65 99.26
(0.05, 0.10)   2598 98.20 98.98
(0.10, 0.20)   3011 97.21 98.47
(0.20, 0.30)   2460 95.37 97.51
(0.30, 0.40)   3132 95.25 97.43
(0.40, 0.50)    756 93.10 96.17


                                                                                        

Chunk 1/7, Epoch 59/100, Train Loss: 26517.7646, Val Loss: 26594.4312
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.64 99.28
(0.05, 0.10)   2598 98.20 99.02
(0.10, 0.20)   3011 97.22 98.50
(0.20, 0.30)   2460 95.37 97.53
(0.30, 0.40)   3132 95.26 97.48
(0.40, 0.50)    756 93.08 96.26
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 60/100, Train Loss: 26466.6212, Val Loss: 26923.2698
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.64 99.22
(0.05, 0.10)   2598 98.20 98.99
(0.10, 0.20)   3011 97.21 98.46
(0.20, 0.30)   2460 95.35 97.50
(0.30, 0.40)   3132 95.23 97.45
(0.40, 0.50)    756 93.07 96.24


                                                                                        

Chunk 1/7, Epoch 61/100, Train Loss: 26496.6832, Val Loss: 26293.4482
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.63 99.24
(0.05, 0.10)   2598 98.18 99.02
(0.10, 0.20)   3011 97.20 98.59
(0.20, 0.30)   2460 95.34 97.59
(0.30, 0.40)   3132 95.23 97.61
(0.40, 0.50)    756 93.07 96.53
  --> updated hg19_chr22_chunk0.pth


                                                                                        

Chunk 1/7, Epoch 62/100, Train Loss: 26485.2509, Val Loss: 26685.5393
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.64 99.23
(0.05, 0.10)   2598 98.18 98.99
(0.10, 0.20)   3011 97.20 98.52
(0.20, 0.30)   2460 95.35 97.56
(0.30, 0.40)   3132 95.23 97.46
(0.40, 0.50)    756 93.07 96.36


                                                                                        

Chunk 1/7, Epoch 63/100, Train Loss: 26565.8941, Val Loss: 26552.0071
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.64 99.24
(0.05, 0.10)   2598 98.18 99.03
(0.10, 0.20)   3011 97.22 98.51
(0.20, 0.30)   2460 95.36 97.56
(0.30, 0.40)   3132 95.25 97.49
(0.40, 0.50)    756 93.10 96.47


                                                                                        

Chunk 1/7, Epoch 64/100, Train Loss: 26435.7651, Val Loss: 26964.0107
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.63 99.20
(0.05, 0.10)   2598 98.18 98.95
(0.10, 0.20)   3011 97.20 98.47
(0.20, 0.30)   2460 95.35 97.45
(0.30, 0.40)   3132 95.21 97.44
(0.40, 0.50)    756 93.05 96.36


                                                                                        

Chunk 1/7, Epoch 65/100, Train Loss: 26228.8341, Val Loss: 27653.3293
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.65 99.15
(0.05, 0.10)   2598 98.20 98.94
(0.10, 0.20)   3011 97.23 98.41
(0.20, 0.30)   2460 95.39 97.28
(0.30, 0.40)   3132 95.28 97.24
(0.40, 0.50)    756 93.12 96.11


                                                                                        

Chunk 1/7, Epoch 66/100, Train Loss: 26124.3559, Val Loss: 26754.9080
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.65 99.22
(0.05, 0.10)   2598 98.21 98.98
(0.10, 0.20)   3011 97.24 98.54
(0.20, 0.30)   2460 95.37 97.52
(0.30, 0.40)   3132 95.26 97.43
(0.40, 0.50)    756 93.14 96.31


                                                                                        

Chunk 1/7, Epoch 67/100, Train Loss: 26428.5944, Val Loss: 26440.5212
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.67 99.25
(0.05, 0.10)   2598 98.21 99.00
(0.10, 0.20)   3011 97.23 98.54
(0.20, 0.30)   2460 95.40 97.57
(0.30, 0.40)   3132 95.28 97.56
(0.40, 0.50)    756 93.17 96.41


                                                                                        

Chunk 1/7, Epoch 68/100, Train Loss: 26344.9654, Val Loss: 27083.2000
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.66 99.21
(0.05, 0.10)   2598 98.21 98.95
(0.10, 0.20)   3011 97.23 98.45
(0.20, 0.30)   2460 95.40 97.41
(0.30, 0.40)   3132 95.29 97.47
(0.40, 0.50)    756 93.14 96.30


                                                                                        

Chunk 1/7, Epoch 69/100, Train Loss: 26223.4657, Val Loss: 27198.1580
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.69 99.19
(0.05, 0.10)   2598 98.25 98.92
(0.10, 0.20)   3011 97.28 98.45
(0.20, 0.30)   2460 95.45 97.41
(0.30, 0.40)   3132 95.34 97.38
(0.40, 0.50)    756 93.19 96.28


                                                                                        

Chunk 1/7, Epoch 70/100, Train Loss: 26303.1813, Val Loss: 26787.3787
     MAF_bin Counts Train   Val
(0.00, 0.05)   4427 98.66 99.25
(0.05, 0.10)   2598 98.23 99.02
(0.10, 0.20)   3011 97.25 98.50
(0.20, 0.30)   2460 95.41 97.49
(0.30, 0.40)   3132 95.33 97.47
(0.40, 0.50)    756 93.17 96.32
Early stopping triggered
=== Chunk 2 STAT ===
     MAF_bin Counts
(0.00, 0.05)   4318
(0.05, 0.10)   2967
(0.10, 0.20)   2900
(0.20, 0.30)   2408
(0.30, 0.40)   2797
(0.40, 0.50)    994


                                                                                       

KeyboardInterrupt: 

In [None]:
# ---------- 1. 必须与训练时完全一致 ----------
d_model       = 64
n_alleles     = 3
total_sites   = 12345
chunk_size    = 4096
chunk_overlap = 64
device        = 'cuda' if torch.cuda.is_available() else 'cpu'

# 重建模型
model = EvoFill(
    d_model=d_model,
    n_alleles=n_alleles,
    total_sites=total_sites,
    chunk_size=chunk_size,
    chunk_overlap=chunk_overlap
).to(device)

# ---------- 2. 加载最终权重 ----------
ckpt = torch.load('exp1/models/final_model.pth', map_location=device)
model.load_state_dict(ckpt['model_state'])

# ---------- 3. 切换推理模式 ----------
model.eval()
with torch.no_grad():
    pred = model(x, chunk_id=0, x_extra=x_extra)