# EvoFill working demo

ver 4.   new imputation loss with evo loss

ver 4.1  long range modules integrated in stage1 training

ver 4.2  stage 3 fine tuning with under-reprensted population samples.

last update: 2025/11/13

In [23]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 
os.chdir('/mnt/qmtang/EvoFill/')

## 0. Dependency

In [24]:
import sys
import json
import numpy as np
from pathlib import Path
import pandas as pd
import torch
import mamba_ssm
from tqdm import tqdm
from itertools import combinations
from torch.utils.data import DataLoader
from torch.optim import AdamW, SparseAdam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split

print('Python ver:  ', sys.version)
print('Pytorch ver: ', torch.__version__)
print('Mamba ver:   ', mamba_ssm.__version__)
print('GPU in use:  ', torch.cuda.get_device_name(torch.cuda.current_device()))

Python ver:   3.10.19 | packaged by conda-forge | (main, Oct 13 2025, 14:08:27) [GCC 14.3.0]
Pytorch ver:  2.8.0+cu129
Mamba ver:    2.2.5
GPU in use:   NVIDIA H100 PCIe


In [25]:
from src.data import GenotypeEncoder, GenomicDataset, GenomicDataset_Missing, ImputationDataset
from src.model import EvoFill
from src.loss import ImputationLoss, ImputationLoss_Missing
from src.utils import setup_workdir, set_seed, precompute_maf, metrics_by_maf, print_maf_stat_df

In [None]:
work_dir = Path('/mnt/qmtang/EvoFill_data/1kGP_chr22')
setup_workdir(work_dir)
os.chdir(work_dir)

## 1. Pretraining (1kGP)

### 1.1 Dataloader

In [None]:
gt_enc = GenotypeEncoder(phased = False, gts012 = False, save2disk = True, save_dir = Path(work_dir / "pretrain"))
gt_enc = gt_enc.encode_new(vcf_path   = "/mnt/qmtang/EvoFill_data/1kGP_chr22/pretrain/major_pops.vcf.gz" ,
                           default_gt = 'ref',
                           evo_mat    = "/mnt/qmtang/EvoFill_data/1kGP_chr22/pretrain/evo_mat_major_pops.tsv")

print(f"[DATA] {gt_enc.n_samples:,} Samples")
print(f"[DATA] {gt_enc.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc.seq_depth} seq_depth")

[DATA] 总计 15,792 个位点  
[DATA] EvoMat shape: (2222, 2222)
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251118_v4.3/pretrain
[DATA] 位点矩阵 = (2222, 15792)，稀疏度 = 44.01%
[DATA] 位点字典 = {'0|1': 1, '1|1': 2, '0|0': 0, '.|.': 3}，字典深度 = 4
[DATA] 2,222 Samples
[DATA] 15,792 Variants Sites
[DATA] 4 seq_depth


In [None]:
gt_enc = GenotypeEncoder.loadfromdisk(Path(work_dir / "pretrain"))
print(f'{gt_enc.n_samples:,} samples, {gt_enc.n_variants:,} variants, {gt_enc.seq_depth} seq-depth.')

In [6]:
test_n_samples = 128
batch_size    = 16
max_mr        = 0.7
min_mr        = 0.3

x_train_indices, x_test_indices = train_test_split(
    range(gt_enc.n_samples),
    test_size=test_n_samples,
    random_state=3047,
    shuffle=True
)

print(f"{len(x_train_indices):,} samples in train")
print(f"{len(x_test_indices):,} samples in test")
# print(f"evo_mat: {gt_enc.evo_mat.shape}")

train_dataset = GenomicDataset(
    gt_enc,
    evo_mat=gt_enc.evo_mat,
    mask=True,
    masking_rates=(min_mr, max_mr),
    indices=x_train_indices
)

test_dataset = GenomicDataset(
    gt_enc,
    evo_mat=gt_enc.evo_mat,
    mask=True,
    masking_rates=(min_mr, max_mr),
    indices=x_test_indices
)

def collate_fn(batch):
    x_onehot = torch.stack([item[0] for item in batch])
    y_onehot = torch.stack([item[1] for item in batch])
    real_idx_list = [item[2] for item in batch]
    # 提取 evo_mat 子矩阵
    if train_dataset.evo_mat is not None:
        evo_mat_batch = train_dataset.evo_mat[np.ix_(real_idx_list, real_idx_list)]
        evo_mat_batch = torch.FloatTensor(evo_mat_batch)
    else:
        evo_mat_batch = torch.empty(0)

    return x_onehot, y_onehot, evo_mat_batch
    
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

2,094 samples in train
128 samples in test


### 1.2 Model initialization

In [None]:
model_name  = 'chr22'
total_sites = gt_enc.n_variants
alleles     = gt_enc.seq_depth
chunk_size  = 32768
overlap     = 1024
d_model     = 64

set_seed(42)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()
print(f"Using device: {device}")

model = EvoFill(d_model, alleles, total_sites, chunk_size, overlap).to(device)
print(f"model[{model_name}] would have {model.n_chunks} chunks.")

criterion = ImputationLoss(use_r2=True, use_evo=True, r2_weight=1, evo_weight=4, evo_lambda=10)
# criterion = ImputationLoss(use_r2=True, use_evo=False)

meta = {
    "model_name": str(model_name),
    "total_sites": int(total_sites),
    "alleles": int(alleles),
    "chunk_size": int(chunk_size),
    "overlap": int(overlap),
    "d_model": int(d_model)
}
save_path = os.path.join(Path(work_dir / "models"), "model_meta.json")
with open(save_path, "w") as f:
    json.dump(meta, f, indent=4)

Using device: cuda
model[aadr_hg19_chr22] would have 1 chunks.


### 1.3 Chunk Module Training

In [8]:
verbose            = False
max_epochs         = 100
lr                 = 0.001
weight_decay       = 1e-5
earlystop_patience = 11
scheduler_factor   = 0.5
scheduler_patience = 5
scheduler_min_lr   = 1e-8


for cid in range(model.n_chunks):
    chunk_mask = model.chunk_masks[cid].cpu()
    chunk_maf, chunk_bin_cnt = precompute_maf(gt_enc.X_gt[:,chunk_mask.bool().cpu().numpy()].toarray(),  mask_int=gt_enc.seq_depth)
    if verbose:
        print(f"=== Chunk {cid + 1} STAT ===")
        maf_df = pd.DataFrame({
            'MAF_bin': ['(0.00, 0.05)', '(0.05, 0.10)', '(0.10, 0.20)',
                        '(0.20, 0.30)', '(0.30, 0.40)', '(0.40, 0.50)'],
            'Counts':  [f"{c}" for c in chunk_bin_cnt],
        })
        print(maf_df.to_string(index=False))

    # 收集所有稀疏参数（主要是嵌入层）
    sparse_params = []
    dense_params = []
    # 当前chunk的模块参数（密集）
    dense_params.extend(model.chunk_modules[cid].parameters())
    
    # 全局输出层的卷积参数（密集）
    dense_params.extend([model.global_out.w1, model.global_out.b1])
    dense_params.extend([model.global_out.w2, model.global_out.b2])
    # ULR默认启用
    if hasattr(model.global_out, 'ulr_mamba'):
        for name, param in model.global_out.ulr_mamba.named_parameters():
            if 'idx_embed' in name:
                sparse_params.append(param)
            else:
                dense_params.append(param)
    
    # 创建分离优化器
    optim_sparse = SparseAdam(sparse_params, lr=lr) if sparse_params else None
    optim_dense = AdamW(dense_params, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))
    
    # 学习率调度器
    scheduler_sparse = ReduceLROnPlateau(optim_sparse, mode='min', factor=scheduler_factor, 
                        patience=scheduler_patience, min_lr=scheduler_min_lr) if optim_sparse else None
    scheduler_dense = ReduceLROnPlateau(optim_dense, mode='min', factor=scheduler_factor, 
                        patience=scheduler_patience, min_lr=scheduler_min_lr)
    
    best_loss = float('inf')
    patience = earlystop_patience
    patience_counter = 0
    is_early_stopped = False
    train_logs_sum = None
    test_logs_sum   = None
    for epoch in range(max_epochs):
        model.train()
        train_loss = 0.0
        train_prob, train_gts, train_mask = [], [], []

        train_pbar = tqdm(train_loader, desc=f'Chunk {cid + 1}/{model.n_chunks}, Epoch {epoch + 1}/{max_epochs}',) # leave=False
        for batch_idx, (x, target, evo_mat) in enumerate(train_pbar):
            x,  target = x.to(device), target.to(device)
            if evo_mat.numel() == 0:
                evo_mat = None
            else:
                evo_mat = evo_mat.to(device)

            # 清零梯度
            if optim_sparse:
                optim_sparse.zero_grad()
            optim_dense.zero_grad()
            
            # 前向传播
            logits, prob, mask_idx = model(x, cid)
            loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:, mask_idx], evo_mat)
            
            # 反向传播
            loss.backward()
            
            # 更新参数
            if optim_sparse:
                optim_sparse.step()
            optim_dense.step()

            train_loss += loss.item()
            if train_logs_sum is None:          # 第一次初始化
                train_logs_sum = {k: 0.0 for k in logs}
            for k, v in logs.items():
                train_logs_sum[k] += v
            # train_pbar.set_postfix({'loss': loss.item(), 'ce':logs['ce'], 'r2':logs['r2'], 'evo':logs['evo']})

            # === 收集训练结果 ===
            miss_mask = x[:, mask_idx][..., -1].bool()         # 只关心被 mask 的位点
            train_prob.append(prob[:, mask_idx].detach())
            train_gts.append(target[:,mask_idx].detach())
            train_mask.append(miss_mask)

        # 训练集 MAF-acc
        train_prob = torch.cat(train_prob, dim=0)
        train_gts  = torch.cat(train_gts,    dim=0)
        train_mask = torch.cat(train_mask,   dim=0)

        # ----------- 验证循环同理 ------------
        model.eval()
        test_loss = 0.0
        test_prob, test_gts = [], []
        if test_logs_sum is None:
            test_logs_sum = {k: 0.0 for k in train_logs_sum}
        with torch.no_grad():
            for x, target, evo_mat in test_loader:
                x,  target = x.to(device), target.to(device)
                if evo_mat.numel() == 0:
                    evo_mat = None
                else:
                    evo_mat = evo_mat.to(device)
                logits, prob, mask_idx = model(x, cid)
                loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:,mask_idx], evo_mat) 
                test_loss += loss.item()
                for k, v in logs.items():
                    test_logs_sum[k] += v
                test_prob.append(prob[:, mask_idx].detach())
                test_gts.append(target[:,mask_idx].detach())

        test_prob = torch.cat(test_prob, dim=0)
        test_gts    = torch.cat(test_gts,    dim=0)
        avg_train_loss = train_loss / len(train_loader)
        avg_test_loss   = test_loss   / len(test_loader)
        avg_train_logs = {k: v / len(train_loader) for k, v in train_logs_sum.items()}
        avg_test_logs = {k: v / len(test_loader) for k, v in test_logs_sum.items()}
        
        # 更新学习率
        if scheduler_sparse:
            scheduler_sparse.step(avg_test_loss)
        scheduler_dense.step(avg_test_loss)
        
        current_denselr = optim_dense.param_groups[0]['lr']
        current_sparselr = optim_sparse.param_groups[0]['lr'] if optim_sparse else 0

        log_str = (f'Chunk {cid + 1}/{model.n_chunks}, '
            f'Epoch {epoch + 1}/{max_epochs}, '
            f'Total Loss, Train = {avg_train_loss:.1f}, '
            f'Test = {avg_test_loss:.1f}, '
            f'LR: {current_denselr:.2e}')

        log_str += '\n        Train'
        for k, v in avg_train_logs.items():
            log_str += f', {k}: {v:.1f}'
        log_str += '\n        Test '
        for k, v in avg_test_logs.items():
            log_str += f', {k}: {v:.1f}'
        print(log_str)

        # 清空累加器，供下一个 epoch 使用
        train_logs_sum = {k: 0.0 for k in train_logs_sum}
        test_logs_sum   = {k: 0.0 for k in test_logs_sum}
        
        # Early stopping
        if avg_test_loss < best_loss:
            best_loss = avg_test_loss
            patience_counter = 0
            # 只存当前 chunk 专家 + 全局层
            ckpt = {
                'chunk_id': cid,
                'chunk_embed_state': model.chunk_embeds[cid].state_dict(),
                'chunk_module_state': model.chunk_modules[cid].state_dict(),
                'global_out_state': model.global_out.state_dict(),
                'best_test_loss': best_loss,
            }
            torch.save(ckpt, f'{work_dir}/models/{model_name}_chunk[{cid}].pth')
            predres_with_bestloss = (train_prob, train_gts, test_prob, test_gts)
            if verbose:
                train_bins_metrics = metrics_by_maf(train_prob, train_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=train_mask)
                test_bins_metrics   = metrics_by_maf(test_prob,  test_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "test":  test_bins_metrics})
                print(f'  --> updated {model_name}_chunk[{cid}].pth')
        else:
            patience_counter += 1
            if patience_counter >= earlystop_patience:
                is_early_stopped = True
                print(f'Chunk {cid + 1}/{model.n_chunks}, Early stopping triggered')
                train_prob, train_gts, test_prob, test_gts = predres_with_bestloss
                train_bins_metrics = metrics_by_maf(train_prob, train_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=train_mask)
                test_bins_metrics   = metrics_by_maf(test_prob,   test_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "test":   test_bins_metrics})
                break

    if not is_early_stopped:
        train_bins_metrics = metrics_by_maf(train_prob, train_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=train_mask)
        test_bins_metrics   = metrics_by_maf(test_prob,   test_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=None)
        print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "test":   test_bins_metrics})

    best_ckpt_path = f'{work_dir}/models/{model_name}_chunk[{cid}].pth'
    best_ckpt = torch.load(best_ckpt_path, map_location='cpu')
    model.chunk_embeds[cid].load_state_dict(best_ckpt['chunk_embed_state'])
    model.chunk_modules[cid].load_state_dict(best_ckpt['chunk_module_state'])
    model.global_out.load_state_dict(best_ckpt['global_out_state'])
    print(f'  --> Chunk {cid + 1} loaded best weights (test_loss={best_ckpt["best_test_loss"]:.3f})')

    # 清理优化器
    del optim_sparse, optim_dense
    if scheduler_sparse:
        del scheduler_sparse
    del scheduler_dense
    torch.cuda.empty_cache()

