更新：

1. 交叉注意力改mamba block
2. 单数据chunk，单模型chunk

## Dependency

In [2]:
import os; os.environ["CUDA_VISIBLE_DEVICES"] = "1" # 设置用GPU1
import gzip
import json
import logging
import shutil
from typing import Union
from argparse import Namespace

import numpy as np
import pandas as pd
import datatable as dt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
from torch.utils.data import DataLoader, Dataset

from tqdm import tqdm
from joblib import Parallel, delayed
from sklearn.model_selection import train_test_split

from mamba_ssm import Mamba2
from mamba_ssm.modules.mamba2_simple import Mamba2Simple as Mamba2Block # 原Mamba2Block
from torch_optimizer import Lamb

## Data

In [3]:
SUPPORTED_FILE_FORMATS = {"vcf", "csv", "tsv"}
class DataReader:
    def __init__(self):
        self.target_is_gonna_be_phased = None
        self.target_set = None
        self.target_sample_value_index = 2
        self.ref_sample_value_index = 2
        self.target_file_extension = None
        self.allele_count = 2
        self.genotype_vals = None
        self.ref_is_phased = None
        self.reference_panel = None
        self.VARIANT_COUNT = 0
        self.is_phased = False
        self.MISSING_VALUE = None
        self.ref_is_hap = False
        self.target_is_hap = False
        self.ref_n_header_lines = []
        self.ref_n_data_header = ""
        self.target_n_header_lines = []
        self.target_n_data_header = ""
        self.ref_separator = None
        self.map_values_1_vec = np.vectorize(self.__map_hap_2_ind_parent_1)
        self.map_values_2_vec = np.vectorize(self.__map_hap_2_ind_parent_2)
        self.map_haps_to_vec = np.vectorize(self.__map_haps_2_ind)
        self.delimiter_dictionary = {"vcf": "\t", "csv": ",", "tsv": "\t", "infer": "\t"}
        self.ref_file_extension = "vcf"
        self.test_file_extension = "vcf"
        self.target_is_phased = True

    def __read_csv(self, file_path, is_vcf=False, is_reference=False, separator="\t", first_column_is_index=True,
                   comments="##") -> pd.DataFrame:
        """Read CSV/VCF files"""
        print("Reading the file...")
        data_header = None
        path_sep = "/" if "/" in file_path else os.path.sep
        line_counter = 0
        root, ext = os.path.splitext(file_path)
        with gzip.open(file_path, 'rt') if ext == '.gz' else open(file_path, 'rt') as f_in:
            while True:
                line = f_in.readline()
                if line.startswith(comments):
                    line_counter += 1
                    if is_reference:
                        self.ref_n_header_lines.append(line)
                    else:
                        self.target_n_header_lines.append(line)
                else:
                    data_header = line
                    break
        if data_header is None:
            raise IOError("The file only contains comments!")
        df = dt.fread(file=file_path, sep=separator, header=True, skip_to_line=line_counter + 1)
        df = df.to_pandas()
        if first_column_is_index:
            df.set_index(df.columns[0], inplace=True)
        return df

    def __find_file_extension(self, file_path, file_format, delimiter):
        separator = "\t"
        found_file_format = None
        if file_format not in ["infer"] + list(SUPPORTED_FILE_FORMATS):
            raise ValueError("File extension must be one of {'vcf', 'csv', 'tsv', 'infer'}.")
        if file_format == 'infer':
            file_name_tokenized = file_path.split(".")
            for possible_extension in file_name_tokenized[::-1]:
                if possible_extension in SUPPORTED_FILE_FORMATS:
                    found_file_format = possible_extension
                    separator = self.delimiter_dictionary[possible_extension] if delimiter is None else delimiter
                    break
            if found_file_format is None:
                logging.warning("Could not infer the file type. Using tsv as the last resort.")
                found_file_format = "tsv"
        else:
            found_file_format = file_format
            separator = self.delimiter_dictionary[file_format] if delimiter is None else delimiter
        return found_file_format, separator

    def assign_training_set(self, file_path: str, target_is_gonna_be_phased_or_haps: bool,
                            variants_as_columns: bool = False, delimiter=None, file_format="infer",
                            first_column_is_index=True, comments="##") -> None:
        self.target_is_gonna_be_phased = target_is_gonna_be_phased_or_haps
        self.ref_file_extension, self.ref_separator = self.__find_file_extension(file_path, file_format, delimiter)

        self.reference_panel = self.__read_csv(file_path, is_reference=True, is_vcf=False, separator=self.ref_separator,
                                               first_column_is_index=first_column_is_index,
                                               comments=comments) if self.ref_file_extension != 'vcf' else self.__read_csv(
            file_path, is_reference=True, is_vcf=True, separator='\t', first_column_is_index=False, comments="##")

        if self.ref_file_extension != "vcf":
            if variants_as_columns:
                self.reference_panel = self.reference_panel.transpose()
            self.reference_panel.reset_index(drop=False, inplace=True)
            self.reference_panel.rename(columns={self.reference_panel.columns[0]: "ID"}, inplace=True)
        else:
            self.ref_sample_value_index += 8

        self.ref_is_hap = not ("|" in self.reference_panel.iloc[0, self.ref_sample_value_index - 1] or "/" in
                               self.reference_panel.iloc[0, self.ref_sample_value_index - 1])
        self.ref_is_phased = "|" in self.reference_panel.iloc[0, self.ref_sample_value_index - 1]

        if self.ref_is_hap and not target_is_gonna_be_phased_or_haps:
            raise ValueError(
                "Reference contains haploids while target will be unphased diploids. Model cannot predict target.")

        if not (self.ref_is_phased or self.ref_is_hap) and target_is_gonna_be_phased_or_haps:
            raise ValueError(
                "Reference contains unphased diploids while target will be phased/haploid. Model cannot predict target.")

        self.VARIANT_COUNT = self.reference_panel.shape[0]
        print(
            f"{self.reference_panel.shape[1] - (self.ref_sample_value_index - 1)} {'haploid' if self.ref_is_hap else 'diploid'} samples with {self.VARIANT_COUNT} variants found!")

        self.is_phased = target_is_gonna_be_phased_or_haps and (self.ref_is_phased or self.ref_is_hap)

        original_allele_sep = "|" if self.ref_is_phased or self.ref_is_hap else "/"
        final_allele_sep = "|" if self.is_phased else "/"

        def get_diploid_alleles(genotype_vals):
            allele_set = set()
            for genotype_val in genotype_vals:
                if genotype_val not in [".", ".|.", "./."]:
                    if final_allele_sep in genotype_val:
                        v1, v2 = genotype_val.split(final_allele_sep)
                        allele_set.update([v1, v2])
                    else:
                        allele_set.add(genotype_val)  # For haploids
            return np.array(list(allele_set))

        genotype_vals = pd.unique(self.reference_panel.iloc[:, self.ref_sample_value_index - 1:].values.ravel('K'))
        print(f"DEBUG: Unique genotypes in dataset: {genotype_vals[:10]}...")  # Show first 10

        if self.ref_is_phased and not target_is_gonna_be_phased_or_haps:
            phased_to_unphased_dict = {}
            for i in range(genotype_vals.shape[0]):
                key = genotype_vals[i]
                if "|" in key and key not in [".", ".|."]:
                    v1, v2 = [int(s) for s in genotype_vals[i].split(original_allele_sep)]
                    genotype_vals[i] = f"{min(v1, v2)}/{max(v1, v2)}"
                    phased_to_unphased_dict[key] = genotype_vals[i]
            if phased_to_unphased_dict:
                self.reference_panel.iloc[:, self.ref_sample_value_index - 1:].replace(phased_to_unphased_dict,
                                                                                       inplace=True)

        self.genotype_vals = np.unique(genotype_vals)
        self.alleles = get_diploid_alleles(self.genotype_vals) if not self.ref_is_hap else self.genotype_vals
        self.allele_count = len(self.alleles)
        self.MISSING_VALUE = self.allele_count if self.is_phased else len(self.genotype_vals)

        print(f"DEBUG: self.genotype_vals: {self.genotype_vals}")
        print(f"DEBUG: self.alleles: {self.alleles}")
        print(f"DEBUG: is_phased: {self.is_phased}")

        if self.is_phased:
            self.hap_map = {str(v): i for i, v in enumerate(list(sorted(self.alleles)))}
            self.hap_map.update({".": self.MISSING_VALUE})
            self.r_hap_map = {i: k for k, i in self.hap_map.items()}
            self.map_preds_2_allele = np.vectorize(lambda x: self.r_hap_map[x])
            print(f"DEBUG: hap_map: {self.hap_map}")
        else:
            unphased_missing_genotype = "./."
            self.replacement_dict = {g: i for i, g in enumerate(list(sorted(self.genotype_vals)))}
            self.replacement_dict[unphased_missing_genotype] = self.MISSING_VALUE
            self.reverse_replacement_dict = {v: k for k, v in self.replacement_dict.items()}
            print(f"DEBUG: replacement_dict: {self.replacement_dict}")

        self.SEQ_DEPTH = self.allele_count + 1 if self.is_phased else len(self.genotype_vals) + 1
        print(f"DEBUG: self.SEQ_DEPTH: {self.SEQ_DEPTH}")

    def assign_test_set(self, file_path, variants_as_columns=False, delimiter=None,
                        file_format="infer", first_column_is_index=True, comments="##") -> None:
        """Assign test set for imputation"""
        if self.reference_panel is None:
            raise RuntimeError("First you need to use 'DataReader.assign_training_set(...) to assign a training set.'")

        self.target_file_extension, separator = self.__find_file_extension(file_path, file_format, delimiter)

        test_df = self.__read_csv(file_path, is_reference=False, is_vcf=False, separator=separator,
                                  first_column_is_index=first_column_is_index,
                                  comments=comments) if self.target_file_extension != 'vcf' else self.__read_csv(
            file_path, is_reference=False, is_vcf=True, separator='\t', first_column_is_index=False, comments="##")

        if self.target_file_extension != "vcf":
            if variants_as_columns:
                test_df = test_df.transpose()
            test_df.reset_index(drop=False, inplace=True)
            test_df.rename(columns={test_df.columns[0]: "ID"}, inplace=True)
        else:
            self.target_sample_value_index += 8

        self.target_is_hap = not ("|" in test_df.iloc[0, self.target_sample_value_index - 1] or "/" in
                                  test_df.iloc[0, self.target_sample_value_index - 1])
        is_phased = "|" in test_df.iloc[0, self.target_sample_value_index - 1]
        test_var_count = test_df.shape[0]
        print(f"{test_var_count} {'haplotype' if self.target_is_hap else 'diplotype'} variants found!")

        # Validate compatibility
        if (self.target_is_hap or is_phased) and not (self.ref_is_phased or self.ref_is_hap):
            raise RuntimeError("The training set contains unphased data. The target must be unphased as well.")
        if self.ref_is_hap and not (self.target_is_hap or is_phased):
            raise RuntimeError("The training set contains haploids. Target set should be phased or haploids.")

        # Merge with reference panel to align variants
        self.target_set = test_df.merge(right=self.reference_panel[["ID"]], on='ID', how='right')
        if self.target_file_extension == "vcf" == self.ref_file_extension:
            self.target_set[self.reference_panel.columns[:9]] = self.reference_panel[self.reference_panel.columns[:9]]

        self.target_set = self.target_set.astype('str')
        missing_value = "." if self.target_is_hap else ".|." if self.is_phased else "./."
        self.target_set.fillna(missing_value, inplace=True)
        self.target_set.replace("nan", missing_value, inplace=True)
        print("Target set assignment done!")

    def __map_hap_2_ind_parent_1(self, x) -> int:
        return self.hap_map[x.split('|')[0]]

    def __map_hap_2_ind_parent_2(self, x) -> int:
        return self.hap_map[x.split('|')[1]]

    def __map_haps_2_ind(self, x) -> int:
        return self.hap_map[x]

    def get_ref_set(self, starting_var_index=0, ending_var_index=0) -> np.ndarray:
        if 0 <= starting_var_index < ending_var_index:
            data = self.reference_panel.iloc[starting_var_index:ending_var_index, self.ref_sample_value_index - 1:]
        else:
            data = self.reference_panel.iloc[:, self.ref_sample_value_index - 1:]

        if self.is_phased:
            is_haps = "|" not in data.iloc[0, 0]
            if not is_haps:
                # diploids to hap vecs
                _x = np.empty((data.shape[1] * 2, data.shape[0]), dtype=np.int32)
                _x[0::2] = self.map_values_1_vec(data.values.T)
                _x[1::2] = self.map_values_2_vec(data.values.T)
                return _x
            else:
                return self.map_haps_to_vec(data.values.T)
        else:
            return data.replace(self.replacement_dict).values.T.astype(np.int32)

    def get_target_set(self, starting_var_index=0, ending_var_index=0) -> np.ndarray:
        """Get target data for imputation"""
        if 0 <= starting_var_index < ending_var_index:
            data = self.target_set.iloc[starting_var_index:ending_var_index, self.target_sample_value_index - 1:]
        else:
            data = self.target_set.iloc[:, self.target_sample_value_index - 1:]

        if self.is_phased:
            is_haps = "|" not in data.iloc[0, 0]
            if not is_haps:
                # diploids to hap vecs
                _x = np.empty((data.shape[1] * 2, data.shape[0]), dtype=np.int32)
                _x[0::2] = self.map_values_1_vec(data.values.T)
                _x[1::2] = self.map_values_2_vec(data.values.T)
                return _x
            else:
                return self.map_haps_to_vec(data.values.T)
        else:
            return data.replace(self.replacement_dict).values.T.astype(np.int32)

    def __convert_unphased_probs_to_genotypes(self, allele_probs) -> np.ndarray:
        """Convert unphased probabilities to genotypes"""
        n_samples, n_variants, n_alleles = allele_probs.shape
        genotypes = np.zeros((n_samples, n_variants), dtype=object)

        for i in tqdm(range(n_samples)):
            for j in range(n_variants):
                unphased_probs = allele_probs[i, j]
                variant_genotypes = np.vectorize(self.reverse_replacement_dict.get)(
                    np.argmax(unphased_probs, axis=-1)).flatten()
                genotypes[i, j] = variant_genotypes
        return genotypes

    def __convert_hap_probs_to_diploid_genotypes(self, allele_probs) -> np.ndarray:
        """Convert haplotype probabilities to diploid genotypes"""
        n_haploids, n_variants, n_alleles = allele_probs.shape

        if n_haploids % 2 != 0:
            raise ValueError("Number of haploids should be even.")

        n_samples = n_haploids // 2
        genotypes = np.empty((n_samples, n_variants), dtype=object)
        haploids_as_diploids = allele_probs.reshape((n_samples, 2, n_variants, -1))
        variant_genotypes = self.map_preds_2_allele(np.argmax(haploids_as_diploids, axis=-1))

        def process_variant_in_sample(haps_for_sample_at_variant, variant_genotypes_for_sample_at_variant):
            if n_alleles > 2:
                return '|'.join(variant_genotypes_for_sample_at_variant)
            else:
                # Output GP (genotype probabilities)
                phased_probs = np.outer(haps_for_sample_at_variant[0], haps_for_sample_at_variant[1]).flatten()
                unphased_probs = np.array([phased_probs[0], phased_probs[1] + phased_probs[2], phased_probs[-1]])
                unphased_probs_str = ",".join([f"{v:.6f}" for v in unphased_probs])
                alt_dosage = np.dot(unphased_probs, [0, 1, 2])
                return '|'.join(variant_genotypes_for_sample_at_variant) + f":{unphased_probs_str}:{alt_dosage:.3f}"

        def process_sample(i):
            return np.array([
                process_variant_in_sample(haploids_as_diploids[i, :, j, :], variant_genotypes[i, :, j])
                for j in range(n_variants)
            ])

        # Parallel processing
        genotypes = Parallel(n_jobs=-1)(delayed(process_sample)(i) for i in tqdm(range(n_samples)))
        return np.array(genotypes)

    def __convert_hap_probs_to_hap_genotypes(self, allele_probs) -> np.ndarray:
        """Convert hap probabilities to hap genotypes"""
        return np.argmax(allele_probs, axis=1).astype(str)

    def __get_headers_for_output(self, contain_probs, chr=22):
        """Get VCF headers for output file"""
        headers = [
            "##fileformat=VCFv4.2",
            '''##source=BiMamba v1.0.0''',
            '''##INFO=<ID=AF,Number=A,Type=Float,Description="Estimated Alternate Allele Frequency">''',
            '''##INFO=<ID=MAF,Number=1,Type=Float,Description="Estimated Minor Allele Frequency">''',
            '''##INFO=<ID=AVG_CS,Number=1,Type=Float,Description="Average Call Score">''',
            '''##INFO=<ID=IMPUTED,Number=0,Type=Flag,Description="Marker was imputed">''',
            '''##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">''',
        ]
        probs_headers = [
            '''##FORMAT=<ID=DS,Number=A,Type=Float,Description="Estimated Alternate Allele Dosage : [P(0/1)+2*P(1/1)]">''',
            '''##FORMAT=<ID=GP,Number=G,Type=Float,Description="Estimated Posterior Probabilities for Genotypes 0/0, 0/1 and 1/1">'''
        ]
        if contain_probs:
            headers.extend(probs_headers)
        return headers

    def __convert_genotypes_to_vcf(self, genotypes, pred_format="GT:GP:DS"):
        """Convert genotypes to VCF format"""
        new_vcf = self.target_set.copy()
        new_vcf[new_vcf.columns[self.target_sample_value_index - 1:]] = genotypes
        new_vcf["FORMAT"] = pred_format
        new_vcf["QUAL"] = "."
        new_vcf["FILTER"] = "."
        new_vcf["INFO"] = "IMPUTED"
        return new_vcf

    def preds_to_genotypes(self, predictions: Union[str, np.ndarray]) -> pd.DataFrame:
        """Convert predictions to genotypes"""
        if isinstance(predictions, str):
            preds = np.load(predictions)
        else:
            preds = predictions

        target_df = self.target_set.copy()
        if not self.is_phased:
            target_df[
                target_df.columns[self.target_sample_value_index - 1:]] = self.__convert_unphased_probs_to_genotypes(
                preds).T
        elif self.target_is_hap:
            target_df[
                target_df.columns[self.target_sample_value_index - 1:]] = self.__convert_hap_probs_to_hap_genotypes(
                preds).T
        else:
            pred_format = "GT:GP:DS" if preds.shape[-1] == 2 else "GT"
            target_df = self.__convert_genotypes_to_vcf(self.__convert_hap_probs_to_diploid_genotypes(preds).T,
                                                        pred_format)
        return target_df

    def write_ligated_results_to_file(self, df: pd.DataFrame, file_name: str, compress=True) -> str:
        """Write results to file"""
        to_write_format = self.ref_file_extension
        file_path = f"{file_name}.{to_write_format}.gz" if compress else f"{file_name}.{to_write_format}"

        with gzip.open(file_path, 'wt') if compress else open(file_path, 'wt') as f_out:
            # Write headers
            if self.ref_file_extension == "vcf":
                f_out.write(
                    "\n".join(self.__get_headers_for_output(contain_probs="GP" in df["FORMAT"].values[0])) + "\n")
            else:
                f_out.write("\n".join(self.ref_n_header_lines))

        # Append data
        df.to_csv(file_path, sep=self.ref_separator, mode='a', index=False)
        return file_path

