# EvoFill working demo

ver 4.   new imputation loss with evo loss

ver 4.1  long range modules integrated in stage1 training

ver 4.2  stage 3 fine tuning with under-reprensted population samples.

last update: 2025/11/13

## Content
- [0. Dependency](#0.-Dependency)  
- [1. Genotype(.vcf) encoding](#1.-Genotype(.vcf)-encoding)  
- [2. Training a new model](#2.-Training-a-new-model)  
  - [2.1 Dataloader](#2.1-Dataloader)  
  - [2.2 Model initialization](#2.2-Model-initialization)  
  - [2.3 stage 1: Chunk Module Training](#2.3-stage-1:-Chunk-Module-Training)  
  - [2.4 stage 2: Ultra-Long-Range LD Module Training](#2.4-stage-2:-Ultra-Long-Range-LD-Module-Training)  
  - [2.5 stage 3:](#2.5-)
- [3. Imputation using trained model](#3.-Imputation-using-trained-model)  
  - [3.1 Load the trained model](#3.1-Load-the-trained-model)  
  - [3.2 Encode .vcf file need be impute](#3.2-Encode-.vcf-file-need-be-impute)  
  - [3.3 Inferring](#3.3-Inferring)  
  - [3.4 Evaluating the imputation results](#3.4-Evaluating-the-imputation-results)  
  - [3.5 Saving to .vcf](#3.5-Saving-to-.vcf)

In jupyter notebook, we use single GPU here for demo.

For parallel GPU training, please use other framework with `./train.py`, like `torch run --nproc_per_node=8 train.py` 

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 
os.chdir('/mnt/qmtang/EvoFill/')

## 0. Dependency

In [2]:
import sys
import json
import numpy as np
from pathlib import Path
import pandas as pd
import torch
import mamba_ssm
from tqdm import tqdm
from itertools import combinations
from torch.utils.data import DataLoader
from torch.optim import AdamW, SparseAdam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split

print('Python ver:  ', sys.version)
print('Pytorch ver: ', torch.__version__)
print('Mamba ver:   ', mamba_ssm.__version__)
print('GPU in use:  ', torch.cuda.get_device_name(torch.cuda.current_device()))

Python ver:   3.10.19 | packaged by conda-forge | (main, Oct 13 2025, 14:08:27) [GCC 14.3.0]
Pytorch ver:  2.8.0+cu129
Mamba ver:    2.2.5
GPU in use:   NVIDIA H100 PCIe


In [None]:
from src.data import GenotypeEncoder, GenomicDataset, GenomicDataset_Missing, ImputationDataset
from src.model import EvoFill
from src.loss import ImputationLoss, ImputationLoss_Missing
from src.utils import create_directories, set_seed, precompute_maf, metrics_by_maf, print_maf_stat_df

In [4]:
work_dir = Path('/mnt/qmtang/EvoFill_data/20251107_ver4')
create_directories(work_dir)
os.chdir(work_dir)

## 1. Genotype(.vcf) encoding

option 1. with extra evolutionary information

In [5]:
gt_enc = GenotypeEncoder(phased = False, gts012 = False, 
                         save2disk = True, save_dir = Path(work_dir / "pre_train"))

gt_enc = gt_enc.encode_new(vcf_path = Path(work_dir / "pre_train" / "major_pops.vcf.gz"),
                           evo_mat = Path(work_dir / "pre_train" / "evo_mat_major_pops.tsv"))

print(f"[DATA] {gt_enc.n_samples:,} Samples")
print(f"[DATA] {gt_enc.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc.seq_depth} seq_depth")

[DATA] 总计 99,314 个位点  
[DATA] EvoMat shape: (2236, 2236)
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251107_ver4/pre_train
[DATA] 位点矩阵 = (2236, 99314)，稀疏度 = 28.19%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': 3}，字典深度 = 4
[DATA] 2,236 Samples
[DATA] 99,314 Variants Sites
[DATA] 4 seq_depth


option 2. no extra evolutionary information

In [None]:
gt_enc = GenotypeEncoder(phased = False, gts012 = False, 
                         save2disk = True, save_dir = Path(work_dir / "pre_train"))

gt_enc = gt_enc.encode_new(vcf_path = Path(work_dir / "pre_train" / "major_pops.vcf.gz"),
                           evo_mat = None)

print(f"[DATA] {gt_enc.n_samples:,} Samples")
print(f"[DATA] {gt_enc.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc.seq_depth} seq_depth")

[DATA] 总计 99,314 个位点  
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251110_ver4/
[DATA] 位点矩阵 = (2236, 99314)，稀疏度 = 28.19%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': 3}，字典深度 = 4
[DATA] 2,236 Samples
[DATA] 99,314 Variants Sites
[DATA] 4 seq_depth


## 2. Training a new model

Choose a path which including GenotypeEncoder processed files.

### 2.1 Dataloader

In [5]:
test_n_samples = 128
batch_size    = 16
max_mr        = 0.7
min_mr        = 0.3

gt_enc = GenotypeEncoder.loadfromdisk(Path(work_dir / "pre_train"))
print(f'{gt_enc.n_samples:,} samples, {gt_enc.n_variants:,} variants, {gt_enc.seq_depth} seq-depth.')

x_train_indices, x_test_indices = train_test_split(
    range(gt_enc.n_samples),
    test_size=test_n_samples,
    random_state=3047,
    shuffle=True
)

print(f"{len(x_train_indices):,} samples in train")
print(f"{len(x_test_indices):,} samples in test")
# print(f"evo_mat: {gt_enc.evo_mat.shape}")

train_dataset = GenomicDataset(
    gt_enc.X_gt,
    evo_mat=gt_enc.evo_mat,
    seq_depth=gt_enc.seq_depth,
    mask=True,
    masking_rates=(min_mr, max_mr),
    indices=x_train_indices
)

test_dataset = GenomicDataset(
    gt_enc.X_gt,
    evo_mat=gt_enc.evo_mat,
    seq_depth=gt_enc.seq_depth,
    mask=True,
    masking_rates=(min_mr, max_mr),
    indices=x_test_indices
)

def collate_fn(batch):
    x_onehot = torch.stack([item[0] for item in batch])
    y_onehot = torch.stack([item[1] for item in batch])
    real_idx_list = [item[2] for item in batch]

    # 提取 evo_mat 子矩阵
    if train_dataset.evo_mat is not None:
        evo_mat_batch = train_dataset.evo_mat[np.ix_(real_idx_list, real_idx_list)]
        evo_mat_batch = torch.FloatTensor(evo_mat_batch)
    else:
        evo_mat_batch = torch.empty(0)

    return x_onehot, y_onehot, evo_mat_batch
    
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

2,236 samples, 99,314 variants, 4 seq-depth.
2,108 samples in train
128 samples in test


### 2.2 Model initialization

In [6]:
model_name  = 'hg19_chr22trim'
total_sites = gt_enc.n_variants
alleles     = gt_enc.seq_depth
chunk_size  = 32768
overlap     = 1024
d_model     = 64

set_seed(42)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()
print(f"Using device: {device}")

model = EvoFill(d_model, alleles, total_sites, chunk_size, overlap).to(device)
print(f"model[{model_name}] would have {model.n_chunks} chunks.")

criterion = ImputationLoss(use_r2=True, use_evo=True, r2_weight=1, evo_weight=4, evo_lambda=10)
# criterion = ImputationLoss(use_r2=True, use_evo=False)

meta = {
    "model_name": model_name,
    "total_sites": total_sites,
    "alleles": alleles,
    "chunk_size": chunk_size,
    "overlap": overlap,
    "d_model": d_model
}
save_path = os.path.join(Path(work_dir / "models"), "model_meta.json")
with open(save_path, "w") as f:
    json.dump(meta, f, indent=4)

Using device: cuda
model[hg19_chr22trim] would have 4 chunks.


### 2.3 stage 1: Chunk Module Training

In [None]:
verbose            = False
max_epochs         = 100
lr                 = 0.001
weight_decay       = 1e-5
earlystop_patience = 11
scheduler_factor   = 0.5
scheduler_patience = 5
scheduler_min_lr   = 1e-8


for cid in range(model.n_chunks):
    chunk_mask = model.chunk_masks[cid].cpu()
    chunk_maf, chunk_bin_cnt = precompute_maf(gt_enc.X_gt[:,chunk_mask.bool().cpu().numpy()].toarray(),  mask_int=gt_enc.seq_depth)
    if verbose:
        print(f"=== Chunk {cid + 1} STAT ===")
        maf_df = pd.DataFrame({
            'MAF_bin': ['(0.00, 0.05)', '(0.05, 0.10)', '(0.10, 0.20)',
                        '(0.20, 0.30)', '(0.30, 0.40)', '(0.40, 0.50)'],
            'Counts':  [f"{c}" for c in chunk_bin_cnt],
        })
        print(maf_df.to_string(index=False))

    # 收集所有稀疏参数（主要是嵌入层）
    sparse_params = []
    dense_params = []
    # 当前chunk的模块参数（密集）
    dense_params.extend(model.chunk_modules[cid].parameters())
    
    # 全局输出层的卷积参数（密集）
    dense_params.extend([model.global_out.w1, model.global_out.b1])
    dense_params.extend([model.global_out.w2, model.global_out.b2])
    # ULR默认启用
    if hasattr(model.global_out, 'ulr_mamba'):
        for name, param in model.global_out.ulr_mamba.named_parameters():
            if 'idx_embed' in name:
                sparse_params.append(param)
            else:
                dense_params.append(param)
    
    # 创建分离优化器
    optim_sparse = SparseAdam(sparse_params, lr=lr) if sparse_params else None
    optim_dense = AdamW(dense_params, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))
    
    # 学习率调度器
    scheduler_sparse = ReduceLROnPlateau(optim_sparse, mode='min', factor=scheduler_factor, 
                        patience=scheduler_patience, min_lr=scheduler_min_lr) if optim_sparse else None
    scheduler_dense = ReduceLROnPlateau(optim_dense, mode='min', factor=scheduler_factor, 
                        patience=scheduler_patience, min_lr=scheduler_min_lr)
    
    best_loss = float('inf')
    patience = earlystop_patience
    patience_counter = 0
    is_early_stopped = False
    train_logs_sum = None
    test_logs_sum   = None
    for epoch in range(max_epochs):
        model.train()
        train_loss = 0.0
        train_prob, train_gts, train_mask = [], [], []

        train_pbar = tqdm(train_loader, desc=f'Chunk {cid + 1}/{model.n_chunks}, Epoch {epoch + 1}/{max_epochs}',) # leave=False
        for batch_idx, (x, target, evo_mat) in enumerate(train_pbar):
            x,  target = x.to(device), target.to(device)
            if evo_mat.numel() == 0:
                evo_mat = None
            else:
                evo_mat = evo_mat.to(device)

            # 清零梯度
            if optim_sparse:
                optim_sparse.zero_grad()
            optim_dense.zero_grad()
            
            # 前向传播
            logits, prob, mask_idx = model(x, cid)
            loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:, mask_idx], evo_mat)
            
            # 反向传播
            loss.backward()
            
            # 更新参数
            if optim_sparse:
                optim_sparse.step()
            optim_dense.step()

            train_loss += loss.item()
            if train_logs_sum is None:          # 第一次初始化
                train_logs_sum = {k: 0.0 for k in logs}
            for k, v in logs.items():
                train_logs_sum[k] += v
            # train_pbar.set_postfix({'loss': loss.item(), 'ce':logs['ce'], 'r2':logs['r2'], 'evo':logs['evo']})

            # === 收集训练结果 ===
            miss_mask = x[:, mask_idx][..., -1].bool()         # 只关心被 mask 的位点
            train_prob.append(prob[:, mask_idx].detach())
            train_gts.append(target[:,mask_idx].detach())
            train_mask.append(miss_mask)

        # 训练集 MAF-acc
        train_prob = torch.cat(train_prob, dim=0)
        train_gts  = torch.cat(train_gts,    dim=0)
        train_mask = torch.cat(train_mask,   dim=0)

        # ----------- 验证循环同理 ------------
        model.eval()
        test_loss = 0.0
        test_prob, test_gts = [], []
        if test_logs_sum is None:
            test_logs_sum = {k: 0.0 for k in train_logs_sum}
        with torch.no_grad():
            for x, target, evo_mat in test_loader:
                x,  target = x.to(device), target.to(device)
                if evo_mat.numel() == 0:
                    evo_mat = None
                else:
                    evo_mat = evo_mat.to(device)
                logits, prob, mask_idx = model(x, cid)
                loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:,mask_idx], evo_mat) 
                test_loss += loss.item()
                for k, v in logs.items():
                    test_logs_sum[k] += v
                test_prob.append(prob[:, mask_idx].detach())
                test_gts.append(target[:,mask_idx].detach())

        test_prob = torch.cat(test_prob, dim=0)
        test_gts    = torch.cat(test_gts,    dim=0)
        avg_train_loss = train_loss / len(train_loader)
        avg_test_loss   = test_loss   / len(test_loader)
        avg_train_logs = {k: v / len(train_loader) for k, v in train_logs_sum.items()}
        avg_test_logs = {k: v / len(test_loader) for k, v in test_logs_sum.items()}
        
        # 更新学习率
        if scheduler_sparse:
            scheduler_sparse.step(avg_test_loss)
        scheduler_dense.step(avg_test_loss)
        
        current_denselr = optim_dense.param_groups[0]['lr']
        current_sparselr = optim_sparse.param_groups[0]['lr'] if optim_sparse else 0

        log_str = (f'Chunk {cid + 1}/{model.n_chunks}, '
            f'Epoch {epoch + 1}/{max_epochs}, '
            f'Total Loss, Train = {avg_train_loss:.1f}, '
            f'Test = {avg_test_loss:.1f}, '
            f'LR: {current_denselr:.2e}')

        log_str += '\n        Train'
        for k, v in avg_train_logs.items():
            log_str += f', {k}: {v:.1f}'
        log_str += '\n        Test '
        for k, v in avg_test_logs.items():
            log_str += f', {k}: {v:.1f}'
        print(log_str)

        # 清空累加器，供下一个 epoch 使用
        train_logs_sum = {k: 0.0 for k in train_logs_sum}
        test_logs_sum   = {k: 0.0 for k in test_logs_sum}
        
        # Early stopping
        if avg_test_loss < best_loss:
            best_loss = avg_test_loss
            patience_counter = 0
            # 只存当前 chunk 专家 + 全局层
            ckpt = {
                'chunk_id': cid,
                'chunk_embed_state': model.chunk_embeds[cid].state_dict(),
                'chunk_module_state': model.chunk_modules[cid].state_dict(),
                'global_out_state': model.global_out.state_dict(),
                'best_val_loss': best_loss,
            }
            torch.save(ckpt, f'{work_dir}/models/{model_name}_chunk[{cid}].pth')
            predres_with_bestloss = (train_prob, train_gts, test_prob, test_gts)
            if verbose:
                train_bins_metrics = metrics_by_maf(train_prob, train_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=train_mask)
                test_bins_metrics   = metrics_by_maf(test_prob,  test_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "test":  test_bins_metrics})
                print(f'  --> updated {model_name}_chunk[{cid}].pth')
        else:
            patience_counter += 1
            if patience_counter >= earlystop_patience:
                is_early_stopped = True
                print(f'Chunk {cid + 1}/{model.n_chunks}, Early stopping triggered')
                train_prob, train_gts, test_prob, test_gts = predres_with_bestloss
                train_bins_metrics = metrics_by_maf(train_prob, train_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=train_mask)
                test_bins_metrics   = metrics_by_maf(test_prob,   test_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "test":   test_bins_metrics})
                break

    if not is_early_stopped:
        train_bins_metrics = metrics_by_maf(train_prob, train_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=train_mask)
        test_bins_metrics   = metrics_by_maf(test_prob,   test_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=None)
        print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "test":   test_bins_metrics})

    # 清理优化器
    del optim_sparse, optim_dense
    if scheduler_sparse:
        del scheduler_sparse
    del scheduler_dense
    torch.cuda.empty_cache()

# ---------------- 全部 chunk 训练完成 -> 保存完整模型 ----------------
final_ckpt = {
    'model_state': model.state_dict(),
    'n_chunks': model.n_chunks,
    'chunk_size': model.chunk_size,
    'chunk_overlap': model.chunk_overlap,
}
torch.save(final_ckpt, f'{work_dir}/models/{model_name}_stage1.pth')
print(f'==> STAGE1 (Chunk Module) training finished: {work_dir}/models/{model_name}_stage1.pth')

Chunk 1/4, Epoch 1/100: 100%|██████████| 132/132 [01:44<00:00,  1.26it/s]


Chunk 1/4, Epoch 1/100, Total Loss, Train = 170168.9, Val = 91773.3, LR: 1.00e-03
        Train, ce: 170168.9, r2: -8356.9, evo: 20506.9
        Val  , ce: 91773.3, r2: -10818.7, evo: 19573.6


Chunk 1/4, Epoch 2/100: 100%|██████████| 132/132 [01:09<00:00,  1.89it/s]


Chunk 1/4, Epoch 2/100, Total Loss, Train = 78159.1, Val = 66247.0, LR: 1.00e-03
        Train, ce: 78159.1, r2: -11029.0, evo: 18623.2
        Val  , ce: 66247.0, r2: -11783.8, evo: 15786.5


Chunk 1/4, Epoch 3/100: 100%|██████████| 132/132 [01:10<00:00,  1.86it/s]


Chunk 1/4, Epoch 3/100, Total Loss, Train = 63727.7, Val = 62764.0, LR: 1.00e-03
        Train, ce: 63727.7, r2: -11627.8, evo: 16587.0
        Val  , ce: 62764.0, r2: -12092.8, evo: 14085.9


Chunk 1/4, Epoch 4/100: 100%|██████████| 132/132 [01:11<00:00,  1.85it/s]


Chunk 1/4, Epoch 4/100, Total Loss, Train = 56631.2, Val = 53795.7, LR: 1.00e-03
        Train, ce: 56631.2, r2: -11866.9, evo: 15397.9
        Val  , ce: 53795.7, r2: -12344.6, evo: 13168.1


Chunk 1/4, Epoch 5/100: 100%|██████████| 132/132 [01:09<00:00,  1.90it/s]


Chunk 1/4, Epoch 5/100, Total Loss, Train = 51525.9, Val = 50172.6, LR: 1.00e-03
        Train, ce: 51525.9, r2: -12095.0, evo: 14443.2
        Val  , ce: 50172.6, r2: -12584.9, evo: 14575.9


Chunk 1/4, Epoch 6/100: 100%|██████████| 132/132 [01:10<00:00,  1.87it/s]


Chunk 1/4, Epoch 6/100, Total Loss, Train = 48550.5, Val = 47136.6, LR: 1.00e-03
        Train, ce: 48550.5, r2: -12193.9, evo: 13843.2
        Val  , ce: 47136.6, r2: -12581.9, evo: 13128.7


Chunk 1/4, Epoch 7/100: 100%|██████████| 132/132 [01:10<00:00,  1.87it/s]


Chunk 1/4, Epoch 7/100, Total Loss, Train = 45062.8, Val = 49123.1, LR: 1.00e-03
        Train, ce: 45062.8, r2: -12300.4, evo: 13201.4
        Val  , ce: 49123.1, r2: -12708.4, evo: 11923.0


Chunk 1/4, Epoch 8/100: 100%|██████████| 132/132 [01:10<00:00,  1.87it/s]


Chunk 1/4, Epoch 8/100, Total Loss, Train = 42822.0, Val = 40542.0, LR: 1.00e-03
        Train, ce: 42822.0, r2: -12397.0, evo: 12717.6
        Val  , ce: 40542.0, r2: -12862.7, evo: 12651.0


Chunk 1/4, Epoch 9/100: 100%|██████████| 132/132 [01:10<00:00,  1.88it/s]


Chunk 1/4, Epoch 9/100, Total Loss, Train = 41126.7, Val = 39755.2, LR: 1.00e-03
        Train, ce: 41126.7, r2: -12485.3, evo: 12407.4
        Val  , ce: 39755.2, r2: -12860.7, evo: 11628.4


Chunk 1/4, Epoch 10/100: 100%|██████████| 132/132 [01:11<00:00,  1.86it/s]


Chunk 1/4, Epoch 10/100, Total Loss, Train = 39622.1, Val = 41968.3, LR: 1.00e-03
        Train, ce: 39622.1, r2: -12526.8, evo: 12106.9
        Val  , ce: 41968.3, r2: -12951.7, evo: 11494.8


Chunk 1/4, Epoch 11/100: 100%|██████████| 132/132 [01:10<00:00,  1.87it/s]


Chunk 1/4, Epoch 11/100, Total Loss, Train = 38179.9, Val = 37982.1, LR: 1.00e-03
        Train, ce: 38179.9, r2: -12567.1, evo: 11796.1
        Val  , ce: 37982.1, r2: -13016.1, evo: 10898.5


Chunk 1/4, Epoch 12/100: 100%|██████████| 132/132 [01:10<00:00,  1.86it/s]


Chunk 1/4, Epoch 12/100, Total Loss, Train = 37036.7, Val = 33193.1, LR: 1.00e-03
        Train, ce: 37036.7, r2: -12636.9, evo: 11556.7
        Val  , ce: 33193.1, r2: -13104.0, evo: 10495.1


Chunk 1/4, Epoch 13/100: 100%|██████████| 132/132 [01:10<00:00,  1.87it/s]


Chunk 1/4, Epoch 13/100, Total Loss, Train = 35963.6, Val = 36950.4, LR: 1.00e-03
        Train, ce: 35963.6, r2: -12637.1, evo: 11323.3
        Val  , ce: 36950.4, r2: -13084.7, evo: 11767.1


Chunk 1/4, Epoch 14/100: 100%|██████████| 132/132 [01:10<00:00,  1.87it/s]


Chunk 1/4, Epoch 14/100, Total Loss, Train = 35045.5, Val = 33393.2, LR: 1.00e-03
        Train, ce: 35045.5, r2: -12676.5, evo: 11147.6
        Val  , ce: 33393.2, r2: -13115.4, evo: 10424.2


Chunk 1/4, Epoch 15/100: 100%|██████████| 132/132 [01:10<00:00,  1.88it/s]


Chunk 1/4, Epoch 15/100, Total Loss, Train = 34044.2, Val = 34036.0, LR: 1.00e-03
        Train, ce: 34044.2, r2: -12741.7, evo: 10916.0
        Val  , ce: 34036.0, r2: -13061.5, evo: 11010.6


Chunk 1/4, Epoch 16/100: 100%|██████████| 132/132 [01:09<00:00,  1.89it/s]


Chunk 1/4, Epoch 16/100, Total Loss, Train = 33233.6, Val = 31978.2, LR: 1.00e-03
        Train, ce: 33233.6, r2: -12745.8, evo: 10748.4
        Val  , ce: 31978.2, r2: -13217.2, evo: 11108.8


Chunk 1/4, Epoch 17/100: 100%|██████████| 132/132 [01:10<00:00,  1.87it/s]


Chunk 1/4, Epoch 17/100, Total Loss, Train = 32216.4, Val = 33772.2, LR: 1.00e-03
        Train, ce: 32216.4, r2: -12786.8, evo: 10523.2
        Val  , ce: 33772.2, r2: -13136.1, evo: 9535.3


Chunk 1/4, Epoch 18/100: 100%|██████████| 132/132 [01:09<00:00,  1.89it/s]


Chunk 1/4, Epoch 18/100, Total Loss, Train = 31794.8, Val = 31105.5, LR: 1.00e-03
        Train, ce: 31794.8, r2: -12769.1, evo: 10430.8
        Val  , ce: 31105.5, r2: -13235.9, evo: 10520.7


Chunk 1/4, Epoch 19/100: 100%|██████████| 132/132 [01:10<00:00,  1.87it/s]


Chunk 1/4, Epoch 19/100, Total Loss, Train = 30945.1, Val = 30810.7, LR: 1.00e-03
        Train, ce: 30945.1, r2: -12783.8, evo: 10253.0
        Val  , ce: 30810.7, r2: -13183.4, evo: 10436.5


Chunk 1/4, Epoch 20/100: 100%|██████████| 132/132 [01:10<00:00,  1.88it/s]


Chunk 1/4, Epoch 20/100, Total Loss, Train = 30227.3, Val = 27421.6, LR: 1.00e-03
        Train, ce: 30227.3, r2: -12840.2, evo: 10098.0
        Val  , ce: 27421.6, r2: -13258.8, evo: 8944.6


Chunk 1/4, Epoch 21/100: 100%|██████████| 132/132 [01:11<00:00,  1.86it/s]


Chunk 1/4, Epoch 21/100, Total Loss, Train = 29891.5, Val = 28463.6, LR: 1.00e-03
        Train, ce: 29891.5, r2: -12841.9, evo: 10000.0
        Val  , ce: 28463.6, r2: -13200.8, evo: 9162.6


Chunk 1/4, Epoch 22/100: 100%|██████████| 132/132 [01:10<00:00,  1.88it/s]


Chunk 1/4, Epoch 22/100, Total Loss, Train = 29424.8, Val = 29185.7, LR: 1.00e-03
        Train, ce: 29424.8, r2: -12881.5, evo: 9917.1
        Val  , ce: 29185.7, r2: -13307.9, evo: 10184.9


Chunk 1/4, Epoch 23/100: 100%|██████████| 132/132 [01:10<00:00,  1.88it/s]


Chunk 1/4, Epoch 23/100, Total Loss, Train = 28265.1, Val = 29257.5, LR: 1.00e-03
        Train, ce: 28265.1, r2: -12902.1, evo: 9669.4
        Val  , ce: 29257.5, r2: -13210.0, evo: 10191.4


Chunk 1/4, Epoch 24/100: 100%|██████████| 132/132 [01:10<00:00,  1.88it/s]


Chunk 1/4, Epoch 24/100, Total Loss, Train = 28565.7, Val = 25880.7, LR: 1.00e-03
        Train, ce: 28565.7, r2: -12889.8, evo: 9736.1
        Val  , ce: 25880.7, r2: -13357.2, evo: 9409.9


Chunk 1/4, Epoch 25/100: 100%|██████████| 132/132 [01:10<00:00,  1.87it/s]


Chunk 1/4, Epoch 25/100, Total Loss, Train = 27559.5, Val = 26757.9, LR: 1.00e-03
        Train, ce: 27559.5, r2: -12927.2, evo: 9506.8
        Val  , ce: 26757.9, r2: -13347.2, evo: 9299.6


Chunk 1/4, Epoch 26/100: 100%|██████████| 132/132 [01:09<00:00,  1.89it/s]


Chunk 1/4, Epoch 26/100, Total Loss, Train = 27623.2, Val = 26087.7, LR: 1.00e-03
        Train, ce: 27623.2, r2: -12897.1, evo: 9522.1
        Val  , ce: 26087.7, r2: -13288.6, evo: 8739.0


Chunk 1/4, Epoch 27/100: 100%|██████████| 132/132 [01:10<00:00,  1.88it/s]


Chunk 1/4, Epoch 27/100, Total Loss, Train = 27359.9, Val = 25483.6, LR: 1.00e-03
        Train, ce: 27359.9, r2: -12917.7, evo: 9464.7
        Val  , ce: 25483.6, r2: -13322.5, evo: 8199.5


Chunk 1/4, Epoch 28/100: 100%|██████████| 132/132 [01:10<00:00,  1.87it/s]


Chunk 1/4, Epoch 28/100, Total Loss, Train = 26509.6, Val = 25699.5, LR: 1.00e-03
        Train, ce: 26509.6, r2: -12991.3, evo: 9275.7
        Val  , ce: 25699.5, r2: -13393.6, evo: 8619.5


Chunk 1/4, Epoch 29/100: 100%|██████████| 132/132 [01:10<00:00,  1.87it/s]


Chunk 1/4, Epoch 29/100, Total Loss, Train = 25965.6, Val = 24639.7, LR: 1.00e-03
        Train, ce: 25965.6, r2: -12952.9, evo: 9159.3
        Val  , ce: 24639.7, r2: -13310.1, evo: 8155.4


Chunk 1/4, Epoch 30/100: 100%|██████████| 132/132 [01:10<00:00,  1.87it/s]


Chunk 1/4, Epoch 30/100, Total Loss, Train = 25678.8, Val = 24185.3, LR: 1.00e-03
        Train, ce: 25678.8, r2: -12944.7, evo: 9095.0
        Val  , ce: 24185.3, r2: -13343.8, evo: 7620.4


Chunk 1/4, Epoch 31/100: 100%|██████████| 132/132 [01:10<00:00,  1.87it/s]


Chunk 1/4, Epoch 31/100, Total Loss, Train = 25461.7, Val = 23995.9, LR: 1.00e-03
        Train, ce: 25461.7, r2: -12963.9, evo: 9060.5
        Val  , ce: 23995.9, r2: -13428.9, evo: 8452.0


Chunk 1/4, Epoch 32/100: 100%|██████████| 132/132 [01:10<00:00,  1.88it/s]


Chunk 1/4, Epoch 32/100, Total Loss, Train = 25278.4, Val = 23361.2, LR: 1.00e-03
        Train, ce: 25278.4, r2: -12987.9, evo: 9015.4
        Val  , ce: 23361.2, r2: -13423.3, evo: 8625.9


Chunk 1/4, Epoch 33/100:  27%|██▋       | 36/132 [00:21<00:56,  1.69it/s]


KeyboardInterrupt: 

### 2.4 stage 2: Ultra-Long-Range LD Module Training

In [None]:
max_epochs_per_pair = 100
lr                 = 5e-4
weight_decay       = 1e-5
earlystop_patience = 15
batch_size         = 8
verbose            = True

criterion = ImputationLoss(use_r2=True, use_evo=True, r2_weight=1, evo_weight=4, evo_lambda=10)

# ----------- 逐个 chunk 加载权重-----------
# for cid in range(model.n_chunks):
#     chunk_file = f'{work_dir}/models/{model_name}_chunk_{cid}.pth'
#     ckpt = torch.load(chunk_file, map_location='cpu')
#     model.chunk_embeds[cid].load_state_dict(ckpt['chunk_embed_state'])
#     model.chunk_modules[cid].load_state_dict(ckpt['chunk_module_state'])

# ----------- 加载第一阶段完整权重-----------
ckpt = torch.load(f'{work_dir}/models/{model_name}_stage1.pth', map_location='cpu')
model.load_state_dict(ckpt['model_state'])

model.eval()        # chunk 专家冻结（requires_grad=False）

# 收集所有稀疏参数（主要是嵌入层）
sparse_params = []
dense_params = []
# 全局输出层的卷积参数（密集）
dense_params.extend([model.global_out.w1, model.global_out.b1])
dense_params.extend([model.global_out.w2, model.global_out.b2])
# ULR默认启用
if hasattr(model.global_out, 'ulr_mamba'):
    for name, param in model.global_out.ulr_mamba.named_parameters():
        if 'idx_embed' in name:
            sparse_params.append(param)
        else:
            dense_params.append(param)

# 创建分离优化器
optim_sparse = SparseAdam(sparse_params, lr=lr) if sparse_params else None
optim_dense = AdamW(dense_params, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))

# 学习率调度器
scheduler_sparse = ReduceLROnPlateau(optim_sparse, mode='min', factor=0.5, patience=5, min_lr=1e-9) if optim_sparse else None
scheduler_dense = ReduceLROnPlateau(optim_dense, mode='min', factor=0.5, patience=5, min_lr=1e-9)

pair_list = list(combinations(range(model.n_chunks), 2))
np.random.shuffle(pair_list)          # 打乱
total_pairs = len(pair_list)

for pair_idx, (cid1, cid2) in enumerate(pair_list, 1):
    # ====== 构造并集 mask ======
    union_mask = (model.chunk_masks[cid1] + model.chunk_masks[cid2]).clamp(max=1).bool()
    train_logs_sum = None
    val_logs_sum   = None
    
    # 并集 MAF
    union_maf, union_bin_cnt = precompute_maf(
        gt_enc.X_gt[:, union_mask.cpu().numpy()].toarray(),
        mask_int=gt_enc.seq_depth
    )

    # ====== 早停变量 ======
    best_loss = float('inf')
    patience_counter = 0
    is_early_stopped = False

    # ====== 训练循环 ======
    for epoch in range(max_epochs_per_pair):
        model.train()
        train_loss = 0.0
        train_prob, train_gts, train_mask = [], [], []

        pbar = tqdm(train_loader,
                    desc=f'Comb {pair_idx}/{total_pairs}  '
                         f'{cid1+1}-{cid2+1}  Epoch {epoch+1}/{max_epochs_per_pair}',
                    leave=False)
        for x, target, evo_mat in pbar:
            x,  target = x.to(device), target.to(device)
            if evo_mat.numel() == 0:
                evo_mat = None
            else:
                evo_mat = evo_mat.to(device)

            optim_sparse.zero_grad()
            optim_dense.zero_grad()

            logits, prob, mask_idx = model(x, [cid1, cid2])
            loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:,mask_idx], evo_mat) 
            loss.backward()

            optim_sparse.step()   # 只更新嵌入表
            optim_dense.step()    # 更新其余所有参数

            train_loss += loss.item()
            if train_logs_sum is None:          # 第一次初始化
                train_logs_sum = {k: 0.0 for k in logs}
            for k, v in logs.items():
                train_logs_sum[k] += v
            # pbar.set_postfix({'loss': loss.item(), 'ce':logs['ce'], 'r2':logs['r2'], 'evo':logs['evo']})

            # 收集指标
            miss_mask = x[:,union_mask][..., -1].bool()
            train_prob.append(prob[:, mask_idx].detach())
            train_gts.append(target[:,mask_idx].detach())
            train_mask.append(miss_mask)

        # 训练集 MAF
        train_prob = torch.cat(train_prob, dim=0)
        train_gts  = torch.cat(train_gts,    dim=0)
        train_mask = torch.cat(train_mask,   dim=0)

        # ----------- 验证 -----------
        model.eval()
        val_loss = 0.0
        val_prob, val_gts = [], []
        with torch.no_grad():
            if val_logs_sum is None:
                val_logs_sum = {k: 0.0 for k in train_logs_sum}
            for x, target, evo_mat in test_loader:
                x = x.to(device)
                target = target.to(device)
                evo_mat = evo_mat.to(device) if evo_mat.numel() else None
                logits, prob, mask_idx = model(x, [cid1, cid2])
                loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:,mask_idx], evo_mat)
                val_loss += loss.item()
                for k, v in logs.items():
                    val_logs_sum[k] += v
                val_prob.append(prob[:,mask_idx])
                val_gts.append(target[:,mask_idx])

        val_prob = torch.cat(val_prob, dim=0)
        val_gts  = torch.cat(val_gts,    dim=0)

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss   = val_loss   / len(test_loader)
        avg_train_logs = {k: v / len(train_loader) for k, v in train_logs_sum.items()}
        avg_val_logs = {k: v / len(test_loader) for k, v in val_logs_sum.items()}
        
        scheduler_sparse.step(val_loss)
        scheduler_dense.step(val_loss)

        current_denselr = optim_dense.param_groups[0]['lr']
        current_sparselr = optim_sparse.param_groups[0]['lr']

        log_str = (f'Comb {pair_idx}/{total_pairs}  '
            f'{cid1+1}-{cid2+1}  Epoch {epoch+1}/{max_epochs_per_pair} '
            f'Total Loss, Train = {avg_train_loss:.1f}, '
            f'Val = {avg_val_loss:.1f}, '
            f'dense LR: {current_denselr:.2e}, '
            f'sparse LR: {current_sparselr:.2e}')
        log_str += '\n        Train'
        for k, v in avg_train_logs.items():
            log_str += f', {k}: {v:.1f}'
        log_str += '\n        Val  '
        for k, v in avg_val_logs.items():
            log_str += f', {k}: {v:.1f}'
        print(log_str)
        # 清空累加器，供下一个 epoch 使用
        train_logs_sum = {k: 0.0 for k in train_logs_sum}
        val_logs_sum   = {k: 0.0 for k in val_logs_sum}
        # 早停
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            torch.save({
                'comb': (cid1, cid2),
                'global_out': model.global_out.state_dict(),
                'best_val_loss': best_loss,
                'epoch': epoch,
            }, f'{work_dir}/models/{model_name}_chunk[{cid1}-{cid2}].pth')
            # MAF 表格
            predres_with_bestloss = (train_prob, train_gts, val_prob, val_gts)
            if verbose:
                train_bins_metrics = metrics_by_maf(train_prob, train_gts, gt_enc.hap_map, union_maf, mask=train_mask)
                val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, gt_enc.hap_map, union_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "val":   val_bins_metrics})
                print(f'  --> updated {model_name}_chunk[{cid1+1}-{cid2+1}].pth')
        else:
            patience_counter += 1
            if patience_counter >= earlystop_patience:
                is_early_stopped = True
                print(f'Pair {cid1+1}-{cid2+1} early stopping')
                train_prob, train_gts, val_prob, val_gts = predres_with_bestloss
                train_bins_metrics = metrics_by_maf(train_prob, train_gts, gt_enc.hap_map, union_maf, mask=train_mask)
                val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, gt_enc.hap_map, union_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "val":   val_bins_metrics})
                break
            
    if not is_early_stopped:
        predres_with_bestloss = (train_prob, train_gts, val_prob, val_gts)
        train_bins_metrics = metrics_by_maf(train_prob, train_gts, gt_enc.hap_map, union_maf, mask=train_mask)
        val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, gt_enc.hap_map, union_maf, mask=None)
        print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "val":   val_bins_metrics})

    # del optimizer, scheduler
    torch.cuda.empty_cache()