# ---------------- 全部 chunk 训练完成 -> 保存完整模型 ----------------
final_ckpt = {
    'model_state': model.state_dict(),
    'n_chunks': model.n_chunks,
    'chunk_size': model.chunk_size,
    'chunk_overlap': model.chunk_overlap,
}
torch.save(final_ckpt, f'{work_dir}/models/{model_name}_stage1.pth')
print(f'==> STAGE1 (Chunk Module) training finished: {work_dir}/models/{model_name}_stage1.pth')

Chunk 1/1, Epoch 1/100: 100%|██████████| 131/131 [01:15<00:00,  1.73it/s]


Chunk 1/1, Epoch 1/100, Total Loss, Train = 120527.3, Test = 86382.3, LR: 1.00e-03
        Train, ce: 120527.3, r2: -4290.3, evo: 10501.1
        Test , ce: 86382.3, r2: -5580.7, evo: 11290.6


Chunk 1/1, Epoch 2/100: 100%|██████████| 131/131 [00:42<00:00,  3.11it/s]


Chunk 1/1, Epoch 2/100, Total Loss, Train = 80379.1, Test = 82244.8, LR: 1.00e-03
        Train, ce: 80379.1, r2: -5744.3, evo: 12102.5
        Test , ce: 82244.8, r2: -6151.1, evo: 11976.8


Chunk 1/1, Epoch 3/100: 100%|██████████| 131/131 [00:42<00:00,  3.08it/s]


Chunk 1/1, Epoch 3/100, Total Loss, Train = 72650.1, Test = 66853.7, LR: 1.00e-03
        Train, ce: 72650.1, r2: -6083.0, evo: 11815.3
        Test , ce: 66853.7, r2: -6482.9, evo: 11368.0


Chunk 1/1, Epoch 4/100: 100%|██████████| 131/131 [00:42<00:00,  3.11it/s]


Chunk 1/1, Epoch 4/100, Total Loss, Train = 67181.4, Test = 63620.5, LR: 1.00e-03
        Train, ce: 67181.4, r2: -6318.6, evo: 11502.5
        Test , ce: 63620.5, r2: -6640.0, evo: 10723.9


Chunk 1/1, Epoch 5/100: 100%|██████████| 131/131 [00:42<00:00,  3.11it/s]


Chunk 1/1, Epoch 5/100, Total Loss, Train = 64276.2, Test = 62745.3, LR: 1.00e-03
        Train, ce: 64276.2, r2: -6439.3, evo: 11257.0
        Test , ce: 62745.3, r2: -6604.5, evo: 11142.8


Chunk 1/1, Epoch 6/100: 100%|██████████| 131/131 [00:42<00:00,  3.11it/s]


Chunk 1/1, Epoch 6/100, Total Loss, Train = 61431.3, Test = 61017.3, LR: 1.00e-03
        Train, ce: 61431.3, r2: -6542.6, evo: 11031.0
        Test , ce: 61017.3, r2: -6890.3, evo: 10676.8


Chunk 1/1, Epoch 7/100: 100%|██████████| 131/131 [00:41<00:00,  3.12it/s]


Chunk 1/1, Epoch 7/100, Total Loss, Train = 59598.6, Test = 56682.3, LR: 1.00e-03
        Train, ce: 59598.6, r2: -6630.2, evo: 10858.6
        Test , ce: 56682.3, r2: -7028.2, evo: 10335.8


Chunk 1/1, Epoch 8/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 8/100, Total Loss, Train = 56972.5, Test = 59138.2, LR: 1.00e-03
        Train, ce: 56972.5, r2: -6723.1, evo: 10612.3
        Test , ce: 59138.2, r2: -6898.9, evo: 10494.4


Chunk 1/1, Epoch 9/100: 100%|██████████| 131/131 [00:41<00:00,  3.13it/s]


Chunk 1/1, Epoch 9/100, Total Loss, Train = 55705.4, Test = 51417.8, LR: 1.00e-03
        Train, ce: 55705.4, r2: -6766.3, evo: 10448.8
        Test , ce: 51417.8, r2: -7120.3, evo: 9721.6


Chunk 1/1, Epoch 10/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 10/100, Total Loss, Train = 54229.2, Test = 52766.5, LR: 1.00e-03
        Train, ce: 54229.2, r2: -6806.8, evo: 10299.0
        Test , ce: 52766.5, r2: -7003.8, evo: 9664.1


Chunk 1/1, Epoch 11/100: 100%|██████████| 131/131 [00:42<00:00,  3.08it/s]


Chunk 1/1, Epoch 11/100, Total Loss, Train = 53164.3, Test = 48600.6, LR: 1.00e-03
        Train, ce: 53164.3, r2: -6843.9, evo: 10193.7
        Test , ce: 48600.6, r2: -7187.2, evo: 9356.5


Chunk 1/1, Epoch 12/100: 100%|██████████| 131/131 [00:43<00:00,  3.03it/s]


Chunk 1/1, Epoch 12/100, Total Loss, Train = 51827.3, Test = 48895.9, LR: 1.00e-03
        Train, ce: 51827.3, r2: -6908.9, evo: 10055.7
        Test , ce: 48895.9, r2: -7162.1, evo: 9568.7


Chunk 1/1, Epoch 13/100: 100%|██████████| 131/131 [00:42<00:00,  3.05it/s]


Chunk 1/1, Epoch 13/100, Total Loss, Train = 50985.4, Test = 48473.5, LR: 1.00e-03
        Train, ce: 50985.4, r2: -6935.7, evo: 9939.9
        Test , ce: 48473.5, r2: -7174.6, evo: 9675.8


Chunk 1/1, Epoch 14/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 14/100, Total Loss, Train = 50100.8, Test = 47107.0, LR: 1.00e-03
        Train, ce: 50100.8, r2: -6976.6, evo: 9851.0
        Test , ce: 47107.0, r2: -7248.4, evo: 9169.7


Chunk 1/1, Epoch 15/100: 100%|██████████| 131/131 [00:42<00:00,  3.08it/s]


Chunk 1/1, Epoch 15/100, Total Loss, Train = 48772.9, Test = 48043.3, LR: 1.00e-03
        Train, ce: 48772.9, r2: -7012.0, evo: 9706.4
        Test , ce: 48043.3, r2: -7262.7, evo: 9467.7


Chunk 1/1, Epoch 16/100: 100%|██████████| 131/131 [00:42<00:00,  3.06it/s]


Chunk 1/1, Epoch 16/100, Total Loss, Train = 48313.0, Test = 47164.7, LR: 1.00e-03
        Train, ce: 48313.0, r2: -7021.8, evo: 9658.4
        Test , ce: 47164.7, r2: -7169.9, evo: 9427.2


Chunk 1/1, Epoch 17/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 17/100, Total Loss, Train = 47312.7, Test = 45793.8, LR: 1.00e-03
        Train, ce: 47312.7, r2: -7078.7, evo: 9527.7
        Test , ce: 45793.8, r2: -7262.7, evo: 9469.9


Chunk 1/1, Epoch 18/100: 100%|██████████| 131/131 [00:42<00:00,  3.06it/s]


Chunk 1/1, Epoch 18/100, Total Loss, Train = 46800.2, Test = 43136.7, LR: 1.00e-03
        Train, ce: 46800.2, r2: -7073.8, evo: 9472.4
        Test , ce: 43136.7, r2: -7341.1, evo: 8847.2


Chunk 1/1, Epoch 19/100: 100%|██████████| 131/131 [00:42<00:00,  3.08it/s]


Chunk 1/1, Epoch 19/100, Total Loss, Train = 46111.3, Test = 48983.7, LR: 1.00e-03
        Train, ce: 46111.3, r2: -7112.4, evo: 9383.6
        Test , ce: 48983.7, r2: -7181.4, evo: 9780.2


Chunk 1/1, Epoch 20/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 20/100, Total Loss, Train = 45790.9, Test = 45293.1, LR: 1.00e-03
        Train, ce: 45790.9, r2: -7143.5, evo: 9325.8
        Test , ce: 45293.1, r2: -7306.2, evo: 9444.2


Chunk 1/1, Epoch 21/100: 100%|██████████| 131/131 [00:42<00:00,  3.05it/s]


Chunk 1/1, Epoch 21/100, Total Loss, Train = 45617.1, Test = 43222.5, LR: 1.00e-03
        Train, ce: 45617.1, r2: -7127.6, evo: 9328.7
        Test , ce: 43222.5, r2: -7344.9, evo: 8679.0


Chunk 1/1, Epoch 22/100: 100%|██████████| 131/131 [00:42<00:00,  3.08it/s]


Chunk 1/1, Epoch 22/100, Total Loss, Train = 44325.7, Test = 45835.8, LR: 1.00e-03
        Train, ce: 44325.7, r2: -7156.8, evo: 9175.4
        Test , ce: 45835.8, r2: -7233.5, evo: 9138.2


Chunk 1/1, Epoch 23/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 23/100, Total Loss, Train = 44742.5, Test = 44976.3, LR: 1.00e-03
        Train, ce: 44742.5, r2: -7169.5, evo: 9238.6
        Test , ce: 44976.3, r2: -7405.2, evo: 8671.2


Chunk 1/1, Epoch 24/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 24/100, Total Loss, Train = 44035.7, Test = 41529.0, LR: 1.00e-03
        Train, ce: 44035.7, r2: -7195.5, evo: 9134.3
        Test , ce: 41529.0, r2: -7438.4, evo: 8464.0


Chunk 1/1, Epoch 25/100: 100%|██████████| 131/131 [00:42<00:00,  3.06it/s]


Chunk 1/1, Epoch 25/100, Total Loss, Train = 43619.6, Test = 44251.7, LR: 1.00e-03
        Train, ce: 43619.6, r2: -7192.3, evo: 9064.3
        Test , ce: 44251.7, r2: -7460.1, evo: 9258.9


Chunk 1/1, Epoch 26/100: 100%|██████████| 131/131 [00:42<00:00,  3.11it/s]


Chunk 1/1, Epoch 26/100, Total Loss, Train = 43063.8, Test = 41002.3, LR: 1.00e-03
        Train, ce: 43063.8, r2: -7250.0, evo: 9009.4
        Test , ce: 41002.3, r2: -7485.5, evo: 8721.3


Chunk 1/1, Epoch 27/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 27/100, Total Loss, Train = 42776.3, Test = 41515.5, LR: 1.00e-03
        Train, ce: 42776.3, r2: -7231.9, evo: 8983.4
        Test , ce: 41515.5, r2: -7459.5, evo: 8947.3


Chunk 1/1, Epoch 28/100: 100%|██████████| 131/131 [00:42<00:00,  3.08it/s]


Chunk 1/1, Epoch 28/100, Total Loss, Train = 42459.9, Test = 41863.1, LR: 1.00e-03
        Train, ce: 42459.9, r2: -7244.3, evo: 8943.7
        Test , ce: 41863.1, r2: -7488.1, evo: 8590.8


Chunk 1/1, Epoch 29/100: 100%|██████████| 131/131 [00:41<00:00,  3.15it/s]


Chunk 1/1, Epoch 29/100, Total Loss, Train = 43047.2, Test = 39950.6, LR: 1.00e-03
        Train, ce: 43047.2, r2: -7240.0, evo: 9020.2
        Test , ce: 39950.6, r2: -7493.7, evo: 8583.1


Chunk 1/1, Epoch 30/100: 100%|██████████| 131/131 [00:42<00:00,  3.08it/s]


Chunk 1/1, Epoch 30/100, Total Loss, Train = 42192.5, Test = 41836.6, LR: 1.00e-03
        Train, ce: 42192.5, r2: -7258.6, evo: 8919.1
        Test , ce: 41836.6, r2: -7491.0, evo: 8744.5


Chunk 1/1, Epoch 31/100: 100%|██████████| 131/131 [00:41<00:00,  3.12it/s]


Chunk 1/1, Epoch 31/100, Total Loss, Train = 42435.7, Test = 42564.0, LR: 1.00e-03
        Train, ce: 42435.7, r2: -7242.5, evo: 8909.4
        Test , ce: 42564.0, r2: -7447.3, evo: 8703.9


