# EvoFill working demo

ver 4.   new imputation loss with evo loss

ver 4.1  long range modules integrated in stage1 training

ver 4.2  stage 3 fine tuning with under-reprensted population samples.

last update: 2025/11/13

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 
os.chdir('/mnt/qmtang/EvoFill/')

## 0. Dependency

In [None]:
import sys
import json
import numpy as np
from pathlib import Path
import torch
import mamba_ssm

print('Python ver:  ', sys.version)
print('Pytorch ver: ', torch.__version__)
print('Mamba ver:   ', mamba_ssm.__version__)
print('GPU in use:  ', torch.cuda.get_device_name(torch.cuda.current_device()))

Python ver:   3.10.19 | packaged by conda-forge | (main, Oct 13 2025, 14:08:27) [GCC 14.3.0]
Pytorch ver:  2.8.0+cu129
Mamba ver:    2.2.5
GPU in use:   NVIDIA H100 PCIe


In [None]:
from src.data import GenotypeEncoder, ImputationDataset
from src.model import EvoFill
from src.loss import ImputationLoss
from src.utils import setup_workdir, precompute_maf, metrics_by_maf, print_maf_stat_df

In [41]:
work_dir = Path('/mnt/qmtang/EvoFill_data/20251121_chr22_v2')
setup_workdir(work_dir)
os.chdir(work_dir)

## 1. Encoding all vcfs

### 1.1 pretraining

In [42]:
gt_enc = GenotypeEncoder(phased = False, gts012 = False, save2disk = True, save_dir = Path(work_dir / "pretrain"))
gt_enc = gt_enc.encode_new(vcf_path   = "pretrain/major_pops.vcf.gz" ,
                           default_gt = 'ref',
                           evo_mat    = "pretrain/evo_mat_major_pops.tsv")

print(f"[DATA] {gt_enc.n_samples:,} Samples")
print(f"[DATA] {gt_enc.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc.seq_depth} seq_depth")

[DATA] 总计 14,867 个位点  
[DATA] EvoMat shape: (2904, 2904)
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251121_chr22_v2/pretrain
[DATA] 位点矩阵 = (2904, 14867)，稀疏度 = 45.53%
[DATA] 位点字典 = {'0|1': 1, '1|1': 2, '0|0': 0, '.|.': 3}，字典深度 = 4
[DATA] 2,904 Samples
[DATA] 14,867 Variants Sites
[DATA] 4 seq_depth


### 1.2 augmentation

In [48]:
gt_enc_aug = GenotypeEncoder(phased=False, gts012=False, save2disk = True, save_dir = Path(work_dir / "augment"))
gt_enc_aug = gt_enc_aug.encode_ref(
        ref_meta_json = work_dir/"pretrain"/"gt_enc_meta.json",   # 与 Stage1 同构
        default_gt    = 'miss',
        vcf_path      = "augment/chr22_trimmed_AADR_renamed.vcf.gz",
        evo_mat       = "augment/evo_mat_aDNA.tsv")

print(f"[DATA] {gt_enc_aug.n_samples:,} Samples")
print(f"[DATA] {gt_enc_aug.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc_aug.seq_depth} seq_depth")

[DATA] 总计 14,867 个位点  
[DATA] EvoMat shape: (17629, 17629)
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251121_chr22_v2/augment
[DATA] 位点矩阵 = (17629, 14867)，稀疏度 = 59.42%，缺失率 = 0.00%
[DATA] 位点字典 = {'0|1': 1, '1|1': 2, '0|0': 0, '.|.': 3, '0/1': 1, '1/1': 2, '0/0': 0, './.': 3}，字典深度 = 4
[DATA] 17,629 Samples
[DATA] 14,867 Variants Sites
[DATA] 4 seq_depth


### 1.3 finetuning

In [43]:
gt_enc_urp = GenotypeEncoder(phased=False, gts012=False, save2disk=True, save_dir = Path(work_dir / "finetune"))
gt_enc_urp = gt_enc_urp.encode_ref(
        ref_meta_json = work_dir/"pretrain"/"gt_enc_meta.json",   # 与 Stage1 同构
        default_gt    = 'ref',
        vcf_path      = work_dir/"finetune"/"minor_pops.10pct.vcf.gz",
        evo_mat       = work_dir/"finetune"/"evo_mat_minor_pops.10pct.tsv")