In [4]:
class GenomicDataset(Dataset):
    """Dataset class for genomic data with masking for training"""

    def __init__(self, data, targets, seq_depth,
                 training=True, masking_rates=(0.5, 0.99)):
        self.data = data
        self.targets = targets
        self.seq_depth = seq_depth
        self.training = training
        self.masking_rates = masking_rates

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx].copy()
        y = self.targets[idx]

        if self.training:
            # Apply masking
            seq_len = len(x)
            masking_rate = np.random.uniform(*self.masking_rates)
            mask_size = int(seq_len * masking_rate)
            mask_indices = np.random.choice(seq_len, mask_size, replace=False)
            x[mask_indices] = self.seq_depth - 1  # Missing value token

        # Convert to one-hot
        x_onehot = np.eye(self.seq_depth)[x]
        y_onehot = np.eye(self.seq_depth - 1)[y]

        return torch.FloatTensor(x_onehot), torch.FloatTensor(y_onehot)

class ImputationDataset(Dataset):
    """Dataset for imputation (no masking needed)"""

    def __init__(self, data, seq_depth):
        self.data = data
        self.seq_depth = seq_depth

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        # Convert to one-hot without masking
        x_onehot = np.eye(self.seq_depth)[x]
        return torch.FloatTensor(x_onehot)

## Model

In [None]:
class BiMambaBlock(nn.Module):
    """Bidirectional Mamba block for genomic sequence processing"""

    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.d_model = d_model

        # Forward and backward Mamba blocks
        self.mamba_forward = Mamba2(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )

        self.mamba_backward = Mamba2(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )

        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model * 2, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.GELU()
        )

        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        # x shape: (batch, seq_len, d_model)
        residual = x

        # Bidirectional processing
        x_norm = self.norm1(x)

        # Forward direction
        forward_out = self.mamba_forward(x_norm)

        # Backward direction (flip sequence)
        x_backward = torch.flip(x_norm, dims=[1])
        backward_out = self.mamba_backward(x_backward)
        backward_out = torch.flip(backward_out, dims=[1])

        # Concatenate bidirectional outputs
        bi_out = torch.cat([forward_out, backward_out], dim=-1)

        # FFN
        ffn_out = self.ffn(bi_out)
        ffn_out = self.dropout(ffn_out)

        # Residual connection
        out = self.norm2(residual + ffn_out)

        return out

