# EvoFill working demo

ver 4.   new imputation loss with evo loss

ver 4.1  long range modules integrated in stage1 training

ver 4.2  stage 3 fine tuning with under-reprensted population samples.

last update: 2025/11/13

## 0. Dependency

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

import sys
import json
import numpy as np
from pathlib import Path
from tqdm import tqdm
import torch
import mamba_ssm

print('Python ver:  ', sys.version)
print('Pytorch ver: ', torch.__version__)
print('Mamba ver:   ', mamba_ssm.__version__)
print('GPU in use:  ', torch.cuda.get_device_name(torch.cuda.current_device()))

os.chdir('/mnt/qmtang/EvoFill/')
from src.data import GenotypeEncoder, ImputationDataset
from src.model import EvoFill
from src.utils import  precompute_maf, metrics_by_maf, print_maf_stat_df

Python ver:   3.10.19 | packaged by conda-forge | (main, Oct 13 2025, 14:08:27) [GCC 14.3.0]
Pytorch ver:  2.8.0+cu129
Mamba ver:    2.2.5
GPU in use:   NVIDIA H100 PCIe


In [None]:
work_dir = Path('/mnt/qmtang/EvoFill_data/20251205_chr22/')
os.chdir(work_dir)

## 1. Encoding all vcfs

### 1.1 1kGP cohorts

In [3]:
gt_enc = GenotypeEncoder(phased = False, gts012 = False, save2disk = True, save_dir = Path(work_dir / "train"))
gt_enc = gt_enc.encode_new(vcf_path   = "train/major_pops.vcf.gz" ,
                           evo_mat    = "train/evo_mat_major_pops.tsv")

print(f"[DATA] {gt_enc.n_samples:,} Samples")
print(f"[DATA] {gt_enc.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc.seq_depth} seq_depth")

[DATA] 总计 190,184 个位点  
[DATA] EvoMat shape: (9967, 9967)
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251205_chr22/train
[DATA] 位点矩阵 = (9967, 190184)，稀疏度 = 18.10%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': -1}，字典深度 = 4
[DATA] 9,967 Samples
[DATA] 190,184 Variants Sites
[DATA] 4 seq_depth


### 1.2 aDNA cohorts

In [None]:
gt_enc_aug = GenotypeEncoder(phased=False, gts012=False, save2disk = True, save_dir = Path(work_dir / "augment"))
gt_enc_aug = gt_enc_aug.encode_ref(
        ref_meta_json = work_dir/"train"/"gt_enc_meta.json",   # 与 Stage1 同构
        # default_gt    = 'miss',
        vcf_path      = "augment/AADR_extracted_samples_renamed.vcf.gz",
        evo_mat       = "augment/evo_mat_aDNA.tsv")

print(f"[DATA] {gt_enc_aug.n_samples:,} Samples")
print(f"[DATA] {gt_enc_aug.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc_aug.seq_depth} seq_depth")

[DATA] 总计 14,867 个位点  
[DATA] EvoMat shape: (7062, 7062)
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251125_chr22/augment
[DATA] 位点矩阵 = (7062, 14867)，稀疏度 = 76.33%，缺失率 = 0.00%
[DATA] 位点字典 = {'0|1': 1, '1|1': 2, '0|0': 0, '.|.': -1, '0/1': 1, '1/1': 2, '0/0': 0, './.': -1}，字典深度 = 4
[DATA] 7,062 Samples
[DATA] 14,867 Variants Sites
[DATA] 4 seq_depth


### 1.3 finetuning

In [4]:
gt_enc_urp = GenotypeEncoder(phased=False, gts012=False, save2disk=True, save_dir = Path(work_dir / "finetune"))
gt_enc_urp = gt_enc_urp.encode_ref(
        ref_meta_json = work_dir/"train"/"gt_enc_meta.json",   # 与 Stage1 同构
        default_gt    = 'ref',
        vcf_path      = work_dir/"finetune"/"minor_pops.10pct.vcf.gz",
        evo_mat       = work_dir/"finetune"/"evo_mat_minor_pops.10pct.tsv")

print(f'[URP] {gt_enc_urp.n_samples} samples, {gt_enc_urp.n_variants} variants')

[DATA] 总计 190,184 个位点  
[DATA] EvoMat shape: (29, 29)
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251204_chr22/finetune
[DATA] 位点矩阵 = (29, 190184)，稀疏度 = 26.40%，缺失率 = 0.00%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': -1}，字典深度 = 4
[URP] 29 samples, 190184 variants


