# EvoFill working demo



## 0. Dependency

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1" 

import sys
import json
import numpy as np
from pathlib import Path
from tqdm import tqdm
import torch
import mamba_ssm

print('Python ver:  ', sys.version)
print('Pytorch ver: ', torch.__version__)
print('Mamba ver:   ', mamba_ssm.__version__)
print('GPU in use:  ', torch.cuda.get_device_name(torch.cuda.current_device()))

os.chdir('/mnt/qmtang/EvoFill/')
from src.data import GenotypeEncoder, ImputationDataset
from src.model import EvoFill
from src.utils import setup_workdir, precompute_maf, metrics_by_maf_with95ci, print_maf_stat_df_with95ci
from src.tensor2vcf import make_imputed_vcfgz_from_prob

Python ver:   3.10.19 | packaged by conda-forge | (main, Oct 13 2025, 14:08:27) [GCC 14.3.0]
Pytorch ver:  2.8.0+cu129
Mamba ver:    2.2.5
GPU in use:   NVIDIA H100 PCIe


## 1. Encoding all vcfs

In [2]:
work_dir = Path('/mnt/qmtang/EvoFill_data/20251230_chr22/')
os.chdir(work_dir)
setup_workdir(work_dir)

### 1.1 1kGP cohorts

In [4]:
gt_enc = GenotypeEncoder(phased = False, gts012 = False, save2disk = True, save_dir = Path(work_dir / "train"))
gt_enc = gt_enc.encode_new(vcf_path   = "data/major_pops_train.vcf.gz" ,
                           evo_mat    = "data/evo_mat_major_pops_train.tsv")

print(f"[DATA] {gt_enc.n_samples:,} Samples")
print(f"[DATA] {gt_enc.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc.seq_depth} seq_depth")

[DATA] 总计 7,182 个位点  
[DATA] EvoMat shape: (3009, 3009)
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251225_chr22_IGL/train
[DATA] 位点矩阵 = (3009, 7182)，稀疏度 = 30.45%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': -1}，字典深度 = 4
[DATA] 3,009 Samples
[DATA] 7,182 Variants Sites
[DATA] 4 seq_depth


### 1.2 1240k-panel cohorts

In [3]:
gt_enc_aug = GenotypeEncoder(phased=False, gts012=False, save2disk = True, save_dir = Path(work_dir / "augment"))
gt_enc_aug = gt_enc_aug.encode_ref(
        ref_meta_json = work_dir/"train"/"gt_enc_meta.json",   # 与 Stage1 同构
        vcf_path      = "augment/augment_hg38.chr22.vcf.gz",
        evo_mat       = "augment/evo_mat_1240k.tsv",)

print(f"[DATA] {gt_enc_aug.n_samples:,} Samples")
print(f"[DATA] {gt_enc_aug.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc_aug.seq_depth} seq_depth")

[DATA] 总计 15,326 个位点  
[DATA] EvoMat shape: (2364, 2364)
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251230_chr22/augment
[DATA] 位点矩阵 = (2364, 15326)，稀疏度 = 53.57%，缺失率 = 24.09%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': -1, '0/0': 0, '0/1': 1, '1/1': 2, './.': -1}，字典深度 = 4
[DATA] 2,364 Samples
[DATA] 15,326 Variants Sites
[DATA] 4 seq_depth


Mapping variant sites in 1240K to 1kGP panel

将1240K中的位点映射到千人基因组

In [None]:
import numpy as np
from cyvcf2 import VCF

ref_vcf_path   = "data/major_pops_train.vcf.gz"
aDNA_vcf_path = 'augment/augment_hg38.chr22.vcf.gz'

# 1. 大 VCF 建索引  fingerprint -> row_index
big_index = {}
for idx, variant in enumerate(VCF(ref_vcf_path)):
    # 用 CHROM:POS:REF:ALT 当 key，和 bcftools 脚本保持一致
    key = f'{variant.CHROM}:{variant.POS}:{variant.REF}:{",".join(variant.ALT)}'
    big_index[key] = idx          # 0-based 行号

# 2. 小 VCF 生成 mapping 数组
mapping = []
for variant in VCF(aDNA_vcf_path):
    key = f'{variant.CHROM}:{variant.POS}:{variant.REF}:{",".join(variant.ALT)}'
    mapping.append(big_index.get(key, -1))

mapping = np.array(mapping, dtype=np.int32)
print('1240k面板位点数:', len(mapping))
print('在参考面板中命中:', np.sum(mapping != -1))

# 保存
np.save('augment/sitesmap_1240k.npy', mapping)