# ----------- 全部 pair 结束 -> 保存最终模型 -----------
torch.save({
    'model_state': model.state_dict(),
    'ulr_enabled': True,
}, f'{work_dir}/models/{model_name}_stage2_final.pth')
print(f'==> STAGE2 training finished: {work_dir}/models/{model_name}_stage2_final.pth')

### 2.5 stage 3: fine tuning with URP samples 

In [7]:
# %%  载入 URP 微调数据
gt_enc_urp = GenotypeEncoder(phased=False, gts012=False, save2disk=False)
gt_enc_urp = gt_enc_urp.encode_ref(
        ref_meta_json = work_dir/"pre_train"/"gt_enc_meta.json",   # 与 Stage1 同构
        vcf_path      = work_dir/"urp_finetune"/"minor_pops.10pct.vcf.gz",
        evo_mat       = work_dir/"urp_finetune"/"evo_mat_minor_pops.10pct.tsv")

print(f'[URP] {gt_enc_urp.n_samples} samples, {gt_enc_urp.n_variants} variants')

[DATA] 总计 99,314 个位点  
[DATA] EvoMat shape: (16, 16)
[DATA] 位点矩阵 = (16, 99314)，稀疏度 = 27.35%，缺失率 = 0.00%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': 3}，字典深度 = 4
[URP] 16 samples, 99314 variants