print(f'[URP] {gt_enc_urp.n_samples} samples, {gt_enc_urp.n_variants} variants')

[DATA] 总计 14,867 个位点  
[DATA] EvoMat shape: (29, 29)
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251121_chr22_v2/finetune
[DATA] 位点矩阵 = (29, 14867)，稀疏度 = 45.87%，缺失率 = 0.00%
[DATA] 位点字典 = {'0|1': 1, '1|1': 2, '0|0': 0, '.|.': 3}，字典深度 = 4
[URP] 29 samples, 14867 variants


### 1.4 validation

In [44]:
gt_enc_imp = GenotypeEncoder(phased=False, gts012=False, save2disk=True, save_dir = Path(work_dir / "impute_in"))
gt_enc_imp = gt_enc_imp.encode_ref(
        ref_meta_json = work_dir/"pretrain"/"gt_enc_meta.json",   # 与 Stage1 同构
        default_gt    = 'ref',
        vcf_path      = work_dir/"impute_in"/"minor_pops.90pct.masked50p.vcf.gz" )

print(f'[INFER] {gt_enc_imp.n_samples} samples, {gt_enc_imp.n_variants} variants')

[DATA] 总计 14,867 个位点  
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251121_chr22_v2/impute_in
[DATA] 位点矩阵 = (269, 14867)，稀疏度 = 72.75%，缺失率 = 50.01%
[DATA] 位点字典 = {'0|1': 1, '1|1': 2, '0|0': 0, '.|.': 3}，字典深度 = 4
[INFER] 269 samples, 14867 variants


## 2. Pretaining

In [None]:
%%bash
cd /mnt/qmtang/EvoFill/
nohup env OMP_NUM_THREADS=8 \
  accelerate launch --config_file ds_zero3.yaml \
  train_stage1_deepspeed.py \
  > logs/pretrian_chr22_251125.log 2>&1 &
%%!

# 3. Augmentation

In [None]:
%%bash
cd /mnt/qmtang/EvoFill/
nohup env OMP_NUM_THREADS=4 \
  accelerate launch --config_file ds_zero3.yaml \
  train_stage3_deepspeed.py \
  > logs/aug_chr22_251121.log 2>&1 &

### 1.4 Ultra-Long-Range LD Module Training

In [None]:
max_epochs_per_pair = 100
lr                 = 5e-4
weight_decay       = 1e-5
earlystop_patience = 15
batch_size         = 8
verbose            = True

criterion = ImputationLoss(use_r2=True, use_evo=True, r2_weight=1, evo_weight=4, evo_lambda=10)

# ----------- 逐个 chunk 加载权重-----------
# for cid in range(model.n_chunks):
#     chunk_file = f'{work_dir}/models/{model_name}_chunk_{cid}.pth'
#     ckpt = torch.load(chunk_file, map_location='cpu')
#     model.chunk_embeds[cid].load_state_dict(ckpt['chunk_embed_state'])
#     model.chunk_modules[cid].load_state_dict(ckpt['chunk_module_state'])

# ----------- 加载第一阶段完整权重-----------
ckpt = torch.load(f'{work_dir}/models/{model_name}_stage1.pth', map_location='cpu')
model.load_state_dict(ckpt['model_state'])

model.eval()        # chunk 专家冻结（requires_grad=False）

# 收集所有稀疏参数（主要是嵌入层）
sparse_params = []
dense_params = []
# 全局输出层的卷积参数（密集）
dense_params.extend([model.global_out.w1, model.global_out.b1])
dense_params.extend([model.global_out.w2, model.global_out.b2])
# ULR默认启用
if hasattr(model.global_out, 'ulr_mamba'):
    for name, param in model.global_out.ulr_mamba.named_parameters():
        if 'idx_embed' in name:
            sparse_params.append(param)
        else:
            dense_params.append(param)

# 创建分离优化器
optim_sparse = SparseAdam(sparse_params, lr=lr) if sparse_params else None
optim_dense = AdamW(dense_params, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))