Chunk 1/1, Epoch 32/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 32/100, Total Loss, Train = 42050.2, Test = 40922.7, LR: 1.00e-03
        Train, ce: 42050.2, r2: -7257.7, evo: 8893.6
        Test , ce: 40922.7, r2: -7464.1, evo: 8997.9


Chunk 1/1, Epoch 33/100: 100%|██████████| 131/131 [00:42<00:00,  3.12it/s]


Chunk 1/1, Epoch 33/100, Total Loss, Train = 41554.0, Test = 42766.6, LR: 1.00e-03
        Train, ce: 41554.0, r2: -7275.4, evo: 8804.4
        Test , ce: 42766.6, r2: -7528.9, evo: 9241.3


Chunk 1/1, Epoch 34/100: 100%|██████████| 131/131 [00:42<00:00,  3.07it/s]


Chunk 1/1, Epoch 34/100, Total Loss, Train = 41025.5, Test = 37915.8, LR: 1.00e-03
        Train, ce: 41025.5, r2: -7305.5, evo: 8763.4
        Test , ce: 37915.8, r2: -7565.8, evo: 8264.5


Chunk 1/1, Epoch 35/100: 100%|██████████| 131/131 [00:42<00:00,  3.12it/s]


Chunk 1/1, Epoch 35/100, Total Loss, Train = 40916.4, Test = 37253.3, LR: 1.00e-03
        Train, ce: 40916.4, r2: -7317.5, evo: 8753.9
        Test , ce: 37253.3, r2: -7601.1, evo: 8159.0


Chunk 1/1, Epoch 36/100: 100%|██████████| 131/131 [00:41<00:00,  3.13it/s]


Chunk 1/1, Epoch 36/100, Total Loss, Train = 40746.4, Test = 39772.3, LR: 1.00e-03
        Train, ce: 40746.4, r2: -7318.6, evo: 8724.9
        Test , ce: 39772.3, r2: -7492.0, evo: 8441.2


Chunk 1/1, Epoch 37/100: 100%|██████████| 131/131 [00:42<00:00,  3.07it/s]


Chunk 1/1, Epoch 37/100, Total Loss, Train = 40883.3, Test = 39794.6, LR: 1.00e-03
        Train, ce: 40883.3, r2: -7295.1, evo: 8745.9
        Test , ce: 39794.6, r2: -7506.7, evo: 8314.0


Chunk 1/1, Epoch 38/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 38/100, Total Loss, Train = 39733.3, Test = 39021.6, LR: 1.00e-03
        Train, ce: 39733.3, r2: -7352.9, evo: 8581.8
        Test , ce: 39021.6, r2: -7546.8, evo: 8091.2


Chunk 1/1, Epoch 39/100: 100%|██████████| 131/131 [00:42<00:00,  3.08it/s]


Chunk 1/1, Epoch 39/100, Total Loss, Train = 40391.1, Test = 42246.7, LR: 1.00e-03
        Train, ce: 40391.1, r2: -7313.8, evo: 8674.6
        Test , ce: 42246.7, r2: -7439.2, evo: 9076.3


Chunk 1/1, Epoch 40/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 40/100, Total Loss, Train = 39596.4, Test = 38623.2, LR: 1.00e-03
        Train, ce: 39596.4, r2: -7356.5, evo: 8594.6
        Test , ce: 38623.2, r2: -7596.0, evo: 8050.3


Chunk 1/1, Epoch 41/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 41/100, Total Loss, Train = 39994.4, Test = 39691.2, LR: 5.00e-04
        Train, ce: 39994.4, r2: -7331.2, evo: 8623.3
        Test , ce: 39691.2, r2: -7621.0, evo: 8542.6


Chunk 1/1, Epoch 42/100: 100%|██████████| 131/131 [00:42<00:00,  3.06it/s]


Chunk 1/1, Epoch 42/100, Total Loss, Train = 38070.2, Test = 35743.7, LR: 5.00e-04
        Train, ce: 38070.2, r2: -7393.2, evo: 8359.7
        Test , ce: 35743.7, r2: -7675.8, evo: 7955.1


Chunk 1/1, Epoch 43/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 43/100, Total Loss, Train = 37159.6, Test = 35162.7, LR: 5.00e-04
        Train, ce: 37159.6, r2: -7432.9, evo: 8249.3
        Test , ce: 35162.7, r2: -7639.5, evo: 7796.2


Chunk 1/1, Epoch 44/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 44/100, Total Loss, Train = 36851.9, Test = 34954.7, LR: 5.00e-04
        Train, ce: 36851.9, r2: -7430.8, evo: 8215.2
        Test , ce: 34954.7, r2: -7686.0, evo: 7982.7


Chunk 1/1, Epoch 45/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 45/100, Total Loss, Train = 36771.3, Test = 36510.4, LR: 5.00e-04
        Train, ce: 36771.3, r2: -7437.9, evo: 8202.5
        Test , ce: 36510.4, r2: -7632.2, evo: 8242.4


Chunk 1/1, Epoch 46/100: 100%|██████████| 131/131 [00:42<00:00,  3.06it/s]


Chunk 1/1, Epoch 46/100, Total Loss, Train = 36930.4, Test = 34632.0, LR: 5.00e-04
        Train, ce: 36930.4, r2: -7428.4, evo: 8233.4
        Test , ce: 34632.0, r2: -7684.4, evo: 8149.1


Chunk 1/1, Epoch 47/100: 100%|██████████| 131/131 [00:42<00:00,  3.07it/s]


Chunk 1/1, Epoch 47/100, Total Loss, Train = 37327.2, Test = 34197.5, LR: 5.00e-04
        Train, ce: 37327.2, r2: -7440.8, evo: 8291.4
        Test , ce: 34197.5, r2: -7692.5, evo: 7943.8


Chunk 1/1, Epoch 48/100: 100%|██████████| 131/131 [00:44<00:00,  2.96it/s]


Chunk 1/1, Epoch 48/100, Total Loss, Train = 36613.4, Test = 34131.7, LR: 5.00e-04
        Train, ce: 36613.4, r2: -7450.5, evo: 8190.8
        Test , ce: 34131.7, r2: -7726.4, evo: 7737.1


Chunk 1/1, Epoch 49/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 49/100, Total Loss, Train = 36668.1, Test = 38048.2, LR: 5.00e-04
        Train, ce: 36668.1, r2: -7445.3, evo: 8178.3
        Test , ce: 38048.2, r2: -7621.6, evo: 8487.6


Chunk 1/1, Epoch 50/100: 100%|██████████| 131/131 [00:43<00:00,  3.00it/s]


Chunk 1/1, Epoch 50/100, Total Loss, Train = 36906.8, Test = 34559.1, LR: 5.00e-04
        Train, ce: 36906.8, r2: -7437.2, evo: 8237.3
        Test , ce: 34559.1, r2: -7677.8, evo: 7903.4


Chunk 1/1, Epoch 51/100: 100%|██████████| 131/131 [00:41<00:00,  3.13it/s]


Chunk 1/1, Epoch 51/100, Total Loss, Train = 36295.9, Test = 35303.5, LR: 5.00e-04
        Train, ce: 36295.9, r2: -7481.5, evo: 8146.6
        Test , ce: 35303.5, r2: -7699.7, evo: 8018.8


Chunk 1/1, Epoch 52/100: 100%|██████████| 131/131 [00:43<00:00,  3.02it/s]


Chunk 1/1, Epoch 52/100, Total Loss, Train = 35961.0, Test = 34936.7, LR: 5.00e-04
        Train, ce: 35961.0, r2: -7444.9, evo: 8090.2
        Test , ce: 34936.7, r2: -7702.9, evo: 7713.2


Chunk 1/1, Epoch 53/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 53/100, Total Loss, Train = 35734.9, Test = 36771.2, LR: 5.00e-04
        Train, ce: 35734.9, r2: -7479.0, evo: 8064.4
        Test , ce: 36771.2, r2: -7665.1, evo: 8597.8


Chunk 1/1, Epoch 54/100: 100%|██████████| 131/131 [00:43<00:00,  3.04it/s]


Chunk 1/1, Epoch 54/100, Total Loss, Train = 36059.7, Test = 34472.1, LR: 2.50e-04
        Train, ce: 36059.7, r2: -7443.5, evo: 8114.8
        Test , ce: 34472.1, r2: -7707.1, evo: 7735.5


Chunk 1/1, Epoch 55/100: 100%|██████████| 131/131 [00:41<00:00,  3.13it/s]


Chunk 1/1, Epoch 55/100, Total Loss, Train = 34868.5, Test = 32262.8, LR: 2.50e-04
        Train, ce: 34868.5, r2: -7499.2, evo: 7933.6
        Test , ce: 32262.8, r2: -7784.7, evo: 7495.3


Chunk 1/1, Epoch 56/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 56/100, Total Loss, Train = 34689.1, Test = 35879.6, LR: 2.50e-04
        Train, ce: 34689.1, r2: -7512.1, evo: 7935.2
        Test , ce: 35879.6, r2: -7649.0, evo: 7896.9


Chunk 1/1, Epoch 57/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 57/100, Total Loss, Train = 34342.1, Test = 35891.9, LR: 2.50e-04
        Train, ce: 34342.1, r2: -7507.4, evo: 7880.8
        Test , ce: 35891.9, r2: -7681.7, evo: 8124.3


Chunk 1/1, Epoch 58/100: 100%|██████████| 131/131 [00:42<00:00,  3.05it/s]


Chunk 1/1, Epoch 58/100, Total Loss, Train = 35033.4, Test = 32273.6, LR: 2.50e-04
        Train, ce: 35033.4, r2: -7471.6, evo: 7989.3
        Test , ce: 32273.6, r2: -7755.4, evo: 7564.2


Chunk 1/1, Epoch 59/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 59/100, Total Loss, Train = 34459.3, Test = 33960.1, LR: 2.50e-04
        Train, ce: 34459.3, r2: -7496.0, evo: 7894.5
        Test , ce: 33960.1, r2: -7701.0, evo: 7824.4


Chunk 1/1, Epoch 60/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 60/100, Total Loss, Train = 34318.5, Test = 33634.1, LR: 2.50e-04
        Train, ce: 34318.5, r2: -7530.6, evo: 7866.5
        Test , ce: 33634.1, r2: -7725.9, evo: 7785.3


Chunk 1/1, Epoch 61/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 61/100, Total Loss, Train = 33913.6, Test = 33244.3, LR: 1.25e-04
        Train, ce: 33913.6, r2: -7547.3, evo: 7820.5
        Test , ce: 33244.3, r2: -7725.5, evo: 7587.0


Chunk 1/1, Epoch 62/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 62/100, Total Loss, Train = 33160.0, Test = 31765.4, LR: 1.25e-04
        Train, ce: 33160.0, r2: -7545.9, evo: 7690.3
        Test , ce: 31765.4, r2: -7781.6, evo: 7466.5


Chunk 1/1, Epoch 63/100: 100%|██████████| 131/131 [00:42<00:00,  3.05it/s]


Chunk 1/1, Epoch 63/100, Total Loss, Train = 33544.6, Test = 32283.7, LR: 1.25e-04
        Train, ce: 33544.6, r2: -7551.9, evo: 7771.4
        Test , ce: 32283.7, r2: -7780.8, evo: 7523.1


Chunk 1/1, Epoch 64/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 64/100, Total Loss, Train = 33077.5, Test = 31741.3, LR: 1.25e-04
        Train, ce: 33077.5, r2: -7571.9, evo: 7693.1
        Test , ce: 31741.3, r2: -7798.5, evo: 7424.8


Chunk 1/1, Epoch 65/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 65/100, Total Loss, Train = 32999.2, Test = 32322.9, LR: 1.25e-04
        Train, ce: 32999.2, r2: -7564.0, evo: 7689.9
        Test , ce: 32322.9, r2: -7762.1, evo: 7605.2


Chunk 1/1, Epoch 66/100: 100%|██████████| 131/131 [00:41<00:00,  3.14it/s]


Chunk 1/1, Epoch 66/100, Total Loss, Train = 33209.1, Test = 31356.4, LR: 1.25e-04
        Train, ce: 33209.1, r2: -7541.2, evo: 7708.7
        Test , ce: 31356.4, r2: -7798.6, evo: 7451.9


Chunk 1/1, Epoch 67/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 67/100, Total Loss, Train = 33691.5, Test = 30723.8, LR: 1.25e-04
        Train, ce: 33691.5, r2: -7532.6, evo: 7787.5
        Test , ce: 30723.8, r2: -7808.6, evo: 7357.2


Chunk 1/1, Epoch 68/100: 100%|██████████| 131/131 [00:43<00:00,  3.04it/s]


Chunk 1/1, Epoch 68/100, Total Loss, Train = 32678.0, Test = 31461.9, LR: 1.25e-04
        Train, ce: 32678.0, r2: -7582.2, evo: 7639.0
        Test , ce: 31461.9, r2: -7806.5, evo: 7453.5


Chunk 1/1, Epoch 69/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 69/100, Total Loss, Train = 33415.8, Test = 30071.4, LR: 1.25e-04
        Train, ce: 33415.8, r2: -7522.7, evo: 7760.1
        Test , ce: 30071.4, r2: -7839.7, evo: 7221.5