In [8]:
from sklearn.model_selection import KFold
# %%  Stage-3 超参与配置
model_name  = 'hg19_chr22trim'
stage3_tag       = 'stage3'
max_epochs       = 50
warmup_epochs    = 3
lr_dense         = 1e-4          # GlobalOut 中稠密参数
lr_sparse        = 5e-5          # ULR 中的 idx_embed
weight_decay     = 1e-4
earlystop_pat    = 9
mask_rate_range  = (0.2, 0.6)    # 数据增强：随机缺失率
k_fold           = 5             # 交叉验证
batch_size       = 4             # 样本少，用小 batch
accumulate_grad  = 2             # 梯度累加，等效 batch=8

# %%  重新建立「微调」Dataset / Loader
urp_dataset = GenomicDataset(
        gt_enc_urp.X_gt,
        evo_mat      = gt_enc_urp.evo_mat,
        seq_depth    = gt_enc_urp.seq_depth,
        mask         = True,
        masking_rates= mask_rate_range,
        indices      = None)               # 全部用于微调

def collate_fn(batch):
    x = torch.stack([b[0] for b in batch])
    y = torch.stack([b[1] for b in batch])
    idx = [b[2] for b in batch]
    if gt_enc_urp.evo_mat is not None:
        evo = gt_enc_urp.evo_mat[np.ix_(idx, idx)]
        evo = torch.FloatTensor(evo)
    else:
        evo = torch.empty(0)
    return x, y, evo