class ConvBlock(nn.Module):
    """Convolutional block for local pattern extraction"""

    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

        self.conv1 = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=5, padding=2)
        self.conv3 = nn.Conv1d(d_model, d_model, kernel_size=7, padding=3)

        self.conv_large1 = nn.Conv1d(d_model, d_model, kernel_size=7, padding=3)
        self.conv_large2 = nn.Conv1d(d_model, d_model, kernel_size=15, padding=7)

        self.conv_final = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.conv_reduce = nn.Conv1d(d_model, d_model, kernel_size=1)

        self.bn1 = nn.BatchNorm1d(d_model)
        self.bn2 = nn.BatchNorm1d(d_model)

        self.gelu = nn.GELU()

    def forward(self, x):
        # x shape: (batch, seq_len, d_model)
        x = x.transpose(1, 2)  # (batch, d_model, seq_len)

        xa = self.gelu(self.conv1(x))

        xb = self.gelu(self.conv2(xa))
        xb = self.gelu(self.conv3(xb))

        xc = self.gelu(self.conv_large1(xa))
        xc = self.gelu(self.conv_large2(xc))

        xa = xb + xc
        xa = self.gelu(self.conv_final(xa))
        xa = self.bn1(xa)
        xa = self.gelu(self.conv_reduce(xa))
        xa = self.bn2(xa)
        xa = self.gelu(xa)

        return xa.transpose(1, 2)  # (batch, seq_len, d_model)

class CrossAttentionLayer(nn.Module):
    """Cross attention for integrating local and global features"""
    def __init__(self, d_model, n_heads=8):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.cross_attention = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=n_heads,
            batch_first=True
        )

        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Linear(d_model // 2, d_model),
            nn.GELU()
        )

    def forward(self, local_repr, global_repr, start_offset=0, end_offset=0):
        local_norm = self.norm1(local_repr)
        global_norm = self.norm2(global_repr)

        # Apply offsets
        if start_offset > 0 or end_offset > 0:
            query = local_norm[:, start_offset:local_norm.shape[1] - end_offset]
        else:
            query = local_norm

        key = value = global_norm

        # Cross attention
        attn_output, _ = self.cross_attention(query, key, value)

        # Pad attn_output back to original length if offsets were applied
        if start_offset > 0 or end_offset > 0:
            pad_left = start_offset
            pad_right = end_offset
            attn_output = torch.nn.functional.pad(attn_output, (0, 0, pad_left, pad_right), mode='constant', value=0)

        # Skip connection
        attn_output = attn_output + local_norm  # Changed from +query to +local_norm

        # FFN
        attn_output = self.norm3(attn_output)
        ffn_output = self.ffn(attn_output)
        output = ffn_output + attn_output

        return output

class Mamba2CrossBlock(nn.Module):
    """
    用 Mamba2Simple 替代 MultiheadAttention 的交叉块。
    接口与原来 CrossAttentionLayer 保持一致，可直接替换。
    """
    def __init__(
        self,
        d_model,
        n_heads=8,            # 仅保留接口，无实际意义
        d_state=64,
        d_conv=4,
        expand=2,
        headdim=128,
        ngroups=1,
        chunk_size=256,
        dropout=0.0,
        device=None,
        dtype=None,
    ):
        super().__init__()
        self.d_model = d_model

        # 1. 归一化
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        # 2. Mamba2Simple 作为交叉建模核心
        self.ssd = Mamba2Block(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
            headdim=headdim,
            ngroups=ngroups,
            chunk_size=chunk_size,
            use_mem_eff_path=True,
            device=device,
            dtype=dtype,
        )

        # 3. FFN 保持不变
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Linear(d_model // 2, d_model),
        )

    def forward(self, local_repr, global_repr, start_offset=0, end_offset=0):
        """
        local_repr: [B, L, D]
        global_repr: [B, G, D]
        输出:       [B, L, D]  （长度与 local_repr 保持一致）
        """
        B, L, D = local_repr.shape

        # 1. 归一化
        local_norm  = self.norm1(local_repr)
        global_norm = self.norm2(global_repr)

        # 2. 如果用了 offset，先截断 query 长度
        if start_offset > 0 or end_offset > 0:
            query = local_norm[:, start_offset:L - end_offset]
        else:
            query = local_norm

        # 3. 拼接：global 在前，local 在后，让 SSD 扫描时 local 能看到 global
        x = torch.cat([global_norm, query], dim=1)   # [B, G + L', D]
        x = self.ssd(x)                              # [B, G + L', D]
        x = x[:, global_norm.shape[1]:, :]           # 只取 local 对应部分 [B, L', D]

        # 4. pad 回原始长度（如果之前截断过）
        if start_offset > 0 or end_offset > 0:
            x = F.pad(x, (0, 0, start_offset, end_offset))  # [B, L, D]

        # 5. 残差连接
        x = x + local_norm

        # 6. FFN
        x = self.norm3(x)
        x = self.ffn(x) + x

        return x

class GenoEmbedding(nn.Module):
    """Genomic embedding layer with positional encoding"""

    def __init__(self, n_alleles, n_snps, d_model):
        super().__init__()
        self.d_model = d_model
        self.n_alleles = n_alleles
        self.n_snps = n_snps

        # Allele embedding
        self.allele_embedding = nn.Parameter(torch.randn(n_alleles, d_model))

        # Positional embedding
        self.position_embedding = nn.Embedding(n_snps, d_model)

        # Initialize parameters
        nn.init.xavier_uniform_(self.allele_embedding)

    def forward(self, x):
        # x shape: (batch, seq_len, n_alleles) - one-hot encoded
        batch_size, seq_len, _ = x.shape

        # Allele embedding
        embedded = torch.einsum('bsn,nd->bsd', x, self.allele_embedding)

        # Positional embedding
        positions = torch.arange(seq_len, device=x.device)
        pos_emb = self.position_embedding(positions).unsqueeze(0)

        return embedded + pos_emb

class ChunkModule(nn.Module):
    """Single chunk processing module with BiMamba"""

    def __init__(self, d_model, start_offset=0, end_offset=0, dropout_rate=0.1):
        super().__init__()
        self.d_model = d_model
        self.start_offset = start_offset
        self.end_offset = end_offset

        # BiMamba block
        self.bimamba_block = BiMambaBlock(d_model)

        # Convolutional blocks
        self.conv_block1 = ConvBlock(d_model)
        self.conv_block2 = ConvBlock(d_model)
        self.conv_block3 = ConvBlock(d_model)

        # Cross attention
        # self.cross_attention = CrossAttentionLayer(d_model, n_heads)
        self.cross_attention = Mamba2CrossBlock(
            d_model=d_model,
            d_state=64,
            d_conv=4,
            expand=2,
            headdim=128,
            ngroups=1,
            chunk_size=256,
        )

        # Additional layers
        self.dense = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout_rate)
        self.gelu = nn.GELU()

    def forward(self, x):
        # BiMamba processing
        xa0 = self.bimamba_block(x)

        # First conv block
        xa = self.conv_block1(xa0)
        xa_skip = self.conv_block2(xa)

        # Dense layer
        xa = self.gelu(self.dense(xa))
        xa = self.conv_block2(xa)

        # Cross attention
        xa = self.cross_attention(xa, xa0, self.start_offset, self.end_offset)
        xa = self.dropout(xa)

        # Final conv block
        xa = self.conv_block3(xa)

        # Concatenate with skip connection
        xa = torch.cat([xa_skip, xa], dim=-1)

        return xa

class EvoFill(nn.Module):
    """Main BiMamba model for genomic imputation"""

    def __init__(self,
                 d_model,
                 chunk_size=2048,
                 chunk_overlap=64,
                 dropout_rate=0.1):
        super().__init__()
        self.d_model = d_model
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.dropout_rate = dropout_rate

        # Will be set during build
        self.seq_len = None
        self.n_alleles = None
        self.embedding = None
        self.chunk_modules = nn.ModuleList()
        self.final_conv = None
        self.output_conv = None

    def build(self, seq_len, n_alleles):
        """Build the model with specific sequence length and allele count"""
        self.seq_len = seq_len
        self.n_alleles = n_alleles

        # Embedding layer
        self.embedding = GenoEmbedding(n_alleles, seq_len, self.d_model)

        # Calculate chunks
        chunk_starts = list(range(0, seq_len, self.chunk_size))
        chunk_ends = [min(cs + self.chunk_size, seq_len) for cs in chunk_starts]
        mask_starts = [max(0, cs - self.chunk_overlap) for cs in chunk_starts]
        mask_ends = [min(ce + self.chunk_overlap, seq_len) for ce in chunk_ends]

        # Create chunk modules
        for i, cs in enumerate(chunk_starts):
            start_offset = cs - mask_starts[i]
            end_offset = mask_ends[i] - chunk_ends[i]

            chunk_module = ChunkModule(
                d_model=self.d_model,
                start_offset=start_offset,
                end_offset=end_offset,
                dropout_rate=self.dropout_rate
            )
            self.chunk_modules.append(chunk_module)

        # Store chunk information
        self.chunk_starts = chunk_starts
        self.chunk_ends = chunk_ends
        self.mask_starts = mask_starts
        self.mask_ends = mask_ends

        # Final layers
        self.final_conv = nn.Conv1d(self.d_model * 2, self.d_model // 2,
                                    kernel_size=5, padding=2)
        self.output_conv = nn.Conv1d(self.d_model // 2, n_alleles - 1,
                                     kernel_size=5, padding=2)
        # self.output_proj = nn.Linear(self.d_model * 2, n_alleles - 1)

        self.gelu = nn.GELU()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # x shape: (batch, seq_len, n_alleles)
        if self.embedding is None:
            raise RuntimeError("Model not built. Call build() first.")

        # Embedding
        x_embedded = self.embedding(x)

        # Process chunks
        chunk_outputs = []
        for i, chunk_module in enumerate(self.chunk_modules):
            chunk_input = x_embedded[:, self.mask_starts[i]:self.mask_ends[i]]
            chunk_output = chunk_module(chunk_input)
            chunk_outputs.append(chunk_output)

        # Concatenate chunks along sequence dimension
        x_concat = torch.cat(chunk_outputs, dim=1)

        # # Final processing
        x_concat = x_concat.transpose(1, 2)  # (batch, features, seq_len)
        x_final = self.gelu(self.final_conv(x_concat))
        x_output = self.output_conv(x_final)
        x_output = x_output.transpose(1, 2)  # (batch, seq_len, n_alleles-1)
        # x_output = self.output_proj(x_concat) 

        # Apply offsets
        x_output = x_output[:, :self.seq_len]

        x_output = self.softmax(x_output)

        return x_output

In [7]:
n_alleles = 4  # 包含missing
model = EvoFill(
    d_model=256,
    chunk_size=5120,
    chunk_overlap=64, 
    dropout_rate=0.1,
).cuda()

B, L = 2, 5120
model.build(seq_len=L, n_alleles=n_alleles)
model = model.cuda()  

