# EvoFill working demo

ver 4.   new imputation loss with evo loss

ver 4.1  long range modules integrated in stage1 training

ver 4.2  stage 3 fine tuning with under-reprensted population samples.

last update: 2025/11/13

## 0. Dependency

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1" 

import sys
import json
import numpy as np
from pathlib import Path
from tqdm import tqdm
import torch
import mamba_ssm

print('Python ver:  ', sys.version)
print('Pytorch ver: ', torch.__version__)
print('Mamba ver:   ', mamba_ssm.__version__)
print('GPU in use:  ', torch.cuda.get_device_name(torch.cuda.current_device()))

os.chdir('/mnt/qmtang/EvoFill/')
from src.data import GenotypeEncoder, ImputationDataset
from src.model import EvoFill
from src.utils import setup_workdir, precompute_maf, metrics_by_maf_with95ci, print_maf_stat_df_with95ci

Python ver:   3.10.19 | packaged by conda-forge | (main, Oct 13 2025, 14:08:27) [GCC 14.3.0]
Pytorch ver:  2.8.0+cu129
Mamba ver:    2.2.5
GPU in use:   NVIDIA H100 PCIe


## 1. Encoding all vcfs

In [3]:
work_dir = Path('/mnt/qmtang/EvoFill_data/20251211_chr22/')
os.chdir(work_dir)
setup_workdir(work_dir)

### 1.1 1kGP cohorts

In [None]:
gt_enc = GenotypeEncoder(phased = False, gts012 = False, save2disk = True, save_dir = Path(work_dir / "train"))
gt_enc = gt_enc.encode_new(vcf_path   = "data/major_pops_train.vcf.gz" ,
                           evo_mat    = "data/evo_mat_major_pops_train.tsv")

print(f"[DATA] {gt_enc.n_samples:,} Samples")
print(f"[DATA] {gt_enc.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc.seq_depth} seq_depth")

[DATA] 总计 190,184 个位点  
[DATA] EvoMat shape: (3009, 3009)
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251211_chr22/train
[DATA] 位点矩阵 = (3009, 190184)，稀疏度 = 26.30%
[DATA] 位点字典 = {'0|1': 1, '0|0': 0, '1|1': 2, '.|.': -1}，字典深度 = 4
[DATA] 3,009 Samples
[DATA] 190,184 Variants Sites
[DATA] 4 seq_depth


### 1.2 aDNA cohorts

In [5]:
gt_enc_aug = GenotypeEncoder(phased=False, gts012=False, save2disk = True, save_dir = Path(work_dir / "augment"))
gt_enc_aug = gt_enc_aug.encode_ref(
        ref_meta_json = work_dir/"train"/"gt_enc_meta.json",   # 与 Stage1 同构
        vcf_path      = "augment/hg38_chr22.AADRl_plus_Shimao.renamed.vcf.gz",
        evo_mat       = "augment/evo_mat_aDNA.tsv",)

print(f"[DATA] {gt_enc_aug.n_samples:,} Samples")
print(f"[DATA] {gt_enc_aug.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc_aug.seq_depth} seq_depth")

[DATA] 总计 15,517 个位点  
[DATA] EvoMat shape: (8953, 8953)
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251211_chr22/augment
[DATA] 位点矩阵 = (8953, 15517)，稀疏度 = 64.16%，缺失率 = 49.41%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': -1, '0/0': 0, '0/1': 1, '1/1': 2, './.': -1}，字典深度 = 4
[DATA] 8,953 Samples
[DATA] 15,517 Variants Sites
[DATA] 4 seq_depth


将1240K中的位点映射到千人基因组

In [None]:
import numpy as np
from cyvcf2 import VCF

ref_vcf_path   = "data/major_pops_train.vcf.gz"
aDNA_vcf_path = 'augment/hg38_chr22.AADRl_plus_Shimao.renamed.vcf.gz'

# 1. 大 VCF 建索引  fingerprint -> row_index
big_index = {}
for idx, variant in enumerate(VCF(ref_vcf_path)):
    # 用 CHROM:POS:REF:ALT 当 key，和 bcftools 脚本保持一致
    key = f'{variant.CHROM}:{variant.POS}:{variant.REF}:{",".join(variant.ALT)}'
    big_index[key] = idx          # 0-based 行号