Chunk 1/1, Epoch 70/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 70/100, Total Loss, Train = 33030.5, Test = 33451.7, LR: 1.25e-04
        Train, ce: 33030.5, r2: -7556.3, evo: 7701.5
        Test , ce: 33451.7, r2: -7744.0, evo: 7679.8


Chunk 1/1, Epoch 71/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 71/100, Total Loss, Train = 33194.8, Test = 31211.2, LR: 1.25e-04
        Train, ce: 33194.8, r2: -7544.6, evo: 7719.7
        Test , ce: 31211.2, r2: -7797.2, evo: 7492.6


Chunk 1/1, Epoch 72/100: 100%|██████████| 131/131 [00:43<00:00,  3.05it/s]


Chunk 1/1, Epoch 72/100, Total Loss, Train = 33246.5, Test = 30846.3, LR: 1.25e-04
        Train, ce: 33246.5, r2: -7553.5, evo: 7725.6
        Test , ce: 30846.3, r2: -7825.7, evo: 7381.9


Chunk 1/1, Epoch 73/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 73/100, Total Loss, Train = 33571.9, Test = 33594.1, LR: 1.25e-04
        Train, ce: 33571.9, r2: -7543.5, evo: 7787.6
        Test , ce: 33594.1, r2: -7733.7, evo: 7738.0


Chunk 1/1, Epoch 74/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 74/100, Total Loss, Train = 33439.4, Test = 31939.1, LR: 1.25e-04
        Train, ce: 33439.4, r2: -7555.4, evo: 7763.4
        Test , ce: 31939.1, r2: -7782.9, evo: 7534.4


Chunk 1/1, Epoch 75/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 75/100, Total Loss, Train = 33327.2, Test = 31401.9, LR: 6.25e-05
        Train, ce: 33327.2, r2: -7539.9, evo: 7746.4
        Test , ce: 31401.9, r2: -7821.9, evo: 7494.3


Chunk 1/1, Epoch 76/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 76/100, Total Loss, Train = 32506.5, Test = 31173.4, LR: 6.25e-05
        Train, ce: 32506.5, r2: -7585.8, evo: 7609.0
        Test , ce: 31173.4, r2: -7797.2, evo: 7447.3


Chunk 1/1, Epoch 77/100: 100%|██████████| 131/131 [00:42<00:00,  3.05it/s]


Chunk 1/1, Epoch 77/100, Total Loss, Train = 32810.4, Test = 32009.5, LR: 6.25e-05
        Train, ce: 32810.4, r2: -7553.5, evo: 7661.5
        Test , ce: 32009.5, r2: -7777.5, evo: 7408.2


Chunk 1/1, Epoch 78/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 78/100, Total Loss, Train = 32676.1, Test = 29179.0, LR: 6.25e-05
        Train, ce: 32676.1, r2: -7579.2, evo: 7655.6
        Test , ce: 29179.0, r2: -7876.7, evo: 7117.6


Chunk 1/1, Epoch 79/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 79/100, Total Loss, Train = 32706.7, Test = 31547.8, LR: 6.25e-05
        Train, ce: 32706.7, r2: -7567.8, evo: 7647.7
        Test , ce: 31547.8, r2: -7815.8, evo: 7482.4


Chunk 1/1, Epoch 80/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 80/100, Total Loss, Train = 32409.4, Test = 31407.7, LR: 6.25e-05
        Train, ce: 32409.4, r2: -7608.2, evo: 7605.9
        Test , ce: 31407.7, r2: -7780.4, evo: 7424.3


Chunk 1/1, Epoch 81/100: 100%|██████████| 131/131 [00:42<00:00,  3.05it/s]


Chunk 1/1, Epoch 81/100, Total Loss, Train = 32352.2, Test = 30379.8, LR: 6.25e-05
        Train, ce: 32352.2, r2: -7565.2, evo: 7603.3
        Test , ce: 30379.8, r2: -7818.7, evo: 7275.9


Chunk 1/1, Epoch 82/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 82/100, Total Loss, Train = 32285.4, Test = 30226.7, LR: 6.25e-05
        Train, ce: 32285.4, r2: -7585.1, evo: 7574.2
        Test , ce: 30226.7, r2: -7830.6, evo: 7190.1


Chunk 1/1, Epoch 83/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 83/100, Total Loss, Train = 32446.1, Test = 27905.1, LR: 6.25e-05
        Train, ce: 32446.1, r2: -7580.7, evo: 7613.5
        Test , ce: 27905.1, r2: -7894.7, evo: 6873.5


Chunk 1/1, Epoch 84/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 84/100, Total Loss, Train = 32827.8, Test = 32299.2, LR: 6.25e-05
        Train, ce: 32827.8, r2: -7578.1, evo: 7662.0
        Test , ce: 32299.2, r2: -7773.2, evo: 7578.1


Chunk 1/1, Epoch 85/100: 100%|██████████| 131/131 [00:42<00:00,  3.05it/s]


Chunk 1/1, Epoch 85/100, Total Loss, Train = 32238.5, Test = 33403.9, LR: 6.25e-05
        Train, ce: 32238.5, r2: -7563.6, evo: 7594.6
        Test , ce: 33403.9, r2: -7737.2, evo: 7646.2


Chunk 1/1, Epoch 86/100: 100%|██████████| 131/131 [00:41<00:00,  3.14it/s]


Chunk 1/1, Epoch 86/100, Total Loss, Train = 32854.7, Test = 33008.6, LR: 6.25e-05
        Train, ce: 32854.7, r2: -7567.7, evo: 7670.9
        Test , ce: 33008.6, r2: -7748.7, evo: 7538.2


Chunk 1/1, Epoch 87/100: 100%|██████████| 131/131 [00:41<00:00,  3.14it/s]


Chunk 1/1, Epoch 87/100, Total Loss, Train = 32237.4, Test = 32442.1, LR: 6.25e-05
        Train, ce: 32237.4, r2: -7579.5, evo: 7577.1
        Test , ce: 32442.1, r2: -7767.7, evo: 7636.0


Chunk 1/1, Epoch 88/100: 100%|██████████| 131/131 [00:41<00:00,  3.14it/s]


Chunk 1/1, Epoch 88/100, Total Loss, Train = 32939.8, Test = 31055.6, LR: 6.25e-05
        Train, ce: 32939.8, r2: -7569.0, evo: 7679.6
        Test , ce: 31055.6, r2: -7805.2, evo: 7344.2


Chunk 1/1, Epoch 89/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 89/100, Total Loss, Train = 32006.5, Test = 29926.6, LR: 3.13e-05
        Train, ce: 32006.5, r2: -7597.4, evo: 7547.9
        Test , ce: 29926.6, r2: -7860.4, evo: 7205.5


Chunk 1/1, Epoch 90/100: 100%|██████████| 131/131 [00:43<00:00,  3.05it/s]


Chunk 1/1, Epoch 90/100, Total Loss, Train = 32643.8, Test = 29333.0, LR: 3.13e-05
        Train, ce: 32643.8, r2: -7563.1, evo: 7639.0
        Test , ce: 29333.0, r2: -7844.0, evo: 7062.7


Chunk 1/1, Epoch 91/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 91/100, Total Loss, Train = 32523.3, Test = 32343.8, LR: 3.13e-05
        Train, ce: 32523.3, r2: -7582.1, evo: 7615.6
        Test , ce: 32343.8, r2: -7775.8, evo: 7553.9


Chunk 1/1, Epoch 92/100: 100%|██████████| 131/131 [00:42<00:00,  3.09it/s]


Chunk 1/1, Epoch 92/100, Total Loss, Train = 32620.8, Test = 29271.6, LR: 3.13e-05
        Train, ce: 32620.8, r2: -7585.9, evo: 7636.7
        Test , ce: 29271.6, r2: -7855.9, evo: 7090.8


Chunk 1/1, Epoch 93/100: 100%|██████████| 131/131 [00:42<00:00,  3.11it/s]


Chunk 1/1, Epoch 93/100, Total Loss, Train = 31704.5, Test = 33162.3, LR: 3.13e-05
        Train, ce: 31704.5, r2: -7603.4, evo: 7506.9
        Test , ce: 33162.3, r2: -7752.1, evo: 7754.4


Chunk 1/1, Epoch 94/100: 100%|██████████| 131/131 [00:42<00:00,  3.10it/s]


Chunk 1/1, Epoch 94/100, Total Loss, Train = 32679.5, Test = 30799.4, LR: 3.13e-05
        Train, ce: 32679.5, r2: -7558.9, evo: 7636.7
        Test , ce: 30799.4, r2: -7826.6, evo: 7377.7
Chunk 1/1, Early stopping triggered
     MAF_bin Counts test_Acc test_INFO test_IQS test_MaCH train_Acc train_INFO train_IQS train_MaCH
(0.00, 0.05)   1573    0.990     0.486    0.513     0.857     0.990      0.557     0.613      0.871
(0.05, 0.10)   1310    0.975     0.754    0.793     0.977     0.972      0.732     0.773      0.996
(0.10, 0.20)   2533    0.965     0.837    0.864     0.997     0.961      0.807     0.842      1.000
(0.20, 0.30)   3400    0.949     0.876    0.888     1.000     0.943      0.850     0.873      1.000
(0.30, 0.40)   5601    0.942     0.895    0.899     1.000     0.933      0.871     0.884      1.000
(0.40, 0.50)   1375    0.940     0.896    0.899     1.000     0.931      0.875     0.885      1.000
  --> Chunk 1 loaded best weights (test_loss=27905.074)
==> STAGE1 (Chunk M

### 1.4 Ultra-Long-Range LD Module Training

In [None]:
max_epochs_per_pair = 100
lr                 = 5e-4
weight_decay       = 1e-5
earlystop_patience = 15
batch_size         = 8
verbose            = True

criterion = ImputationLoss(use_r2=True, use_evo=True, r2_weight=1, evo_weight=4, evo_lambda=10)

# ----------- 逐个 chunk 加载权重-----------
# for cid in range(model.n_chunks):
#     chunk_file = f'{work_dir}/models/{model_name}_chunk_{cid}.pth'
#     ckpt = torch.load(chunk_file, map_location='cpu')
#     model.chunk_embeds[cid].load_state_dict(ckpt['chunk_embed_state'])
#     model.chunk_modules[cid].load_state_dict(ckpt['chunk_module_state'])

# ----------- 加载第一阶段完整权重-----------
ckpt = torch.load(f'{work_dir}/models/{model_name}_stage1.pth', map_location='cpu')
model.load_state_dict(ckpt['model_state'])

model.eval()        # chunk 专家冻结（requires_grad=False）

# 收集所有稀疏参数（主要是嵌入层）
sparse_params = []
dense_params = []
# 全局输出层的卷积参数（密集）
dense_params.extend([model.global_out.w1, model.global_out.b1])
dense_params.extend([model.global_out.w2, model.global_out.b2])
# ULR默认启用
if hasattr(model.global_out, 'ulr_mamba'):
    for name, param in model.global_out.ulr_mamba.named_parameters():
        if 'idx_embed' in name:
            sparse_params.append(param)
        else:
            dense_params.append(param)

# 创建分离优化器
optim_sparse = SparseAdam(sparse_params, lr=lr) if sparse_params else None
optim_dense = AdamW(dense_params, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))

# 学习率调度器
scheduler_sparse = ReduceLROnPlateau(optim_sparse, mode='min', factor=0.5, patience=5, min_lr=1e-9) if optim_sparse else None
scheduler_dense = ReduceLROnPlateau(optim_dense, mode='min', factor=0.5, patience=5, min_lr=1e-9)

pair_list = list(combinations(range(model.n_chunks), 2))
np.random.shuffle(pair_list)          # 打乱
total_pairs = len(pair_list)