# 1. 生成输入
x = torch.randint(0, n_alleles, (B, L)).long().cuda()   # {0,1,2,3} 3=missing

# 2. -1 -> 3，并构造 one-hot（4 维）
x_map = x.clone()
x_onehot = torch.zeros(B, L, n_alleles, device='cuda')
x_onehot.scatter_(2, x_map.unsqueeze(-1), 1)

# 3. 前向
with torch.no_grad():
    probs = model(x_onehot)          # shape: (B,L,3)

# 4. 简单校验
assert torch.allclose(probs.sum(dim=-1), torch.ones(B, L, device='cuda'), atol=1e-5), \
    "概率未归一"
print("✅ 含缺失数据前向通过，输出形状:", probs.shape)

✅ 含缺失数据前向通过，输出形状: torch.Size([2, 5120, 3])


## Loss

In [8]:
class ImputationLoss(nn.Module):
    """Custom loss function for genomic imputation"""

    def __init__(self, use_r2=True, 
                 use_focal=False, #  all dummy 
                 group_size=None,
                 gamma=None,
                 alpha=None,
                 eps=None,
                 use_gradnorm=None,
                 gn_alpha=None,
                 gn_lr_w=None,):
        super().__init__()
        self.use_r2_loss = use_r2
        self.ce_loss = nn.CrossEntropyLoss(reduction='sum')
        self.kl_loss = nn.KLDivLoss(reduction='sum')

    def calculate_minimac_r2(self, pred_alt_allele_probs, gt_alt_af):
        """Calculate Minimac-style RÂ² metric"""
        mask = torch.logical_or(torch.eq(gt_alt_af, 0.0), torch.eq(gt_alt_af, 1.0))
        gt_alt_af = torch.where(mask, 0.5, gt_alt_af)
        denom = gt_alt_af * (1.0 - gt_alt_af)
        denom = torch.where(denom < 0.01, 0.01, denom)
        r2 = torch.mean(torch.square(pred_alt_allele_probs - gt_alt_af), dim=0) / denom
        r2 = torch.where(mask, torch.zeros_like(r2), r2)
        return r2

    def forward(self, y_pred, y_true):
        y_true = y_true.float()

        # Convert to proper format for losses
        y_true_ce = torch.argmax(y_true, dim=-1)  # For CrossEntropy
        y_pred_log = torch.log(y_pred + 1e-8)  # For KL divergence

        # Basic losses
        ce_loss = self.ce_loss(y_pred.view(-1, y_pred.size(-1)), y_true_ce.view(-1))
        kl_loss = self.kl_loss(y_pred_log.view(-1, y_pred.size(-1)),
                               y_true.view(-1, y_true.size(-1)))

        total_loss = ce_loss + kl_loss

        if self.use_r2_loss:
            batch_size = y_true.size(0)
            group_size = 4
            num_full_groups = batch_size // group_size

            if num_full_groups > 0:
                y_true_grouped = y_true[:num_full_groups * group_size].view(
                    num_full_groups, group_size, *y_true.shape[1:])
                y_pred_grouped = y_pred[:num_full_groups * group_size].view(
                    num_full_groups, group_size, *y_pred.shape[1:])

                r2_loss = 0.0
                for i in range(num_full_groups):
                    gt_alt_af = torch.count_nonzero(
                        torch.argmax(y_true_grouped[i], dim=-1), dim=0
                    ).float() / group_size

                    pred_alt_allele_probs = torch.sum(y_pred_grouped[i][:, :, 1:], dim=-1)
                    r2_loss += -torch.sum(self.calculate_minimac_r2(
                        pred_alt_allele_probs, gt_alt_af)) * group_size

                total_loss += r2_loss

        return total_loss, None

## Train

In [None]:
def remove_similar_rows(array):
    """Remove duplicate haploids from training set"""
    print("Finding duplicate haploids in training set.")
    unique_array = np.unique(array, axis=0)
    print(f"Removed {len(array) - len(unique_array)} rows. {len(unique_array)} training samples remaining.")
    return unique_array

def create_directories(save_dir, models_dir="models", outputs="out") -> None:
    """Create necessary directories"""
    for dd in [save_dir, f"{save_dir}/{models_dir}", f"{save_dir}/{outputs}"]:
        if not os.path.exists(dd):
            os.makedirs(dd)

def clear_dir(path) -> None:
    """Clear directory if it exists"""
    if os.path.exists(path):
        shutil.rmtree(path)

def load_chunk_info(save_dir, break_points):
    """Load chunk training status information"""
    chunk_info = {ww: False for ww in list(range(len(break_points) - 1))}
    if os.path.isfile(f"{save_dir}/models/chunks_info.json"):
        with open(f"{save_dir}/models/chunks_info.json", 'r') as f:
            loaded_chunks_info = json.load(f)
            if isinstance(loaded_chunks_info, dict) and len(loaded_chunks_info) == len(chunk_info):
                print("Resuming the training...")
                chunk_info = {int(k): v for k, v in loaded_chunks_info.items()}
    return chunk_info

def save_chunk_status(save_dir, chunk_info) -> None:
    """Save chunk training status information"""
    with open(f"{save_dir}/models/chunks_info.json", "w") as outfile:
        json.dump(chunk_info, outfile)

def create_model(args, seq_len, n_alleles):
    """Create BiMamba model"""
    model = EvoFill(
        d_model=args.embed_dim,
        chunk_size=args.cs,
        chunk_overlap=args.co,
        dropout_rate=0.1
    )

    # Build the model
    model.build(seq_len, n_alleles)
    return model

In [10]:
MAF_BINS = [(0.00, 0.05), (0.05, 0.10), (0.10, 0.20),
            (0.20, 0.30), (0.30, 0.40), (0.40, 0.50)]

def precompute_maf(gts_np, mask_int=-1):
    """
    gts_np: (N, L)  int64
    return:
        maf: (L,) float32
        bin_cnt: list[int] 长度 6，对应 6 个 bin 的位点数量
    """
    L = gts_np.shape[1]
    maf = np.zeros(L, dtype=np.float32)
    bin_cnt = [0] * 6

    for l in range(L):
        alleles = gts_np[:, l]
        alleles = alleles[alleles != mask_int]   # 去掉缺失
        if alleles.size == 0:
            maf[l] = 0.0
            continue

        uniq, cnt = np.unique(alleles, return_counts=True)
        total = cnt.sum()
        freq = cnt / total
        freq[::-1].sort()
        maf_val = freq[1] if len(freq) > 1 else 0.0
        maf[l] = maf_val

        # 统计 bin
        for i, (lo, hi) in enumerate(MAF_BINS):
            if lo <= maf_val < hi:
                bin_cnt[i] += 1
                break

    return maf, bin_cnt

def imputation_maf_accuracy_epoch(all_logits, all_gts, global_maf, mask=None):
    """
    all_logits: (N, L, C)
    all_gts:    (N, L, C) one-hot
    global_maf: (L,)
    mask:       (N, L) 或 None
    return:     list[float] 长度 6
    """
    # 1. 预测 vs 真实
    all_gts = all_gts.argmax(dim=-1)      # (N, L)
    preds   = all_logits.argmax(dim=-1)   # (N, L)

    # 2. 如果没有外部 mask，就默认全 1
    if mask is None:
        mask = torch.ones_like(all_gts, dtype=torch.bool)   # (N, L)
    correct = (preds == all_gts) & mask                   # (N, L)

    # 3. MAF 条件 -> (1, L) 再广播到 (N, L)
    maf = global_maf.unsqueeze(0)                         # (1, L)

    # 4. 分 bin 计算
    accs = []
    for lo, hi in MAF_BINS:
        maf_bin = mask & (maf >= lo) & (maf < hi)                # (1, L)
        n_cor = (correct & maf_bin).sum()
        n_tot = maf_bin.sum()
        accs.append(100*(n_cor / n_tot).item() if n_tot > 0 else 0.0)
    return accs

In [13]:
# ---------------- 以下即命令行参数对应的行内变量 ----------------
mode                 = 'train'
restart_training     = True          # 对应命令行 1
ref                  = '/home/qmtang/mnt_qmtang/EvoFill/data/sim_s2_250901/YRI_CEU_chr22.vcf.gz'
tihp                 = True          # 对应命令行 1
save_dir             = '/home/qmtang/mnt_qmtang/EvoFill/data/mam_YRI_mamba2'
co                   = 64
cs                   = 8192
max_mr               = 0.7
min_mr               = 0.3
epochs               = 100
embed_dim            = 64
lr                   = 0.001
weight_decay         = 1e-5
batch_size_per_gpu   = 8
use_r2               = True
use_focal            = True
earlystop_patience   = 9
verbose              = 1
# -------------------------------------------------------------

# 组装成 Namespace
args = Namespace(
    restart_training=restart_training,
    ref=ref,
    tihp=tihp,
    save_dir=save_dir,
    co=co,
    cs=cs,
    max_mr=max_mr,
    min_mr=min_mr,
    epochs=epochs,
    embed_dim=embed_dim,
    lr=lr,
    weight_decay=weight_decay,
    batch_size_per_gpu=batch_size_per_gpu,
    use_r2=use_r2,
    earlystop_patience=earlystop_patience,
    verbose=verbose,
)

assert args.max_mr > 0
assert args.min_mr > 0
assert args.max_mr >= args.min_mr

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create directories
create_directories(args.save_dir)

# Load data
dr = DataReader()
dr.assign_training_set(
    file_path=args.ref,
    target_is_gonna_be_phased_or_haps=args.tihp,
    variants_as_columns=getattr(args, 'ref_vac', False),
    delimiter=getattr(args, 'ref_sep', None),
    file_format=getattr(args, 'ref_file_format', 'infer'),
    first_column_is_index=getattr(args, 'ref_fcai', True),
    comments=getattr(args, 'ref_comment', '##')
)

# Split data for validation
n_samples = dr.get_ref_set(0, 1).shape[0]
val_n_samples = args.batch_size_per_gpu * getattr(args, 'val_n_batches', 8)
x_train_indices, x_valid_indices = train_test_split(
    range(n_samples),
    test_size=val_n_samples,
    random_state=getattr(args, 'random_seed', 2022),
    shuffle=True
)

args.sites_per_model = dr.VARIANT_COUNT

# Get data for this chunk
ref_set = dr.get_ref_set(0, dr.VARIANT_COUNT).astype(np.int32)
print(f"Data shape: {ref_set.shape}")

# Remove duplicates from training
ref_set_train = remove_similar_rows(ref_set[x_train_indices])
ref_set_val = ref_set[x_valid_indices]

# MAF bins counts
chunk_maf, chunk_bin_cnt = precompute_maf(
    ref_set_train, 
    mask_int=-1
)
chunk_maf = torch.from_numpy(chunk_maf).to(device)          # (L_chunk,)
if args.verbose:
    print('MAF-bin counts:', chunk_bin_cnt)

# Create targets (same as input for reconstruction)
target_train = ref_set_train
target_val = ref_set_val