urp_loader = DataLoader(urp_dataset, batch_size=batch_size,
                        shuffle=True, num_workers=4,
                        pin_memory=True, collate_fn=collate_fn)

# 1. 准备 URP 数据
urp_idx = np.arange(gt_enc_urp.n_samples)
kf = KFold(n_splits=k_fold, shuffle=True, random_state=42)


In [9]:
# %%  载入 Stage-2 最终权重
ckpt = torch.load(f'{work_dir}/models/{model_name}_stage1.pth', map_location='cpu')
model.load_state_dict(ckpt['model_state'])
print('[Stage3] Stage-2 weights loaded.')

# %%  参数分组 & 优化器
for p in model.parameters():                # 先全部冻结
    p.requires_grad = False

# 只解冻需要的部分
trainable_dense, trainable_sparse = [], []
# 1. GlobalOut 全部
for name, p in model.global_out.named_parameters():
    if 'idx_embed' in name:
        trainable_sparse.append(p)
    else:
        trainable_dense.append(p)
# 2. Chunk-Embedding（可选，若显存紧张可留冻）
for emb in model.chunk_embeds:
    for p in emb.parameters():
        trainable_dense.append(p)

for p in trainable_dense+trainable_sparse:
    p.requires_grad = True

opt_dense  = AdamW(trainable_dense,  lr=lr_dense,
                   weight_decay=weight_decay, betas=(0.9, 0.999))