### 1.4 validation

In [5]:
gt_enc_imp = GenotypeEncoder(phased=False, gts012=False, save2disk=True, save_dir = Path(work_dir / "impute_in"))
gt_enc_imp = gt_enc_imp.encode_ref(
        ref_meta_json = work_dir/"train"/"gt_enc_meta.json",   # 与 Stage1 同构
        default_gt    = 'ref',
        vcf_path      = work_dir/"impute_in"/"minor_pops.90pct.masked90p.vcf.gz" )

print(f'[INFER] {gt_enc_imp.n_samples} samples, {gt_enc_imp.n_variants} variants')

[DATA] 总计 190,184 个位点  
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251204_chr22/impute_in
[DATA] 位点矩阵 = (269, 190184)，稀疏度 = 92.60%，缺失率 = 90.01%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': -1}，字典深度 = 4
[INFER] 269 samples, 190184 variants


## 2. Training with multi-GPUs

### 2.1 Pre-training

In [None]:
%%bash
cd /mnt/qmtang/EvoFill/
nohup env OMP_NUM_THREADS=4 \
  accelerate launch --config_file ds_zero3.yaml \
  stage1_training_ds.py \
  > logs/pre_chr6_251204.log 2>&1 &
%%!

### 2.1b

In [None]:
%%bash
cd /mnt/qmtang/EvoFill/
nohup env OMP_NUM_THREADS=8 \
  accelerate launch --config_file ds_zero3.yaml \
  stage2_1kGP-posttrain_ds.py \
  > logs/post_chr22_251127.log 2>&1 &

### 2.2 Fine-tuning (Few-shot URP)

In [None]:
%%bash
cd /mnt/qmtang/EvoFill/
nohup env OMP_NUM_THREADS=8 \
  accelerate launch --config_file ds_zero3.yaml \
  stage4_finetuning_ds.py \
  > logs/finetune_chr22_251125.log 2>&1 &

### 2.3 Merge weight files

In [None]:
%%bash
cd /mnt/qmtang/EvoFill_data/20251204_chr22/models/checkpoint-stage1/
python zero_to_fp32.py . ./

## 3. Imputation

### 3.1 Load the trained model

Choose a path where including `<work_dir>/model` and have trained model.

In [7]:
work_dir = Path('/mnt/qmtang/EvoFill_data/20251205_chr22')
os.chdir(work_dir)
print(f"Work Dir: {work_dir}")

# ---- 1. 加载模型 ----
device = 'cuda' if torch.cuda.is_available() else 'cpu'

json_path = f"{work_dir}/models/model_meta.json"
meta = json.load(open(json_path))
model = EvoFill(
    n_alleles=int(meta["alleles"]),
    total_sites=int(meta["total_sites"]),
    chunk_size=int(meta["chunk_size"]),
    chunk_overlap=int(meta["overlap"]),
    d_model=int(meta["d_model"]),
    d_state=int(meta["d_state"]),
    headdim=int(meta["headdim"]),
).to(device)


state_dict = torch.load(f"{work_dir}/models/checkpoint-stage1-1207-195430/pytorch_model.bin", map_location="cpu")


model.load_state_dict(state_dict)
total_params = sum(p.numel() for p in model.parameters())
model.eval()
print(f'[INF] Model[{meta["model_name"]}] loaded.')
print(f"[INF] Total params: {total_params:,}")

Work Dir: /mnt/qmtang/EvoFill_data/20251205_chr22
[INF] Model[hg38_chr22plus] loaded.
[INF] Total params: 14,792,729


### 3.2 Inferring

In [None]:
gt_enc_imp = GenotypeEncoder(phased=False, gts012=False, save2disk=True, save_dir = Path(work_dir / "impute_in"))
gt_enc_imp = gt_enc_imp.encode_ref(
        ref_meta_json = work_dir/"train"/"gt_enc_meta.json",   # 与 Stage1 同构
        vcf_path      = work_dir/"impute_in"/"minor_pops.90pct.masked90p.vcf.gz" )  #"HLA_minor_pops.90pct.masked90p.vcf.gz"
print(f'[INFER] {gt_enc_imp.n_samples} samples, {gt_enc_imp.n_variants} variants')

In [9]:
gt_enc_imp = GenotypeEncoder.loadfromdisk(Path(work_dir / "impute_in"))