# Create datasets
train_dataset = GenomicDataset(
    ref_set_train, target_train, dr.SEQ_DEPTH,
    training=True,
    masking_rates=(args.min_mr, args.max_mr)
)

val_dataset = GenomicDataset(
    ref_set_val, target_val, dr.SEQ_DEPTH,
    training=False,
    masking_rates=(args.min_mr, args.max_mr)
)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=args.batch_size_per_gpu,
                            shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size_per_gpu,
                        shuffle=False, num_workers=4, pin_memory=True)

# Create model
seq_len = ref_set.shape[1]
model = create_model(args, seq_len, dr.SEQ_DEPTH)
model.to(device)

# Loss and optimizer
criterion = ImputationLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-7
)



Using device: cuda
Reading the file...
2000 diploid samples with 56145 variants found!
DEBUG: Unique genotypes in dataset: ['0|0' '0|1' '1|0' '1|1' '2|0' '2|1' '2|2' '1|2' '0|2']...
DEBUG: self.genotype_vals: ['0|0' '0|1' '0|2' '1|0' '1|1' '1|2' '2|0' '2|1' '2|2']
DEBUG: self.alleles: ['2' '0' '1']
DEBUG: is_phased: True
DEBUG: hap_map: {'0': 0, '1': 1, '2': 2, '.': 3}
DEBUG: self.SEQ_DEPTH: 4
Data shape: (4000, 56145)
Finding duplicate haploids in training set.
Removed 307 rows. 3629 training samples remaining.
MAF-bin counts: [50330, 1640, 1518, 1008, 816, 833]


In [None]:
# Training loop
best_loss = float('inf')
patience = args.earlystop_patience
patience_counter = 0