opt_sparse = SparseAdam(trainable_sparse, lr=lr_sparse)

# 余弦退火 + 热身
def lr_lambda(epoch):
    if epoch < warmup_epochs:
        return epoch / warmup_epochs
    return 0.5*(1+np.cos(np.pi*(epoch-warmup_epochs)/(max_epochs-warmup_epochs)))

sched_dense  = torch.optim.lr_scheduler.LambdaLR(opt_dense,  lr_lambda)
sched_sparse = torch.optim.lr_scheduler.LambdaLR(opt_sparse, lr_lambda)

# %%  训练循环
criterion = ImputationLoss(use_r2=True, use_evo=True,
                           r2_weight=1, evo_weight=4, evo_lambda=10)

best_avg_val_loss, patience_counter = np.inf, 0

for epoch in range(max_epochs):
    model.train()
    fold_val_loss = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(urp_idx)):
        # ---- 当前折数据 ----
        train_dataset = GenomicDataset(
                gt_enc_urp.X_gt, evo_mat=gt_enc_urp.evo_mat,
                seq_depth=gt_enc_urp.seq_depth, mask=True,
                masking_rates=(0.2, 0.6), indices=train_idx)
        val_dataset   = GenomicDataset(
                gt_enc_urp.X_gt, evo_mat=gt_enc_urp.evo_mat,
                seq_depth=gt_enc_urp.seq_depth, mask=True,
                masking_rates=(0.2, 0.6), indices=val_idx)

        train_loader = DataLoader(train_dataset, batch_size=8,
                                  shuffle=True, num_workers=2,
                                  collate_fn=collate_fn, pin_memory=True)
        val_loader   = DataLoader(val_dataset, batch_size=8,
                                  shuffle=False, num_workers=2,
                                  collate_fn=collate_fn, pin_memory=True)

        # ---- 训练 ----
        for step, (x, y, evo) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            evo  = evo.to(device) if evo.numel() else None

            logits, prob, mask_idx = model(x)
            loss, _ = criterion(logits[:, mask_idx], prob[:, mask_idx],
                                y[:, mask_idx], evo)
            loss.backward()

            if (step+1) % accumulate_grad == 0 or (step+1) == len(train_loader):
                opt_dense.step(); opt_sparse.step()
                opt_dense.zero_grad(set_to_none=True); opt_sparse.zero_grad(set_to_none=True)

        # ---- 验证 ----
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, y, evo in val_loader:
                x, y = x.to(device), y.to(device)
                evo  = evo.to(device) if evo.numel() else None
                logits, prob, mask_idx = model(x)
                loss, _ = criterion(logits[:, mask_idx], prob[:, mask_idx],
                                    y[:, mask_idx], evo)
                val_loss += loss.item()
        fold_val_loss.append(val_loss / len(val_loader))

    # ---- epoch 级日志 & 调度 ----
    avg_val_loss = np.mean(fold_val_loss)
    print(f'Epoch {epoch+1}: avg_val_loss={avg_val_loss:.3f}, '
          f'lr_dense={opt_dense.param_groups[0]["lr"]:.1e}')
    sched_dense.step(); sched_sparse.step()

    # ---- 早停 ----
    if avg_val_loss < best_avg_val_loss:
        best_avg_val_loss = avg_val_loss
        patience_counter = 0
        torch.save({'model_state': model.state_dict(),
                    'epoch': epoch,
                    'avg_val_loss': avg_val_loss},
                   f'{work_dir}/models/{model_name}_{stage3_tag}_best.pth')
    else:
        patience_counter += 1
        if patience_counter >= earlystop_pat:
            print('Early stopping triggered.')
            break

