In [None]:
import os
import warnings
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from types import SimpleNamespace as _NS
import joblib
from scipy.stats import pearsonr
import matplotlib.colors as mcolors

warnings.filterwarnings("ignore")

from src.data import TRIMDataModule
from src.regression_model import Regression

# ================= Configuration =================
CONFIG = _NS(
    # 1. Model checkpoint path
    best_ckpt_path="outputs/logs/version_0/checkpoints/best-epoch=022-val/r2_reg=0.795.ckpt",
    # 2. Data paths
    val_csv="data/filtered_dataset/filtered_dataset.val.csv",
    test_csv="data/filtered_dataset/filtered_dataset.test.csv",
    train_csv_for_init="data/filtered_dataset/filtered_dataset.train.csv",
    # 3. Preprocessor paths
    env_preproc_path="data/preprocessor/env_preproc.joblib",
    target_scaler_path="data/preprocessor/target_TE_value_zscaler.pkl",
    # 4. Output configuration
    output_dir="outputs/Validate_and_Plot",
    force_inference=False,  # Set True to ignore cache and re-run model
    # 5. Parameters
    batch_size=512,
    num_workers=16,
    max_utr5=1381,
    max_cds_utr3=11937,
    xlabel="Real TE",
    ylabel="Predicted TE",
)
# ============================================


def load_model_automatic(ckpt_path):
    """
    Automatically load model from checkpoint.
    Benefits: Restores hyperparameters automatically, no need to manually specify network structure.
    """
    print(f"[Info] Loading model automatically from: {ckpt_path}")

    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    # load_from_checkpoint reads hyper_parameters from ckpt and instantiates the model
    # map_location ensures it runs even without GPU
    model = Regression.load_from_checkpoint(ckpt_path, map_location="cpu")

    model.eval()
    return model


def predict_and_inverse(model, dataloader, scaler, device="cuda"):
    """Perform inference and denormalize Z-scores to real values."""
    model = model.to(device)
    y_true_list = []
    y_pred_list = []

    # Get denormalization parameters
    if isinstance(scaler, dict):
        mu = float(scaler["mean"])
        std = float(scaler["std"])
    elif hasattr(scaler, "mean_") and hasattr(
        scaler, "scale_"
    ):  # sklearn standard scaler
        mu = float(scaler.mean_[0])
        std = float(scaler.scale_[0])
    else:  # Custom TargetZScaler
        mu = float(scaler.mean)
        std = float(scaler.std)

    print(f"[Info] Scaler params loaded: Mean={mu:.4f}, Std={std:.4f}")
    print(f"[Info] Starting inference on {device}...")

    with torch.no_grad():
        for batch in dataloader:
            x = batch["x"].to(device).float()
            env = batch["env"].to(device).float()
            yt_z = batch["TE_value"].to(device).float()  # True values (Z-score)

            # Model prediction (output is Z-score)
            yhat_z = model(x, env)

            # Filter NaN/Inf
            mask = torch.isfinite(yt_z) & torch.isfinite(yhat_z)
            if not mask.any():
                continue

            y_true_list.append(yt_z[mask].cpu().numpy())
            y_pred_list.append(yhat_z[mask].cpu().numpy())

    if len(y_true_list) == 0:
        print("[Warning] No valid data found!")
        return np.array([]), np.array([])

    y_true_z = np.concatenate(y_true_list)
    y_pred_z = np.concatenate(y_pred_list)

    # === Key Step: Denormalization ===
    # Real Value = Z-value * std + mean
    y_true_real = y_true_z * std + mu
    y_pred_real = y_pred_z * std + mu

    return y_true_real, y_pred_real