for pair_idx, (cid1, cid2) in enumerate(pair_list, 1):
    # ====== 构造并集 mask ======
    union_mask = (model.chunk_masks[cid1] + model.chunk_masks[cid2]).clamp(max=1).bool()
    train_logs_sum = None
    val_logs_sum   = None
    
    # 并集 MAF
    union_maf, union_bin_cnt = precompute_maf(
        gt_enc.X_gt[:, union_mask.cpu().numpy()].toarray(),
        mask_int=gt_enc.seq_depth
    )

    # ====== 早停变量 ======
    best_loss = float('inf')
    patience_counter = 0
    is_early_stopped = False

    # ====== 训练循环 ======
    for epoch in range(max_epochs_per_pair):
        model.train()
        train_loss = 0.0
        train_prob, train_gts, train_mask = [], [], []

        pbar = tqdm(train_loader,
                    desc=f'Comb {pair_idx}/{total_pairs}  '
                         f'{cid1+1}-{cid2+1}  Epoch {epoch+1}/{max_epochs_per_pair}',
                    leave=False)
        for x, target, evo_mat in pbar:
            x,  target = x.to(device), target.to(device)
            if evo_mat.numel() == 0:
                evo_mat = None
            else:
                evo_mat = evo_mat.to(device)

            optim_sparse.zero_grad()
            optim_dense.zero_grad()

            logits, prob, mask_idx = model(x, [cid1, cid2])
            loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:,mask_idx], evo_mat) 
            loss.backward()

            optim_sparse.step()   # 只更新嵌入表
            optim_dense.step()    # 更新其余所有参数

            train_loss += loss.item()
            if train_logs_sum is None:          # 第一次初始化
                train_logs_sum = {k: 0.0 for k in logs}
            for k, v in logs.items():
                train_logs_sum[k] += v
            # pbar.set_postfix({'loss': loss.item(), 'ce':logs['ce'], 'r2':logs['r2'], 'evo':logs['evo']})

            # 收集指标
            miss_mask = x[:,union_mask][..., -1].bool()
            train_prob.append(prob[:, mask_idx].detach())
            train_gts.append(target[:,mask_idx].detach())
            train_mask.append(miss_mask)

        # 训练集 MAF
        train_prob = torch.cat(train_prob, dim=0)
        train_gts  = torch.cat(train_gts,    dim=0)
        train_mask = torch.cat(train_mask,   dim=0)

        # ----------- 验证 -----------
        model.eval()
        val_loss = 0.0
        val_prob, val_gts = [], []
        with torch.no_grad():
            if val_logs_sum is None:
                val_logs_sum = {k: 0.0 for k in train_logs_sum}
            for x, target, evo_mat in test_loader:
                x = x.to(device)
                target = target.to(device)
                evo_mat = evo_mat.to(device) if evo_mat.numel() else None
                logits, prob, mask_idx = model(x, [cid1, cid2])
                loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:,mask_idx], evo_mat)
                val_loss += loss.item()
                for k, v in logs.items():
                    val_logs_sum[k] += v
                val_prob.append(prob[:,mask_idx])
                val_gts.append(target[:,mask_idx])

        val_prob = torch.cat(val_prob, dim=0)
        val_gts  = torch.cat(val_gts,    dim=0)

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss   = val_loss   / len(test_loader)
        avg_train_logs = {k: v / len(train_loader) for k, v in train_logs_sum.items()}
        avg_val_logs = {k: v / len(test_loader) for k, v in val_logs_sum.items()}
        
        scheduler_sparse.step(val_loss)
        scheduler_dense.step(val_loss)

        current_denselr = optim_dense.param_groups[0]['lr']
        current_sparselr = optim_sparse.param_groups[0]['lr']

        log_str = (f'Comb {pair_idx}/{total_pairs}  '
            f'{cid1+1}-{cid2+1}  Epoch {epoch+1}/{max_epochs_per_pair} '
            f'Total Loss, Train = {avg_train_loss:.1f}, '
            f'Val = {avg_val_loss:.1f}, '
            f'dense LR: {current_denselr:.2e}, '
            f'sparse LR: {current_sparselr:.2e}')
        log_str += '\n        Train'
        for k, v in avg_train_logs.items():
            log_str += f', {k}: {v:.1f}'
        log_str += '\n        Val  '
        for k, v in avg_val_logs.items():
            log_str += f', {k}: {v:.1f}'
        print(log_str)
        # 清空累加器，供下一个 epoch 使用
        train_logs_sum = {k: 0.0 for k in train_logs_sum}
        val_logs_sum   = {k: 0.0 for k in val_logs_sum}
        # 早停
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            torch.save({
                'comb': (cid1, cid2),
                'global_out': model.global_out.state_dict(),
                'best_val_loss': best_loss,
                'epoch': epoch,
            }, f'{work_dir}/models/{model_name}_chunk[{cid1}-{cid2}].pth')
            # MAF 表格
            predres_with_bestloss = (train_prob, train_gts, val_prob, val_gts)
            if verbose:
                train_bins_metrics = metrics_by_maf(train_prob, train_gts, gt_enc.hap_map, union_maf, mask=train_mask)
                val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, gt_enc.hap_map, union_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "val":   val_bins_metrics})
                print(f'  --> updated {model_name}_chunk[{cid1+1}-{cid2+1}].pth')
        else:
            patience_counter += 1
            if patience_counter >= earlystop_patience:
                is_early_stopped = True
                print(f'Pair {cid1+1}-{cid2+1} early stopping')
                train_prob, train_gts, val_prob, val_gts = predres_with_bestloss
                train_bins_metrics = metrics_by_maf(train_prob, train_gts, gt_enc.hap_map, union_maf, mask=train_mask)
                val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, gt_enc.hap_map, union_maf, mask=None)
                print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "val":   val_bins_metrics})
                break
            
    if not is_early_stopped:
        predres_with_bestloss = (train_prob, train_gts, val_prob, val_gts)
        train_bins_metrics = metrics_by_maf(train_prob, train_gts, gt_enc.hap_map, union_maf, mask=train_mask)
        val_bins_metrics   = metrics_by_maf(val_prob,   val_gts, gt_enc.hap_map, union_maf, mask=None)
        print_maf_stat_df(chunk_bin_cnt,
                      {"train": train_bins_metrics,
                       "val":   val_bins_metrics})

    # del optimizer, scheduler
    torch.cuda.empty_cache()

# ----------- 全部 pair 结束 -> 保存最终模型 -----------
torch.save({
    'model_state': model.state_dict(),
    'ulr_enabled': True,
}, f'{work_dir}/models/{model_name}_stage2_final.pth')
print(f'==> STAGE2 training finished: {work_dir}/models/{model_name}_stage2_final.pth')

## 2. Augmentation (aDNA)

### 2.1 Data loading

In [11]:
gt_enc_aug = GenotypeEncoder(phased=False, gts012=False, save2disk = True, save_dir = Path(work_dir / "augment"))
gt_enc_aug = gt_enc_aug.encode_ref(
        ref_meta_json = work_dir/"pretrain"/"gt_enc_meta.json",   # 与 Stage1 同构
        default_gt    = 'miss',
        vcf_path      = "/mnt/qmtang/EvoFill_data/20251118_v4.3/AADR_chr22_15972sites.hg19.top1000_nearest_BEB_CDX_GIH.vcf.gz",
        evo_mat       = "/mnt/qmtang/EvoFill_data/20251118_v4.3/AADR_chr22_evomat.tsv")

print(f"[DATA] {gt_enc_aug.n_samples:,} Samples")
print(f"[DATA] {gt_enc_aug.n_variants:,} Variants Sites")
print(f"[DATA] {gt_enc_aug.seq_depth} seq_depth")

[DATA] 总计 15,792 个位点  
[DATA] EvoMat shape: (1000, 1000)
[DATA] 结果已写入 /mnt/qmtang/EvoFill_data/20251118_v4.3/augment
[DATA] 位点矩阵 = (1000, 15792)，稀疏度 = 80.47%，缺失率 = 0.00%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': 3, '0/0': 0, '0/1': 1, '1/1': 2, './.': 3}，字典深度 = 4
[DATA] 1,000 Samples
[DATA] 15,792 Variants Sites
[DATA] 4 seq_depth


In [None]:
gt_enc_aug = GenotypeEncoder.loadfromdisk(Path(work_dir / "augment"))
print(f'{gt_enc_aug.n_samples:,} samples, {gt_enc_aug.n_variants:,} variants, {gt_enc_aug.seq_depth} seq-depth.')

### 2.2 Model loading

In [20]:
# %%  载入 Stage-2 最终权重
ckpt = torch.load(f'{work_dir}/models/{model_name}_stage1.pth', map_location='cpu')
model.load_state_dict(ckpt['model_state'])

criterion = ImputationLoss_Missing(use_r2=True, use_evo=False, r2_weight=1)
print('[Stage3] Stage-1 weights loaded.')

[Stage3] Stage-1 weights loaded.


### 2.3 Training

In [21]:
test_n_samples = 128
batch_size    = 16
max_mr        = 0.7
min_mr        = 0.3

x_train_indices, x_test_indices = train_test_split(
    range(gt_enc_aug.n_samples),
    test_size=test_n_samples,
    random_state=3047,
    shuffle=True
)

print(f"{len(x_train_indices):,} samples in train")
print(f"{len(x_test_indices):,} samples in test")
# print(f"evo_mat: {gt_enc.evo_mat.shape}")

train_dataset = GenomicDataset_Missing(
    gt_enc_aug,
    evo_mat=gt_enc_aug.evo_mat,
    mask=True,
    masking_rates=(min_mr, max_mr),
    indices=x_train_indices
)

test_dataset = GenomicDataset_Missing(
    gt_enc_aug,
    evo_mat=gt_enc_aug.evo_mat,
    mask=True,
    masking_rates=(min_mr, max_mr),
    indices=x_test_indices
)

def collate_fn(batch):
    x_onehot = torch.stack([item[0] for item in batch])
    y_onehot = torch.stack([item[1] for item in batch])
    real_idx_list = [item[2] for item in batch]
    # 提取 evo_mat 子矩阵
    if train_dataset.evo_mat is not None:
        evo_mat_batch = train_dataset.evo_mat[np.ix_(real_idx_list, real_idx_list)]
        evo_mat_batch = torch.FloatTensor(evo_mat_batch)
    else:
        evo_mat_batch = torch.empty(0)

    return x_onehot, y_onehot, evo_mat_batch
    
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

872 samples in train
128 samples in test


In [None]:
verbose            = False
max_epochs         = 100
lr                 = 0.001
weight_decay       = 1e-5
earlystop_patience = 11
scheduler_factor   = 0.5
scheduler_patience = 5
scheduler_min_lr   = 1e-8


