# EvoFill working demo

ver 4. new imputation loss with evo loss

last update: 2025/11/4

## Content
- [0. Dependency](#0.-Dependency)  
- [1. Genotype(.vcf) encoding](#1.-Genotype(.vcf)-encoding)  
- [2. Training a new model](#2.-Training-a-new-model)  
  - [2.1 Dataloader](#2.1-Dataloader)  
  - [2.2 Model initialization](#2.2-Model-initialization)  
  - [2.3 stage 1: Chunk Module Training](#2.3-stage-1:-Chunk-Module-Training)  
  - [2.4 stage 2: Ultra-Long-Range LD Module Training](#2.4-stage-2:-Ultra-Long-Range-LD-Module-Training)  
- [3. Imputation using trained model](#3.-Imputation-using-trained-model)  
  - [3.1 Load the trained model](#3.1-Load-the-trained-model)  
  - [3.2 Encode .vcf file need be impute](#3.2-Encode-.vcf-file-need-be-impute)  
  - [3.3 Inferring](#3.3-Inferring)  
  - [3.4 Evaluating the imputation results](#3.4-Evaluating-the-imputation-results)  

In jupyter notebook, we use single GPU here for demo.

For parallel GPU training, please use other framework with `./train.py`, like `torch run --nproc_per_node=8 train.py` 

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

## 0. Dependency

In [None]:
import sys
import json
import numpy as np
from scipy import sparse
import pandas as pd
import torch
import mamba_ssm
from tqdm import tqdm
from itertools import combinations
from torch.utils.data import DataLoader
from torch.optim import Adam, AdamW, SparseAdam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split

print('Python ver:  ', sys.version)
print('Pytorch ver: ', torch.__version__)
print('Mamba ver:   ', mamba_ssm.__version__)
print('GPU in use:  ', torch.cuda.get_device_name(torch.cuda.current_device()))

Python ver:   3.10.19 | packaged by conda-forge | (main, Oct 13 2025, 14:08:27) [GCC 14.3.0]
Pytorch ver:  2.8.0+cu129
Mamba ver:    2.2.5
GPU in use:   NVIDIA H100 PCIe


In [3]:
os.chdir('/home/qmtang/EvoFill/')

In [None]:
from src.data import GenotypeEncoder, GenomicDataset, ImputationDataset
from src.model import EvoFill
from src.loss import ImputationLoss
from src.utils import create_directories, set_seed, precompute_maf, metrics_by_maf, print_maf_stat_df

## 1. Genotype(.vcf) encoding

option 1. with extra evolutionary information

In [None]:
work_dir = '/mnt/qmtang/EvoFill/data/251105_ver4_chr22trim'
gt_enc = GenotypeEncoder(save_dir = work_dir, phased = True, gts012 = False)

gt_enc = gt_enc.encode_new(vcf_path='/home/qmtang/GitHub/STICI-HPC/data/training_sets/ALL.chr22.training.samples.100k.any.type.0.01.maf.variants.vcf.gz',
                  evo_mat="/mnt/qmtang/EvoFill/data/251104_ver3_chr22trim_ex/evo_mat.tsv")

print(f"[DATA] {gt_enc.n_samples:,} Samples")
print(f"[DATA] {gt_enc.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc.seq_depth} seq_depth")

[DATA] 总计 99,314 个位点  
[DATA] 位点矩阵 = (2404, 99314)，稀疏度 = 28.10%
[DATA] EvoMat shape: (2404, 2404)
[DATA] 结果已写入 /mnt/qmtang/EvoFill/data/251105_ver4_chr22trim
[DATA] 2,404 Samples
[DATA] 99,314 Variants Sites
[DATA] 4 seq_depth


option 2. vcf files only

In [None]:
work_dir = '/mnt/qmtang/EvoFill/data/251105_ver4_chr22trim'
gt_enc = GenotypeEncoder(save_dir = work_dir, phased = True, gts012 = False)

gt_enc = gt_enc.encode_new(vcf_path='/home/qmtang/GitHub/STICI-HPC/data/training_sets/ALL.chr22.training.samples.100k.any.type.0.01.maf.variants.vcf.gz',
                  evo_mat=None)

print(f"[DATA] {gt_enc.n_samples:,} Samples")
print(f"[DATA] {gt_enc.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc.seq_depth} seq_depth")

## 2. Training a new model

Choose a path which including GenotypeEncoder processed files.

In [5]:
work_dir = '/mnt/qmtang/EvoFill/data/251105_ver4_chr22trim'

### 2.1 Dataloader

In [6]:
print(f"Work Dir: {work_dir}")
create_directories(work_dir)

val_n_samples = 128
batch_size    = 16
max_mr        = 0.7
min_mr        = 0.3

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()
print(f"Using device: {device}")

gt_enc = GenotypeEncoder.loadfromdisk(work_dir)
print(f'{gt_enc.n_samples:,} samples, {gt_enc.n_variants:,} variants, {gt_enc.seq_depth} seq-depth.')

x_train_indices, x_valid_indices = train_test_split(
    range(gt_enc.n_samples),
    test_size=val_n_samples,
    random_state=3047,
    shuffle=True
)

print(f"{len(x_train_indices):,} samples in train")
print(f"{len(x_valid_indices):,} samples in val")
print(f"evo_mat: {gt_enc.evo_mat.shape}")

train_dataset = GenomicDataset(
    gt_enc.X_gt,
    evo_mat=gt_enc.evo_mat,
    seq_depth=gt_enc.seq_depth,
    mask=True,
    masking_rates=(min_mr, max_mr),
    indices=x_train_indices
)

val_dataset = GenomicDataset(
    gt_enc.X_gt,
    evo_mat=gt_enc.evo_mat,
    seq_depth=gt_enc.seq_depth,
    mask=True,
    masking_rates=(min_mr, max_mr),
    indices=x_valid_indices
)

def collate_fn(batch):
    x_onehot = torch.stack([item[0] for item in batch])
    y_onehot = torch.stack([item[1] for item in batch])
    real_idx_list = [item[2] for item in batch]

    # 提取 evo_mat 子矩阵
    if train_dataset.evo_mat is not None:
        evo_mat_batch = train_dataset.evo_mat[np.ix_(real_idx_list, real_idx_list)]
        evo_mat_batch = torch.FloatTensor(evo_mat_batch)
    else:
        evo_mat_batch = torch.empty(0)

    return x_onehot, y_onehot, evo_mat_batch
    
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

Work Dir: /mnt/qmtang/EvoFill/data/251105_ver4_chr22trim
Using device: cuda
2,404 samples, 99,314 variants, 4 seq-depth.
2,276 samples in train
128 samples in val
evo_mat: (2404, 2404)


### 2.2 Model initialization

In [7]:
model_name  = 'hg19_chr22trim'
total_sites = gt_enc.n_variants
alleles     = gt_enc.seq_depth
chunk_size  = 32768
overlap     = 1024
d_model     = 64

set_seed(42)

model = EvoFill(d_model, alleles, total_sites, chunk_size, overlap).to(device)
print(f"model[{model_name}] would have {model.n_chunks} chunks.")

criterion = ImputationLoss(use_r2=True, use_evo=True, r2_weight=1, evo_weight=4, evo_lambda=10)
# criterion = ImputationLoss(use_r2=True, use_evo=False)

meta = {
    "model_name": model_name,
    "total_sites": total_sites,
    "alleles": alleles,
    "chunk_size": chunk_size,
    "overlap": overlap,
    "d_model": d_model
}
save_path = os.path.join(work_dir, "model_meta.json")
with open(save_path, "w") as f:
    json.dump(meta, f, indent=4)

model[hg19_chr22trim] would have 4 chunks.


### 2.3 stage 1: Chunk Module Training

In [None]:
verbose            = False
max_epochs         = 100
lr                 = 0.001
weight_decay       = 1e-5
earlystop_patience = 13

model.global_out.set_ulr_enabled(False)

for cid in range(model.n_chunks):
    chunk_mask = model.chunk_masks[cid].cpu()
    chunk_maf, chunk_bin_cnt = precompute_maf(gt_enc.X_gt[:,chunk_mask.bool().cpu().numpy()].toarray(),  mask_int=gt_enc.seq_depth)
    chunk_maf = torch.from_numpy(chunk_maf).to(device)
    if verbose:
        print(f"=== Chunk {cid + 1} STAT ===")
        maf_df = pd.DataFrame({
            'MAF_bin': ['(0.00, 0.05)', '(0.05, 0.10)', '(0.10, 0.20)',
                        '(0.20, 0.30)', '(0.30, 0.40)', '(0.40, 0.50)'],
            'Counts':  [f"{c}" for c in chunk_bin_cnt],
        })
        print(maf_df.to_string(index=False))

    # 2. 只给当前chunk专家+GlobalOut局部卷积上优化器
    trainable = (list(model.chunk_embeds[cid].parameters()) +
                list(model.chunk_modules[cid].parameters()) +
                [model.global_out.w1, model.global_out.b1,
                model.global_out.w2, model.global_out.b2])
    # shared_params = [p for n, p in model.global_out.named_parameters()
    #              if p.requires_grad and p.grad_fn is not None]

    optimizer = AdamW(trainable, lr=lr, weight_decay=weight_decay)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-8)
    best_loss = float('inf')
    patience = earlystop_patience
    patience_counter = 0
    is_early_stopped = False
    train_logs_sum = None
    val_logs_sum   = None
    for epoch in range(max_epochs):
        model.train()
        train_loss = 0.0
        train_prob, train_gts, train_mask = [], [], []

        train_pbar = tqdm(train_loader, desc=f'Chunk {cid + 1}/{model.n_chunks}, Epoch {epoch + 1}/{max_epochs}',) # leave=False
        for batch_idx, (x, target, evo_mat) in enumerate(train_pbar):
            x,  target = x.to(device), target.to(device)
            if evo_mat.numel() == 0:
                evo_mat = None
            else:
                evo_mat = evo_mat.to(device)

            optimizer.zero_grad()
            logits, prob, mask_idx = model(x, cid)
            loss, logs = criterion(logits[:, mask_idx], 
                                   prob[:, mask_idx], 
                                   target[:,mask_idx], 
                                   evo_mat) 
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            if train_logs_sum is None:          # 第一次初始化
                train_logs_sum = {k: 0.0 for k in logs}
            for k, v in logs.items():
                train_logs_sum[k] += v
            # train_pbar.set_postfix({'loss': loss.item(), 'ce':logs['ce'], 'r2':logs['r2'], 'evo':logs['evo']})

            # === 收集训练结果 ===
            miss_mask = x[:, mask_idx][..., -1].bool()         # 只关心被 mask 的位点
            train_prob.append(prob[:, mask_idx].detach())
            train_gts.append(target[:,mask_idx].detach())
            train_mask.append(miss_mask)

        # 训练集 MAF-acc
        train_prob = torch.cat(train_prob, dim=0)
        train_gts  = torch.cat(train_gts,    dim=0)
        train_mask = torch.cat(train_mask,   dim=0)

        # ----------- 验证循环同理 ------------
        model.eval()
        val_loss = 0.0
        val_prob, val_gts = [], []
        if val_logs_sum is None:
            val_logs_sum = {k: 0.0 for k in train_logs_sum}
        with torch.no_grad():
            for x, target, evo_mat in val_loader:
                x,  target = x.to(device), target.to(device)
                if evo_mat.numel() == 0:
                    evo_mat = None
                else:
                    evo_mat = evo_mat.to(device)
                logits, prob, mask_idx = model(x, cid)
                loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:,mask_idx], evo_mat) 
                val_loss += loss.item()
                for k, v in logs.items():
                    val_logs_sum[k] += v
                val_prob.append(prob[:, mask_idx].detach())
                val_gts.append(target[:,mask_idx].detach())

        val_prob = torch.cat(val_prob, dim=0)
        val_gts    = torch.cat(val_gts,    dim=0)
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss   = val_loss   / len(val_loader)
        avg_train_logs = {k: v / len(train_loader) for k, v in train_logs_sum.items()}
        avg_val_logs = {k: v / len(val_loader) for k, v in val_logs_sum.items()}

        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]['lr']

        log_str = (f'Chunk {cid + 1}/{model.n_chunks}, '
            f'Epoch {epoch + 1}/{max_epochs}, '
            f'Total Loss, Train = {avg_train_loss:.1f}, '
            f'Val = {avg_val_loss:.1f}, '
            f'LR: {current_lr:.2e}')
        log_str += '\n        Train'
        for k, v in avg_train_logs.items():
            log_str += f', {k}: {v:.1f}'
        log_str += '\n        Val  '
        for k, v in avg_val_logs.items():
            log_str += f', {k}: {v:.1f}'
        print(log_str)

        # 清空累加器，供下一个 epoch 使用
        train_logs_sum = {k: 0.0 for k in train_logs_sum}
        val_logs_sum   = {k: 0.0 for k in val_logs_sum}
        
        # Early stopping
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            # 只存当前 chunk 专家 + 全局层
            ckpt = {
                'chunk_id': cid,
                'chunk_embed_state': model.chunk_embeds[cid].state_dict(),
                'chunk_module_state': model.chunk_modules[cid].state_dict(),
                'global_out_state': model.global_out.state_dict(),
                'best_val_loss': best_loss,
            }
            torch.save(ckpt, f'{work_dir}/models/{model_name}_chunk_{cid}.pth')
            predres_with_bestloss = (train_prob, train_gts, val_prob, val_gts)
            if verbose:
                train_bins_metrics = metrics_by_maf(train_prob, train_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=train_mask)
                val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,train_bins_metrics,val_bins_metrics)
                print(f'  --> updated {model_name}_chunk_{cid}.pth')
        else:
            patience_counter += 1
            if patience_counter >= earlystop_patience:
                is_early_stopped = True
                print(f'Chunk {cid + 1}/{model.n_chunks}, Early stopping triggered')
                train_prob, train_gts, val_prob, val_gts = predres_with_bestloss
                train_bins_metrics = metrics_by_maf(train_prob, train_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=train_mask)
                val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,train_bins_metrics,val_bins_metrics)
                break

    if not is_early_stopped:
        train_bins_metrics = metrics_by_maf(train_prob, train_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=train_mask)
        val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=None)
        print_maf_stat_df(chunk_bin_cnt,train_bins_metrics,val_bins_metrics)
    del optimizer, scheduler
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

# ---------------- 全部 chunk 训练完成 -> 保存完整模型 ----------------
final_ckpt = {
    'model_state': model.state_dict(),
    'n_chunks': model.n_chunks,
    'chunk_size': model.chunk_size,
    'chunk_overlap': model.chunk_overlap,
}
torch.save(final_ckpt, f'{work_dir}/models/{model_name}_stage_1.pth')
print(f'==> STAGE1 (Chunk Module) training finished: {work_dir}/models/{model_name}_stage1.pth')

### 2.4 stage 2: Ultra-Long-Range LD Module Training

In [None]:
max_epochs_per_pair = 100
lr                 = 5e-4
weight_decay       = 1e-5
earlystop_patience = 15
batch_size         = 8
verbose            = True

criterion = ImputationLoss(use_r2=True, use_evo=True, r2_weight=1, evo_weight=4, evo_lambda=10)

# ----------- 逐个 chunk 加载权重-----------
# for cid in range(model.n_chunks):
#     chunk_file = f'{work_dir}/models/{model_name}_chunk_{cid}.pth'
#     ckpt = torch.load(chunk_file, map_location='cpu')
#     model.chunk_embeds[cid].load_state_dict(ckpt['chunk_embed_state'])
#     model.chunk_modules[cid].load_state_dict(ckpt['chunk_module_state'])

# ----------- 加载第一阶段完整权重-----------
ckpt = torch.load(f'{work_dir}/models/{model_name}_stage_1.pth', map_location='cpu')
model.load_state_dict(ckpt['model_state'])

model.eval()        # chunk 专家冻结（requires_grad=False）
model.global_out.set_ulr_enabled(True)  # 只开 ulr 分支

# 分离优化器
embed_weight = model.global_out.ulr_mamba.idx_embed.weight   # 已经 sparse=True
optim_sparse = SparseAdam([embed_weight], lr=1e-4)

# 2. 其余所有可训练参数（避开嵌入表）
dense_params = [
    p for n, p in model.global_out.named_parameters()
    if p.requires_grad and 'idx_embed.weight' not in n
]

optim_dense = Adam(dense_params, lr=1e-4, weight_decay=1e-5, betas=(0.9, 0.999))

scheduler_sparse = ReduceLROnPlateau(optim_sparse, mode='min', factor=0.5,
                                     patience=5, min_lr=1e-9)
scheduler_dense  = ReduceLROnPlateau(optim_dense,  mode='min', factor=0.5,
                                     patience=5, min_lr=1e-9)

pair_list = list(combinations(range(model.n_chunks), 2))
np.random.shuffle(pair_list)          # 打乱
total_pairs = len(pair_list)

for pair_idx, (cid1, cid2) in enumerate(pair_list, 1):
    # ====== 构造并集 mask ======
    union_mask = (model.chunk_masks[cid1] + model.chunk_masks[cid2]).clamp(max=1).bool()
    train_logs_sum = None
    val_logs_sum   = None
    
    # 并集 MAF
    union_maf, union_bin_cnt = precompute_maf(
        gt_enc.X_gt[:, union_mask.cpu().numpy()].toarray(),
        mask_int=gt_enc.seq_depth
    )

    # ====== 早停变量 ======
    best_loss = float('inf')
    patience_counter = 0
    is_early_stopped = False

    # ====== 训练循环 ======
    for epoch in range(max_epochs_per_pair):
        model.train()
        train_loss = 0.0
        train_prob, train_gts, train_mask = [], [], []

        pbar = tqdm(train_loader,
                    desc=f'Comb {pair_idx}/{total_pairs}  '
                         f'{cid1+1}-{cid2+1}  Epoch {epoch+1}/{max_epochs_per_pair}',
                    leave=False)
        for x, target, evo_mat in pbar:
            x,  target = x.to(device), target.to(device)
            if evo_mat.numel() == 0:
                evo_mat = None
            else:
                evo_mat = evo_mat.to(device)

            optim_sparse.zero_grad()
            optim_dense.zero_grad()

            logits, prob, mask_idx = model(x, [cid1, cid2])
            loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:,mask_idx], evo_mat) 
            loss.backward()

            optim_sparse.step()   # 只更新嵌入表
            optim_dense.step()    # 更新其余所有参数

            train_loss += loss.item()
            if train_logs_sum is None:          # 第一次初始化
                train_logs_sum = {k: 0.0 for k in logs}
            for k, v in logs.items():
                train_logs_sum[k] += v
            # pbar.set_postfix({'loss': loss.item(), 'ce':logs['ce'], 'r2':logs['r2'], 'evo':logs['evo']})

            # 收集指标
            miss_mask = x[:,union_mask][..., -1].bool()
            train_prob.append(prob[:, mask_idx].detach())
            train_gts.append(target[:,mask_idx].detach())
            train_mask.append(miss_mask)

        # 训练集 MAF
        train_prob = torch.cat(train_prob, dim=0)
        train_gts  = torch.cat(train_gts,    dim=0)
        train_mask = torch.cat(train_mask,   dim=0)

        # ----------- 验证 -----------
        model.eval()
        val_loss = 0.0
        val_prob, val_gts = [], []
        with torch.no_grad():
            if val_logs_sum is None:
                val_logs_sum = {k: 0.0 for k in train_logs_sum}
            for x, target, evo_mat in val_loader:
                x = x.to(device)
                target = target.to(device)
                evo_mat = evo_mat.to(device) if evo_mat.numel() else None
                logits, prob, mask_idx = model(x, [cid1, cid2])
                loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:,mask_idx], evo_mat)
                val_loss += loss.item()
                for k, v in logs.items():
                    val_logs_sum[k] += v
                val_prob.append(prob[:,mask_idx])
                val_gts.append(target[:,mask_idx])

        val_prob = torch.cat(val_prob, dim=0)
        val_gts  = torch.cat(val_gts,    dim=0)

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss   = val_loss   / len(val_loader)
        avg_train_logs = {k: v / len(train_loader) for k, v in train_logs_sum.items()}
        avg_val_logs = {k: v / len(val_loader) for k, v in val_logs_sum.items()}
        
        scheduler_sparse.step(val_loss)
        scheduler_dense.step(val_loss)

        current_denselr = optim_dense.param_groups[0]['lr']
        current_sparselr = optim_sparse.param_groups[0]['lr']

        log_str = (f'Comb {pair_idx}/{total_pairs}  '
            f'{cid1+1}-{cid2+1}  Epoch {epoch+1}/{max_epochs_per_pair} '
            f'Total Loss, Train = {avg_train_loss:.1f}, '
            f'Val = {avg_val_loss:.1f}, '
            f'dense LR: {current_denselr:.2e}, '
            f'sparse LR: {current_sparselr:.2e}')
        log_str += '\n        Train'
        for k, v in avg_train_logs.items():
            log_str += f', {k}: {v:.1f}'
        log_str += '\n        Val  '
        for k, v in avg_val_logs.items():
            log_str += f', {k}: {v:.1f}'
        print(log_str)
        # 清空累加器，供下一个 epoch 使用
        train_logs_sum = {k: 0.0 for k in train_logs_sum}
        val_logs_sum   = {k: 0.0 for k in val_logs_sum}
        # 早停
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            torch.save({
                'comb': (cid1, cid2),
                'global_out': model.global_out.state_dict(),
                'best_val_loss': best_loss,
                'epoch': epoch,
            }, f'{work_dir}/models/{model_name}_chunk_{cid1}-{cid2}.pth')
            # MAF 表格
            predres_with_bestloss = (train_prob, train_gts, val_prob, val_gts)
            if verbose:
                train_bins_metrics = metrics_by_maf(train_prob, train_gts, gt_enc.hap_map, union_maf, mask=train_mask)
                val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, gt_enc.hap_map, union_maf, mask=None)
                print_maf_stat_df(union_bin_cnt,train_bins_metrics,val_bins_metrics)
                print(f'  --> updated {model_name}_chunk_{cid1+1}-{cid2+1}.pth')
        else:
            patience_counter += 1
            if patience_counter >= earlystop_patience:
                is_early_stopped = True
                print(f'Pair {cid1+1}-{cid2+1} early stopping')
                train_prob, train_gts, val_prob, val_gts = predres_with_bestloss
                train_bins_metrics = metrics_by_maf(train_prob, train_gts, gt_enc.hap_map, union_maf, mask=train_mask)
                val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, gt_enc.hap_map, union_maf, mask=None)
                print_maf_stat_df(union_bin_cnt,train_bins_metrics,val_bins_metrics)
                break
            
    if not is_early_stopped:
        predres_with_bestloss = (train_prob, train_gts, val_prob, val_gts)
        train_bins_metrics = metrics_by_maf(train_prob, train_gts, gt_enc.hap_map, union_maf, mask=train_mask)
        val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, gt_enc.hap_map, union_maf, mask=None)
        print_maf_stat_df(union_bin_cnt,train_bins_metrics,val_bins_metrics)

    # del optimizer, scheduler
    torch.cuda.empty_cache()