torch.save({'model_state': model.state_dict(),
            'stage3_tag': stage3_tag},
           f'{work_dir}/models/{model_name}_{stage3_tag}_final.pth')
print(f'==> Stage-3 KFold-loss fine-tuning finished. Best avg_val_loss={best_avg_val_loss:.3f}')

[Stage3] Stage-2 weights loaded.
Epoch 1: avg_val_loss=68964.580, lr_dense=0.0e+00
Epoch 2: avg_val_loss=60720.118, lr_dense=3.3e-05
Epoch 3: avg_val_loss=49563.400, lr_dense=6.7e-05
Epoch 4: avg_val_loss=43253.447, lr_dense=1.0e-04
Epoch 5: avg_val_loss=41547.148, lr_dense=1.0e-04
Epoch 6: avg_val_loss=34705.204, lr_dense=1.0e-04
Epoch 7: avg_val_loss=34851.017, lr_dense=9.9e-05
Epoch 8: avg_val_loss=31193.683, lr_dense=9.8e-05
Epoch 9: avg_val_loss=25390.002, lr_dense=9.7e-05
Epoch 10: avg_val_loss=32677.517, lr_dense=9.6e-05
Epoch 11: avg_val_loss=29019.098, lr_dense=9.5e-05
Epoch 12: avg_val_loss=25310.655, lr_dense=9.3e-05
Epoch 13: avg_val_loss=30964.865, lr_dense=9.1e-05
Epoch 14: avg_val_loss=22583.397, lr_dense=8.9e-05
Epoch 15: avg_val_loss=26063.507, lr_dense=8.7e-05
Epoch 16: avg_val_loss=21442.901, lr_dense=8.5e-05
Epoch 17: avg_val_loss=25226.832, lr_dense=8.2e-05
Epoch 18: avg_val_loss=21774.188, lr_dense=8.0e-05
Epoch 19: avg_val_loss=22828.542, lr_dense=7.7e-05
Epoch 2