for cid in range(model.n_chunks):
    chunk_mask = model.chunk_masks[cid].cpu()
    chunk_maf, chunk_bin_cnt = precompute_maf(gt_enc.X_gt[:,chunk_mask.bool().cpu().numpy()].toarray(),  mask_int=gt_enc.seq_depth)
    if verbose:
        print(f"=== Chunk {cid + 1} STAT ===")
        maf_df = pd.DataFrame({
            'MAF_bin': ['(0.00, 0.05)', '(0.05, 0.10)', '(0.10, 0.20)',
                        '(0.20, 0.30)', '(0.30, 0.40)', '(0.40, 0.50)'],
            'Counts':  [f"{c}" for c in chunk_bin_cnt],
        })
        print(maf_df.to_string(index=False))

    # 收集所有稀疏参数（主要是嵌入层）
    sparse_params = []
    dense_params = []
    # 当前chunk的模块参数（密集）
    dense_params.extend(model.chunk_modules[cid].parameters())
    
    # 全局输出层的卷积参数（密集）
    dense_params.extend([model.global_out.w1, model.global_out.b1])
    dense_params.extend([model.global_out.w2, model.global_out.b2])
    # ULR默认启用
    if hasattr(model.global_out, 'ulr_mamba'):
        for name, param in model.global_out.ulr_mamba.named_parameters():
            if 'idx_embed' in name:
                sparse_params.append(param)
            else:
                dense_params.append(param)
    
    # 创建分离优化器
    optim_sparse = SparseAdam(sparse_params, lr=lr) if sparse_params else None
    optim_dense = AdamW(dense_params, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))
    
    # 学习率调度器
    scheduler_sparse = ReduceLROnPlateau(optim_sparse, mode='min', factor=scheduler_factor, 
                        patience=scheduler_patience, min_lr=scheduler_min_lr) if optim_sparse else None
    scheduler_dense = ReduceLROnPlateau(optim_dense, mode='min', factor=scheduler_factor, 
                        patience=scheduler_patience, min_lr=scheduler_min_lr)
    
    best_loss = float('inf')
    patience = earlystop_patience
    patience_counter = 0
    is_early_stopped = False
    train_logs_sum = None
    test_logs_sum   = None
    for epoch in range(max_epochs):
        model.train()
        train_loss = 0.0
        train_prob, train_gts, train_mask = [], [], []

        train_pbar = tqdm(train_loader, desc=f'Chunk {cid + 1}/{model.n_chunks}, Epoch {epoch + 1}/{max_epochs}',) # leave=False
        for batch_idx, (x, target, evo_mat) in enumerate(train_pbar):
            x,  target = x.to(device), target.to(device)
            if evo_mat.numel() == 0:
                evo_mat = None
            else:
                evo_mat = evo_mat.to(device)

            # 清零梯度
            if optim_sparse:
                optim_sparse.zero_grad()
            optim_dense.zero_grad()
            
            # 前向传播
            logits, prob, mask_idx = model(x, cid)
            loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:, mask_idx], evo_mat)
            
            # 反向传播
            loss.backward()
            
            # 更新参数
            if optim_sparse:
                optim_sparse.step()
            optim_dense.step()

            train_loss += loss.item()
            if train_logs_sum is None:          # 第一次初始化
                train_logs_sum = {k: 0.0 for k in logs}
            for k, v in logs.items():
                train_logs_sum[k] += v
            # train_pbar.set_postfix({'loss': loss.item(), 'ce':logs['ce'], 'r2':logs['r2'], 'evo':logs['evo']})

            # === 收集训练结果 ===
            miss_mask = x[:, mask_idx][..., -1].bool()         # 只关心被 mask 的位点
            train_prob.append(prob[:, mask_idx].detach())
            train_gts.append(target[:,mask_idx].detach())
            train_mask.append(miss_mask)

        # 训练集 MAF-acc
        train_prob = torch.cat(train_prob, dim=0)
        train_gts  = torch.cat(train_gts,    dim=0)
        train_mask = torch.cat(train_mask,   dim=0)

        # ----------- 验证循环同理 ------------
        model.eval()
        test_loss = 0.0
        test_prob, test_gts = [], []
        if test_logs_sum is None:
            test_logs_sum = {k: 0.0 for k in train_logs_sum}
        with torch.no_grad():
            for x, target, evo_mat in test_loader:
                x,  target = x.to(device), target.to(device)
                if evo_mat.numel() == 0:
                    evo_mat = None
                else:
                    evo_mat = evo_mat.to(device)
                logits, prob, mask_idx = model(x, cid)
                loss, logs = criterion(logits[:, mask_idx], prob[:, mask_idx], target[:,mask_idx], evo_mat) 
                test_loss += loss.item()
                for k, v in logs.items():
                    test_logs_sum[k] += v
                test_prob.append(prob[:, mask_idx].detach())
                test_gts.append(target[:,mask_idx].detach())

        test_prob = torch.cat(test_prob, dim=0)
        test_gts    = torch.cat(test_gts,    dim=0)
        avg_train_loss = train_loss / len(train_loader)
        avg_test_loss   = test_loss   / len(test_loader)
        avg_train_logs = {k: v / len(train_loader) for k, v in train_logs_sum.items()}
        avg_test_logs = {k: v / len(test_loader) for k, v in test_logs_sum.items()}
        
        # 更新学习率
        if scheduler_sparse:
            scheduler_sparse.step(avg_test_loss)
        scheduler_dense.step(avg_test_loss)
        
        current_denselr = optim_dense.param_groups[0]['lr']
        current_sparselr = optim_sparse.param_groups[0]['lr'] if optim_sparse else 0

        log_str = (f'Chunk {cid + 1}/{model.n_chunks}, '
            f'Epoch {epoch + 1}/{max_epochs}, '
            f'Total Loss, Train = {avg_train_loss:.1f}, '
            f'Test = {avg_test_loss:.1f}, '
            f'LR: {current_denselr:.2e}')

        log_str += '\n        Train'
        for k, v in avg_train_logs.items():
            log_str += f', {k}: {v:.1f}'
        log_str += '\n        Test '
        for k, v in avg_test_logs.items():
            log_str += f', {k}: {v:.1f}'
        print(log_str)

        # 清空累加器，供下一个 epoch 使用
        train_logs_sum = {k: 0.0 for k in train_logs_sum}
        test_logs_sum   = {k: 0.0 for k in test_logs_sum}
        
        # Early stopping
        if avg_test_loss < best_loss:
            best_loss = avg_test_loss
            patience_counter = 0
            # 只存当前 chunk 专家 + 全局层
            ckpt = {
                'chunk_id': cid,
                'chunk_embed_state': model.chunk_embeds[cid].state_dict(),
                'chunk_module_state': model.chunk_modules[cid].state_dict(),
                'global_out_state': model.global_out.state_dict(),
                'best_test_loss': best_loss,
            }
            torch.save(ckpt, f'{work_dir}/models/{model_name}_stage3_chunk[{cid}].pth')
            predres_with_bestloss = (train_prob, train_gts, test_prob, test_gts)
            if verbose:
                # train_bins_metrics = metrics_by_maf(train_prob, train_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=train_mask)
                # test_bins_metrics   = metrics_by_maf(test_prob,  test_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=None)
                # print_maf_stat_df(chunk_bin_cnt,
                #       {"train": train_bins_metrics,
                #        "test":  test_bins_metrics})
                print(f'  --> updated {model_name}_stage3_chunk[{cid}].pth')
        else:
            patience_counter += 1
            if patience_counter >= earlystop_patience:
                is_early_stopped = True
                print(f'Chunk {cid + 1}/{model.n_chunks}, Early stopping triggered')
                # train_prob, train_gts, test_prob, test_gts = predres_with_bestloss
                # train_bins_metrics = metrics_by_maf(train_prob, train_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=train_mask)
                # test_bins_metrics   = metrics_by_maf(test_prob,   test_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=None)
                # print_maf_stat_df(chunk_bin_cnt,
                #       {"train": train_bins_metrics,
                #        "test":   test_bins_metrics})
                break

    # if not is_early_stopped:
    #     train_bins_metrics = metrics_by_maf(train_prob, train_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=train_mask)
    #     test_bins_metrics   = metrics_by_maf(test_prob,   test_gts, hap_map = gt_enc.hap_map, maf_vec = chunk_maf, mask=None)
    #     print_maf_stat_df(chunk_bin_cnt,
    #                   {"train": train_bins_metrics,
    #                    "test":   test_bins_metrics})

    best_ckpt_path = f'{work_dir}/models/{model_name}_stage3_chunk[{cid}].pth'
    best_ckpt = torch.load(best_ckpt_path, map_location='cpu')
    model.chunk_embeds[cid].load_state_dict(best_ckpt['chunk_embed_state'])
    model.chunk_modules[cid].load_state_dict(best_ckpt['chunk_module_state'])
    model.global_out.load_state_dict(best_ckpt['global_out_state'])
    print(f'  --> Chunk {cid + 1} loaded best weights (test_loss={best_ckpt["best_test_loss"]:.3f})')

    # 清理优化器
    del optim_sparse, optim_dense
    if scheduler_sparse:
        del scheduler_sparse
    del scheduler_dense
    torch.cuda.empty_cache()

# ---------------- 全部 chunk 训练完成 -> 保存完整模型 ----------------
final_ckpt = {
    'model_state': model.state_dict(),
    'n_chunks': model.n_chunks,
    'chunk_size': model.chunk_size,
    'chunk_overlap': model.chunk_overlap,
}
torch.save(final_ckpt, f'{work_dir}/models/{model_name}_stage3.pth')
print(f'==> STAGE1 (Chunk Module) training finished: {work_dir}/models/{model_name}_stage3.pth')

Chunk 1/1, Epoch 1/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 1/100, Total Loss, Train = 190986.5, Test = 167024.9, LR: 1.00e-03
        Train, ce: 190986.5, r2: 60370.5, evo: 0.0
        Test , ce: 167024.9, r2: 30550.4, evo: 0.0


Chunk 1/1, Epoch 2/100: 100%|██████████| 55/55 [00:11<00:00,  4.82it/s]


Chunk 1/1, Epoch 2/100, Total Loss, Train = 159780.4, Test = 162852.4, LR: 1.00e-03
        Train, ce: 159780.4, r2: 41093.5, evo: 0.0
        Test , ce: 162852.4, r2: 24829.3, evo: 0.0


Chunk 1/1, Epoch 3/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 3/100, Total Loss, Train = 156245.2, Test = 163274.3, LR: 1.00e-03
        Train, ce: 156245.2, r2: 40025.4, evo: 0.0
        Test , ce: 163274.3, r2: 22934.9, evo: 0.0


Chunk 1/1, Epoch 4/100: 100%|██████████| 55/55 [00:11<00:00,  4.81it/s]


Chunk 1/1, Epoch 4/100, Total Loss, Train = 153915.3, Test = 156155.9, LR: 1.00e-03
        Train, ce: 153915.3, r2: 40296.4, evo: 0.0
        Test , ce: 156155.9, r2: 27817.6, evo: 0.0


Chunk 1/1, Epoch 5/100: 100%|██████████| 55/55 [00:11<00:00,  4.81it/s]


Chunk 1/1, Epoch 5/100, Total Loss, Train = 151412.4, Test = 153858.9, LR: 1.00e-03
        Train, ce: 151412.4, r2: 39744.9, evo: 0.0
        Test , ce: 153858.9, r2: 27504.6, evo: 0.0


Chunk 1/1, Epoch 6/100: 100%|██████████| 55/55 [00:11<00:00,  4.82it/s]


Chunk 1/1, Epoch 6/100, Total Loss, Train = 150589.7, Test = 151266.2, LR: 1.00e-03
        Train, ce: 150589.7, r2: 39605.1, evo: 0.0
        Test , ce: 151266.2, r2: 31627.3, evo: 0.0


Chunk 1/1, Epoch 7/100: 100%|██████████| 55/55 [00:11<00:00,  4.82it/s]


Chunk 1/1, Epoch 7/100, Total Loss, Train = 149373.8, Test = 152286.2, LR: 1.00e-03
        Train, ce: 149373.8, r2: 39785.9, evo: 0.0
        Test , ce: 152286.2, r2: 29077.2, evo: 0.0


Chunk 1/1, Epoch 8/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 8/100, Total Loss, Train = 148687.6, Test = 150371.3, LR: 1.00e-03
        Train, ce: 148687.6, r2: 39614.9, evo: 0.0
        Test , ce: 150371.3, r2: 30935.9, evo: 0.0


Chunk 1/1, Epoch 9/100: 100%|██████████| 55/55 [00:11<00:00,  4.81it/s]


Chunk 1/1, Epoch 9/100, Total Loss, Train = 148327.9, Test = 152531.7, LR: 1.00e-03
        Train, ce: 148327.9, r2: 39621.5, evo: 0.0
        Test , ce: 152531.7, r2: 30649.6, evo: 0.0


Chunk 1/1, Epoch 10/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 10/100, Total Loss, Train = 148236.1, Test = 150289.5, LR: 1.00e-03
        Train, ce: 148236.1, r2: 39726.4, evo: 0.0
        Test , ce: 150289.5, r2: 34000.2, evo: 0.0


Chunk 1/1, Epoch 11/100: 100%|██████████| 55/55 [00:11<00:00,  4.79it/s]


Chunk 1/1, Epoch 11/100, Total Loss, Train = 147266.4, Test = 148462.7, LR: 1.00e-03
        Train, ce: 147266.4, r2: 39185.5, evo: 0.0
        Test , ce: 148462.7, r2: 31742.5, evo: 0.0


Chunk 1/1, Epoch 12/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 12/100, Total Loss, Train = 147063.1, Test = 151408.5, LR: 1.00e-03
        Train, ce: 147063.1, r2: 39592.9, evo: 0.0
        Test , ce: 151408.5, r2: 26994.4, evo: 0.0


Chunk 1/1, Epoch 13/100: 100%|██████████| 55/55 [00:11<00:00,  4.79it/s]


Chunk 1/1, Epoch 13/100, Total Loss, Train = 147040.6, Test = 148724.5, LR: 1.00e-03
        Train, ce: 147040.6, r2: 39326.5, evo: 0.0
        Test , ce: 148724.5, r2: 36312.4, evo: 0.0


Chunk 1/1, Epoch 14/100: 100%|██████████| 55/55 [00:11<00:00,  4.79it/s]


Chunk 1/1, Epoch 14/100, Total Loss, Train = 146496.1, Test = 147788.8, LR: 1.00e-03
        Train, ce: 146496.1, r2: 39472.8, evo: 0.0
        Test , ce: 147788.8, r2: 32030.1, evo: 0.0


Chunk 1/1, Epoch 15/100: 100%|██████████| 55/55 [00:11<00:00,  4.81it/s]


Chunk 1/1, Epoch 15/100, Total Loss, Train = 146072.2, Test = 149599.6, LR: 1.00e-03
        Train, ce: 146072.2, r2: 39204.1, evo: 0.0
        Test , ce: 149599.6, r2: 28820.4, evo: 0.0


Chunk 1/1, Epoch 16/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 16/100, Total Loss, Train = 145780.8, Test = 149371.7, LR: 1.00e-03
        Train, ce: 145780.8, r2: 39277.3, evo: 0.0
        Test , ce: 149371.7, r2: 32808.2, evo: 0.0


Chunk 1/1, Epoch 17/100: 100%|██████████| 55/55 [00:11<00:00,  4.79it/s]


Chunk 1/1, Epoch 17/100, Total Loss, Train = 145493.2, Test = 149614.1, LR: 1.00e-03
        Train, ce: 145493.2, r2: 39427.3, evo: 0.0
        Test , ce: 149614.1, r2: 26761.0, evo: 0.0


Chunk 1/1, Epoch 18/100: 100%|██████████| 55/55 [00:11<00:00,  4.79it/s]


Chunk 1/1, Epoch 18/100, Total Loss, Train = 146274.2, Test = 148207.5, LR: 1.00e-03
        Train, ce: 146274.2, r2: 39440.8, evo: 0.0
        Test , ce: 148207.5, r2: 30226.5, evo: 0.0


Chunk 1/1, Epoch 19/100: 100%|██████████| 55/55 [00:11<00:00,  4.81it/s]


Chunk 1/1, Epoch 19/100, Total Loss, Train = 145723.7, Test = 148526.8, LR: 1.00e-03
        Train, ce: 145723.7, r2: 39466.8, evo: 0.0
        Test , ce: 148526.8, r2: 28796.2, evo: 0.0


Chunk 1/1, Epoch 20/100: 100%|██████████| 55/55 [00:11<00:00,  4.78it/s]


Chunk 1/1, Epoch 20/100, Total Loss, Train = 145755.6, Test = 149820.2, LR: 5.00e-04
        Train, ce: 145755.6, r2: 39411.4, evo: 0.0
        Test , ce: 149820.2, r2: 30132.6, evo: 0.0


Chunk 1/1, Epoch 21/100: 100%|██████████| 55/55 [00:11<00:00,  4.81it/s]


Chunk 1/1, Epoch 21/100, Total Loss, Train = 145325.3, Test = 145563.5, LR: 5.00e-04
        Train, ce: 145325.3, r2: 39228.8, evo: 0.0
        Test , ce: 145563.5, r2: 32847.8, evo: 0.0


Chunk 1/1, Epoch 22/100: 100%|██████████| 55/55 [00:11<00:00,  4.81it/s]


Chunk 1/1, Epoch 22/100, Total Loss, Train = 144090.2, Test = 145922.6, LR: 5.00e-04
        Train, ce: 144090.2, r2: 38949.4, evo: 0.0
        Test , ce: 145922.6, r2: 33702.1, evo: 0.0