# ----------- 全部 pair 结束 -> 保存最终模型 -----------
torch.save({
    'model_state': model.state_dict(),
    'ulr_enabled': True,
}, f'{work_dir}/models/{model_name}_stage2_final.pth')
print(f'==> STAGE2 training finished: {work_dir}/models/{model_name}_stage2_final.pth')

## 3. Imputation using trained model

### 3.1 Load the trained model

Choose a path where including `<work_dir>/model` and have trained model.

In [None]:
work_dir = '/mnt/qmtang/EvoFill/data/251105_ver4_chr22trim'

# ---- 1. 加载模型 ----
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gt_enc = GenotypeEncoder.loadfromdisk(work_dir)

json_path = f"{work_dir}/model_meta.json"
meta = json.load(open(json_path))
model = EvoFill(
    d_model=int(meta["d_model"]),
    alleles=int(meta["alleles"]),
    total_sites=int(meta["total_sites"]),
    chunk_size=int(meta["chunk_size"]),
    overlap=int(meta["overlap"])
).to(device)

ckpt = torch.load(f'{work_dir}/models/{meta["model_name"]}_stage2_final.pth', map_location=device)
model.load_state_dict(ckpt['model_state'])
model.eval()
print(f'[INF] 模型加载完成，ULR-enabled={model.global_out.ulr_enabled}')