古DNA VCF 位点数: 15326
在参考面板中命中: 14713


### 1.4 validation

In [8]:
gt_enc_imp = GenotypeEncoder(phased=False, gts012=False, save2disk=True, save_dir = Path(work_dir / "impute_in"))
gt_enc_imp = gt_enc_imp.encode_ref(
        ref_meta_json = work_dir/"train"/"gt_enc_meta.json",   # 与 Stage1 同构
        vcf_path      = "./data/minor_pops_all.mask90p.vcf.gz" )

print(f'[INFER] {gt_enc_imp.n_samples} samples, {gt_enc_imp.n_variants} variants')

[DATA] 总计 7,182 个位点  
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251225_chr22_IGL/impute_in
[DATA] 位点矩阵 = (93, 7182)，稀疏度 = 93.26%，缺失率 = 90.27%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': -1, './.': -1}，字典深度 = 4
[INFER] 93 samples, 7182 variants


In [9]:
gt_enc_imp = GenotypeEncoder(phased=False, gts012=False, save2disk=True, save_dir = Path(work_dir / "impute_out"))
gt_enc_imp = gt_enc_imp.encode_ref(
        ref_meta_json = work_dir/"train"/"gt_enc_meta.json",   # 与 Stage1 同构
        vcf_path      = "./data/minor_pops_all.vcf.gz" )

print(f'[INFER] {gt_enc_imp.n_samples} samples, {gt_enc_imp.n_variants} variants')

[DATA] 总计 7,182 个位点  
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251225_chr22_IGL/impute_out
[DATA] 位点矩阵 = (93, 7182)，稀疏度 = 30.24%，缺失率 = 0.00%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': -1, './.': -1}，字典深度 = 4
[INFER] 93 samples, 7182 variants


## 2. Training with multi-GPUs

### 2.1 training stage 1

In [None]:
%%bash
cd /mnt/qmtang/EvoFill/

nohup env OMP_NUM_THREADS=8 accelerate launch --config_file ds_zero3.yaml stage1_training_ds.py > logs/chr22_260104.log 2>&1 &

tail -f logs/chr22_260104.log

### 2.2 training stage 2

In [None]:
%%bash
cd /mnt/qmtang/EvoFill/

nohup env OMP_NUM_THREADS=8 accelerate launch --config_file ds_zero3.yaml hybrid_training_ds.py > logs/chr22_260104.log 2>&1 &

tail -f logs/chr22_260104.log


### 2.3 Merge weight files

In [None]:
%%bash
cd /mnt/qmtang/EvoFill_data/20251204_chr22/models/checkpoint-stage1/
python zero_to_fp32.py . ./

## 3. Imputation

### 3.1 Load the trained model

Choose a path where including `<work_dir>/model` and have trained model.

In [2]:
work_dir = Path('/mnt/qmtang/EvoFill_data/20251230_chr22')
os.chdir(work_dir)
print(f"Work Dir: {work_dir}")

# ---- 1. 加载模型 ----
device = 'cuda' if torch.cuda.is_available() else 'cpu'

json_path = f"{work_dir}/models/model_meta.json"
meta = json.load(open(json_path))
model = EvoFill(
    n_alleles=int(meta["alleles"]),
    total_sites=int(meta["total_sites"]),
    chunk_size=int(meta["chunk_size"]),
    chunk_overlap=int(meta["overlap"]),
    d_model=int(meta["d_model"]),
    d_state=int(meta["d_state"]),
    headdim=int(meta["headdim"]),
    bimamba_layers=int(meta["bimamba_layers"]),
    stack_mamba_layers=int(meta["stack_mamba_layers"])
).to(device)


state_dict = torch.load("/mnt/qmtang/EvoFill_data/20251230_chr22/models/hg38_chr22_v1.0.bin", map_location="cpu")
# state_dict = torch.load(f"{work_dir}/models/pytorch_model_stage2.bin", map_location="cpu")

model.load_state_dict(state_dict)
total_params = sum(p.numel() for p in model.parameters())
model.eval()
print(f'[INF] Model[{meta["model_name"]}] loaded.')
print(f"[INF] Total params: {total_params:,}")

Work Dir: /mnt/qmtang/EvoFill_data/20251230_chr22


[INF] Model[hg38_chr22] loaded.
[INF] Total params: 39,503,723


### 3.2 Inferring

In [3]:
gt_enc_imp = GenotypeEncoder.loadfromdisk(Path(work_dir / "impute_in"))
out_dir = os.path.join(work_dir, 'impute_out')