Chunk 1/1, Epoch 23/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 23/100, Total Loss, Train = 144181.4, Test = 145680.7, LR: 5.00e-04
        Train, ce: 144181.4, r2: 39512.6, evo: 0.0
        Test , ce: 145680.7, r2: 32171.7, evo: 0.0


Chunk 1/1, Epoch 24/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 24/100, Total Loss, Train = 144638.4, Test = 146413.2, LR: 5.00e-04
        Train, ce: 144638.4, r2: 39226.5, evo: 0.0
        Test , ce: 146413.2, r2: 32184.8, evo: 0.0


Chunk 1/1, Epoch 25/100: 100%|██████████| 55/55 [00:11<00:00,  4.79it/s]


Chunk 1/1, Epoch 25/100, Total Loss, Train = 144822.6, Test = 145109.5, LR: 5.00e-04
        Train, ce: 144822.6, r2: 38994.4, evo: 0.0
        Test , ce: 145109.5, r2: 35508.5, evo: 0.0


Chunk 1/1, Epoch 26/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 26/100, Total Loss, Train = 143922.9, Test = 145727.8, LR: 5.00e-04
        Train, ce: 143922.9, r2: 39227.7, evo: 0.0
        Test , ce: 145727.8, r2: 33313.2, evo: 0.0


Chunk 1/1, Epoch 27/100: 100%|██████████| 55/55 [00:11<00:00,  4.79it/s]


Chunk 1/1, Epoch 27/100, Total Loss, Train = 144173.7, Test = 145802.3, LR: 5.00e-04
        Train, ce: 144173.7, r2: 39287.5, evo: 0.0
        Test , ce: 145802.3, r2: 33722.6, evo: 0.0


Chunk 1/1, Epoch 28/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 28/100, Total Loss, Train = 144055.0, Test = 145036.6, LR: 5.00e-04
        Train, ce: 144055.0, r2: 39129.4, evo: 0.0
        Test , ce: 145036.6, r2: 32260.0, evo: 0.0


Chunk 1/1, Epoch 29/100: 100%|██████████| 55/55 [00:11<00:00,  4.77it/s]


Chunk 1/1, Epoch 29/100, Total Loss, Train = 144167.2, Test = 146024.6, LR: 5.00e-04
        Train, ce: 144167.2, r2: 39140.2, evo: 0.0
        Test , ce: 146024.6, r2: 33255.7, evo: 0.0


Chunk 1/1, Epoch 30/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 30/100, Total Loss, Train = 143661.0, Test = 145554.3, LR: 5.00e-04
        Train, ce: 143661.0, r2: 39074.5, evo: 0.0
        Test , ce: 145554.3, r2: 36269.8, evo: 0.0


Chunk 1/1, Epoch 31/100: 100%|██████████| 55/55 [00:11<00:00,  4.81it/s]


Chunk 1/1, Epoch 31/100, Total Loss, Train = 143859.1, Test = 145446.0, LR: 5.00e-04
        Train, ce: 143859.1, r2: 39090.6, evo: 0.0
        Test , ce: 145446.0, r2: 39625.4, evo: 0.0


Chunk 1/1, Epoch 32/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 32/100, Total Loss, Train = 144367.9, Test = 146265.3, LR: 5.00e-04
        Train, ce: 144367.9, r2: 39469.8, evo: 0.0
        Test , ce: 146265.3, r2: 34358.9, evo: 0.0


Chunk 1/1, Epoch 33/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 33/100, Total Loss, Train = 143894.5, Test = 147259.5, LR: 5.00e-04
        Train, ce: 143894.5, r2: 39221.6, evo: 0.0
        Test , ce: 147259.5, r2: 31247.3, evo: 0.0


Chunk 1/1, Epoch 34/100: 100%|██████████| 55/55 [00:11<00:00,  4.81it/s]


Chunk 1/1, Epoch 34/100, Total Loss, Train = 144320.5, Test = 145375.6, LR: 2.50e-04
        Train, ce: 144320.5, r2: 39111.1, evo: 0.0
        Test , ce: 145375.6, r2: 32696.6, evo: 0.0


Chunk 1/1, Epoch 35/100: 100%|██████████| 55/55 [00:11<00:00,  4.81it/s]


Chunk 1/1, Epoch 35/100, Total Loss, Train = 143575.3, Test = 143868.1, LR: 2.50e-04
        Train, ce: 143575.3, r2: 38999.5, evo: 0.0
        Test , ce: 143868.1, r2: 36598.1, evo: 0.0


Chunk 1/1, Epoch 36/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 36/100, Total Loss, Train = 143020.4, Test = 144089.4, LR: 2.50e-04
        Train, ce: 143020.4, r2: 39111.3, evo: 0.0
        Test , ce: 144089.4, r2: 37500.4, evo: 0.0


Chunk 1/1, Epoch 37/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 37/100, Total Loss, Train = 143383.4, Test = 144967.6, LR: 2.50e-04
        Train, ce: 143383.4, r2: 39071.8, evo: 0.0
        Test , ce: 144967.6, r2: 36052.3, evo: 0.0


Chunk 1/1, Epoch 38/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 38/100, Total Loss, Train = 143383.2, Test = 142912.5, LR: 2.50e-04
        Train, ce: 143383.2, r2: 39288.7, evo: 0.0
        Test , ce: 142912.5, r2: 36705.7, evo: 0.0


Chunk 1/1, Epoch 39/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 39/100, Total Loss, Train = 142784.1, Test = 144883.6, LR: 2.50e-04
        Train, ce: 142784.1, r2: 39024.9, evo: 0.0
        Test , ce: 144883.6, r2: 35617.7, evo: 0.0


Chunk 1/1, Epoch 40/100: 100%|██████████| 55/55 [00:11<00:00,  4.79it/s]


Chunk 1/1, Epoch 40/100, Total Loss, Train = 142265.3, Test = 144190.4, LR: 2.50e-04
        Train, ce: 142265.3, r2: 39108.4, evo: 0.0
        Test , ce: 144190.4, r2: 35709.3, evo: 0.0


Chunk 1/1, Epoch 41/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 41/100, Total Loss, Train = 143051.2, Test = 144734.2, LR: 2.50e-04
        Train, ce: 143051.2, r2: 39027.5, evo: 0.0
        Test , ce: 144734.2, r2: 38441.3, evo: 0.0


Chunk 1/1, Epoch 42/100: 100%|██████████| 55/55 [00:11<00:00,  4.79it/s]


Chunk 1/1, Epoch 42/100, Total Loss, Train = 142923.4, Test = 144159.4, LR: 2.50e-04
        Train, ce: 142923.4, r2: 39382.1, evo: 0.0
        Test , ce: 144159.4, r2: 34075.6, evo: 0.0


Chunk 1/1, Epoch 43/100: 100%|██████████| 55/55 [00:11<00:00,  4.78it/s]


Chunk 1/1, Epoch 43/100, Total Loss, Train = 142947.3, Test = 144624.6, LR: 2.50e-04
        Train, ce: 142947.3, r2: 39066.0, evo: 0.0
        Test , ce: 144624.6, r2: 37303.1, evo: 0.0


Chunk 1/1, Epoch 44/100: 100%|██████████| 55/55 [00:11<00:00,  4.79it/s]


Chunk 1/1, Epoch 44/100, Total Loss, Train = 143089.0, Test = 143889.2, LR: 1.25e-04
        Train, ce: 143089.0, r2: 39118.8, evo: 0.0
        Test , ce: 143889.2, r2: 34093.8, evo: 0.0


Chunk 1/1, Epoch 45/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 45/100, Total Loss, Train = 142329.8, Test = 143452.9, LR: 1.25e-04
        Train, ce: 142329.8, r2: 38867.8, evo: 0.0
        Test , ce: 143452.9, r2: 37415.4, evo: 0.0


Chunk 1/1, Epoch 46/100: 100%|██████████| 55/55 [00:11<00:00,  4.82it/s]


Chunk 1/1, Epoch 46/100, Total Loss, Train = 142555.0, Test = 143354.4, LR: 1.25e-04
        Train, ce: 142555.0, r2: 39321.2, evo: 0.0
        Test , ce: 143354.4, r2: 37591.4, evo: 0.0


Chunk 1/1, Epoch 47/100: 100%|██████████| 55/55 [00:11<00:00,  4.81it/s]


Chunk 1/1, Epoch 47/100, Total Loss, Train = 142646.1, Test = 144497.9, LR: 1.25e-04
        Train, ce: 142646.1, r2: 39073.2, evo: 0.0
        Test , ce: 144497.9, r2: 35796.9, evo: 0.0


Chunk 1/1, Epoch 48/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 48/100, Total Loss, Train = 142403.3, Test = 143236.3, LR: 1.25e-04
        Train, ce: 142403.3, r2: 39164.3, evo: 0.0
        Test , ce: 143236.3, r2: 35914.6, evo: 0.0


Chunk 1/1, Epoch 49/100: 100%|██████████| 55/55 [00:11<00:00,  4.80it/s]


Chunk 1/1, Epoch 49/100, Total Loss, Train = 142245.2, Test = 143638.8, LR: 1.25e-04
        Train, ce: 142245.2, r2: 38802.6, evo: 0.0
        Test , ce: 143638.8, r2: 39411.5, evo: 0.0
Chunk 1/1, Early stopping triggered


ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 3 is different from 4)

## 3. Fine-tuning (Few-shot URP)

In [7]:
# %%  载入 URP 微调数据
gt_enc_urp = GenotypeEncoder(phased=False, gts012=False, save2disk=False)
gt_enc_urp = gt_enc_urp.encode_ref(
        ref_meta_json = work_dir/"pre_train"/"gt_enc_meta.json",   # 与 Stage1 同构
        vcf_path      = work_dir/"urp_finetune"/"minor_pops.10pct.vcf.gz",
        evo_mat       = work_dir/"urp_finetune"/"evo_mat_minor_pops.10pct.tsv")

print(f'[URP] {gt_enc_urp.n_samples} samples, {gt_enc_urp.n_variants} variants')

[DATA] 总计 99,314 个位点  
[DATA] EvoMat shape: (16, 16)
[DATA] 位点矩阵 = (16, 99314)，稀疏度 = 27.35%，缺失率 = 0.00%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': 3}，字典深度 = 4
[URP] 16 samples, 99314 variants


In [8]:
from sklearn.model_selection import KFold
# %%  Stage-3 超参与配置
model_name  = 'hg19_chr22trim'
stage3_tag       = 'stage3'
max_epochs       = 50
warmup_epochs    = 3
lr_dense         = 1e-4          # GlobalOut 中稠密参数
lr_sparse        = 5e-5          # ULR 中的 idx_embed
weight_decay     = 1e-4
earlystop_pat    = 9
mask_rate_range  = (0.2, 0.6)    # 数据增强：随机缺失率
k_fold           = 5             # 交叉验证
batch_size       = 4             # 样本少，用小 batch
accumulate_grad  = 2             # 梯度累加，等效 batch=8

# %%  重新建立「微调」Dataset / Loader
urp_dataset = GenomicDataset(
        gt_enc_urp.X_gt,
        evo_mat      = gt_enc_urp.evo_mat,
        seq_depth    = gt_enc_urp.seq_depth,
        mask         = True,
        masking_rates= mask_rate_range,
        indices      = None)               # 全部用于微调

def collate_fn(batch):
    x = torch.stack([b[0] for b in batch])
    y = torch.stack([b[1] for b in batch])
    idx = [b[2] for b in batch]
    if gt_enc_urp.evo_mat is not None:
        evo = gt_enc_urp.evo_mat[np.ix_(idx, idx)]
        evo = torch.FloatTensor(evo)
    else:
        evo = torch.empty(0)
    return x, y, evo

urp_loader = DataLoader(urp_dataset, batch_size=batch_size,
                        shuffle=True, num_workers=4,
                        pin_memory=True, collate_fn=collate_fn)

# 1. 准备 URP 数据
urp_idx = np.arange(gt_enc_urp.n_samples)
kf = KFold(n_splits=k_fold, shuffle=True, random_state=42)


In [9]:
# %%  载入 Stage-2 最终权重
ckpt = torch.load(f'{work_dir}/models/{model_name}_stage1.pth', map_location='cpu')
model.load_state_dict(ckpt['model_state'])
print('[Stage3] Stage-2 weights loaded.')

# %%  参数分组 & 优化器
for p in model.parameters():                # 先全部冻结
    p.requires_grad = False

# 只解冻需要的部分
trainable_dense, trainable_sparse = [], []
# 1. GlobalOut 全部
for name, p in model.global_out.named_parameters():
    if 'idx_embed' in name:
        trainable_sparse.append(p)
    else:
        trainable_dense.append(p)
# 2. Chunk-Embedding（可选，若显存紧张可留冻）
for emb in model.chunk_embeds:
    for p in emb.parameters():
        trainable_dense.append(p)

for p in trainable_dense+trainable_sparse:
    p.requires_grad = True

opt_dense  = AdamW(trainable_dense,  lr=lr_dense,
                   weight_decay=weight_decay, betas=(0.9, 0.999))
opt_sparse = SparseAdam(trainable_sparse, lr=lr_sparse)