### 3.2 Encode .vcf file need be impute

In [None]:
gt_enc = GenotypeEncoder(save_dir = "/mnt/qmtang/EvoFill/data/251105_ver4_chr22trim_2/", phased = True, gts012 = False)
gt_enc = gt_enc.encode_ref(ref_meta_json = "/mnt/qmtang/EvoFill/data/251105_ver4_chr22trim/gt_enc_meta.json",
                  vcf_path='/home/qmtang/GitHub/STICI-HPC/data/training_sets/ALL.chr22.training.samples.100k.any.type.0.01.maf.variants.vcf.gz',
                  evo_mat="/mnt/qmtang/EvoFill/data/251104_ver3_chr22trim_ex/evo_mat.tsv")

[DATA] 总计 99,314 个位点  
[DATA] 位点矩阵 = (2404, 99314)，稀疏度 = 28.10%
[DATA] EvoMat shape: (2404, 2404)
[DATA] 结果已写入 /mnt/qmtang/EvoFill/data/251105_ver4_chr22trim_2/


<src.data.GenotypeEncoder at 0x712ec5c972b0>

In [None]:
gt_enc = GenotypeEncoder.loadfromdisk(work_dir)
print(f'{gt_enc.n_samples:,} samples, {gt_enc.n_variants:,} variants, {gt_enc.seq_depth} seq-depth.')

