In [1]:
from argparse import ArgumentParser, Namespace
from datetime import datetime
import gc
import logging
from pathlib import Path
from typing import Literal
import warnings
import numpy as np
import pandas as pd
import polars as pl
import xgboost as xgb
import optuna
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
from sklearn.model_selection import train_test_split


warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)

SLICE_PARAMS: dict[str, int] = {
    'b5': 0, 'b4': 1, 'b3': 2, 'b2': 3, 'b1': 4,
    'a1': 6, 'a2': 7, 'a3': 8, 'a4': 9, 'a5': 10}


def read_df(f_path: Path, split: bool,
            train_test: Literal['train', 'test'],
            n_rows: int = 5000) -> pd.DataFrame:
    selected_columns: list[str] = ['predicted_beta', 'depth',
                                   'GC_skew_70', 'CpG_GC_ratio_70',
                                   'ShannonEntropy_70', 'BWT_ratio_70',
                                   'cpg', 'location', 'promoter', 'enhancer']
    df: pd.DataFrame
    if train_test == 'train':
        _df_ls: list[pl.DataFrame] = []
        for _df in (pl.scan_parquet(f_path)
                      .with_columns(pl.col('predicted_beta')
                                      .cut(breaks=list(range(5, 100, 5)),
                                           labels=[f'{i}-{i + 5}' for i in range(0, 100, 5)])
                                      .alias('beta_bin'))
                      .collect()
                      .partition_by('beta_bin', include_key=False)):
            if _df.shape[0] < n_rows:
                _df_ls.append(_df)
            else:
                _df_ls.append(_df.sample(n=n_rows))
        if split:
            selected_columns += [*SLICE_PARAMS.keys(), 'actual_beta']
            df = (pl.concat(_df_ls)
                    .lazy()
                    .with_columns(pl.when(pl.col('seq_5')
                                            .str
                                            .slice(5, 1) == 'C')
                                    .then(pl.col('seq_5'))
                                    .otherwise(pl.col('seq_5')
                                                 .str.reverse()
                                                 .str.replace_many(['A', 'C', 'G', 'T'],
                                                                   ['T', 'G', 'C', 'A']))
                                    .alias('processed_seq_5'))
                    .with_columns([pl.col('processed_seq_5')
                                     .str
                                     .slice(index, 1)
                                     .alias(name) for name, index in SLICE_PARAMS.items()])
                    .select(selected_columns)
                    .collect()
                    .to_pandas())
        else:
            selected_columns += ['seq_5', 'actual_beta']
            df = (pl.concat(_df_ls)
                    .select(selected_columns)
                    .to_pandas())
        del _df_ls
        gc.collect()
    else:
        if split:
            selected_columns += [*SLICE_PARAMS.keys()]
            df = (pl.scan_parquet(f_path)
                    .with_columns(pl.when(pl.col('seq_5')
                                            .str
                                            .slice(5, 1) == 'C')
                                    .then(pl.col('seq_5'))
                                    .otherwise(pl.col('seq_5')
                                                 .str.reverse()
                                                 .str.replace_many(['A', 'C', 'G', 'T'],
                                                                   ['T', 'G', 'C', 'A']))
                                    .alias('processed_seq_5'))
                    .with_columns([pl.col('processed_seq_5')
                                     .str
                                     .slice(index, 1)
                                     .alias(name) for name, index in SLICE_PARAMS.items()])
                    .select(['chrom', 'start', 'end'] + selected_columns)
                    .collect()
                    .to_pandas())
        else:
            selected_columns += ['seq_5']
            df = pd.read_parquet(f_path)[['chrom', 'start', 'end'] + selected_columns]
    return df


def slice_dataframe(df: pd.DataFrame, slice_size: int = 100_000) -> list[pd.DataFrame]:
    num_chunks: int = int(np.ceil(len(df) / slice_size))
    return [df.iloc[i * slice_size: (i + 1) * slice_size] for i in range(num_chunks)]



In [8]:
features: list[str] = ['predicted_beta', 'depth',
                       'GC_skew_70', 'CpG_GC_ratio_70',
                       'ShannonEntropy_70', 'BWT_ratio_70',
                       'cpg', 'location', 'promoter', 'enhancer']
hi_input_path='/mnt/eqa/zhangyuanfeng/methylation/best_pipeline/data/calibrated_high_depth/BS2_D5_1.parquet.lz4'

hd_df: pd.DataFrame = read_df(f_path=hi_input_path, split=True, train_test='train')

for col in ['cpg', 'location', 'promoter', 'enhancer',
            'b5', 'b4', 'b3', 'b2', 'b1',
            'a1', 'a2', 'a3', 'a4', 'a5']:
    hd_df[col] = hd_df[col].astype('category')

X = hd_df[features]
y = hd_df['actual_beta']