# 学习率调度器
scheduler_sparse = ReduceLROnPlateau(optim_sparse, mode='min', factor=0.5, patience=5, min_lr=1e-9) if optim_sparse else None
scheduler_dense = ReduceLROnPlateau(optim_dense, mode='min', factor=0.5, patience=5, min_lr=1e-9)

pair_list = list(combinations(range(model.n_chunks), 2))
np.random.shuffle(pair_list)          # 打乱
total_pairs = len(pair_list)

for pair_idx, (cid1, cid2) in enumerate(pair_list, 1):
    # ====== 构造并集 mask ======
    union_mask = (model.chunk_masks[cid1] + model.chunk_masks[cid2]).clamp(max=1).bool()
    train_logs_sum = None
    val_logs_sum   = None
    
    # 并集 MAF
    union_maf, union_bin_cnt = precompute_maf(
        gt_enc.X_gt[:, union_mask.cpu().numpy()].toarray(),
        mask_int=gt_enc.seq_depth
    )

    # ====== 早停变量 ======
    best_loss = float('inf')
    patience_counter = 0
    is_early_stopped = False

    # ====== 训练循环 ======
    for epoch in range(max_epochs_per_pair):
        model.train()
        train_loss = 0.0
        train_prob, train_gts, train_mask = [], [], []

        pbar = tqdm(train_loader,
                    desc=f'Comb {pair_idx}/{total_pairs}  '
                         f'{cid1+1}-{cid2+1}  Epoch {epoch+1}/{max_epochs_per_pair}',
                    leave=False)
        for x, target, evo_mat in pbar:
            x,  target = x.to(device), target.to(device)
            if evo_mat.numel() == 0:
                evo_mat = None
            else:
                evo_mat = evo_mat.to(device)

            optim_sparse.zero_grad()
            optim_dense.zero_grad()

            logits, prob, mask_idx = model(x, [cid1, cid2])
            loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:,mask_idx], evo_mat) 
            loss.backward()

            optim_sparse.step()   # 只更新嵌入表
            optim_dense.step()    # 更新其余所有参数

            train_loss += loss.item()
            if train_logs_sum is None:          # 第一次初始化
                train_logs_sum = {k: 0.0 for k in logs}
            for k, v in logs.items():
                train_logs_sum[k] += v
            # pbar.set_postfix({'loss': loss.item(), 'ce':logs['ce'], 'r2':logs['r2'], 'evo':logs['evo']})

            # 收集指标
            miss_mask = x[:,union_mask][..., -1].bool()
            train_prob.append(prob[:, mask_idx].detach())
            train_gts.append(target[:,mask_idx].detach())
            train_mask.append(miss_mask)

        # 训练集 MAF
        train_prob = torch.cat(train_prob, dim=0)
        train_gts  = torch.cat(train_gts,    dim=0)
        train_mask = torch.cat(train_mask,   dim=0)

        # ----------- 验证 -----------
        model.eval()
        val_loss = 0.0
        val_prob, val_gts = [], []
        with torch.no_grad():
            if val_logs_sum is None:
                val_logs_sum = {k: 0.0 for k in train_logs_sum}
            for x, target, evo_mat in test_loader:
                x = x.to(device)
                target = target.to(device)
                evo_mat = evo_mat.to(device) if evo_mat.numel() else None
                logits, prob, mask_idx = model(x, [cid1, cid2])
                loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:,mask_idx], evo_mat)
                val_loss += loss.item()
                for k, v in logs.items():
                    val_logs_sum[k] += v
                val_prob.append(prob[:,mask_idx])
                val_gts.append(target[:,mask_idx])

        val_prob = torch.cat(val_prob, dim=0)
        val_gts  = torch.cat(val_gts,    dim=0)

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss   = val_loss   / len(test_loader)
        avg_train_logs = {k: v / len(train_loader) for k, v in train_logs_sum.items()}
        avg_val_logs = {k: v / len(test_loader) for k, v in val_logs_sum.items()}
        
        scheduler_sparse.step(val_loss)
        scheduler_dense.step(val_loss)

        current_denselr = optim_dense.param_groups[0]['lr']
        current_sparselr = optim_sparse.param_groups[0]['lr']

        log_str = (f'Comb {pair_idx}/{total_pairs}  '
            f'{cid1+1}-{cid2+1}  Epoch {epoch+1}/{max_epochs_per_pair} '
            f'Total Loss, Train = {avg_train_loss:.1f}, '
            f'Val = {avg_val_loss:.1f}, '
            f'dense LR: {current_denselr:.2e}, '
            f'sparse LR: {current_sparselr:.2e}')
        log_str += '\n        Train'
        for k, v in avg_train_logs.items():
            log_str += f', {k}: {v:.1f}'
        log_str += '\n        Val  '
        for k, v in avg_val_logs.items():
            log_str += f', {k}: {v:.1f}'
        print(log_str)
        # 清空累加器，供下一个 epoch 使用
        train_logs_sum = {k: 0.0 for k in train_logs_sum}
        val_logs_sum   = {k: 0.0 for k in val_logs_sum}
        # 早停
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            torch.save({
                'comb': (cid1, cid2),
                'global_out': model.global_out.state_dict(),
                'best_val_loss': best_loss,
                'epoch': epoch,
            }, f'{work_dir}/models/{model_name}_chunk[{cid1}-{cid2}].pth')
            # MAF 表格
            predres_with_bestloss = (train_prob, train_gts, val_prob, val_gts)
            if verbose:
                train_bins_metrics = metrics_by_maf(train_prob, train_gts, gt_enc.hap_map, union_maf, mask=train_mask)
                val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, gt_enc.hap_map, union_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "val":   val_bins_metrics})
                print(f'  --> updated {model_name}_chunk[{cid1+1}-{cid2+1}].pth')
        else:
            patience_counter += 1
            if patience_counter >= earlystop_patience:
                is_early_stopped = True
                print(f'Pair {cid1+1}-{cid2+1} early stopping')
                train_prob, train_gts, val_prob, val_gts = predres_with_bestloss
                train_bins_metrics = metrics_by_maf(train_prob, train_gts, gt_enc.hap_map, union_maf, mask=train_mask)
                val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, gt_enc.hap_map, union_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "val":   val_bins_metrics})
                break
            
    if not is_early_stopped:
        predres_with_bestloss = (train_prob, train_gts, val_prob, val_gts)
        train_bins_metrics = metrics_by_maf(train_prob, train_gts, gt_enc.hap_map, union_maf, mask=train_mask)
        val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, gt_enc.hap_map, union_maf, mask=None)
        print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "val":   val_bins_metrics})

    # del optimizer, scheduler
    torch.cuda.empty_cache()