print(f'[INFER] {gt_enc_imp.n_samples} samples, {gt_enc_imp.n_variants} variants')

# ---- 2. 构建推理 Dataset / Loader ----
imp_dataset = ImputationDataset(
    x_gts_sparse=gt_enc_imp.X_gt,
    seq_depth=gt_enc_imp.seq_depth,
    indices=None                 # 可传入指定样本索引
)
imp_dataset.print_missing_stat()          # 查看原始缺失比例

def collate_fn(batch):
    x_onehot = torch.stack([item[0] for item in batch])
    real_idx_list = [item[1] for item in batch]
    return x_onehot, real_idx_list   # 无 y

imp_loader = torch.utils.data.DataLoader(
    imp_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

y_prob = []
y_mask = []
with torch.no_grad():
    for x_onehot, real_idx in tqdm(imp_loader, desc='Imputing'):
        x_onehot = x_onehot.to(device)
        _, prob, _ = model(x_onehot)
        miss_mask = x_onehot[..., -1].bool()
        y_prob.append(prob)
        y_mask.append(miss_mask)
y_prob = torch.cat(y_prob, dim=0).cpu().numpy()
y_mask = torch.cat(y_mask, dim=0).cpu().numpy()

# 4. 保存
os.makedirs(out_dir, exist_ok=True)
np.save(os.path.join(out_dir, 'impute_prob.npy'), y_prob)
np.save(os.path.join(out_dir, 'impute_mask.npy'), y_mask)
print(f'[INF] 概率矩阵已保存 → {out_dir}/impute_prob.npy '
      f'with shape = {y_prob.shape} ')


[INFER] 93 samples, 190184 variants
[INFER] 93 samples, missing rate = 89.93%


Imputing: 100%|██████████| 93/93 [00:25<00:00,  3.66it/s]


[INF] 概率矩阵已保存 → /mnt/qmtang/EvoFill_data/20251230_chr22/impute_out/impute_prob.npy with shape = (93, 190184, 3) 


### 3.4 Evaluating the imputation results

In [4]:
gt_enc_true = GenotypeEncoder.loadfromdisk(Path(work_dir / "impute_out"))
y_true = gt_enc_true.X_gt.toarray()
maf, bin_cnt = precompute_maf(y_true)
y_true_oh = np.eye(gt_enc_true.seq_depth - 1)[y_true]
bins_metrics   = metrics_by_maf_with95ci(y_prob, y_true_oh, hap_map = gt_enc_true.hap_map, maf_vec = maf, mask=y_mask)
print_maf_stat_df_with95ci(bin_cnt,{'val': bins_metrics})

     MAF_bin Counts val_Acc val_INFO val_IQS val_MaCH   val_Acc_CI95  val_INFO_CI95   val_IQS_CI95  val_MaCH_CI95
(0.00, 0.05) 109870   0.996    0.153   0.156    0.583 [0.996, 0.996] [0.152, 0.155] [0.154, 0.158] [0.580, 0.585]
(0.05, 0.10)  15765   0.964    0.695   0.714    0.780 [0.963, 0.965] [0.689, 0.701] [0.708, 0.720] [0.774, 0.786]
(0.10, 0.20)  22521   0.953    0.757   0.774    0.837 [0.952, 0.954] [0.752, 0.761] [0.769, 0.779] [0.833, 0.842]
(0.20, 0.30)  14896   0.938    0.779   0.790    0.861 [0.936, 0.940] [0.774, 0.784] [0.785, 0.796] [0.856, 0.866]
(0.30, 0.40)  13774   0.935    0.799   0.800    0.867 [0.933, 0.937] [0.793, 0.804] [0.794, 0.805] [0.862, 0.873]
(0.40, 0.50)  12984   0.933    0.799   0.794    0.865 [0.931, 0.935] [0.793, 0.804] [0.788, 0.799] [0.859, 0.870]


### 3.5 Saving to .vcf

In [None]:
PROB_NPY = "/mnt/qmtang/EvoFill_data/20251211_chr22/impute_out2/impute_prob.npy"
MASK_VCF = "/mnt/qmtang/EvoFill_data/20251211_chr22/data/major_pops_val.mask90p.vcf.gz"
OUT_VCFGZ = "/home/zqyin/mamba_test/Other_Model_test/chr22_major/Evofill/major_pops_val.imputed.from_prob.vcf.gz"

OUT_VCFGZ_CREATED = make_imputed_vcfgz_from_prob(PROB_NPY, MASK_VCF, OUT_VCFGZ, digits=FLOAT_DIGITS)