In [9]:
(X_train_split,
    X_val,
    y_train_split,
    y_val)= train_test_split(
    X, y, test_size=0.2, random_state=28
)

In [12]:
def get_xgb_params(trial: optuna.Trial) -> dict[str, int | float | str | list[float]]:
    """获取XGBoost参数, 固定使用0.5分位数"""
    return {
        'objective': 'reg:quantileerror',
        'quantile_alpha': [0.5],  # 只训练中位数
        'eval_metric': 'quantile',
        'tree_method': 'gpu_hist',  # 固定使用GPU
        'n_estimators': trial.suggest_int('n_estimators', 50, 500),
        'max_depth': trial.suggest_int('max_depth', 3, 10),
        'learning_rate': trial.suggest_float('learning_rate', 1e-3, 0.3, log=True),
        'subsample': trial.suggest_float('subsample', 0.6, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
        'gamma': trial.suggest_float('gamma', 1e-8, 1.0, log=True),
        'min_child_weight': trial.suggest_int('min_child_weight', 1, 20),
        'reg_alpha': trial.suggest_float('reg_alpha', 1e-8, 10.0, log=True),
        'reg_lambda': trial.suggest_float('reg_lambda', 1e-8, 10.0, log=True),
        'random_state': 42,
        'verbosity': 0
    }


def objective(trial: optuna.Trial, X_train: pd.DataFrame, y_train: pd.Series, 
              X_val: pd.DataFrame, y_val: pd.Series) -> float:
    """Optuna目标函数, 只优化0.5分位数"""
    params = get_xgb_params(trial)

    # 创建DMatrix
    dtrain = xgb.QuantileDMatrix(X_train, y_train, enable_categorical=True)
    dval = xgb.QuantileDMatrix(X_val, y_val, ref=dtrain, enable_categorical=True)

    # 训练模型
    model = xgb.train(
        params,
        dtrain,
        num_boost_round=params['n_estimators'],
        evals=[(dtrain, 'Train'), (dval, 'Test')],
        early_stopping_rounds=20,
        verbose_eval=False
    )

    # 预测验证集
    val_pred = model.predict(dval)

    q = 0.5
    loss = np.mean(np.maximum(q * (y_val - val_pred), (q - 1) * (y_val - val_pred)))

    return loss


study: optuna.Study = optuna.create_study(direction='minimize',
                                          sampler=optuna.samplers.TPESampler(seed=28))
study.optimize(
        lambda trial: objective(trial, X_train_split, y_train_split, X_val, y_val),
        n_trials=100
    )

[I 2025-07-30 21:17:54,523] A new study created in memory with name: no-name-965030d6-7303-4ca9-9e3f-d37048d03de6


[I 2025-07-30 21:18:01,057] Trial 0 finished with value: 4.390315532684326 and parameters: {'n_estimators': 378, 'max_depth': 7, 'learning_rate': 0.0020396640426346663, 'subsample': 0.7590369474800469, 'colsample_bytree': 0.9125232825093904, 'gamma': 0.00012244578877765376, 'min_child_weight': 4, 'reg_alpha': 0.48041424592671783, 'reg_lambda': 3.965948165384268}. Best is trial 0 with value: 4.390315532684326.
[I 2025-07-30 21:18:06,882] Trial 1 finished with value: 0.3265712559223175 and parameters: {'n_estimators': 493, 'max_depth': 4, 'learning_rate': 0.25396693609158416, 'subsample': 0.6939233418405771, 'colsample_bytree': 0.6105415394039483, 'gamma': 0.37023949489210145, 'min_child_weight': 19, 'reg_alpha': 4.4236617352233705e-06, 'reg_lambda': 3.136630201373402e-07}. Best is trial 1 with value: 0.3265712559223175.
[I 2025-07-30 21:18:11,664] Trial 2 finished with value: 2.743213653564453 and parameters: {'n_estimators': 228, 'max_depth': 7, 'learning_rate': 0.005794894929578646, '

In [14]:
best_params = get_xgb_params(study.best_trial)

# 移除n_estimators, 使用early stopping
n_estimators = best_params.pop('n_estimators')

dtrain = xgb.QuantileDMatrix(X, y, enable_categorical=True)
dval = xgb.QuantileDMatrix(X_val, y_val, ref=dtrain, enable_categorical=True)

model = xgb.train(
    best_params,
    dtrain,
    num_boost_round=n_estimators,
    evals=[(dval, 'validation')],
    early_stopping_rounds=50,
    verbose_eval=10
)

[0]	validation-quantile:7.67989
[10]	validation-quantile:3.40057
[20]	validation-quantile:1.66675
[30]	validation-quantile:0.84833
[40]	validation-quantile:0.50864
[50]	validation-quantile:0.39397
[60]	validation-quantile:0.35525
[70]	validation-quantile:0.33855
[80]	validation-quantile:0.32969
[90]	validation-quantile:0.32611
[100]	validation-quantile:0.32240
[110]	validation-quantile:0.31951
[120]	validation-quantile:0.31723
[130]	validation-quantile:0.31491
[140]	validation-quantile:0.31251
[150]	validation-quantile:0.31111
[160]	validation-quantile:0.30964
[170]	validation-quantile:0.30826
[180]	validation-quantile:0.30700
[190]	validation-quantile:0.30570
[200]	validation-quantile:0.30452
[210]	validation-quantile:0.30374
[220]	validation-quantile:0.30273
[230]	validation-quantile:0.30168
[240]	validation-quantile:0.30097
[250]	validation-quantile:0.30028
[260]	validation-quantile:0.29942
[270]	validation-quantile:0.29858
[280]	validation-quantile:0.29793
[290]	validation-quantile

In [15]:
mi_input_path='/mnt/eqa/zhangyuanfeng/methylation/best_pipeline/data/medium_depth/BS2_D5_1.parquet.lz4'
md_df = read_df(f_path=mi_input_path, split=True, train_test='test')
for col in ['cpg', 'location', 'promoter', 'enhancer',
            'b5', 'b4', 'b3', 'b2', 'b1',
            'a1', 'a2', 'a3', 'a4', 'a5']:
    md_df[col] = md_df[col].astype('category')

slices = slice_dataframe(df=md_df)

In [17]:
slice_d = xgb.QuantileDMatrix(slices[0][features], ref=dtrain, enable_categorical=True)

In [18]:
predictions = model.predict(slice_d)
slices[0]['actual_beta'] = predictions

In [19]:
slices[0]

Unnamed: 0,chrom,start,end,predicted_beta,depth,GC_skew_70,CpG_GC_ratio_70,ShannonEntropy_70,BWT_ratio_70,cpg,...,b4,b3,b2,b1,a1,a2,a3,a4,a5,actual_beta
0,chr1,10471,10472,100.000000,7,-0.511628,0.116279,14.649351,1.832707,cpg_inter,...,G,T,A,C,G,C,G,A,G,74.477859
1,chr1,10620,10621,100.000000,9,0.047619,0.333333,11.873684,1.803982,cpg_inter,...,G,C,A,A,G,G,C,G,G,75.082214
2,chr1,10781,10782,100.000000,6,0.182609,0.365217,19.448276,1.615711,cpg_inter,...,G,G,C,G,G,G,C,G,C,75.304825
3,chr1,10783,10784,100.000000,6,0.172414,0.370690,18.800000,1.607113,cpg_inter,...,C,C,G,G,G,C,G,G,C,75.227005
4,chr1,10786,10787,100.000000,5,0.145299,0.376068,19.118644,1.600962,cpg_inter,...,G,C,G,C,G,G,C,G,C,75.286575
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,chr1,11757708,11757709,100.000000,6,0.098592,0.140845,11.058824,1.982006,cpg_inter,...,A,A,A,C,G,G,G,C,T,75.350266
99996,chr1,11757743,11757744,50.000000,6,0.117647,0.147059,10.742857,1.978385,cpg_inter,...,C,A,T,G,G,T,C,T,G,43.401722
99997,chr1,11758549,11758550,33.333333,9,0.207547,0.094340,10.254545,1.919861,cpg_shelve,...,A,G,G,A,G,A,A,A,A,33.529198
99998,chr1,11758685,11758686,42.857143,7,0.151515,0.121212,10.444444,1.973826,cpg_shelve,...,G,G,T,G,G,G,T,G,G,39.561565


In [21]:
truset = '/mnt/eqa/zhangyuanfeng/methylation/quartet_reference/single_c/ensembl/full_seq_info/D5.parquet.lz4'

df: pl.DataFrame = (pl.from_dataframe(slices[0])
                      .lazy()
                      .select(['chrom', 'start', 'end', 'predicted_beta', 'actual_beta'])
                      .join(other=pl.scan_parquet(truset)
                                    .filter(pl.col('in_hcr'))
                                    .select('chrom', 'start', 'end', 'beta_pyro'),
                            on=['chrom', 'start', 'end'], how='inner')
                      .collect())

In [22]:
original_rmse = (df.select(
    (pl.col('predicted_beta') - pl.col('beta_pyro')).pow(2)
    .mean()
    .sqrt()
    .alias('rmse')).item())

corrected_rmse = (df.select(
    (pl.col('actual_beta') - pl.col('beta_pyro')).pow(2)
    .mean()
    .sqrt()
    .alias('rmse')).item())

print(f'Original RMSE: {original_rmse:.4f}, Corrected RMSE: {corrected_rmse:.4f}')

Original RMSE: 20.2552, Corrected RMSE: 7.9402