In [None]:
# ---- 2. 构建推理 Dataset / Loader ----
imp_dataset = ImputationDataset(
    x_gts_sparse=gt_enc.X_gt,
    seq_depth=gt_enc.seq_depth,
    indices=None                 # 可传入指定样本索引
)
imp_dataset.print_missing_stat()          # 查看原始缺失比例

def collate_fn(batch):
    x_onehot = torch.stack([item[0] for item in batch])
    real_idx_list = [item[1] for item in batch]
    return x_onehot, real_idx_list   # 无 y

imp_loader = torch.utils.data.DataLoader(
    imp_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

### 3.3 Inferring

In [None]:
row_list, col_list, data_list = [], [], []   # COO 三元组
n_site = gt_enc.n_variants

with torch.no_grad():
    offset = 0                                  # 全局样本偏移
    for x_onehot, real_idx in tqdm(imp_loader, desc='Imputing'):
        x_onehot = x_onehot.to(device)
        _, prob, _ = model(x_onehot)           # (B, L, 3)
        prob = prob.cpu()

        B, L, _ = prob.shape
        # 1. 缺失位点：直接存预测概率
        miss_mask = (x_onehot[..., -1] == 1)    # (B, L)
        miss_idx  = miss_mask.nonzero(as_tuple=False)  # (N_missing, 2)
        row_list.append(offset + miss_idx[:, 0].numpy())  # 全局行号
        col_list.append(miss_idx[:, 1].numpy())           # 列号（site）
        data_list.append(prob[miss_mask].numpy())         # (N_missing, 3)

        # 2. 观测位点：构造 one-hot 概率
        obs_mask = ~miss_mask
        if obs_mask.any():
            obs_idx = obs_mask.nonzero(as_tuple=False)    # (N_obs, 2)
            gt_obs  = x_onehot[obs_mask].argmax(-1)       # 0/1/2
            eye = torch.eye(3, dtype=torch.float32)
            onehot_p = eye[gt_obs]                        # (N_obs, 3)
            row_list.append(offset + obs_idx[:, 0].numpy())
            col_list.append(obs_idx[:, 1].numpy())
            data_list.append(onehot_p.numpy())

        offset += B

# ---- 4. 拼成整体稀疏矩阵 (B*L, 3) ----
row = np.concatenate(row_list)
col = np.concatenate(col_list)
data = np.concatenate(data_list)        # 已经展平 (N, 3)
prob_sp = sparse.csr_matrix(
    (data.ravel(), np.repeat(row, 3), np.arange(0, len(data)*3+1, 3)),
    shape=(offset * n_site, 3)
)

# ---- 5. 保存 ----
out_dir  = os.path.join(work_dir, 'impute_out')
os.makedirs(out_dir, exist_ok=True)
sparse.save_npz(os.path.join(out_dir, 'impute_prob_sparse.npz'), prob_sp)
print(f'[INF] 稀疏概率矩阵已保存 → {out_dir}/impute_prob_sparse.npz')

### 3.4 Evaluating the imputation results