# 2. 小 VCF 生成 mapping 数组
mapping = []
for variant in VCF(aDNA_vcf_path):
    key = f'{variant.CHROM}:{variant.POS}:{variant.REF}:{",".join(variant.ALT)}'
    mapping.append(big_index.get(key, -1))

mapping = np.array(mapping, dtype=np.int32)
print('古DNA VCF 位点数:', len(mapping))
print('在参考面板中命中:', np.sum(mapping != -1))

# 保存
np.save('augment/aDNA-1kGP_sitesmap.npy', mapping)

古DNA VCF 位点数: 15517
在参考 VCF 中命中: 14713


### 1.4 validation

In [4]:
gt_enc_imp = GenotypeEncoder(phased=False, gts012=False, save2disk=True, save_dir = Path(work_dir / "impute_in"))
gt_enc_imp = gt_enc_imp.encode_ref(
        ref_meta_json = work_dir/"train"/"gt_enc_meta.json",   # 与 Stage1 同构
        vcf_path      = "/mnt/qmtang/EvoFill_data/20251211_chr22/data/minor_pops_all.mask90p.vcf.gz" )

print(f'[INFER] {gt_enc_imp.n_samples} samples, {gt_enc_imp.n_variants} variants')

[DATA] 总计 190,184 个位点  
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251211_chr22/impute_in
[DATA] 位点矩阵 = (93, 190184)，稀疏度 = 92.38%，缺失率 = 89.93%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': -1, '0/0': 0, '0/1': 1, '1/1': 2, './.': -1}，字典深度 = 4
[INFER] 93 samples, 190184 variants


In [5]:
gt_enc_imp = GenotypeEncoder(phased=False, gts012=False, save2disk=True, save_dir = Path(work_dir / "impute_out"))
gt_enc_imp = gt_enc_imp.encode_ref(
        ref_meta_json = work_dir/"train"/"gt_enc_meta.json",   # 与 Stage1 同构
        vcf_path      = "/mnt/qmtang/EvoFill_data/20251211_chr22/data/minor_pops_all.vcf.gz" )

print(f'[INFER] {gt_enc_imp.n_samples} samples, {gt_enc_imp.n_variants} variants')

[DATA] 总计 190,184 个位点  
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251211_chr22/impute_out
[DATA] 位点矩阵 = (93, 190184)，稀疏度 = 24.31%，缺失率 = 0.00%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': -1, '0/0': 0, '0/1': 1, '1/1': 2, './.': -1}，字典深度 = 4
[INFER] 93 samples, 190184 variants


## 2. Training with multi-GPUs

### 2.1 Pre-training

In [None]:
%%bash
cd /mnt/qmtang/EvoFill/
nohup env OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES=1 \
  accelerate launch --config_file ds_zero3.yaml \
  stage1_training_ds.py \
  > logs/pre_HLA_251206.log 2>&1 &

nohup env OMP_NUM_THREADS=8 \
  accelerate launch --config_file ds_zero3.yaml \
  stage1_training_ds.py \
  > logs/pre_chr22_251211-2.log 2>&1 &

tail -f logs/pre_chr22_251211-2.log

### 2.2 Augmentation

In [None]:
%%bash
cd /mnt/qmtang/EvoFill/
nohup env OMP_NUM_THREADS=4 accelerate launch --config_file ds_zero3.yaml stage2_augmentation_ds.py > logs/aug_chr22_2512011.log 2>&1 &

### 2.3 Merge weight files

In [None]:
%%bash
cd /mnt/qmtang/EvoFill_data/20251204_chr22/models/checkpoint-stage1/
python zero_to_fp32.py . ./

## 3. Imputation

### 3.1 Load the trained model

Choose a path where including `<work_dir>/model` and have trained model.

In [2]:
work_dir = Path('/mnt/qmtang/EvoFill_data/20251211_chr22')
os.chdir(work_dir)
print(f"Work Dir: {work_dir}")

# ---- 1. 加载模型 ----
device = 'cuda' if torch.cuda.is_available() else 'cpu'