## 3. Imputation using trained model

### 3.1 Load the trained model

Choose a path where including `<work_dir>/model` and have trained model.

In [None]:
print(f"Work Dir: {work_dir}")

# ---- 1. 加载模型 ----
device = 'cuda' if torch.cuda.is_available() else 'cpu'

json_path = f"{work_dir}/models/model_meta.json"
meta = json.load(open(json_path))
model = EvoFill(
    d_model=int(meta["d_model"]),
    n_alleles=int(meta["alleles"]),
    total_sites=int(meta["total_sites"]),
    chunk_size=int(meta["chunk_size"]),
    chunk_overlap=int(meta["overlap"])
).to(device)

# ckpt = torch.load(f'{work_dir}/models/{meta["model_name"]}_stage1.pth', map_location=device)
# ckpt = torch.load(f'{work_dir}/models/{meta["model_name"]}_stage2_best.pth', map_location=device)
ckpt = torch.load(f'{work_dir}/models/{meta["model_name"]}_stage3_best.pth', map_location=device)
model.load_state_dict(ckpt['model_state'])
model.eval()
print(f'[INF] Model[{meta["model_name"]}] loaded.')

Work Dir: /mnt/qmtang/EvoFill_data/20251107_ver4
[INF] Model[hg19_chr22trim] loaded.