# 余弦退火 + 热身
def lr_lambda(epoch):
    if epoch < warmup_epochs:
        return epoch / warmup_epochs
    return 0.5*(1+np.cos(np.pi*(epoch-warmup_epochs)/(max_epochs-warmup_epochs)))

sched_dense  = torch.optim.lr_scheduler.LambdaLR(opt_dense,  lr_lambda)
sched_sparse = torch.optim.lr_scheduler.LambdaLR(opt_sparse, lr_lambda)

# %%  训练循环
criterion = ImputationLoss(use_r2=True, use_evo=True,
                           r2_weight=1, evo_weight=4, evo_lambda=10)

best_avg_val_loss, patience_counter = np.inf, 0

for epoch in range(max_epochs):
    model.train()
    fold_val_loss = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(urp_idx)):
        # ---- 当前折数据 ----
        train_dataset = GenomicDataset(
                gt_enc_urp.X_gt, evo_mat=gt_enc_urp.evo_mat,
                seq_depth=gt_enc_urp.seq_depth, mask=True,
                masking_rates=(0.2, 0.6), indices=train_idx)
        val_dataset   = GenomicDataset(
                gt_enc_urp.X_gt, evo_mat=gt_enc_urp.evo_mat,
                seq_depth=gt_enc_urp.seq_depth, mask=True,
                masking_rates=(0.2, 0.6), indices=val_idx)

        train_loader = DataLoader(train_dataset, batch_size=8,
                                  shuffle=True, num_workers=2,
                                  collate_fn=collate_fn, pin_memory=True)
        val_loader   = DataLoader(val_dataset, batch_size=8,
                                  shuffle=False, num_workers=2,
                                  collate_fn=collate_fn, pin_memory=True)

        # ---- 训练 ----
        for step, (x, y, evo) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            evo  = evo.to(device) if evo.numel() else None

            logits, prob, mask_idx = model(x)
            loss, _ = criterion(logits[:, mask_idx], prob[:, mask_idx],
                                y[:, mask_idx], evo)
            loss.backward()

            if (step+1) % accumulate_grad == 0 or (step+1) == len(train_loader):
                opt_dense.step(); opt_sparse.step()
                opt_dense.zero_grad(set_to_none=True); opt_sparse.zero_grad(set_to_none=True)

        # ---- 验证 ----
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, y, evo in val_loader:
                x, y = x.to(device), y.to(device)
                evo  = evo.to(device) if evo.numel() else None
                logits, prob, mask_idx = model(x)
                loss, _ = criterion(logits[:, mask_idx], prob[:, mask_idx],
                                    y[:, mask_idx], evo)
                val_loss += loss.item()
        fold_val_loss.append(val_loss / len(val_loader))

    # ---- epoch 级日志 & 调度 ----
    avg_val_loss = np.mean(fold_val_loss)
    print(f'Epoch {epoch+1}: avg_val_loss={avg_val_loss:.3f}, '
          f'lr_dense={opt_dense.param_groups[0]["lr"]:.1e}')
    sched_dense.step(); sched_sparse.step()

    # ---- 早停 ----
    if avg_val_loss < best_avg_val_loss:
        best_avg_val_loss = avg_val_loss
        patience_counter = 0
        torch.save({'model_state': model.state_dict(),
                    'epoch': epoch,
                    'avg_val_loss': avg_val_loss},
                   f'{work_dir}/models/{model_name}_{stage3_tag}_best.pth')
    else:
        patience_counter += 1
        if patience_counter >= earlystop_pat:
            print('Early stopping triggered.')
            break

torch.save({'model_state': model.state_dict(),
            'stage3_tag': stage3_tag},
           f'{work_dir}/models/{model_name}_{stage3_tag}_final.pth')
print(f'==> Stage-3 KFold-loss fine-tuning finished. Best avg_val_loss={best_avg_val_loss:.3f}')

[Stage3] Stage-2 weights loaded.
Epoch 1: avg_val_loss=68964.580, lr_dense=0.0e+00
Epoch 2: avg_val_loss=60720.118, lr_dense=3.3e-05
Epoch 3: avg_val_loss=49563.400, lr_dense=6.7e-05
Epoch 4: avg_val_loss=43253.447, lr_dense=1.0e-04
Epoch 5: avg_val_loss=41547.148, lr_dense=1.0e-04
Epoch 6: avg_val_loss=34705.204, lr_dense=1.0e-04
Epoch 7: avg_val_loss=34851.017, lr_dense=9.9e-05
Epoch 8: avg_val_loss=31193.683, lr_dense=9.8e-05
Epoch 9: avg_val_loss=25390.002, lr_dense=9.7e-05
Epoch 10: avg_val_loss=32677.517, lr_dense=9.6e-05
Epoch 11: avg_val_loss=29019.098, lr_dense=9.5e-05
Epoch 12: avg_val_loss=25310.655, lr_dense=9.3e-05
Epoch 13: avg_val_loss=30964.865, lr_dense=9.1e-05
Epoch 14: avg_val_loss=22583.397, lr_dense=8.9e-05
Epoch 15: avg_val_loss=26063.507, lr_dense=8.7e-05
Epoch 16: avg_val_loss=21442.901, lr_dense=8.5e-05
Epoch 17: avg_val_loss=25226.832, lr_dense=8.2e-05
Epoch 18: avg_val_loss=21774.188, lr_dense=8.0e-05
Epoch 19: avg_val_loss=22828.542, lr_dense=7.7e-05
Epoch 2

## 4. Imputation

### 3.1 Load the trained model

Choose a path where including `<work_dir>/model` and have trained model.

In [None]:
print(f"Work Dir: {work_dir}")

# ---- 1. 加载模型 ----
device = 'cuda' if torch.cuda.is_available() else 'cpu'

json_path = f"{work_dir}/models/model_meta.json"
meta = json.load(open(json_path))
model = EvoFill(
    d_model=int(meta["d_model"]),
    n_alleles=int(meta["alleles"]),
    total_sites=int(meta["total_sites"]),
    chunk_size=int(meta["chunk_size"]),
    chunk_overlap=int(meta["overlap"])
).to(device)

# ckpt = torch.load(f'{work_dir}/models/{meta["model_name"]}_stage1.pth', map_location=device)
# ckpt = torch.load(f'{work_dir}/models/{meta["model_name"]}_stage2_best.pth', map_location=device)
ckpt = torch.load(f'{work_dir}/models/{meta["model_name"]}_stage3_best.pth', map_location=device)
model.load_state_dict(ckpt['model_state'])
model.eval()
print(f'[INF] Model[{meta["model_name"]}] loaded.')

Work Dir: /mnt/qmtang/EvoFill_data/20251107_ver4
[INF] Model[hg19_chr22trim] loaded.


### 3.2 Encode .vcf file need be impute

In [6]:
gt_enc_imp = GenotypeEncoder(phased=False, gts012=False, save2disk=False)
gt_enc_imp = gt_enc_imp.encode_ref(
        ref_meta_json = work_dir/"pre_train"/"gt_enc_meta.json",   # 与 Stage1 同构
        vcf_path      = work_dir/"impute_in"/"minor_pops.90pct.masked50p.vcf.gz" )

print(f'[INFER] {gt_enc_imp.n_samples} samples, {gt_enc_imp.n_variants} variants')

# ---- 2. 构建推理 Dataset / Loader ----
imp_dataset = ImputationDataset(
    x_gts_sparse=gt_enc_imp.X_gt,
    seq_depth=gt_enc_imp.seq_depth,
    indices=None                 # 可传入指定样本索引
)
imp_dataset.print_missing_stat()          # 查看原始缺失比例

def collate_fn(batch):
    x_onehot = torch.stack([item[0] for item in batch])
    real_idx_list = [item[1] for item in batch]
    return x_onehot, real_idx_list   # 无 y

imp_loader = torch.utils.data.DataLoader(
    imp_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

[DATA] 总计 99,314 个位点  
[DATA] 位点矩阵 = (152, 99314)，稀疏度 = 63.43%，缺失率 = 50.01%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': 3}，字典深度 = 4
[INFER] 152 samples, 99314 variants
[ImputationDataset] 152 samples, missing rate = 50.01%


### 3.3 Inferring

In [7]:
y_prob = []
y_mask = []
with torch.no_grad():
    for x_onehot, real_idx in tqdm(imp_loader, desc='Imputing'):
        x_onehot = x_onehot.to(device)
        _, prob, _ = model(x_onehot)
        miss_mask = x_onehot[..., -1].bool()
        y_prob.append(prob)
        y_mask.append(miss_mask)
y_prob = torch.cat(y_prob, dim=0).cpu().numpy()
y_mask = torch.cat(y_mask, dim=0).cpu().numpy()
# 4. 保存
out_dir = os.path.join(work_dir, 'impute_out')
os.makedirs(out_dir, exist_ok=True)
np.save(os.path.join(out_dir, 'impute_prob.npy'), y_prob)
np.save(os.path.join(out_dir, 'impute_mask.npy'), y_mask)
print(f'[INF] 概率矩阵已保存 → {out_dir}/impute_prob.npy '
      f'with shape = {y_prob.shape} ')


Imputing: 100%|██████████| 3/3 [00:20<00:00,  6.79s/it]


[INF] 概率矩阵已保存 → /mnt/qmtang/EvoFill_data/20251107_ver4/impute_out/impute_prob.npy with shape = (152, 99314, 3) 


### 3.4 Evaluating the imputation results

In [8]:
gt_enc_true = GenotypeEncoder(phased=False, gts012=False, save2disk=False)
gt_enc_true = gt_enc_true.encode_ref(
        ref_meta_json = work_dir/"pre_train"/"gt_enc_meta.json",   # 与 Stage1 同构
        vcf_path      = work_dir/"impute_out"/"minor_pops.90pct.vcf.gz" )
y_true = gt_enc_true.X_gt.toarray()
maf, bin_cnt = precompute_maf(y_true,  mask_int=gt_enc_true.seq_depth)
y_true_oh = np.eye(gt_enc_true.seq_depth - 1)[y_true]
bins_metrics   = metrics_by_maf(y_prob, y_true_oh, hap_map = gt_enc_true.hap_map, maf_vec = maf, mask=y_mask)
print_maf_stat_df(bin_cnt,{'val': bins_metrics})

[DATA] 总计 99,314 个位点  
[DATA] 位点矩阵 = (152, 99314)，稀疏度 = 26.84%，缺失率 = 0.00%
[DATA] 位点字典 = {'0|0': 0, '0|1': 1, '1|1': 2, '.|.': 3}，字典深度 = 4
     MAF_bin Counts val_Acc val_INFO val_IQS val_MaCH
(0.00, 0.05)  41365   0.976    0.379   0.082    0.939
(0.05, 0.10)   9108   0.930    0.506   0.384    0.770
(0.10, 0.20)  11962   0.898    0.614   0.544    0.836
(0.20, 0.30)  12887   0.857    0.725   0.668    0.919
(0.30, 0.40)  17081   0.844    0.774   0.720    0.962
(0.40, 0.50)   6909   0.840    0.784   0.722    0.960


### 3.5 Saving to .vcf

In [38]:
from cyvcf2 import VCF, Writer

# 0. 路径
ref_vcf = "/mnt/qmtang/EvoFill_data/20251107_ver4/minor_pops.masked30p.vcf.gz"
out_vcf = os.path.join(out_dir, 'imputed.vcf.gz')

n_site = gt_enc.n_variants
n_samp = gt_enc.n_samples
n_alleles = gt_enc.seq_depth - 1
assert y_prob.shape == (n_samp, n_site, n_alleles)

# 2. 反向映射  idx -> '0|0' / '0|1' / ...
rev_hap_map = {v: k for k, v in gt_enc.hap_map.items()}

samp2idx = {sid: i for i, sid in enumerate(gt_enc.sample_ids)}

# 4. 打开参考 VCF
invcf = VCF(ref_vcf)
tmpl  = invcf
tmpl.set_samples(gt_enc.sample_ids)   # 替换样本列

out = Writer(out_vcf, tmpl, mode='wz')

for rec_idx, rec in enumerate(invcf):
    # 当前位点全部样本的 GT
    gt_int_pairs = []
    for samp_idx, sample_id in enumerate(gt_enc.sample_ids):
        old_gt = rec.genotypes[samp_idx]          # [allele1, allele2, phased]
        if old_gt[0] == -1 or old_gt[1] == -1:    # 缺失
            prob_vec = y_prob[samp_idx, rec_idx, :].ravel()
            best_idx = int(prob_vec.argmax())
            gt_str   = rev_hap_map[best_idx]
            alleles  = list(map(int, gt_str.split('|')))
            phased   = old_gt[2] if old_gt[2] != -1 else 1
            gt_int_pairs.append([alleles[0], alleles[1], phased])
        else:                                       # 非缺失，保持原样
            gt_int_pairs.append(old_gt)

    # 转成 int8 二维数组  (n_sample, 3)  last dim = [a1,a2,phased]
    gt_array = np.array(gt_int_pairs, dtype=np.int8)
    rec.set_format('GT', gt_array)
    out.write_record(rec)

invcf.close()
out.close()

# 5. tabix
os.system(f'tabix -fp vcf {out_vcf}')
print(f'[INF] 缺失位点填充完成 → {out_vcf}')

[INF] 缺失位点填充完成 → /mnt/qmtang/EvoFill_data/20251107_ver4/impute_out/imputed.vcf.gz