print(f'[INFER] {gt_enc_imp.n_samples} samples, {gt_enc_imp.n_variants} variants')

# ---- 2. 构建推理 Dataset / Loader ----
imp_dataset = ImputationDataset(
    x_gts_sparse=gt_enc_imp.X_gt,
    seq_depth=gt_enc_imp.seq_depth,
    indices=None                 # 可传入指定样本索引
)
imp_dataset.print_missing_stat()          # 查看原始缺失比例

def collate_fn(batch):
    x_onehot = torch.stack([item[0] for item in batch])
    real_idx_list = [item[1] for item in batch]
    return x_onehot, real_idx_list   # 无 y

imp_loader = torch.utils.data.DataLoader(
    imp_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

y_prob = []
y_mask = []
with torch.no_grad():
    for x_onehot, real_idx in tqdm(imp_loader, desc='Imputing'):
        x_onehot = x_onehot.to(device)
        _, prob, _ = model(x_onehot)
        miss_mask = x_onehot[..., -1].bool()
        y_prob.append(prob)
        y_mask.append(miss_mask)
y_prob = torch.cat(y_prob, dim=0).cpu().numpy()
y_mask = torch.cat(y_mask, dim=0).cpu().numpy()
# 4. 保存
out_dir = os.path.join(work_dir, 'impute_out')
os.makedirs(out_dir, exist_ok=True)
np.save(os.path.join(out_dir, 'impute_prob.npy'), y_prob)
np.save(os.path.join(out_dir, 'impute_mask.npy'), y_mask)
print(f'[INF] 概率矩阵已保存 → {out_dir}/impute_prob.npy '
      f'with shape = {y_prob.shape} ')


[INFER] 269 samples, 190184 variants
[ImputationDataset] 269 samples, missing rate = 90.01%


Imputing: 100%|██████████| 5/5 [00:10<00:00,  2.18s/it]


[INF] 概率矩阵已保存 → /mnt/qmtang/EvoFill_data/20251205_chr22/impute_out/impute_prob.npy with shape = (269, 190184, 3) 


### 3.4 Evaluating the imputation results

In [10]:
gt_enc_true = GenotypeEncoder(phased=False, gts012=False, save2disk=False)
gt_enc_true = gt_enc_true.encode_ref(
        ref_meta_json = work_dir/"train"/"gt_enc_meta.json",   # 与 Stage1 同构
        vcf_path      = work_dir/"impute_out"/"minor_pops.90pct.vcf.gz" )
y_true = gt_enc_true.X_gt.toarray()
maf, bin_cnt = precompute_maf(y_true)
y_true_oh = np.eye(gt_enc_true.seq_depth - 1)[y_true]
bins_metrics   = metrics_by_maf(y_prob, y_true_oh, hap_map = gt_enc_true.hap_map, maf_vec = maf, mask=y_mask)
print_maf_stat_df(bin_cnt,{'val': bins_metrics})

[DATA] 总计 190,184 个位点  
[DATA] 位点矩阵 = (269, 190184)，稀疏度 = 25.95%，缺失率 = 0.00%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': -1}，字典深度 = 4
     MAF_bin Counts val_Acc val_INFO val_IQS val_MaCH
(0.00, 0.05)  93213   0.970    0.142   0.169    0.287
(0.05, 0.10)  22681   0.909    0.387   0.428    0.626
(0.10, 0.20)  25627   0.859    0.514   0.557    0.801
(0.20, 0.30)  19171   0.822    0.595   0.640    0.912
(0.30, 0.40)  15660   0.805    0.633   0.673    0.960
(0.40, 0.50)  13727   0.788    0.631   0.669    0.977


chr22 AADR 中总计包含 14,867 个位点（删去1kGP中不存在、非 biallelic 位点）

测试集：90% 少数族裔样本（CDX, BEB, ASW），90% masked 位点

训练集：

**= 1kGP ONLY =**

|      MAF_bin | Counts | val_Acc | val_INFO | val_IQS | val_MaCH |
| :----------: | :----: | ------: | -------: | ------: | -------: |
| (0.00, 0.05) |   1546 |   0.972 |    0.060 |   0.068 |    0.166 |
| (0.05, 0.10) |   1246 |   0.931 |    0.198 |   0.216 |    0.367 |
| (0.10, 0.20) |   2205 |   0.875 |    0.332 |   0.353 |    0.545 |
| (0.20, 0.30) |   3112 |   0.798 |    0.487 |   0.495 |    0.754 |
| (0.30, 0.40) |   5193 |   0.755 |    0.550 |   0.552 |    0.853 |
| (0.40, 0.50) |   1565 |   0.739 |    0.548 |   0.545 |    0.845 |

**= aDNA -> 1kGP =** 

|      MAF_bin | Counts | val_Acc | val_INFO | val_IQS | val_MaCH |
| :----------: | :----: | ------: | -------: | ------: | -------: |
| (0.00, 0.05) |   1546 |   0.972 |    0.089 |   0.085 |    0.193 |
| (0.05, 0.10) |   1246 |   0.935 |    0.280 |   0.283 |    0.463 |
| (0.10, 0.20) |   2205 |   0.890 |    0.452 |   0.456 |    0.685 |
| (0.20, 0.30) |   3112 |   0.824 |    0.592 |   0.574 |    0.850 |
| (0.30, 0.40) |   5193 |   0.787 |    0.645 |   0.614 |    0.906 |
| (0.40, 0.50) |   1565 |   0.770 |    0.634 |   0.599 |    0.894 |

**= 1kGP -> aDNA =** 

|      MAF_bin | Counts | val_Acc | val_INFO | val_IQS | val_MaCH |
| :----------: | :----: | ------: | -------: | ------: | -------: |
| (0.00, 0.05) |   1546 |   0.938 |    0.109 |   0.227 |    0.734 |
| (0.05, 0.10) |   1246 |   0.893 |    0.174 |   0.396 |    0.633 |
| (0.10, 0.20) |   2205 |   0.810 |    0.196 |   0.435 |    0.545 |
| (0.20, 0.30) |   3112 |   0.690 |    0.144 |   0.433 |    0.499 |
| (0.30, 0.40) |   5193 |   0.640 |    0.101 |   0.429 |    0.503 |
| (0.40, 0.50) |   1565 |   0.604 |    0.102 |   0.398 |    0.464 |


### 3.5 Saving to .vcf

In [38]:
from cyvcf2 import VCF, Writer

# 0. 路径
ref_vcf = "/mnt/qmtang/EvoFill_data/20251107_ver4/minor_pops.masked30p.vcf.gz"
out_vcf = os.path.join(out_dir, 'imputed.vcf.gz')

n_site = gt_enc.n_variants
n_samp = gt_enc.n_samples
n_alleles = gt_enc.seq_depth - 1
assert y_prob.shape == (n_samp, n_site, n_alleles)

# 2. 反向映射  idx -> '0|0' / '0|1' / ...
rev_hap_map = {v: k for k, v in gt_enc.hap_map.items()}

samp2idx = {sid: i for i, sid in enumerate(gt_enc.sample_ids)}

# 4. 打开参考 VCF
invcf = VCF(ref_vcf)
tmpl  = invcf
tmpl.set_samples(gt_enc.sample_ids)   # 替换样本列

out = Writer(out_vcf, tmpl, mode='wz')

for rec_idx, rec in enumerate(invcf):
    # 当前位点全部样本的 GT
    gt_int_pairs = []
    for samp_idx, sample_id in enumerate(gt_enc.sample_ids):
        old_gt = rec.genotypes[samp_idx]          # [allele1, allele2, phased]
        if old_gt[0] == -1 or old_gt[1] == -1:    # 缺失
            prob_vec = y_prob[samp_idx, rec_idx, :].ravel()
            best_idx = int(prob_vec.argmax())
            gt_str   = rev_hap_map[best_idx]
            alleles  = list(map(int, gt_str.split('|')))
            phased   = old_gt[2] if old_gt[2] != -1 else 1
            gt_int_pairs.append([alleles[0], alleles[1], phased])
        else:                                       # 非缺失，保持原样
            gt_int_pairs.append(old_gt)

    # 转成 int8 二维数组  (n_sample, 3)  last dim = [a1,a2,phased]
    gt_array = np.array(gt_int_pairs, dtype=np.int8)
    rec.set_format('GT', gt_array)
    out.write_record(rec)

invcf.close()
out.close()

# 5. tabix
os.system(f'tabix -fp vcf {out_vcf}')
print(f'[INF] 缺失位点填充完成 → {out_vcf}')

[INF] 缺失位点填充完成 → /mnt/qmtang/EvoFill_data/20251107_ver4/impute_out/imputed.vcf.gz