def plot_paper_style_scatter(y_true, y_pred, save_path, xlabel, ylabel):
    """
    Draw Nature-style scatter density plot.
    Color scheme: Teal/Blue-Green.
    """
    # 1. Set style
    sns.set(style="ticks", font_scale=1.4)

    # --- Nature Style Color Scheme ---
    # Background Heatmap: GnBu (Green-Blue) gradient
    HEATMAP_CMAP = "GnBu"

    # Main color for marginals: Deep Teal
    MAIN_COLOR = "#0868AC"

    # Marginal fill color: Lighter Teal
    MARGIN_FILL = "#43A2CA"

    # 2. Calculate statistics
    N = len(y_true)
    if N > 1:
        sse = np.sum((y_true - y_pred) ** 2)
        sst = np.sum((y_true - np.mean(y_true)) ** 2)
        r2 = 1 - sse / sst
        pcc, _ = pearsonr(y_true, y_pred)
    else:
        r2, pcc = 0.0, 0.0

    # 3. Create JointGrid canvas
    g = sns.JointGrid(x=y_true, y=y_pred, height=8, ratio=5, space=0.1)

    # --- Core Plotting ---

    # [Layer 1: Background Heatmap]
    # Use GnBu gradient with PowerNorm to emphasize high density areas
    print("  -> Rendering Heatmap layer (Nature Teal style)...")
    g.plot_joint(
        sns.kdeplot,
        cmap=HEATMAP_CMAP,
        fill=True,
        levels=100,
        thresh=5e-2,  # Keep smooth edges
        norm=mcolors.PowerNorm(gamma=0.4),  # Expand dark areas
        alpha=0.9,
        zorder=0,
    )

    # [Layer 2: Foreground Scatter]
    # Dark dots with very low alpha for texture
    print("  -> Rendering Scatter layer...")
    g.plot_joint(
        sns.scatterplot,
        s=1.5,
        color="#0F2C40",  # Very dark blue, softer than pure black
        alpha=0.05,
        marker=".",
        edgecolor=None,
        zorder=1,
        rasterized=True,
    )

    # [Layer 3: Marginal Distribution]
    # Fresh Teal fill
    g.plot_marginals(sns.kdeplot, color=MAIN_COLOR, fill=True, alpha=0.2, linewidth=1.5)

    # 4. Decorate Plot
    ax = g.ax_joint

    # Dynamic range
    all_vals = np.concatenate([y_true, y_pred])
    if len(all_vals) > 0:
        vmin, vmax = np.min(all_vals), np.max(all_vals)
        margin = (vmax - vmin) * 0.01
        lims = [vmin - margin, vmax + margin]
    else:
        lims = [-1, 1]

    # Diagonal line
    ax.plot(lims, lims, ls="--", c=".4", lw=1.2, zorder=2)

    # Axis settings
    ax.set_xlim(lims)
    ax.set_ylim(lims)
    ax.set_xlabel(xlabel, fontsize=40, fontweight="bold")
    ax.set_ylabel(ylabel, fontsize=40, fontweight="bold")
    ax.tick_params(labelsize=34)

    # 5. Statistics Text
    stats_text = f"$R^2 = {r2:.3f}$\n$PCC = {pcc:.3f}$\n$N = {N}$"
    ax.text(
        0.96,
        0.04,
        stats_text,
        transform=ax.transAxes,
        ha="right",
        va="bottom",
        fontsize=40,
        fontweight="medium",
        color="#2C3E50",
    )

    # Hide marginal axis labels
    g.ax_marg_x.tick_params(labelbottom=False)
    g.ax_marg_y.tick_params(labelleft=False)

    # Save
    g.savefig(save_path, dpi=600, bbox_inches="tight", facecolor="white")
    print(f"[Plot] Saved to: {save_path}")
    plt.close()


def main():
    os.makedirs(CONFIG.output_dir, exist_ok=True)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # ================= Caching Logic =================
    cache_file = os.path.join(CONFIG.output_dir, "inference_cache.npz")

    data_loaded = False
    if os.path.exists(cache_file) and not CONFIG.force_inference:
        print(f"[Cache] Found cache: {cache_file}")
        try:
            data = np.load(cache_file)
            y_true_val = data["y_true_val"]
            y_pred_val = data["y_pred_val"]
            y_true_test = data["y_true_test"]
            y_pred_test = data["y_pred_test"]
            print("[Cache] Data loaded successfully!")
            data_loaded = True
        except Exception as e:
            print(f"[Cache] Error: {e}")

    if not data_loaded:
        print("[Inference] Starting pipeline...")
        dm = TRIMDataModule(
            train_csv=CONFIG.train_csv_for_init,
            val_csv=CONFIG.val_csv,
            test_csv=CONFIG.test_csv,
            env_preproc_path=CONFIG.env_preproc_path,
            target_scaler_path=CONFIG.target_scaler_path,
            batch_size=CONFIG.batch_size,
            num_workers=CONFIG.num_workers,
            max_utr5=CONFIG.max_utr5,
            max_cds_utr3=CONFIG.max_cds_utr3,
        )
        dm.setup(stage="fit")
        scaler = joblib.load(CONFIG.target_scaler_path)
        model = load_model_automatic(CONFIG.best_ckpt_path)

        y_true_val, y_pred_val = predict_and_inverse(
            model, dm.val_dataloader(), scaler, device
        )
        y_true_test, y_pred_test = predict_and_inverse(
            model, dm.test_dataloader(), scaler, device
        )

        np.savez(
            cache_file,
            y_true_val=y_true_val,
            y_pred_val=y_pred_val,
            y_true_test=y_true_test,
            y_pred_test=y_pred_test,
        )

    # ================= Plotting =================
    print("\n[Plotting] Generating Validation Plot...")
    plot_paper_style_scatter(
        y_true_val,
        y_pred_val,
        os.path.join(CONFIG.output_dir, "validate_val.png"),
        CONFIG.xlabel,
        CONFIG.ylabel,
    )

    print("\n[Plotting] Generating Test Plot...")
    plot_paper_style_scatter(
        y_true_test,
        y_pred_test,
        os.path.join(CONFIG.output_dir, "validate_test.png"),
        CONFIG.xlabel,
        CONFIG.ylabel,
    )

    print("\nDone.")


if __name__ == "__main__":
    main()

Using device: cuda
[Cache] Found cache: outputs/Validate_and_Plot/inference_cache.npz
[Cache] Data loaded successfully!

[Plotting] Generating Validation Plot...
  -> Rendering Heatmap layer (Nature Teal style)...
  -> Rendering Scatter layer...
[Plot] Saved to: outputs/Validate_and_Plot/validate_val.png

[Plotting] Generating Test Plot...
  -> Rendering Heatmap layer (Nature Teal style)...
  -> Rendering Scatter layer...
[Plot] Saved to: outputs/Validate_and_Plot/validate_test.png

Done.