# ----------- 全部 pair 结束 -> 保存最终模型 -----------
torch.save({
    'model_state': model.state_dict(),
    'ulr_enabled': True,
}, f'{work_dir}/models/{model_name}_stage2_final.pth')
print(f'==> STAGE2 training finished: {work_dir}/models/{model_name}_stage2_final.pth')

## 3. Fine-tuning (Few-shot URP)

In [None]:
%%bash
cd /mnt/qmtang/EvoFill/
nohup env OMP_NUM_THREADS=4 \
  accelerate launch --config_file ds_zero3.yaml \
  finetuning_deepspeed.py \
  > logs/finetune_chr22_251124.log 2>&1 &

## 4. Imputation

### 3.1 Load the trained model

Choose a path where including `<work_dir>/model` and have trained model.

In [8]:
work_dir = Path("/mnt/qmtang/EvoFill_data/20251121_chr22_v2/")
print(f"Work Dir: {work_dir}")

# ---- 1. 加载模型 ----
device = 'cuda' if torch.cuda.is_available() else 'cpu'

json_path = f"{work_dir}/models/model_meta.json"
meta = json.load(open(json_path))
model = EvoFill(
    d_model=int(meta["d_model"]),
    n_alleles=int(meta["alleles"]),
    total_sites=int(meta["total_sites"]),
    chunk_size=int(meta["chunk_size"]),
    chunk_overlap=int(meta["overlap"])
).to(device)

# ckpt = torch.load(f'{work_dir}/models/{meta["model_name"]}_stage1.pth', map_location=device)
# ckpt = torch.load(f'{work_dir}/models/{meta["model_name"]}_stage3.pth', map_location=device)
ckpt = torch.load(f'{work_dir}/models/{meta["model_name"]}_CDX_BEB_ASW.pth', map_location=device)
model.load_state_dict(ckpt['model_state'])
model.eval()
print(f'[INF] Model[{meta["model_name"]}] loaded.')