for epoch in range(args.epochs):
    model.train()
    train_loss = 0.0
    train_logits, train_gts, train_mask = [], [], []

    train_pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{args.epochs}', leave=False)
    for batch_idx, (data, target) in enumerate(train_pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss, logs = criterion(output, target)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_pbar.set_postfix({'loss': loss.item()})

        # === 收集训练结果 ===
        mask = data[..., -1].bool()         # 只关心被 mask 的位点
        train_logits.append(output.detach())
        train_gts.append(target.detach())
        train_mask.append(mask)

    # 训练集 MAF-acc
    train_logits = torch.cat(train_logits, dim=0)
    train_gts    = torch.cat(train_gts,    dim=0)
    train_mask   = torch.cat(train_mask,   dim=0)
    # 只保留有效位点（去掉 offset  padding）
    train_mask   = train_mask  [:, : train_mask.shape[1]]
    train_maf_accs = imputation_maf_accuracy_epoch(train_logits, train_gts, chunk_maf, mask=train_mask)

    # ----------- 验证循环同理 ------------
    model.eval()
    val_loss = 0.0
    val_logits, val_gts = [], []
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss, logs = criterion(output, target)
            val_loss += loss.item()

            mask = data[..., -1].bool()
            val_logits.append(output)
            val_gts.append(target)

    val_logits = torch.cat(val_logits, dim=0)
    val_gts    = torch.cat(val_gts,    dim=0)
    val_maf_accs = imputation_maf_accuracy_epoch(
        val_logits, val_gts, chunk_maf,  mask=None,)

    avg_train_loss = train_loss / len(train_loader)
    avg_val_loss   = val_loss   / len(val_loader)

    if args.verbose >= 1:
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss   = val_loss   / len(val_loader)
        print(f'Epoch {epoch + 1}/{args.epochs}, '
              f'Train Loss: {avg_train_loss:.1f}, Val Loss: {avg_val_loss:.1f}')

        # 用 DataFrame 打印 MAF-bin 结果
        maf_df = pd.DataFrame({
            'MAF_bin': ['(0.00, 0.05)', '(0.05, 0.10)', '(0.10, 0.20)',
                        '(0.20, 0.30)', '(0.30, 0.40)', '(0.40, 0.50)'],
            'Counts':  [f"{c}" for c in chunk_bin_cnt],
            'Train':   [f"{acc:.2f}" for acc in train_maf_accs],
            'Val':     [f"{acc:.2f}" for acc in val_maf_accs]
        })
        print(maf_df.to_string(index=False))

    scheduler.step(avg_val_loss)

    # Early stopping
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        patience_counter = 0
        # Save best model
        import os
        os.makedirs(f'{args.save_dir}/models', exist_ok=True)
        torch.save(model.state_dict(), f'{args.save_dir}/models/best.pth')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            if args.verbose >= 1:
                print('Early stopping triggered')
            break

os.makedirs(f'{args.save_dir}/models', exist_ok=True)
torch.save(model.state_dict(), f'{args.save_dir}/models/final.pth')

                                                                            

Epoch 1/100, Train Loss: 271125.2892, Val Loss: 265617.1875
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.66 99.72
(0.05, 0.10)   1640 93.12 94.52
(0.10, 0.20)   1518 86.45 88.61
(0.20, 0.30)   1008 77.29 80.90
(0.30, 0.40)    816 69.08 75.72
(0.40, 0.50)    833 62.35 71.91


                                                                            

Epoch 2/100, Train Loss: 266508.6040, Val Loss: 261718.9590
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.66 99.73
(0.05, 0.10)   1640 93.44 94.94
(0.10, 0.20)   1518 87.21 88.97
(0.20, 0.30)   1008 78.69 81.61
(0.30, 0.40)    816 72.37 76.98
(0.40, 0.50)    833 67.82 74.06


                                                                            

Epoch 3/100, Train Loss: 264465.3053, Val Loss: 259382.3398
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.66 99.73
(0.05, 0.10)   1640 93.58 95.03
(0.10, 0.20)   1518 87.53 89.34
(0.20, 0.30)   1008 79.30 82.00
(0.30, 0.40)    816 73.48 78.21
(0.40, 0.50)    833 69.89 75.68


                                                                            

Epoch 4/100, Train Loss: 263145.1474, Val Loss: 257953.2578
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.67 99.74
(0.05, 0.10)   1640 93.77 95.04
(0.10, 0.20)   1518 87.81 89.64
(0.20, 0.30)   1008 79.77 82.25
(0.30, 0.40)    816 74.20 78.69
(0.40, 0.50)    833 70.90 76.77


                                                                            

Epoch 5/100, Train Loss: 262259.5344, Val Loss: 256868.1973
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.67 99.74
(0.05, 0.10)   1640 93.80 95.17
(0.10, 0.20)   1518 87.94 89.59
(0.20, 0.30)   1008 79.94 82.57
(0.30, 0.40)    816 74.57 78.82
(0.40, 0.50)    833 71.91 77.89


                                                                            

Epoch 6/100, Train Loss: 261264.5605, Val Loss: 255886.5293
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.67 99.74
(0.05, 0.10)   1640 93.88 95.17
(0.10, 0.20)   1518 88.08 89.80
(0.20, 0.30)   1008 80.14 82.96
(0.30, 0.40)    816 75.41 79.51
(0.40, 0.50)    833 73.27 78.57


                                                                            

Epoch 7/100, Train Loss: 259595.6375, Val Loss: 253079.4238
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.68 99.75
(0.05, 0.10)   1640 93.96 95.35
(0.10, 0.20)   1518 88.23 89.89
(0.20, 0.30)   1008 80.78 84.65
(0.30, 0.40)    816 76.48 81.45
(0.40, 0.50)    833 75.26 81.11


                                                                            

Epoch 8/100, Train Loss: 256936.0507, Val Loss: 250140.4629
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.68 99.75
(0.05, 0.10)   1640 94.10 95.63
(0.10, 0.20)   1518 88.80 91.13
(0.20, 0.30)   1008 82.55 86.27
(0.30, 0.40)    816 78.37 83.80
(0.40, 0.50)    833 77.44 82.97


                                                                            

Epoch 9/100, Train Loss: 254903.6298, Val Loss: 248511.8516
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.68 99.76
(0.05, 0.10)   1640 94.30 95.80
(0.10, 0.20)   1518 89.26 91.58
(0.20, 0.30)   1008 83.58 87.36
(0.30, 0.40)    816 79.83 84.89
(0.40, 0.50)    833 78.82 83.97


                                                                             

Epoch 10/100, Train Loss: 253669.5660, Val Loss: 247223.1406
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.69 99.76
(0.05, 0.10)   1640 94.45 95.96
(0.10, 0.20)   1518 89.64 91.83
(0.20, 0.30)   1008 84.24 87.61
(0.30, 0.40)    816 80.67 85.74
(0.40, 0.50)    833 79.51 84.68


                                                                             

Epoch 11/100, Train Loss: 252817.6837, Val Loss: 246733.2285
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.69 99.76
(0.05, 0.10)   1640 94.56 96.11
(0.10, 0.20)   1518 89.77 92.11
(0.20, 0.30)   1008 84.52 87.84
(0.30, 0.40)    816 81.07 85.96
(0.40, 0.50)    833 79.95 84.99


                                                                             

Epoch 12/100, Train Loss: 251971.9122, Val Loss: 245874.2012
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.69 99.77
(0.05, 0.10)   1640 94.68 96.21
(0.10, 0.20)   1518 90.03 92.25
(0.20, 0.30)   1008 84.89 88.35
(0.30, 0.40)    816 81.65 86.33
(0.40, 0.50)    833 80.49 85.44


                                                                             

Epoch 13/100, Train Loss: 251368.4075, Val Loss: 245261.7656
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.70 99.77
(0.05, 0.10)   1640 94.75 96.33
(0.10, 0.20)   1518 90.14 92.58
(0.20, 0.30)   1008 85.15 88.66
(0.30, 0.40)    816 82.03 86.74
(0.40, 0.50)    833 80.89 85.60


                                                                             

Epoch 14/100, Train Loss: 250820.5390, Val Loss: 244844.7852
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.70 99.78
(0.05, 0.10)   1640 94.85 96.35
(0.10, 0.20)   1518 90.37 92.64
(0.20, 0.30)   1008 85.41 88.68
(0.30, 0.40)    816 82.30 86.81
(0.40, 0.50)    833 81.12 85.80


                                                                             

Epoch 15/100, Train Loss: 250431.3028, Val Loss: 244527.2246
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.70 99.78
(0.05, 0.10)   1640 94.91 96.38
(0.10, 0.20)   1518 90.39 92.64
(0.20, 0.30)   1008 85.57 88.85
(0.30, 0.40)    816 82.56 87.05
(0.40, 0.50)    833 81.41 86.05


                                                                             

Epoch 16/100, Train Loss: 250038.2895, Val Loss: 244301.5781
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.70 99.77
(0.05, 0.10)   1640 94.93 96.38
(0.10, 0.20)   1518 90.45 92.65
(0.20, 0.30)   1008 85.66 88.94
(0.30, 0.40)    816 82.71 87.26
(0.40, 0.50)    833 81.57 86.20


                                                                             

Epoch 17/100, Train Loss: 249609.0867, Val Loss: 243965.3633
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.71 99.78
(0.05, 0.10)   1640 94.99 96.44
(0.10, 0.20)   1518 90.56 92.75
(0.20, 0.30)   1008 85.92 89.04
(0.30, 0.40)    816 83.09 87.25
(0.40, 0.50)    833 81.84 86.36


                                                                             

Epoch 18/100, Train Loss: 249305.2038, Val Loss: 243493.6270
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.71 99.79
(0.05, 0.10)   1640 95.04 96.53
(0.10, 0.20)   1518 90.65 92.88
(0.20, 0.30)   1008 86.03 89.36
(0.30, 0.40)    816 83.20 87.53
(0.40, 0.50)    833 81.98 86.50


                                                                             

Epoch 19/100, Train Loss: 249107.1645, Val Loss: 243334.1406
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.71 99.79
(0.05, 0.10)   1640 95.06 96.56
(0.10, 0.20)   1518 90.66 92.95
(0.20, 0.30)   1008 86.05 89.30
(0.30, 0.40)    816 83.29 87.89
(0.40, 0.50)    833 82.15 86.82


                                                                             

Epoch 20/100, Train Loss: 248779.7616, Val Loss: 242843.5039
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.71 99.79
(0.05, 0.10)   1640 95.10 96.55
(0.10, 0.20)   1518 90.79 93.03
(0.20, 0.30)   1008 86.24 89.48
(0.30, 0.40)    816 83.51 88.03
(0.40, 0.50)    833 82.35 87.01


                                                                             

Epoch 21/100, Train Loss: 248442.2996, Val Loss: 242729.6641
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.71 99.79
(0.05, 0.10)   1640 95.18 96.61
(0.10, 0.20)   1518 90.88 93.05
(0.20, 0.30)   1008 86.40 89.57
(0.30, 0.40)    816 83.77 87.99
(0.40, 0.50)    833 82.59 86.92


                                                                             

Epoch 22/100, Train Loss: 248138.7041, Val Loss: 242432.0078
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.71 99.79
(0.05, 0.10)   1640 95.16 96.66
(0.10, 0.20)   1518 90.93 93.05
(0.20, 0.30)   1008 86.46 89.56
(0.30, 0.40)    816 83.93 88.09
(0.40, 0.50)    833 82.82 87.30


                                                                             

Epoch 23/100, Train Loss: 247745.9491, Val Loss: 242055.6680
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.72 99.79
(0.05, 0.10)   1640 95.24 96.72
(0.10, 0.20)   1518 91.02 93.23
(0.20, 0.30)   1008 86.60 89.64
(0.30, 0.40)    816 84.09 88.47
(0.40, 0.50)    833 83.12 87.33


                                                                             

Epoch 24/100, Train Loss: 247296.9679, Val Loss: 241117.2930
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.72 99.79
(0.05, 0.10)   1640 95.28 96.73
(0.10, 0.20)   1518 91.08 93.35
(0.20, 0.30)   1008 86.85 90.38
(0.30, 0.40)    816 84.57 89.06
(0.40, 0.50)    833 83.63 88.39


                                                                             

Epoch 25/100, Train Loss: 245762.5667, Val Loss: 239707.7949
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.72 99.79
(0.05, 0.10)   1640 95.37 96.84
(0.10, 0.20)   1518 91.35 93.84
(0.20, 0.30)   1008 87.68 91.53
(0.30, 0.40)    816 85.60 90.09
(0.40, 0.50)    833 84.91 89.53


                                                                             

Epoch 26/100, Train Loss: 244688.7644, Val Loss: 238522.3203
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.72 99.80
(0.05, 0.10)   1640 95.42 96.97
(0.10, 0.20)   1518 91.62 94.26
(0.20, 0.30)   1008 88.44 92.16
(0.30, 0.40)    816 86.36 90.72
(0.40, 0.50)    833 85.65 90.16


                                                                             

Epoch 27/100, Train Loss: 243739.5354, Val Loss: 237615.5527
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.73 99.80
(0.05, 0.10)   1640 95.56 97.09
(0.10, 0.20)   1518 91.94 94.50
(0.20, 0.30)   1008 88.82 92.63
(0.30, 0.40)    816 86.77 90.99
(0.40, 0.50)    833 86.25 90.55


                                                                             

Epoch 28/100, Train Loss: 243178.8515, Val Loss: 237125.5137
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.73 99.81
(0.05, 0.10)   1640 95.66 97.26
(0.10, 0.20)   1518 92.11 94.78
(0.20, 0.30)   1008 89.11 92.92
(0.30, 0.40)    816 86.91 91.14
(0.40, 0.50)    833 86.49 90.54


                                                                             

Epoch 29/100, Train Loss: 242560.9877, Val Loss: 236664.7578
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.73 99.81
(0.05, 0.10)   1640 95.72 97.31
(0.10, 0.20)   1518 92.31 94.82
(0.20, 0.30)   1008 89.37 93.00
(0.30, 0.40)    816 87.15 91.42
(0.40, 0.50)    833 86.67 90.96


                                                                             

Epoch 30/100, Train Loss: 241994.7000, Val Loss: 236038.7910
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.74 99.82
(0.05, 0.10)   1640 95.84 97.40
(0.10, 0.20)   1518 92.51 95.05
(0.20, 0.30)   1008 89.64 93.28
(0.30, 0.40)    816 87.45 91.62
(0.40, 0.50)    833 86.93 91.03


                                                                             

Epoch 31/100, Train Loss: 241710.0214, Val Loss: 235645.2637
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.74 99.82
(0.05, 0.10)   1640 95.91 97.45
(0.10, 0.20)   1518 92.59 95.20
(0.20, 0.30)   1008 89.78 93.42
(0.30, 0.40)    816 87.57 91.82
(0.40, 0.50)    833 87.13 91.27


                                                                             

Epoch 32/100, Train Loss: 241278.3359, Val Loss: 235708.6582
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.74 99.82
(0.05, 0.10)   1640 95.94 97.45
(0.10, 0.20)   1518 92.64 95.26
(0.20, 0.30)   1008 90.01 93.46
(0.30, 0.40)    816 87.77 91.71
(0.40, 0.50)    833 87.29 91.21


                                                                             

Epoch 33/100, Train Loss: 241051.6144, Val Loss: 235128.6836
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.74 99.83
(0.05, 0.10)   1640 96.04 97.57
(0.10, 0.20)   1518 92.77 95.33
(0.20, 0.30)   1008 90.09 93.60
(0.30, 0.40)    816 87.89 91.93
(0.40, 0.50)    833 87.39 91.55


                                                                             

Epoch 34/100, Train Loss: 240665.3759, Val Loss: 235078.0312
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.75 99.83
(0.05, 0.10)   1640 96.06 97.58
(0.10, 0.20)   1518 92.87 95.34
(0.20, 0.30)   1008 90.25 93.58
(0.30, 0.40)    816 87.95 91.97
(0.40, 0.50)    833 87.55 91.49


                                                                             

Epoch 35/100, Train Loss: 240310.8680, Val Loss: 234658.1660
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.75 99.83
(0.05, 0.10)   1640 96.15 97.64
(0.10, 0.20)   1518 92.95 95.50
(0.20, 0.30)   1008 90.39 93.84
(0.30, 0.40)    816 88.13 92.19
(0.40, 0.50)    833 87.66 91.58


                                                                             

Epoch 36/100, Train Loss: 240215.8564, Val Loss: 234625.9844
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.75 99.84
(0.05, 0.10)   1640 96.20 97.63
(0.10, 0.20)   1518 93.04 95.48
(0.20, 0.30)   1008 90.50 93.73
(0.30, 0.40)    816 88.24 92.13
(0.40, 0.50)    833 87.87 91.66


                                                                             

Epoch 37/100, Train Loss: 239897.0538, Val Loss: 234521.5059
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.75 99.83
(0.05, 0.10)   1640 96.20 97.63
(0.10, 0.20)   1518 93.10 95.49
(0.20, 0.30)   1008 90.57 93.90
(0.30, 0.40)    816 88.34 92.30
(0.40, 0.50)    833 87.91 91.84


                                                                             

Epoch 38/100, Train Loss: 239721.9153, Val Loss: 234482.9902
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.75 99.84
(0.05, 0.10)   1640 96.22 97.62
(0.10, 0.20)   1518 93.11 95.41
(0.20, 0.30)   1008 90.66 93.90
(0.30, 0.40)    816 88.36 92.27
(0.40, 0.50)    833 88.02 91.75


                                                                             

Epoch 39/100, Train Loss: 239613.9912, Val Loss: 233950.3594
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.76 99.84
(0.05, 0.10)   1640 96.26 97.69
(0.10, 0.20)   1518 93.19 95.63
(0.20, 0.30)   1008 90.71 94.08
(0.30, 0.40)    816 88.47 92.41
(0.40, 0.50)    833 88.08 92.03


                                                                             

Epoch 40/100, Train Loss: 239349.4492, Val Loss: 233809.6582
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.76 99.85
(0.05, 0.10)   1640 96.32 97.76
(0.10, 0.20)   1518 93.23 95.66
(0.20, 0.30)   1008 90.82 94.16
(0.30, 0.40)    816 88.58 92.71
(0.40, 0.50)    833 88.17 92.23


                                                                             

Epoch 41/100, Train Loss: 239101.6829, Val Loss: 233595.4551
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.76 99.84
(0.05, 0.10)   1640 96.37 97.74
(0.10, 0.20)   1518 93.32 95.67
(0.20, 0.30)   1008 90.97 94.18
(0.30, 0.40)    816 88.76 92.67
(0.40, 0.50)    833 88.46 92.18


                                                                             

Epoch 42/100, Train Loss: 238858.7378, Val Loss: 233509.8203
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.76 99.85
(0.05, 0.10)   1640 96.36 97.86
(0.10, 0.20)   1518 93.38 95.70
(0.20, 0.30)   1008 91.06 94.27
(0.30, 0.40)    816 88.87 92.68
(0.40, 0.50)    833 88.50 92.23


                                                                             

Epoch 43/100, Train Loss: 238794.3227, Val Loss: 233524.3438
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.76 99.85
(0.05, 0.10)   1640 96.41 97.82
(0.10, 0.20)   1518 93.44 95.67
(0.20, 0.30)   1008 91.08 94.18
(0.30, 0.40)    816 88.93 92.75
(0.40, 0.50)    833 88.53 92.22


                                                                             

Epoch 44/100, Train Loss: 238633.6368, Val Loss: 233339.4707
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.77 99.85
(0.05, 0.10)   1640 96.43 97.81
(0.10, 0.20)   1518 93.47 95.80
(0.20, 0.30)   1008 91.11 94.26
(0.30, 0.40)    816 88.95 92.79
(0.40, 0.50)    833 88.58 92.40


                                                                             

Epoch 45/100, Train Loss: 238547.8044, Val Loss: 233265.1660
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.76 99.85
(0.05, 0.10)   1640 96.43 97.74
(0.10, 0.20)   1518 93.47 95.73
(0.20, 0.30)   1008 91.14 94.19
(0.30, 0.40)    816 89.03 92.96
(0.40, 0.50)    833 88.64 92.30


                                                                             

Epoch 46/100, Train Loss: 238385.0336, Val Loss: 233148.3594
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.77 99.85
(0.05, 0.10)   1640 96.42 97.80
(0.10, 0.20)   1518 93.47 95.76
(0.20, 0.30)   1008 91.21 94.34
(0.30, 0.40)    816 89.17 93.07
(0.40, 0.50)    833 88.76 92.54


                                                                             

Epoch 47/100, Train Loss: 238108.1473, Val Loss: 232934.8906
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.77 99.85
(0.05, 0.10)   1640 96.53 97.85
(0.10, 0.20)   1518 93.60 95.90
(0.20, 0.30)   1008 91.35 94.47
(0.30, 0.40)    816 89.33 93.30
(0.40, 0.50)    833 88.90 92.53


                                                                             

Epoch 48/100, Train Loss: 237907.6207, Val Loss: 232610.1660
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.77 99.85
(0.05, 0.10)   1640 96.53 97.92
(0.10, 0.20)   1518 93.68 95.98
(0.20, 0.30)   1008 91.45 94.50
(0.30, 0.40)    816 89.46 93.30
(0.40, 0.50)    833 89.00 92.68


                                                                             

Epoch 49/100, Train Loss: 237741.4574, Val Loss: 232553.3223
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.77 99.85
(0.05, 0.10)   1640 96.53 97.91
(0.10, 0.20)   1518 93.73 95.95
(0.20, 0.30)   1008 91.46 94.53
(0.30, 0.40)    816 89.53 93.27
(0.40, 0.50)    833 89.10 92.85


                                                                             

Epoch 50/100, Train Loss: 237547.8993, Val Loss: 232414.2520
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.77 99.85
(0.05, 0.10)   1640 96.54 97.92
(0.10, 0.20)   1518 93.80 95.96
(0.20, 0.30)   1008 91.51 94.64
(0.30, 0.40)    816 89.61 93.26
(0.40, 0.50)    833 89.21 92.85


                                                                             

Epoch 51/100, Train Loss: 237493.4320, Val Loss: 232457.0469
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.77 99.86
(0.05, 0.10)   1640 96.56 97.90
(0.10, 0.20)   1518 93.81 95.95
(0.20, 0.30)   1008 91.59 94.56
(0.30, 0.40)    816 89.62 93.28
(0.40, 0.50)    833 89.24 92.86


                                                                             

Epoch 52/100, Train Loss: 237208.0549, Val Loss: 232198.2148
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.77 99.86
(0.05, 0.10)   1640 96.60 97.98
(0.10, 0.20)   1518 93.86 96.05
(0.20, 0.30)   1008 91.65 94.73
(0.30, 0.40)    816 89.77 93.47
(0.40, 0.50)    833 89.37 92.95


                                                                             

Epoch 53/100, Train Loss: 237034.7593, Val Loss: 232038.2637
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.78 99.85
(0.05, 0.10)   1640 96.65 98.04
(0.10, 0.20)   1518 93.97 96.16
(0.20, 0.30)   1008 91.79 94.76
(0.30, 0.40)    816 89.92 93.41
(0.40, 0.50)    833 89.54 93.03


                                                                             

Epoch 54/100, Train Loss: 236960.6775, Val Loss: 231931.6758
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.78 99.86
(0.05, 0.10)   1640 96.64 98.03
(0.10, 0.20)   1518 93.98 96.17
(0.20, 0.30)   1008 91.83 94.79
(0.30, 0.40)    816 89.97 93.57
(0.40, 0.50)    833 89.61 93.21


                                                                             

Epoch 55/100, Train Loss: 236740.8724, Val Loss: 231746.1309
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.78 99.86
(0.05, 0.10)   1640 96.65 97.98
(0.10, 0.20)   1518 94.03 96.20
(0.20, 0.30)   1008 91.94 94.85
(0.30, 0.40)    816 90.09 93.73
(0.40, 0.50)    833 89.72 93.37


                                                                             

Epoch 56/100, Train Loss: 236659.6075, Val Loss: 231563.9629
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.78 99.86
(0.05, 0.10)   1640 96.69 98.02
(0.10, 0.20)   1518 94.08 96.25
(0.20, 0.30)   1008 91.95 94.93
(0.30, 0.40)    816 90.13 93.82
(0.40, 0.50)    833 89.80 93.49


                                                                             

Epoch 57/100, Train Loss: 236292.5391, Val Loss: 231470.3379
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.78 99.86
(0.05, 0.10)   1640 96.75 98.02
(0.10, 0.20)   1518 94.18 96.32
(0.20, 0.30)   1008 92.06 95.03
(0.30, 0.40)    816 90.31 93.89
(0.40, 0.50)    833 90.00 93.56


                                                                             

Epoch 58/100, Train Loss: 236213.1411, Val Loss: 231363.3477
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.78 99.86
(0.05, 0.10)   1640 96.75 98.06
(0.10, 0.20)   1518 94.20 96.41
(0.20, 0.30)   1008 92.13 95.11
(0.30, 0.40)    816 90.33 93.80
(0.40, 0.50)    833 90.01 93.59


                                                                             

Epoch 59/100, Train Loss: 236147.1211, Val Loss: 231234.9707
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.78 99.86
(0.05, 0.10)   1640 96.77 98.10
(0.10, 0.20)   1518 94.24 96.41
(0.20, 0.30)   1008 92.13 95.09
(0.30, 0.40)    816 90.40 93.82
(0.40, 0.50)    833 90.13 93.61


                                                                             

Epoch 60/100, Train Loss: 235981.6807, Val Loss: 231037.1797
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.78 99.86
(0.05, 0.10)   1640 96.80 98.12
(0.10, 0.20)   1518 94.29 96.41
(0.20, 0.30)   1008 92.24 95.21
(0.30, 0.40)    816 90.56 94.04
(0.40, 0.50)    833 90.26 93.84


                                                                             

Epoch 61/100, Train Loss: 235681.7765, Val Loss: 230817.7148
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.79 99.87
(0.05, 0.10)   1640 96.85 98.17
(0.10, 0.20)   1518 94.35 96.46
(0.20, 0.30)   1008 92.33 95.18
(0.30, 0.40)    816 90.70 94.12
(0.40, 0.50)    833 90.47 94.02


                                                                             

Epoch 62/100, Train Loss: 235463.3370, Val Loss: 230606.1621
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.78 99.87
(0.05, 0.10)   1640 96.86 98.17
(0.10, 0.20)   1518 94.41 96.53
(0.20, 0.30)   1008 92.43 95.30
(0.30, 0.40)    816 90.71 94.22
(0.40, 0.50)    833 90.57 94.00


                                                                             

Epoch 63/100, Train Loss: 235339.1933, Val Loss: 230741.3145
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.79 99.87
(0.05, 0.10)   1640 96.92 98.17
(0.10, 0.20)   1518 94.46 96.47
(0.20, 0.30)   1008 92.47 95.26
(0.30, 0.40)    816 90.82 94.10
(0.40, 0.50)    833 90.67 94.04


                                                                             

Epoch 64/100, Train Loss: 235217.4863, Val Loss: 230596.3594
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.79 99.87
(0.05, 0.10)   1640 96.91 98.14
(0.10, 0.20)   1518 94.50 96.41
(0.20, 0.30)   1008 92.51 95.31
(0.30, 0.40)    816 90.86 94.25
(0.40, 0.50)    833 90.73 94.16


                                                                             

Epoch 65/100, Train Loss: 235001.8992, Val Loss: 230243.7637
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.79 99.87
(0.05, 0.10)   1640 96.95 98.23
(0.10, 0.20)   1518 94.53 96.61
(0.20, 0.30)   1008 92.53 95.49
(0.30, 0.40)    816 90.96 94.34
(0.40, 0.50)    833 90.91 94.28


                                                                             

Epoch 66/100, Train Loss: 234905.8769, Val Loss: 230151.3691
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.79 99.87
(0.05, 0.10)   1640 96.96 98.27
(0.10, 0.20)   1518 94.59 96.63
(0.20, 0.30)   1008 92.63 95.52
(0.30, 0.40)    816 90.99 94.44
(0.40, 0.50)    833 90.96 94.32


                                                                             

Epoch 67/100, Train Loss: 234642.9123, Val Loss: 230037.7637
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.79 99.87
(0.05, 0.10)   1640 97.03 98.26
(0.10, 0.20)   1518 94.69 96.70
(0.20, 0.30)   1008 92.77 95.54
(0.30, 0.40)    816 91.16 94.46
(0.40, 0.50)    833 91.15 94.42


                                                                             

Epoch 68/100, Train Loss: 234501.5508, Val Loss: 229907.2949
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.79 99.87
(0.05, 0.10)   1640 97.03 98.29
(0.10, 0.20)   1518 94.69 96.71
(0.20, 0.30)   1008 92.77 95.60
(0.30, 0.40)    816 91.19 94.53
(0.40, 0.50)    833 91.23 94.41


                                                                             

Epoch 69/100, Train Loss: 234328.8725, Val Loss: 229931.4062
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.79 99.87
(0.05, 0.10)   1640 97.04 98.29
(0.10, 0.20)   1518 94.74 96.70
(0.20, 0.30)   1008 92.89 95.53
(0.30, 0.40)    816 91.29 94.47
(0.40, 0.50)    833 91.26 94.55


                                                                             

Epoch 70/100, Train Loss: 234148.4194, Val Loss: 229581.8906
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.79 99.87
(0.05, 0.10)   1640 97.10 98.33
(0.10, 0.20)   1518 94.77 96.85
(0.20, 0.30)   1008 92.93 95.62
(0.30, 0.40)    816 91.35 94.54
(0.40, 0.50)    833 91.37 94.68


                                                                             

Epoch 71/100, Train Loss: 234161.6867, Val Loss: 229758.5566
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.79 99.87
(0.05, 0.10)   1640 97.09 98.33
(0.10, 0.20)   1518 94.77 96.83
(0.20, 0.30)   1008 92.94 95.59
(0.30, 0.40)    816 91.36 94.51
(0.40, 0.50)    833 91.45 94.54


                                                                             

Epoch 72/100, Train Loss: 234039.2073, Val Loss: 229580.9238
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.80 99.87
(0.05, 0.10)   1640 97.13 98.36
(0.10, 0.20)   1518 94.83 96.83
(0.20, 0.30)   1008 93.01 95.71
(0.30, 0.40)    816 91.43 94.65
(0.40, 0.50)    833 91.48 94.67


                                                                             

Epoch 73/100, Train Loss: 233933.0560, Val Loss: 229500.0566
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.80 99.87
(0.05, 0.10)   1640 97.12 98.39
(0.10, 0.20)   1518 94.86 96.83
(0.20, 0.30)   1008 93.03 95.69
(0.30, 0.40)    816 91.45 94.71
(0.40, 0.50)    833 91.50 94.73


                                                                             

Epoch 74/100, Train Loss: 233768.3574, Val Loss: 229506.3340
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.80 99.87
(0.05, 0.10)   1640 97.16 98.35
(0.10, 0.20)   1518 94.90 96.80
(0.20, 0.30)   1008 93.07 95.73
(0.30, 0.40)    816 91.60 94.68
(0.40, 0.50)    833 91.62 94.78


                                                                             

Epoch 75/100, Train Loss: 233742.0390, Val Loss: 229265.0547
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.80 99.87
(0.05, 0.10)   1640 97.14 98.45
(0.10, 0.20)   1518 94.92 96.84
(0.20, 0.30)   1008 93.12 95.82
(0.30, 0.40)    816 91.53 94.68
(0.40, 0.50)    833 91.64 94.93


                                                                             

Epoch 76/100, Train Loss: 233613.4163, Val Loss: 229258.8848
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.80 99.88
(0.05, 0.10)   1640 97.18 98.43
(0.10, 0.20)   1518 94.94 96.85
(0.20, 0.30)   1008 93.11 95.79
(0.30, 0.40)    816 91.61 94.69
(0.40, 0.50)    833 91.61 94.86


                                                                             

Epoch 77/100, Train Loss: 233436.3069, Val Loss: 229215.8945
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.80 99.88
(0.05, 0.10)   1640 97.22 98.44
(0.10, 0.20)   1518 95.01 96.89
(0.20, 0.30)   1008 93.21 95.83
(0.30, 0.40)    816 91.70 94.81
(0.40, 0.50)    833 91.79 94.98


                                                                             

Epoch 78/100, Train Loss: 233382.1537, Val Loss: 229099.0117
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.80 99.88
(0.05, 0.10)   1640 97.23 98.46
(0.10, 0.20)   1518 95.01 96.94
(0.20, 0.30)   1008 93.28 95.85
(0.30, 0.40)    816 91.76 94.94
(0.40, 0.50)    833 91.81 94.97


                                                                             

Epoch 79/100, Train Loss: 233330.3687, Val Loss: 228971.9434
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.80 99.88
(0.05, 0.10)   1640 97.24 98.46
(0.10, 0.20)   1518 95.04 96.96
(0.20, 0.30)   1008 93.22 95.99
(0.30, 0.40)    816 91.77 95.06
(0.40, 0.50)    833 91.79 95.00


                                                                             

Epoch 80/100, Train Loss: 233147.1453, Val Loss: 228953.3906
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.80 99.88
(0.05, 0.10)   1640 97.27 98.50
(0.10, 0.20)   1518 95.10 96.95
(0.20, 0.30)   1008 93.36 95.94
(0.30, 0.40)    816 91.86 95.03
(0.40, 0.50)    833 91.90 94.97


                                                                             

Epoch 81/100, Train Loss: 233099.2278, Val Loss: 229031.9219
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.80 99.88
(0.05, 0.10)   1640 97.28 98.46
(0.10, 0.20)   1518 95.09 96.93
(0.20, 0.30)   1008 93.35 95.90
(0.30, 0.40)    816 91.86 94.76
(0.40, 0.50)    833 91.91 95.00


                                                                             

Epoch 82/100, Train Loss: 232995.2113, Val Loss: 228777.2871
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.80 99.88
(0.05, 0.10)   1640 97.30 98.48
(0.10, 0.20)   1518 95.15 97.02
(0.20, 0.30)   1008 93.44 96.00
(0.30, 0.40)    816 91.98 94.97
(0.40, 0.50)    833 91.98 95.07


                                                                             

Epoch 83/100, Train Loss: 232964.7416, Val Loss: 228745.0059
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.80 99.88
(0.05, 0.10)   1640 97.29 98.49
(0.10, 0.20)   1518 95.11 97.06
(0.20, 0.30)   1008 93.44 95.98
(0.30, 0.40)    816 91.98 95.07
(0.40, 0.50)    833 91.97 95.15


                                                                             

Epoch 84/100, Train Loss: 232802.6227, Val Loss: 228663.0957
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.81 99.88
(0.05, 0.10)   1640 97.34 98.53
(0.10, 0.20)   1518 95.22 97.07
(0.20, 0.30)   1008 93.49 96.05
(0.30, 0.40)    816 92.20 95.14
(0.40, 0.50)    833 92.23 95.22


                                                                             

Epoch 85/100, Train Loss: 232622.6439, Val Loss: 228426.5000
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.81 99.88
(0.05, 0.10)   1640 97.34 98.58
(0.10, 0.20)   1518 95.21 97.14
(0.20, 0.30)   1008 93.51 96.16
(0.30, 0.40)    816 92.22 95.43
(0.40, 0.50)    833 92.29 95.30


                                                                             

Epoch 86/100, Train Loss: 232488.2369, Val Loss: 228221.0176
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.81 99.88
(0.05, 0.10)   1640 97.37 98.58
(0.10, 0.20)   1518 95.27 97.10
(0.20, 0.30)   1008 93.60 96.23
(0.30, 0.40)    816 92.41 95.57
(0.40, 0.50)    833 92.38 95.48


                                                                             

Epoch 87/100, Train Loss: 232317.2899, Val Loss: 228259.6758
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.81 99.88
(0.05, 0.10)   1640 97.38 98.58
(0.10, 0.20)   1518 95.30 97.12
(0.20, 0.30)   1008 93.60 96.20
(0.30, 0.40)    816 92.46 95.58
(0.40, 0.50)    833 92.39 95.42


                                                                             

Epoch 88/100, Train Loss: 232174.7125, Val Loss: 228174.9043
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.81 99.88
(0.05, 0.10)   1640 97.41 98.60
(0.10, 0.20)   1518 95.34 97.16
(0.20, 0.30)   1008 93.70 96.25
(0.30, 0.40)    816 92.71 95.61
(0.40, 0.50)    833 92.60 95.61


                                                                             

Epoch 89/100, Train Loss: 231928.2353, Val Loss: 227912.0566
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.81 99.89
(0.05, 0.10)   1640 97.44 98.64
(0.10, 0.20)   1518 95.39 97.21
(0.20, 0.30)   1008 93.76 96.29
(0.30, 0.40)    816 92.75 95.76
(0.40, 0.50)    833 92.73 95.68


                                                                             

Epoch 90/100, Train Loss: 231896.5062, Val Loss: 227816.3340
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.81 99.89
(0.05, 0.10)   1640 97.46 98.65
(0.10, 0.20)   1518 95.42 97.20
(0.20, 0.30)   1008 93.83 96.33
(0.30, 0.40)    816 92.86 95.81
(0.40, 0.50)    833 92.87 95.86


                                                                             

Epoch 91/100, Train Loss: 231653.0817, Val Loss: 227598.5391
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.81 99.89
(0.05, 0.10)   1640 97.47 98.68
(0.10, 0.20)   1518 95.45 97.23
(0.20, 0.30)   1008 93.90 96.47
(0.30, 0.40)    816 93.01 95.99
(0.40, 0.50)    833 92.94 95.99


                                                                             

Epoch 92/100, Train Loss: 231536.4251, Val Loss: 227463.9766
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.81 99.89
(0.05, 0.10)   1640 97.47 98.71
(0.10, 0.20)   1518 95.46 97.28
(0.20, 0.30)   1008 93.93 96.49
(0.30, 0.40)    816 93.13 96.25
(0.40, 0.50)    833 93.05 95.93


                                                                             

Epoch 93/100, Train Loss: 231247.3214, Val Loss: 227285.7109
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.81 99.89
(0.05, 0.10)   1640 97.54 98.69
(0.10, 0.20)   1518 95.59 97.30
(0.20, 0.30)   1008 94.13 96.63
(0.30, 0.40)    816 93.42 96.31
(0.40, 0.50)    833 93.19 96.04


                                                                             

Epoch 94/100, Train Loss: 231184.0242, Val Loss: 227254.8691
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.81 99.89
(0.05, 0.10)   1640 97.54 98.74
(0.10, 0.20)   1518 95.59 97.35
(0.20, 0.30)   1008 94.13 96.67
(0.30, 0.40)    816 93.57 96.31
(0.40, 0.50)    833 93.24 96.12


                                                                             

Epoch 95/100, Train Loss: 230723.1837, Val Loss: 226940.5996
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.82 99.89
(0.05, 0.10)   1640 97.60 98.75
(0.10, 0.20)   1518 95.69 97.44
(0.20, 0.30)   1008 94.21 96.68
(0.30, 0.40)    816 93.76 96.56
(0.40, 0.50)    833 93.38 96.25


                                                                             

Epoch 96/100, Train Loss: 230616.7096, Val Loss: 226936.4727
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.82 99.89
(0.05, 0.10)   1640 97.60 98.78
(0.10, 0.20)   1518 95.69 97.45
(0.20, 0.30)   1008 94.30 96.72
(0.30, 0.40)    816 93.86 96.58
(0.40, 0.50)    833 93.53 96.37


                                                                             

Epoch 97/100, Train Loss: 230528.3870, Val Loss: 226718.8574
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.82 99.89
(0.05, 0.10)   1640 97.63 98.79
(0.10, 0.20)   1518 95.76 97.54
(0.20, 0.30)   1008 94.31 96.84
(0.30, 0.40)    816 93.96 96.64
(0.40, 0.50)    833 93.56 96.53


                                                                             

Epoch 98/100, Train Loss: 230322.7818, Val Loss: 226599.6191
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.82 99.89
(0.05, 0.10)   1640 97.66 98.81
(0.10, 0.20)   1518 95.78 97.45
(0.20, 0.30)   1008 94.39 96.87
(0.30, 0.40)    816 94.06 96.74
(0.40, 0.50)    833 93.68 96.59


                                                                             

Epoch 99/100, Train Loss: 230204.9569, Val Loss: 226583.7344
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.82 99.89
(0.05, 0.10)   1640 97.69 98.87
(0.10, 0.20)   1518 95.85 97.51
(0.20, 0.30)   1008 94.46 96.88
(0.30, 0.40)    816 94.17 96.83
(0.40, 0.50)    833 93.78 96.57


                                                                              

Epoch 100/100, Train Loss: 230074.0283, Val Loss: 226321.5273
     MAF_bin Counts Train   Val
(0.00, 0.05)  50330 99.82 99.90
(0.05, 0.10)   1640 97.71 98.90
(0.10, 0.20)   1518 95.90 97.58
(0.20, 0.30)   1008 94.53 96.96
(0.30, 0.40)    816 94.26 96.89
(0.40, 0.50)    833 93.82 96.65


In [17]:
os.makedirs(f'{args.save_dir}/models', exist_ok=True)
torch.save(model.state_dict(), f'{args.save_dir}/models/final.pth')

In [18]:
# Clean up
del ref_set, train_dataset, val_dataset, train_loader, val_loader, model
torch.cuda.empty_cache() if torch.cuda.is_available() else None