json_path = f"{work_dir}/models/model_meta.json"
meta = json.load(open(json_path))
model = EvoFill(
    n_alleles=int(meta["alleles"]),
    total_sites=int(meta["total_sites"]),
    chunk_size=int(meta["chunk_size"]),
    chunk_overlap=int(meta["overlap"]),
    d_model=int(meta["d_model"]),
    d_state=int(meta["d_state"]),
    headdim=int(meta["headdim"]),
    bimamba_layers=int(meta["bimamba_layers"]),
    stack_mamba_layers=int(meta["stack_mamba_layers"])
).to(device)


# state_dict = torch.load(f"{work_dir}/models/pytorch_model_stage1.bin", map_location="cpu")
state_dict = torch.load(f"{work_dir}/models/pytorch_model_stage2.bin", map_location="cpu")

model.load_state_dict(state_dict)
total_params = sum(p.numel() for p in model.parameters())
model.eval()
print(f'[INF] Model[{meta["model_name"]}] loaded.')
print(f"[INF] Total params: {total_params:,}")

Work Dir: /mnt/qmtang/EvoFill_data/20251211_chr22
[INF] Model[hg38_chr22] loaded.
[INF] Total params: 39,503,723


### 3.2 Inferring

In [3]:
gt_enc_imp = GenotypeEncoder.loadfromdisk(Path(work_dir / "impute_in"))

print(f'[INFER] {gt_enc_imp.n_samples} samples, {gt_enc_imp.n_variants} variants')

# ---- 2. 构建推理 Dataset / Loader ----
imp_dataset = ImputationDataset(
    x_gts_sparse=gt_enc_imp.X_gt,
    seq_depth=gt_enc_imp.seq_depth,
    indices=None                 # 可传入指定样本索引
)
imp_dataset.print_missing_stat()          # 查看原始缺失比例

def collate_fn(batch):
    x_onehot = torch.stack([item[0] for item in batch])
    real_idx_list = [item[1] for item in batch]
    return x_onehot, real_idx_list   # 无 y