Work Dir: /mnt/qmtang/EvoFill_data/20251121_chr22_v2
[INF] Model[chr22_trim] loaded.


### 3.2 Encode .vcf file need be impute

In [7]:
gt_enc_imp = GenotypeEncoder(phased=False, gts012=False, save2disk=True, save_dir = Path(work_dir / "impute_in"))
gt_enc_imp = gt_enc_imp.encode_ref(
        ref_meta_json = work_dir/"pretrain"/"gt_enc_meta.json",   # 与 Stage1 同构
        default_gt    = 'ref',
        vcf_path      = work_dir/"impute_in"/"minor_pops.90pct.masked50p.vcf.gz" )

print(f'[INFER] {gt_enc_imp.n_samples} samples, {gt_enc_imp.n_variants} variants')

# ---- 2. 构建推理 Dataset / Loader ----
imp_dataset = ImputationDataset(
    x_gts_sparse=gt_enc_imp.X_gt,
    seq_depth=gt_enc_imp.seq_depth,
    indices=None                 # 可传入指定样本索引
)
imp_dataset.print_missing_stat()          # 查看原始缺失比例

def collate_fn(batch):
    x_onehot = torch.stack([item[0] for item in batch])
    real_idx_list = [item[1] for item in batch]
    return x_onehot, real_idx_list   # 无 y