### 3.2 Encode .vcf file need be impute

In [6]:
gt_enc_imp = GenotypeEncoder(phased=False, gts012=False, save2disk=False)
gt_enc_imp = gt_enc_imp.encode_ref(
        ref_meta_json = work_dir/"pre_train"/"gt_enc_meta.json",   # 与 Stage1 同构
        vcf_path      = work_dir/"impute_in"/"minor_pops.90pct.masked50p.vcf.gz" )

print(f'[INFER] {gt_enc_imp.n_samples} samples, {gt_enc_imp.n_variants} variants')

# ---- 2. 构建推理 Dataset / Loader ----
imp_dataset = ImputationDataset(
    x_gts_sparse=gt_enc_imp.X_gt,
    seq_depth=gt_enc_imp.seq_depth,
    indices=None                 # 可传入指定样本索引
)
imp_dataset.print_missing_stat()          # 查看原始缺失比例

def collate_fn(batch):
    x_onehot = torch.stack([item[0] for item in batch])
    real_idx_list = [item[1] for item in batch]
    return x_onehot, real_idx_list   # 无 y

imp_loader = torch.utils.data.DataLoader(
    imp_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

[DATA] 总计 99,314 个位点  
[DATA] 位点矩阵 = (152, 99314)，稀疏度 = 63.43%，缺失率 = 50.01%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': 3}，字典深度 = 4
[INFER] 152 samples, 99314 variants
[ImputationDataset] 152 samples, missing rate = 50.01%


### 3.3 Inferring

In [7]:
y_prob = []
y_mask = []
with torch.no_grad():
    for x_onehot, real_idx in tqdm(imp_loader, desc='Imputing'):
        x_onehot = x_onehot.to(device)
        _, prob, _ = model(x_onehot)
        miss_mask = x_onehot[..., -1].bool()
        y_prob.append(prob)
        y_mask.append(miss_mask)
y_prob = torch.cat(y_prob, dim=0).cpu().numpy()
y_mask = torch.cat(y_mask, dim=0).cpu().numpy()
# 4. 保存
out_dir = os.path.join(work_dir, 'impute_out')
os.makedirs(out_dir, exist_ok=True)
np.save(os.path.join(out_dir, 'impute_prob.npy'), y_prob)
np.save(os.path.join(out_dir, 'impute_mask.npy'), y_mask)
print(f'[INF] 概率矩阵已保存 → {out_dir}/impute_prob.npy '
      f'with shape = {y_prob.shape} ')


Imputing: 100%|██████████| 3/3 [00:20<00:00,  6.79s/it]


[INF] 概率矩阵已保存 → /mnt/qmtang/EvoFill_data/20251107_ver4/impute_out/impute_prob.npy with shape = (152, 99314, 3) 


### 3.4 Evaluating the imputation results

In [8]:
gt_enc_true = GenotypeEncoder(phased=False, gts012=False, save2disk=False)
gt_enc_true = gt_enc_true.encode_ref(
        ref_meta_json = work_dir/"pre_train"/"gt_enc_meta.json",   # 与 Stage1 同构
        vcf_path      = work_dir/"impute_out"/"minor_pops.90pct.vcf.gz" )
y_true = gt_enc_true.X_gt.toarray()
maf, bin_cnt = precompute_maf(y_true,  mask_int=gt_enc_true.seq_depth)
y_true_oh = np.eye(gt_enc_true.seq_depth - 1)[y_true]
bins_metrics   = metrics_by_maf(y_prob, y_true_oh, hap_map = gt_enc_true.hap_map, maf_vec = maf, mask=y_mask)
print_maf_stat_df(bin_cnt,{'val': bins_metrics})

[DATA] 总计 99,314 个位点  
[DATA] 位点矩阵 = (152, 99314)，稀疏度 = 26.84%，缺失率 = 0.00%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': 3}，字典深度 = 4
     MAF_bin Counts val_Acc val_INFO val_IQS val_MaCH
(0.00, 0.05)  41365   0.976    0.379   0.082    0.939
(0.05, 0.10)   9108   0.930    0.506   0.384    0.770
(0.10, 0.20)  11962   0.898    0.614   0.544    0.836
(0.20, 0.30)  12887   0.857    0.725   0.668    0.919
(0.30, 0.40)  17081   0.844    0.774   0.720    0.962
(0.40, 0.50)   6909   0.840    0.784   0.722    0.960


### 3.5 Saving to .vcf

In [38]:
from cyvcf2 import VCF, Writer

# 0. 路径
ref_vcf = "/mnt/qmtang/EvoFill_data/20251107_ver4/minor_pops.masked30p.vcf.gz"
out_vcf = os.path.join(out_dir, 'imputed.vcf.gz')

n_site = gt_enc.n_variants
n_samp = gt_enc.n_samples
n_alleles = gt_enc.seq_depth - 1
assert y_prob.shape == (n_samp, n_site, n_alleles)

# 2. 反向映射  idx -> '0|0' / '0|1' / ...
rev_hap_map = {v: k for k, v in gt_enc.hap_map.items()}

samp2idx = {sid: i for i, sid in enumerate(gt_enc.sample_ids)}

# 4. 打开参考 VCF
invcf = VCF(ref_vcf)
tmpl  = invcf
tmpl.set_samples(gt_enc.sample_ids)   # 替换样本列

out = Writer(out_vcf, tmpl, mode='wz')

for rec_idx, rec in enumerate(invcf):
    # 当前位点全部样本的 GT
    gt_int_pairs = []
    for samp_idx, sample_id in enumerate(gt_enc.sample_ids):
        old_gt = rec.genotypes[samp_idx]          # [allele1, allele2, phased]
        if old_gt[0] == -1 or old_gt[1] == -1:    # 缺失
            prob_vec = y_prob[samp_idx, rec_idx, :].ravel()
            best_idx = int(prob_vec.argmax())
            gt_str   = rev_hap_map[best_idx]
            alleles  = list(map(int, gt_str.split('|')))
            phased   = old_gt[2] if old_gt[2] != -1 else 1
            gt_int_pairs.append([alleles[0], alleles[1], phased])
        else:                                       # 非缺失，保持原样
            gt_int_pairs.append(old_gt)

    # 转成 int8 二维数组  (n_sample, 3)  last dim = [a1,a2,phased]
    gt_array = np.array(gt_int_pairs, dtype=np.int8)
    rec.set_format('GT', gt_array)
    out.write_record(rec)

invcf.close()
out.close()

# 5. tabix
os.system(f'tabix -fp vcf {out_vcf}')
print(f'[INF] 缺失位点填充完成 → {out_vcf}')

[INF] 缺失位点填充完成 → /mnt/qmtang/EvoFill_data/20251107_ver4/impute_out/imputed.vcf.gz