imp_loader = torch.utils.data.DataLoader(
    imp_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

y_prob = []
y_mask = []
with torch.no_grad():
    for x_onehot, real_idx in tqdm(imp_loader, desc='Imputing'):
        x_onehot = x_onehot.to(device)
        _, prob, _ = model(x_onehot)
        miss_mask = x_onehot[..., -1].bool()
        y_prob.append(prob)
        y_mask.append(miss_mask)
y_prob = torch.cat(y_prob, dim=0).cpu().numpy()
y_mask = torch.cat(y_mask, dim=0).cpu().numpy()
# 4. 保存
out_dir = os.path.join(work_dir, 'impute_out')
os.makedirs(out_dir, exist_ok=True)
np.save(os.path.join(out_dir, 'impute_prob.npy'), y_prob)
np.save(os.path.join(out_dir, 'impute_mask.npy'), y_mask)
print(f'[INF] 概率矩阵已保存 → {out_dir}/impute_prob.npy '
      f'with shape = {y_prob.shape} ')


[INFER] 93 samples, 190184 variants
[ImputationDataset] 93 samples, missing rate = 89.93%


Imputing: 100%|██████████| 93/93 [00:51<00:00,  1.82it/s]


[INF] 概率矩阵已保存 → /mnt/qmtang/EvoFill_data/20251211_chr22/impute_out/impute_prob.npy with shape = (93, 190184, 3) 


### 3.4 Evaluating the imputation results

In [4]:
gt_enc_true = GenotypeEncoder.loadfromdisk(Path(work_dir / "impute_out"))
y_true = gt_enc_true.X_gt.toarray()
maf, bin_cnt = precompute_maf(y_true)
y_true_oh = np.eye(gt_enc_true.seq_depth - 1)[y_true]
bins_metrics   = metrics_by_maf_with95ci(y_prob, y_true_oh, hap_map = gt_enc_true.hap_map, maf_vec = maf, mask=y_mask)
print_maf_stat_df_with95ci(bin_cnt,{'val': bins_metrics})

     MAF_bin Counts val_Acc val_INFO val_IQS val_MaCH   val_Acc_CI95  val_INFO_CI95   val_IQS_CI95  val_MaCH_CI95
(0.00, 0.05) 109870   0.994    0.145   0.142    0.752 [0.994, 0.994] [0.143, 0.147] [0.140, 0.144] [0.749, 0.754]
(0.05, 0.10)  15765   0.952    0.644   0.656    0.733 [0.951, 0.953] [0.638, 0.650] [0.650, 0.663] [0.726, 0.739]
(0.10, 0.20)  22521   0.934    0.717   0.725    0.806 [0.932, 0.935] [0.713, 0.722] [0.720, 0.730] [0.801, 0.810]
(0.20, 0.30)  14896   0.915    0.751   0.749    0.841 [0.913, 0.917] [0.746, 0.757] [0.743, 0.754] [0.836, 0.846]
(0.30, 0.40)  13774   0.911    0.772   0.764    0.857 [0.909, 0.914] [0.766, 0.777] [0.759, 0.770] [0.852, 0.863]
(0.40, 0.50)  12984   0.910    0.777   0.761    0.857 [0.907, 0.912] [0.772, 0.783] [0.756, 0.767] [0.852, 0.863]


测试集：90% 少数族裔样本（CDX），90% masked 位点

训练集：hg38 chr22 190184 variants

**= 1kGP ONLY =**

|      MAF_bin | Counts | val_Acc | val_INFO | val_IQS | val_MaCH |
| :----------: | :----: | ------: | -------: | ------: | -------: |
| (0.00, 0.05) |   1546 |   0.972 |    0.060 |   0.068 |    0.166 |
| (0.05, 0.10) |   1246 |   0.931 |    0.198 |   0.216 |    0.367 |
| (0.10, 0.20) |   2205 |   0.875 |    0.332 |   0.353 |    0.545 |
| (0.20, 0.30) |   3112 |   0.798 |    0.487 |   0.495 |    0.754 |
| (0.30, 0.40) |   5193 |   0.755 |    0.550 |   0.552 |    0.853 |
| (0.40, 0.50) |   1565 |   0.739 |    0.548 |   0.545 |    0.845 |

     MAF_bin Counts val_Acc val_INFO val_IQS val_MaCH   val_Acc_CI95  val_INFO_CI95   val_IQS_CI95  val_MaCH_CI95
(0.00, 0.05) 109870   0.994    0.145   0.141    0.757 [0.993, 0.994] [0.143, 0.147] [0.139, 0.143] [0.755, 0.760]
(0.05, 0.10)  15765   0.951    0.643   0.653    0.732 [0.950, 0.952] [0.636, 0.649] [0.646, 0.659] [0.725, 0.738]
(0.10, 0.20)  22521   0.932    0.719   0.723    0.807 [0.931, 0.934] [0.714, 0.723] [0.718, 0.728] [0.802, 0.812]
(0.20, 0.30)  14896   0.915    0.756   0.749    0.842 [0.913, 0.917] [0.751, 0.762] [0.743, 0.754] [0.837, 0.848]
(0.30, 0.40)  13774   0.909    0.778   0.761    0.858 [0.907, 0.912] [0.772, 0.783] [0.755, 0.767] [0.852, 0.863]
(0.40, 0.50)  12984   0.909    0.782   0.760    0.856 [0.906, 0.912] [0.777, 0.788] [0.754, 0.766] [0.851, 0.862]

     MAF_bin Counts val_Acc val_INFO val_IQS val_MaCH   val_Acc_CI95  val_INFO_CI95   val_IQS_CI95  val_MaCH_CI95
(0.00, 0.05)  83276   0.976    0.400   0.419    0.557 [0.976, 0.976] [0.397, 0.403] [0.416, 0.422] [0.554, 0.560]
(0.05, 0.10)  29657   0.954    0.640   0.656    0.779 [0.953, 0.954] [0.636, 0.644] [0.651, 0.660] [0.775, 0.784]
(0.10, 0.20)  27650   0.931    0.711   0.719    0.843 [0.930, 0.932] [0.707, 0.715] [0.715, 0.723] [0.839, 0.847]
(0.20, 0.30)  19387   0.916    0.757   0.755    0.874 [0.914, 0.917] [0.753, 0.762] [0.750, 0.759] [0.870, 0.878]
(0.30, 0.40)  15843   0.909    0.769   0.761    0.877 [0.907, 0.911] [0.765, 0.774] [0.756, 0.766] [0.872, 0.882]
(0.40, 0.50)  14047   0.897    0.767   0.751    0.879 [0.895, 0.900] [0.762, 0.772] [0.746, 0.756] [0.874, 0.884]

**= aDNA -> 1kGP =** 

|      MAF_bin | Counts | val_Acc | val_INFO | val_IQS | val_MaCH |
| :----------: | :----: | ------: | -------: | ------: | -------: |
| (0.00, 0.05) |   1546 |   0.972 |    0.089 |   0.085 |    0.193 |
| (0.05, 0.10) |   1246 |   0.935 |    0.280 |   0.283 |    0.463 |
| (0.10, 0.20) |   2205 |   0.890 |    0.452 |   0.456 |    0.685 |
| (0.20, 0.30) |   3112 |   0.824 |    0.592 |   0.574 |    0.850 |
| (0.30, 0.40) |   5193 |   0.787 |    0.645 |   0.614 |    0.906 |
| (0.40, 0.50) |   1565 |   0.770 |    0.634 |   0.599 |    0.894 |

**= 1kGP -> aDNA =** 

|      MAF_bin | Counts | val_Acc | val_INFO | val_IQS | val_MaCH |
| :----------: | :----: | ------: | -------: | ------: | -------: |
| (0.00, 0.05) |   1546 |   0.938 |    0.109 |   0.227 |    0.734 |
| (0.05, 0.10) |   1246 |   0.893 |    0.174 |   0.396 |    0.633 |
| (0.10, 0.20) |   2205 |   0.810 |    0.196 |   0.435 |    0.545 |
| (0.20, 0.30) |   3112 |   0.690 |    0.144 |   0.433 |    0.499 |
| (0.30, 0.40) |   5193 |   0.640 |    0.101 |   0.429 |    0.503 |
| (0.40, 0.50) |   1565 |   0.604 |    0.102 |   0.398 |    0.464 |


### 3.5 Saving to .vcf

In [38]:
from cyvcf2 import VCF, Writer

# 0. 路径
ref_vcf = "/mnt/qmtang/EvoFill_data/20251107_ver4/minor_pops.masked30p.vcf.gz"
out_vcf = os.path.join(out_dir, 'imputed.vcf.gz')

n_site = gt_enc.n_variants
n_samp = gt_enc.n_samples
n_alleles = gt_enc.seq_depth - 1
assert y_prob.shape == (n_samp, n_site, n_alleles)

# 2. 反向映射  idx -> '0|0' / '0|1' / ...
rev_hap_map = {v: k for k, v in gt_enc.hap_map.items()}

samp2idx = {sid: i for i, sid in enumerate(gt_enc.sample_ids)}

# 4. 打开参考 VCF
invcf = VCF(ref_vcf)
tmpl  = invcf
tmpl.set_samples(gt_enc.sample_ids)   # 替换样本列

out = Writer(out_vcf, tmpl, mode='wz')

for rec_idx, rec in enumerate(invcf):
    # 当前位点全部样本的 GT
    gt_int_pairs = []
    for samp_idx, sample_id in enumerate(gt_enc.sample_ids):
        old_gt = rec.genotypes[samp_idx]          # [allele1, allele2, phased]
        if old_gt[0] == -1 or old_gt[1] == -1:    # 缺失
            prob_vec = y_prob[samp_idx, rec_idx, :].ravel()
            best_idx = int(prob_vec.argmax())
            gt_str   = rev_hap_map[best_idx]
            alleles  = list(map(int, gt_str.split('|')))
            phased   = old_gt[2] if old_gt[2] != -1 else 1
            gt_int_pairs.append([alleles[0], alleles[1], phased])
        else:                                       # 非缺失，保持原样
            gt_int_pairs.append(old_gt)

    # 转成 int8 二维数组  (n_sample, 3)  last dim = [a1,a2,phased]
    gt_array = np.array(gt_int_pairs, dtype=np.int8)
    rec.set_format('GT', gt_array)
    out.write_record(rec)

invcf.close()
out.close()

# 5. tabix
os.system(f'tabix -fp vcf {out_vcf}')
print(f'[INF] 缺失位点填充完成 → {out_vcf}')

[INF] 缺失位点填充完成 → /mnt/qmtang/EvoFill_data/20251107_ver4/impute_out/imputed.vcf.gz