imp_loader = torch.utils.data.DataLoader(
    imp_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

[DATA] 总计 14,867 个位点  
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251121_chr22_v2/impute_in
[DATA] 位点矩阵 = (269, 14867)，稀疏度 = 72.75%，缺失率 = 50.01%
[DATA] 位点字典 = {'0|1': 1, '1|1': 2, '0|0': 0, '.|.': 3}，字典深度 = 4
[INFER] 269 samples, 14867 variants
[ImputationDataset] 269 samples, missing rate = 50.01%


### 3.3 Inferring

In [9]:
y_prob = []
y_mask = []
with torch.no_grad():
    for x_onehot, real_idx in tqdm(imp_loader, desc='Imputing'):
        x_onehot = x_onehot.to(device)
        _, prob, _ = model(x_onehot)
        miss_mask = x_onehot[..., -1].bool()
        y_prob.append(prob)
        y_mask.append(miss_mask)
y_prob = torch.cat(y_prob, dim=0).cpu().numpy()
y_mask = torch.cat(y_mask, dim=0).cpu().numpy()
# 4. 保存
out_dir = os.path.join(work_dir, 'impute_out')
os.makedirs(out_dir, exist_ok=True)
np.save(os.path.join(out_dir, 'impute_prob.npy'), y_prob)
np.save(os.path.join(out_dir, 'impute_mask.npy'), y_mask)
print(f'[INF] 概率矩阵已保存 → {out_dir}/impute_prob.npy '
      f'with shape = {y_prob.shape} ')


Imputing: 100%|██████████| 5/5 [00:17<00:00,  3.51s/it]


[INF] 概率矩阵已保存 → /mnt/qmtang/EvoFill_data/20251121_chr22_v2/impute_out/impute_prob.npy with shape = (269, 14867, 3) 


### 3.4 Evaluating the imputation results

In [10]:
gt_enc_true = GenotypeEncoder(phased=False, gts012=False, save2disk=False)
gt_enc_true = gt_enc_true.encode_ref(
        ref_meta_json = work_dir/"pretrain"/"gt_enc_meta.json",   # 与 Stage1 同构
        default_gt    = 'ref',
        vcf_path      = work_dir/"impute_out"/"minor_pops.90pct.vcf.gz" )
y_true = gt_enc_true.X_gt.toarray()
maf, bin_cnt = precompute_maf(y_true,  mask_int=gt_enc_true.seq_depth)
y_true_oh = np.eye(gt_enc_true.seq_depth - 1)[y_true]
bins_metrics   = metrics_by_maf(y_prob, y_true_oh, hap_map = gt_enc_true.hap_map, maf_vec = maf, mask=y_mask)
print_maf_stat_df(bin_cnt,{'val': bins_metrics})

[DATA] 总计 14,867 个位点  
[DATA] 位点矩阵 = (269, 14867)，稀疏度 = 45.44%，缺失率 = 0.00%
[DATA] 位点字典 = {'0|1': 1, '1|1': 2, '0|0': 0, '.|.': 3}，字典深度 = 4
     MAF_bin Counts val_Acc val_INFO val_IQS val_MaCH
(0.00, 0.05)   1546   0.978    0.340   0.340    0.608
(0.05, 0.10)   1246   0.957    0.580   0.596    0.818
(0.10, 0.20)   2205   0.933    0.705   0.713    0.921
(0.20, 0.30)   3112   0.904    0.789   0.781    0.969
(0.30, 0.40)   5193   0.888    0.822   0.801    0.981
(0.40, 0.50)   1565   0.877    0.812   0.788    0.974


= STAGE 1 =

|      MAF_bin | Counts | val_Acc | val_INFO | val_IQS | val_MaCH |
| :----------: | :----: | ------: | -------: | ------: | -------: |
| (0.00, 0.05) |   1546 |  0.978  |    0.299 |  0.319  |    0.572 |
| (0.05, 0.10) |   1246 |  0.957  |    0.539 |  0.585  |    0.796 |
| (0.10, 0.20) |   2205 |  0.934  |    0.672 |  0.705  |    0.901 |
| (0.20, 0.30) |   3112 |  0.904  |    0.768 |  0.779  |    0.960 |
| (0.30, 0.40) |   5193 |  0.888  |    0.804 |  0.801  |    0.976 |
| (0.40, 0.50) |   1565 |  0.878  |    0.795 |  0.790  |    0.970 |


### 3.5 Saving to .vcf

In [38]:
from cyvcf2 import VCF, Writer

# 0. 路径
ref_vcf = "/mnt/qmtang/EvoFill_data/20251107_ver4/minor_pops.masked30p.vcf.gz"
out_vcf = os.path.join(out_dir, 'imputed.vcf.gz')

n_site = gt_enc.n_variants
n_samp = gt_enc.n_samples
n_alleles = gt_enc.seq_depth - 1
assert y_prob.shape == (n_samp, n_site, n_alleles)

# 2. 反向映射  idx -> '0|0' / '0|1' / ...
rev_hap_map = {v: k for k, v in gt_enc.hap_map.items()}

samp2idx = {sid: i for i, sid in enumerate(gt_enc.sample_ids)}

# 4. 打开参考 VCF
invcf = VCF(ref_vcf)
tmpl  = invcf
tmpl.set_samples(gt_enc.sample_ids)   # 替换样本列

out = Writer(out_vcf, tmpl, mode='wz')

for rec_idx, rec in enumerate(invcf):
    # 当前位点全部样本的 GT
    gt_int_pairs = []
    for samp_idx, sample_id in enumerate(gt_enc.sample_ids):
        old_gt = rec.genotypes[samp_idx]          # [allele1, allele2, phased]
        if old_gt[0] == -1 or old_gt[1] == -1:    # 缺失
            prob_vec = y_prob[samp_idx, rec_idx, :].ravel()
            best_idx = int(prob_vec.argmax())
            gt_str   = rev_hap_map[best_idx]
            alleles  = list(map(int, gt_str.split('|')))
            phased   = old_gt[2] if old_gt[2] != -1 else 1
            gt_int_pairs.append([alleles[0], alleles[1], phased])
        else:                                       # 非缺失，保持原样
            gt_int_pairs.append(old_gt)

    # 转成 int8 二维数组  (n_sample, 3)  last dim = [a1,a2,phased]
    gt_array = np.array(gt_int_pairs, dtype=np.int8)
    rec.set_format('GT', gt_array)
    out.write_record(rec)

invcf.close()
out.close()

# 5. tabix
os.system(f'tabix -fp vcf {out_vcf}')
print(f'[INF] 缺失位点填充完成 → {out_vcf}')

[INF] 缺失位点填充完成 → /mnt/qmtang/EvoFill_data/20251107_ver4/impute_out/imputed.vcf.gz
