In [None]:
import numpy as np
import pandas as pd
import lightgbm as lgb
import time 
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from typing import Optional, Union, List, Dict
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
import math

###############################################################################
# 1) Exponential Schedules for VESDE
###############################################################################
def exponential_schedule_get_value(t, min_val, max_val):
    return min_val * (max_val / min_val) ** t

def exponential_schedule_get_derivative(t, min_val, max_val):
    val = exponential_schedule_get_value(t, min_val, max_val)
    return val * np.log(max_val / min_val)

def linear_schedule_get_value(t, min_val, max_val):
    return min_val + (max_val - min_val) * t


###############################################################################
# 2) VESDE (Forward)
###############################################################################
def create_vpsde(hyperparam_min=0.01, hyperparam_max=20.0):
    def drift_and_diffusion_vpsde(y, t):
        beta_t = linear_schedule_get_value(t, hyperparam_min, hyperparam_max)
        drift = -0.5 * beta_t * y
        diffusion = np.sqrt(beta_t)
        return drift, diffusion

    def sample_prior_vpsde(shape, seed=None):
        rng = np.random.default_rng(seed)
        return rng.normal(0.0, 1.0, size=shape)

    def get_mean_std_vpsde(y0, t):
        a = hyperparam_min
        b = hyperparam_max - hyperparam_min
        beta_integral = a*t + 0.5*b*(t**2)
        mean = y0 * np.exp(-0.5 * beta_integral)
        std = np.sqrt(1 - np.exp(-beta_integral))
        std = np.broadcast_to(std, y0.shape)
        return mean, std

    return {
        "name": "vpsde",
        "hyperparam_min": hyperparam_min,
        "hyperparam_max": hyperparam_max,
        "drift_and_diffusion": drift_and_diffusion_vpsde,
        "sample_prior": sample_prior_vpsde,
        "get_mean_std": get_mean_std_vpsde
    }


def create_vesde(hyperparam_min=0.01, hyperparam_max=20.0):
    def drift_and_diffusion_vesde(y, t):
        sigma = exponential_schedule_get_value(t, hyperparam_min, hyperparam_max)
        sigma_prime = exponential_schedule_get_derivative(t, hyperparam_min, hyperparam_max)
        drift = np.zeros_like(y)
        diffusion = np.sqrt(2.0 * sigma * sigma_prime)
        return drift, diffusion

    def sample_prior_vesde(shape, seed=None):
        rng = np.random.default_rng(seed)
        return rng.normal(0.0, hyperparam_max, size=shape)

    def get_mean_std_vesde(y0, t):
        sigma_t = exponential_schedule_get_value(t, hyperparam_min, hyperparam_max)
        sigma_0 = exponential_schedule_get_value(0.0, hyperparam_min, hyperparam_max)
        std = np.sqrt(sigma_t**2 - sigma_0**2)
        std = np.broadcast_to(std, y0.shape)
        return y0, std

    def initialize_from_data(y0):
        hyperparam_min_ = 0.01
        if y0.shape[1] == 1:
            max_diff = y0.max() - y0.min()
        else:
            y0_aug = y0[:, np.newaxis, :]
            max_diff = np.max(np.sqrt(np.sum((y0_aug - y0)**2, axis=-1)))
        return hyperparam_min_, max_diff

    return {
        "name": "vesde",
        "hyperparam_min": hyperparam_min,
        "hyperparam_max": hyperparam_max,
        "drift_and_diffusion": drift_and_diffusion_vesde,
        "sample_prior": sample_prior_vesde,
        "get_mean_std": get_mean_std_vesde,
        "init_from_data": initialize_from_data
    }

###############################################################################
# 3) VESDE Reverse: Euler-Maruyama
###############################################################################
def euler_maruyama_update(y, drift, diffusion, dt, rng):
    dW = rng.normal(size=y.shape)
    return y + drift * dt + diffusion * np.sqrt(dt) * dW

def euler_maruyama_integration(sde_dict, y0, t0, t1, n_steps, rng):
    dt = (t1 - t0) / n_steps
    y = y0.copy()
    t = t0
    for _ in range(n_steps):
        t_arr = np.full((y.shape[0], 1), t)
        drift, diffusion = sde_dict["drift_and_diffusion"](y, t_arr)
        y = euler_maruyama_update(y, drift, diffusion, dt, rng)
        t += dt
    return y

def euler_maruyama_reverse_integration(sde_dict, y0, original_t0, rev_total, n_steps, rng, score_fn):
    dt = rev_total / n_steps
    y = y0.copy()
    t_rev = 0.0
    for _ in range(n_steps):
        forward_t = original_t0 - t_rev
        t_arr = np.full((y.shape[0], 1), forward_t)
        drift_fwd, diff_fwd = sde_dict["drift_and_diffusion"](y, t_arr)
        drift_rev = -drift_fwd + (diff_fwd**2) * score_fn(y, t_arr)
        y = euler_maruyama_update(y, drift_rev, diff_fwd, dt, rng)
        t_rev += dt
    return y

def _pndm_step_vpsde(x_t, noise_t, t_now, t_next, hyperparam_min, hyperparam_max):
    def compute_alpha_bar_lin(u):
        a = hyperparam_min
        b = hyperparam_max - hyperparam_min
        val = a*u + 0.5*b*(u**2)
        return np.exp(-val)
    
    ab_t = compute_alpha_bar_lin(t_now)
    ab_s = compute_alpha_bar_lin(t_next)

    def s_sqrt(x):
        return np.sqrt(np.maximum(x, 1e-40))
    
    sqrt_ab_t = s_sqrt(ab_t)
    sqrt_ab_s = s_sqrt(ab_s)
    diff_ab   = ab_s - ab_t

    denom1 = sqrt_ab_t * (sqrt_ab_s + sqrt_ab_t)
    denom2 = sqrt_ab_t * (s_sqrt((1.0-ab_s)*ab_t) + s_sqrt((1.0-ab_t)*ab_s))

    bracket = (x_t / denom1) - (noise_t / denom2)
    return x_t + diff_ab * bracket

###############################################################################
# 4) VESDE Reverse: DPM Solver
###############################################################################
def vesde_reverse_finalDPM(
    sde_dict, 
    y0, 
    original_t0, 
    rev_total, 
    n_steps, 
    rng, 
    noise_fn
):
    y = y0.copy()
    t_rev = 0.0
    dt = rev_total / n_steps

    for _ in range(n_steps):
        forward_t = original_t0 - t_rev
        t_arr = np.full((y.shape[0], 1), forward_t)
        
        sigma_t = exponential_schedule_get_value(forward_t, 
                                                 sde_dict["hyperparam_min"], 
                                                 sde_dict["hyperparam_max"])
        
        next_t = forward_t - dt
        sigma_next = exponential_schedule_get_value(next_t, 
                                                    sde_dict["hyperparam_min"], 
                                                    sde_dict["hyperparam_max"])
        
        h_i = np.log(sigma_next/sigma_t + 1e-40)
        
        epsilon = noise_fn(y, t_arr)
        
        y = y + epsilon * sigma_t * (np.exp(h_i) - 1)
        t_rev += dt
        
    return y  

###############################################################################
# A) Helper to choose sigma(t) from either "karras_linear" or "exponential"
###############################################################################
def get_sigma_value(
    t_value: float,
    schedule: str,
    hyperparam_min: float,
    hyperparam_max: float
) -> float:
    if schedule == "karras_linear":
        return t_value
    else:
        return exponential_schedule_get_value(t_value, hyperparam_min, hyperparam_max)

###############################################################################
# B) Modified DPM sampler that includes Karras-churn
###############################################################################
def vesde_reverse_finalDPM_stoch(
    sde_dict,
    y0,
    original_t0,
    rev_total,
    n_steps,
    rng,
    noise_fn,
    S_churn=10.0,
    S_min=0.05,
    S_max=50.0,
    S_noise=1.0,
    schedule="karras_linear"
):
    # Create descending array of time steps, e.g. from 1.0 down to 0.0
    t_array = np.linspace(original_t0, original_t0 - rev_total, n_steps + 1)
    x = y0.copy()

    hp_min = sde_dict["hyperparam_min"]
    hp_max = sde_dict["hyperparam_max"]

    for i in range(n_steps):
        t_i   = t_array[i]
        t_ip1 = t_array[i+1]

        # 1) Current sigma_i
        sigma_i = get_sigma_value(t_i, schedule, hp_min, hp_max)

        # 2) Possibly churn
        if (sigma_i >= S_min) and (sigma_i <= S_max):
            gamma_i = min(S_churn / n_steps, np.sqrt(2.0) - 1.0)
        else:
            gamma_i = 0.0

        sigma_hat_i = sigma_i * (1.0 + gamma_i)

        # If increasing sigma, add noise
        if sigma_hat_i > sigma_i:
            delta_sig_sq = sigma_hat_i**2 - sigma_i**2
            delta_sig_sq = max(delta_sig_sq, 0.0)
            eps_i = rng.normal(scale=S_noise, size=x.shape)
            x = x + np.sqrt(delta_sig_sq) * eps_i

        # 3) DPM step from sigma_hat_i -> sigma_{i+1}
        sigma_ip1 = get_sigma_value(t_ip1, schedule, hp_min, hp_max)

        epsilon_i = noise_fn(x, np.full((x.shape[0], 1), sigma_hat_i))
        
        # First-order DPM update
        x = x + epsilon_i * (sigma_ip1 - sigma_hat_i)

    return x

def inverse_lambda_mapping(lambda_value, sigma_min, sigma_max):
    return -np.log(lambda_value * sigma_min) / np.log(sigma_max / sigma_min)

import numpy as np

def vesde_reverse_finalDPM_solver2_stoch(
    sde_dict,
    y0,
    original_t0,
    rev_total,
    n_steps,
    rng,
    noise_fn,
    S_churn=10.0,
    S_min=0.05,
    S_max=50.0,
    S_noise=1.0,
    schedule="karras_linear"
):
    hp_min = sde_dict["hyperparam_min"]
    hp_max = sde_dict["hyperparam_max"]

    t_array = np.linspace(original_t0, original_t0 - rev_total, n_steps + 1)
    x = y0.copy()

    for i in range(n_steps):
        t_i   = t_array[i]
        t_ip1 = t_array[i+1]

        
        sigma_i = get_sigma_value(t_i, schedule, hp_min, hp_max)
        if (sigma_i >= S_min) and (sigma_i <= S_max):
            gamma_i = min(S_churn / n_steps, np.sqrt(2.0) - 1.0)
        else:
            gamma_i = 0.0

        sigma_hat_i = sigma_i * (1.0 + gamma_i)

        if sigma_hat_i > sigma_i:
            delta_sig_sq = sigma_hat_i**2 - sigma_i**2
            eps_i = rng.normal(scale=S_noise, size=x.shape)
            x = x + np.sqrt(delta_sig_sq) * eps_i

    
        sigma_ip1 = get_sigma_value(t_ip1, schedule, hp_min, hp_max)

        e_i = noise_fn(x, np.full((x.shape[0], 1), sigma_hat_i))

        half_step = 0.5 * (sigma_ip1 - sigma_hat_i)
        x_mid = x + half_step * e_i

        t_mid = 0.5*(t_i + t_ip1)
        sigma_mid = get_sigma_value(t_mid, schedule, hp_min, hp_max)

        e_mid = noise_fn(x_mid, np.full((x.shape[0], 1), sigma_mid))

        x = x + (sigma_ip1 - sigma_hat_i) * e_mid

    return x

def vesde_reverse_finalDPM_solver3_stoch(
    sde_dict,
    y0,
    original_t0=1.0,
    rev_total=1.0,
    n_steps=20,
    rng=None,
    noise_fn=None,
    S_churn=10.0,
    S_min=0.05,
    S_max=50.0,
    S_noise=1.0,
    schedule="karras_linear"
):
    hp_min = sde_dict["hyperparam_min"]
    hp_max = sde_dict["hyperparam_max"]

    t_array = np.linspace(original_t0, original_t0 - rev_total, n_steps + 1)
    x = y0.copy()

    r1 = 1.0/3.0
    r2 = 2.0/3.0

    for i in range(n_steps):
        t_i   = t_array[i]
        t_ip1 = t_array[i+1]

    
        sigma_i = get_sigma_value(t_i, schedule, hp_min, hp_max)
        if (sigma_i >= S_min) and (sigma_i <= S_max):
            gamma_i = min(S_churn / n_steps, np.sqrt(2.0) - 1.0)
        else:
            gamma_i = 0.0

        sigma_hat_i = sigma_i * (1.0 + gamma_i)

        if sigma_hat_i > sigma_i:
            delta_sig_sq = sigma_hat_i**2 - sigma_i**2
            eps_i = rng.normal(scale=S_noise, size=x.shape)
            x = x + np.sqrt(delta_sig_sq) * eps_i

        
        sigma_ip1 = get_sigma_value(t_ip1, schedule, hp_min, hp_max)
        delta_sigma = (sigma_ip1 - sigma_hat_i)

        e0 = noise_fn(x, np.full((x.shape[0], 1), sigma_hat_i))

    
        sigma_1 = sigma_hat_i + r1*delta_sigma  # partial sigma
        half_step1 = r1 * delta_sigma
        x1 = x + half_step1 * e0

        e1 = noise_fn(x1, np.full((x1.shape[0], 1), sigma_1))

        sigma_2 = sigma_hat_i + r2*delta_sigma
        half_step2 = r2 * delta_sigma
        x2 = x + half_step2 * e1

        e2 = noise_fn(x2, np.full((x2.shape[0], 1), sigma_2))

    
        x = x + delta_sigma * e2

    return x


###############################################################################
# 5) VESDE Reverse: Exponential Integration with Polynomial Noise
###############################################################################
def _trapezoid_integration(func, a, b, n=32):
    flipped = False
    if a > b:
        a, b = b, a
        flipped = True

    x = np.linspace(a, b, n+1)
    y = np.array([func(xi) for xi in x])
    h = (b - a)/n
    integral = 0.5 * h * (y[0] + 2.0*np.sum(y[1:-1]) + y[-1])
    if flipped:
        integral = -integral
    return integral

def _lagrange_basis_poly(tau, node_times, j):
    num = 1.0
    den = 1.0
    t_j = node_times[j]
    for m, t_m in enumerate(node_times):
        if m == j:
            continue
        num *= (tau - t_m)
        den *= (t_j - t_m)
    return num / (den + 1e-40)

def _precompute_coeffs_expint_poly(t_points, sde_dict, poly_order=1, n_substeps=64):
    def trapz_integrate(func, a, b, n):
        xs = np.linspace(a, b, n+1)
        vals = [func(xv) for xv in xs]
        h = (b - a)/n
        return 0.5 * h * (vals[0] + 2.0*sum(vals[1:-1]) + vals[-1])

    N = len(t_points) - 1
    C_list = []

    for i in range(N):
        t_i  = t_points[i]
        t_i1 = t_points[i+1]

        node_times = np.linspace(t_i, t_i1, poly_order+1)

        def sigma_tau(tau):
            return exponential_schedule_get_value(tau,
                                                  sde_dict["hyperparam_min"],
                                                  sde_dict["hyperparam_max"])
        def sigma_prime_tau(tau):
            return exponential_schedule_get_derivative(tau,
                                                       sde_dict["hyperparam_min"],
                                                       sde_dict["hyperparam_max"])
        def g_sq_over_sigma(tau):
            s_val = sigma_tau(tau)
            s_p   = sigma_prime_tau(tau)
            g_sq  = 2.0 * s_val * s_p
            return 0.5 * g_sq * (1.0/(s_val + 1e-40))

        C_i = []
        for j in range(poly_order+1):
            def integrand_j(xv):
                Lj_tau = _lagrange_basis_poly(xv, node_times, j)
                return g_sq_over_sigma(xv) * Lj_tau
            val_j = trapz_integrate(integrand_j, t_i, t_i1, n_substeps)
            C_i.append(val_j)

        C_list.append(C_i)

    return C_list

def vesde_reverse_exponential_integration_polynomial_noise(
    sde_dict, 
    y0,          
    original_t0, 
    rev_total,   
    n_steps,  
    noise_fn,    
    poly_order=1,
    seed=None,
    y_scaler=None
):
    rng = np.random.default_rng(seed if seed is not None else 42)
    if y0.shape[1] != 1:
        raise ValueError("Exponential Integrator (poly) is univariate-only in this demo.")

    t_array = np.linspace(original_t0, original_t0 - rev_total, n_steps+1)
    C_all = _precompute_coeffs_expint_poly(t_array, sde_dict, poly_order=poly_order, n_substeps=64)

    y_values = np.empty((n_steps+1, y0.shape[0], 1))
    y_values[0] = y0.copy()

    eps_buffer = {}
    x_init = y_values[0]
    t_init_val = np.full((x_init.shape[0],1), t_array[0])
    eps_buffer[0] = noise_fn(x_init, t_init_val)

    for i in range(n_steps):
        x_i = y_values[i]
        t_i = t_array[i]
        t_ip1 = t_array[i+1]
        coeffs_i = C_all[i]
        x_new = x_i.copy()

        if (i+1) not in eps_buffer:
            t_next_val = np.full((x_i.shape[0],1), t_ip1)
            eps_buffer[i+1] = noise_fn(x_i, t_next_val)  

        if poly_order == 1:
            if (i+1) not in eps_buffer:
                C0 = coeffs_i[0] + coeffs_i[1]
                eps_i_ = eps_buffer[i]
                x_new += C0 * eps_i_
            else:
                C0, C1 = coeffs_i[:2]
                eps_i_ = eps_buffer[i]
                eps_ip1_ = eps_buffer[i+1]
                x_new += (C0 * eps_i_ + C1 * eps_ip1_)

        elif poly_order == 2:
            if i == 0:
                if (i+1) not in eps_buffer:
                    C0 = coeffs_i[0] + coeffs_i[1]
                    eps_i_ = eps_buffer[i]
                    x_new += C0 * eps_i_
                else:
                    C0, C1 = coeffs_i[:2]
                    eps_i_ = eps_buffer[i]
                    eps_ip1_ = eps_buffer[i+1]
                    x_new += (C0 * eps_i_ + C1 * eps_ip1_)
            else:
                if (i+1) not in eps_buffer:
                    C0 = coeffs_i[0] + coeffs_i[1]
                    eps_i_ = eps_buffer[i]
                    x_new += C0 * eps_i_
                else:
                    C0, C1, C2 = coeffs_i[:3]
                    eps_im1_ = eps_buffer[i-1]
                    eps_i_ = eps_buffer[i]
                    eps_ip1_ = eps_buffer[i+1]
                    x_new += (C0 * eps_i_ + C1 * eps_ip1_ + C2 * eps_im1_)

        else:  
            if i < 2:
                if i == 1:
                    if (i+1) not in eps_buffer:
                        C0 = coeffs_i[0] + coeffs_i[1]
                        eps_i_ = eps_buffer[i]
                        x_new += C0 * eps_i_
                    else:
                        if len(coeffs_i) < 3:
                            C0 = coeffs_i[0] + coeffs_i[1]
                            eps_i_ = eps_buffer[i]
                            x_new += C0 * eps_i_
                        else:
                            C0, C1, C2 = coeffs_i[:3]
                            eps_im1_ = eps_buffer[i-1]
                            eps_i_ = eps_buffer[i]
                            eps_ip1_ = eps_buffer[i+1]
                            x_new += (C0*eps_i_ + C1*eps_ip1_ + C2*eps_im1_)
                else:  
                    if (i+1) not in eps_buffer:
                        C0 = coeffs_i[0] + coeffs_i[1]
                        eps_i_ = eps_buffer[i]
                        x_new += C0 * eps_i_
                    else:
                        C0, C1 = coeffs_i[:2]
                        eps_i_ = eps_buffer[i]
                        eps_ip1_ = eps_buffer[i+1]
                        x_new += (C0*eps_i_ + C1*eps_ip1_)
            else:
                if (i+1) not in eps_buffer:
                    if len(coeffs_i) < 3:
                        C0 = coeffs_i[0] + coeffs_i[1]
                        eps_i_ = eps_buffer[i]
                        x_new += C0 * eps_i_
                    else:
                        C0, C1, C2 = coeffs_i[:3]
                        eps_im1_ = eps_buffer[i-1]
                        eps_i_ = eps_buffer[i]
                        x_new += (C0*eps_i_ + C1*eps_i_ + C2*eps_im1_)
                else:
                    if len(coeffs_i) < 4:
                        C0, C1, C2 = coeffs_i[:3]
                        eps_im1_ = eps_buffer[i-1]
                        eps_i_ = eps_buffer[i]
                        eps_ip1_ = eps_buffer[i+1]
                        x_new += (C0*eps_i_ + C1*eps_ip1_ + C2*eps_im1_)
                    else:
                        C0, C1, C2, C3 = coeffs_i[:4]
                        eps_im2_ = eps_buffer[i-2]
                        eps_im1_ = eps_buffer[i-1]
                        eps_i_ = eps_buffer[i]
                        eps_ip1_ = eps_buffer[i+1]
                        x_new += (
                            C0*eps_i_
                            + C1*eps_ip1_
                            + C2*eps_im1_
                            + C3*eps_im2_
                        )

        y_values[i+1] = x_new
        t_next_arr = np.full((x_new.shape[0],1), t_ip1)
        eps_buffer[i+1] = noise_fn(x_new, t_next_arr)

    result = y_values[-1]
    
    if y_scaler is not None:
        result_flat = result.reshape(-1, 1)
        result_original = y_scaler.inverse_transform(result_flat)
        result = result_original.reshape(result.shape)
        
    return result


###############################################################################
# 7) Stochastic sampler (k-diff approach)
###############################################################################
def vesde_reverse_stochsampler_once(
    sde_dict,
    y0,
    n_steps=50,
    noise_fn=None,
    S_churn=10.0,
    S_min=0.05,
    S_max=50.0,
    S_noise=1.0,
):
    if noise_fn is None:
        raise ValueError("noise_fn must be provided.")
    
    rng = np.random.default_rng()
    t_array = np.linspace(1.0, 0.0, n_steps + 1)
    x = y0.copy()

    def denoiser_fn(x_, t_):
        pred_noise = noise_fn(x_, t_)
        return x_ - t_ * pred_noise

    for i in range(n_steps):
        t_i = t_array[i]
        t_ip1 = t_array[i + 1]

        if S_min <= t_i <= S_max:
            gamma_i = min(S_churn / n_steps, np.sqrt(2.0) - 1.0)
        else:
            gamma_i = 0.0

        eps_i = rng.normal(size=x.shape, scale=S_noise)
        t_hat_i = t_i * (1.0 + gamma_i)
        sigma_delta = np.sqrt(np.maximum(t_hat_i**2 - t_i**2, 0.0))
        hat_x_i = x + sigma_delta * eps_i

        t_hat_arr = np.full((hat_x_i.shape[0],1), t_hat_i)
        d_i = (hat_x_i - denoiser_fn(hat_x_i, t_hat_arr)) / (t_hat_i+1e-40)

        x_next = hat_x_i + (t_ip1 - t_hat_i) * d_i
        if t_ip1 > 1e-14:
            t_ip1_arr = np.full((hat_x_i.shape[0],1), t_ip1)
            d_i_prime = (x_next - denoiser_fn(x_next, t_ip1_arr)) / (t_ip1 + 1e-40)
            x_next = hat_x_i + (t_ip1 - t_hat_i) * (0.5 * d_i + 0.5 * d_i_prime)

        x = x_next
    
    return x

###############################################################################
# 8) Unified SDE Integration Interface (modified to remove n_samples)
###############################################################################
def sdeint(
    sde_dict, 
    y0, 
    t0=0.0, 
    t1=1.0, 
    method="euler", 
    n_steps=20, 
    score_fn=None, 
    seed=None,
    y_scaler=None,
    time_budget=10.0,
    S_churn=10.0,
    S_min=0.05,
    S_max=50.0,
    S_noise=1.0,
    schedule="exponential"
):
  
    rng = np.random.default_rng(seed)

    if t1 >= t0:
        # forward pass
        if method == "euler":
            samples = euler_maruyama_integration(sde_dict, y0, t0, t1, n_steps, rng) 
        elif method == "expint":
            raise ValueError("'expint' is only for reverse-time in this demo.")
        else:
            raise ValueError(f"Unknown method {method} for forward SDE.")
    else:
        # reverse pass
        if score_fn is None:
            raise ValueError("score_fn required for reverse pass.")
        rev_total = t0 - t1
        if method=="euler":
            samples = euler_maruyama_reverse_integration(sde_dict, y0, t0, rev_total, n_steps, rng, score_fn)
        elif method=="dpm":
            samples = vesde_reverse_finalDPM(sde_dict, y0, t0, rev_total, n_steps, rng, score_fn)
        elif method=="pndm":
            samples = pndm_vpsde_integration(sde_dict, y0, n_steps, score_fn, rng, y_scaler=y_scaler)
        elif method=="expint":
            samples = vesde_reverse_exponential_integration_polynomial_noise(
                sde_dict, y0, t0, rev_total, n_steps, score_fn, poly_order=3, y_scaler=y_scaler
            )
        elif method == 'unipc':
            samples = unipc_vpsde_integration(
                sde_dict=sde_dict,
                y0=y0,
                n_steps=n_steps,
                score_fn=score_fn,
                rng=rng,
                p=2
            )
        elif method == "stochsampler":
            samples = vesde_reverse_stochsampler_once(
                sde_dict=sde_dict,
                y0=y0,
                n_steps=n_steps,
                noise_fn=score_fn,
                y_scaler=y_scaler,
                seed=seed
            )
        elif method=="dpm_stoch":
            samples = vesde_reverse_finalDPM_stoch(
                sde_dict,
                y0,
                original_t0=t0,
                rev_total=rev_total,
                n_steps=n_steps,
                rng=rng,
                noise_fn=score_fn,
                S_churn=S_churn,
                S_min=S_min,
                S_max=S_max,
                S_noise=S_noise,
                schedule=schedule
            )
        elif method=="dpm_solver_2_stoch":
            samples = vesde_reverse_finalDPM_solver2_stoch(
                sde_dict,
                y0,
                original_t0=t0,
                rev_total=rev_total,
                n_steps=n_steps,
                rng=rng,
                noise_fn=score_fn,
                S_churn=S_churn,
                S_min=S_min,
                S_max=S_max,
                S_noise=S_noise,
                schedule=schedule
            )
        else:
            raise ValueError(f"Unknown method {method} for reverse SDE")
    
    return samples


###############################################################################
# 9) OU Process
###############################################################################
def create_ou_sde():
    
    def drift_and_diffusion_ou(y, t):
        drift = -y
        diffusion = np.sqrt(2.0)*np.ones_like(y)
        return drift, diffusion
    
    def sample_prior_ou(shape, seed=None):
        rng = np.random.default_rng(seed)
        return rng.normal(0,1,size=shape)

    def get_mean_std_ou(y0,t):
        t_ = t.reshape(-1)
        e_neg_t=np.exp(-t_)
        e_neg_2t=np.exp(-2*t_)
        mean=(y0.T*e_neg_t).T
        std = np.sqrt(1.0 - e_neg_2t)
        std = np.broadcast_to(std.reshape(-1,1), y0.shape)
        return mean,std

    def init_from_data(y0):
        return 0.0,1.0

    return {
        "name":"ou",
        "drift_and_diffusion":drift_and_diffusion_ou,
        "sample_prior":sample_prior_ou,
        "get_mean_std":get_mean_std_ou,
        "init_from_data":init_from_data
    }

###############################################################################
# 10) Underdamped Langevin (used as corrector step in OU)
###############################################################################
def underdamped_langevin_step(y, v, dt, gamma, rng, score_fn, t_local):
    y_new = y + dt*v
    s_val = score_fn(y_new, t_local)
    noise = rng.normal(size=v.shape)
    v_new = v + dt*(-gamma*v - s_val) + np.sqrt(2.0*gamma*dt)*noise
    return y_new, v_new

###############################################################################
# 11) OU Reverse Solvers
###############################################################################
def ou_reverse_step_exponential(y, dt, rng, score_fn, t_local):
    exp_dt=np.exp(dt)
    s_val=score_fn(y, t_local)
    return exp_dt*y + (exp_dt-1.0)*s_val

def ou_reverse_step_midpoint(y, dt, rng, alpha, score_fn, t_n, t_n_alpha):
    exp_alpha_dt=np.exp(alpha*dt)
    s_tn=score_fn(y, t_n)
    y_half=exp_alpha_dt*y + (exp_alpha_dt-1.0)*s_tn

    exp_dt=np.exp(dt)
    s_tn_alpha = score_fn(y_half, t_n_alpha)
    factor= dt*np.exp((1.0-alpha)*dt)
    return exp_dt*y + factor*s_tn_alpha

def ou_reverse_integration_once(
    sde_dict,
    y0,
    original_t0,
    rev_total,
    n_steps,
    rng,
    score_fn,
    gamma,
    corrector_steps,
    dt_corr,
    predictor_method="exponential",
    y_scaler=None 
):
    dt = rev_total/n_steps
    y = y0.copy()
    v = np.zeros_like(y)
    t_rev = 0.0

    for _ in range(n_steps):
        forward_t = original_t0 - t_rev
        t_arr = np.full((y.shape[0], 1), forward_t)

        if predictor_method == "exponential":
            y = ou_reverse_step_exponential(y, dt, rng, score_fn, t_arr)
        elif predictor_method == "random_midpoint":
            alpha = rng.uniform()
            t_n_alpha = np.full((y.shape[0], 1), forward_t + alpha*dt)
            y = ou_reverse_step_midpoint(y, dt, rng, alpha, score_fn, t_arr, t_n_alpha)
        else:
            raise ValueError(f"Unknown predictor method '{predictor_method}'")

        t_rev += dt
        for _corr in range(corrector_steps):
            y, v = underdamped_langevin_step(y, v, dt_corr, gamma, rng, score_fn, t_arr)
    
    
    if y_scaler is not None:
        y_flat = y.reshape(-1, 1)
        y_original = y_scaler.inverse_transform(y_flat)
        y = y_original.reshape(y.shape)
    
    return y



def sdeint_ou_reverse(
    sde_dict,
    y0,
    t0=1.0,
    t1=0.0,
    n_steps=20,
    predictor_method="exponential",
    gamma=0.1,
    corrector_steps=1,
    dt_corr=0.001,
    score_fn=None,
    seed=None,
    y_scaler=None
):
   
    if score_fn is None:
        raise ValueError("score_fn required for reverse OU pass.")
    rng = np.random.default_rng(seed)
    rev_total = t0 - t1
    return ou_reverse_integration_once(
        sde_dict,
        y0,
        t0,
        rev_total,
        n_steps,
        rng,
        score_fn,
        gamma,
        corrector_steps,
        dt_corr,
        predictor_method
    )

###############################################################################
# 12) Data Scaling and Sample Classes
###############################################################################
class ScalerMixedTypes:
    def __init__(self, categorical_columns=None):
        self._scaler = StandardScaler()
        self._cat_cols = categorical_columns if categorical_columns is not None else []
        self._label_encoders = {}
        self._category_maps = {}
        self._is_fitted = False

    def fit(self, X):
        if self._cat_cols is None:
            self._cat_cols = []
        self._num_cols = [col for col in X.columns if col not in self._cat_cols]
        
        if len(self._num_cols) > 0:
            self._scaler.fit(X[self._num_cols])
            
        for col in self._cat_cols:
            unique_cats = X[col].unique()
            self._category_maps[col] = {cat: idx for idx,cat in enumerate(unique_cats)}
            
        self._is_fitted = True
        return self

    def transform(self, X):
        if not self._is_fitted:
            raise ValueError("Scaler not fitted.")
        X_out = X.copy()
        if len(self._num_cols) > 0:
            X_out[self._num_cols] = self._scaler.transform(X[self._num_cols])
        for col in self._cat_cols:
            X_out[col] = X[col].map(lambda x: self._category_maps[col].get(x,-1))
        return X_out

    def fit_transform(self, X):
        return self.fit(X).transform(X)

    def inverse_transform(self, X):
        if not self._is_fitted:
            raise ValueError("Scaler not fitted.")
        X_out = X.copy()
        if len(self._num_cols) > 0:
            X_out[self._num_cols] = self._scaler.inverse_transform(X[self._num_cols])
            
        for col in self._cat_cols:
            reverse_map = {v: k for k,v in self._category_maps[col].items()}
            X_out[col] = X[col].map(lambda x: reverse_map.get(x, 'unknown'))
        return X_out

class Samples:
    def __init__(self, arr):
        if arr.ndim < 2 or arr.ndim > 3:
            raise ValueError("Samples must have 2 or 3 dims.")
        self._samples = arr

    @property
    def shape(self):
        return self._samples.shape

    def to_numpy(self):
        return self._samples


###############################################################################
# 13) Multi-step UniPC for VP-SDE (one sample version)
###############################################################################
def _compute_alpha_bar_lin(u, hyper_min, hyper_max):
    a = hyper_min
    b = hyper_max - hyper_min
    val = a*u + 0.5*b*(u**2)
    return np.exp(-val)

def unipc_vpsde_integration(
    sde_dict,
    y0,
    n_steps,
    score_fn,
    rng=None,
    p=2,
    y_scaler=None
):
    
    hyper_min = sde_dict["hyperparam_min"]
    hyper_max = sde_dict["hyperparam_max"]
    
    current_y0 = y0.copy()
    
    dt = 1.0
    t_now = float(n_steps)

    def alpha_bar_fn(step_idx):
        return _compute_alpha_bar_lin(step_idx / n_steps, hyper_min, hyper_max)

    def alpha_fn(step_idx):
        return np.sqrt(alpha_bar_fn(step_idx))

    def sigma_fn(step_idx):
        return np.sqrt(1.0 - alpha_bar_fn(step_idx))

    def lambda_fn(step_idx):
        return -np.log(alpha_bar_fn(step_idx) + 1e-40)

    x = current_y0.copy()
    eps_buffer = {}

    ab_init = alpha_bar_fn(t_now)
    score_init = score_fn(x, np.full((x.shape[0],1), t_now/n_steps))
    eps_buffer[t_now] = score_init

    for step_i in range(n_steps):
        t_im1 = t_now
        t_i = t_now - dt

        p_i = min(p, step_i+1)
        
        lam_i = lambda_fn(t_i)
        lam_im1 = lambda_fn(t_im1)
        h_i = lam_i - lam_im1

        alpha_im1 = alpha_fn(t_im1)
        alpha_i = alpha_fn(t_i)
        sigma_i = sigma_fn(t_i)
        e_hi_minus_1 = np.exp(h_i) - 1.0
        eps_im1 = eps_buffer[t_im1]
        
        x_tilde_1 = (alpha_i/alpha_im1)*x - sigma_i*e_hi_minus_1*eps_im1

        if p_i == 1:
            x_c = x_tilde_1
        else:
            t_im2 = t_im1 + dt
            
            if t_im2 in eps_buffer:
                eps_im2 = eps_buffer[t_im2]
            else:
                eps_im2 = eps_im1

            x_tilde = x_tilde_1
            eps_i_val = score_fn(x_tilde, np.full((x.shape[0],1), t_i/n_steps))

            D2 = eps_i_val - eps_im1
            lam_im2 = lambda_fn(t_im2)
            r_1 = (lam_im2 - lam_i)/ (h_i + 1e-40)
            D1 = eps_im2 - eps_im1

            B_h = e_hi_minus_1
            R_mat = np.array([
                [1, 1],
                [r_1*h_i, h_i]
            ], dtype=np.float64)

            phi_vec = compute_phi_p(2, h_i)
            inv_R = np.linalg.inv(R_mat)
            a_vec = (inv_R @ phi_vec) / B_h

            c_term = a_vec[0]*(D1/(r_1+1e-40)) + a_vec[1]*D2
            x_c = x_tilde_1 - sigma_i*B_h*c_term

        x = x_c

        eps_i_corrected = score_fn(x, np.full((x.shape[0],1), t_i/n_steps))
        eps_buffer[t_i] = eps_i_corrected
        t_now = t_i

    return x.reshape(-1, 1)

def compute_phi_p(p, z):
    def varphi_0(z): 
        return np.exp(z)
    
    def varphi_next(vn, n, z):
        return (vn - 1.0/math.factorial(n))/z
    
    def compute_varphis_up_to(p, z):
        out = [varphi_0(z)]
        for i in range(p):
            out.append(varphi_next(out[-1], i, z))
        return out

    v_all = compute_varphis_up_to(p+1, z)
    phi = []
    for n in range(1, p+1):
        f_n = (z**n)*math.factorial(n)*v_all[n] 
        phi.append(f_n)
    return np.array(phi)

###############################################################################
# 14) LightGBM-based Score/Noise Model
###############################################################################
def make_training_data(X, y, sde_dict, n_repeats=10, eval_percent=0.1, seed=None):
    if eval_percent>0:
        X_tr, X_val, y_tr, y_val = train_test_split(X,y, test_size=eval_percent, random_state=seed)
    else:
        X_tr, y_tr = X,y
        X_val, y_val = None, None

    X_tr = np.tile(X_tr, (n_repeats,1))
    y_tr = np.tile(y_tr, (n_repeats,1))

    rng= np.random.default_rng(seed)
    t_tr= rng.uniform(0,1,size=(X_tr.shape[0],1))
    z_tr= rng.normal(size=y_tr.shape)

    m_tr, s_tr= sde_dict["get_mean_std"](y_tr, t_tr)
    pert_y_tr= m_tr+ s_tr*z_tr
    feats_tr= np.hstack([ pert_y_tr, X_tr, t_tr])
    labels_tr= -z_tr

    feats_val, labels_val= None, None
    if X_val is not None and y_val is not None:
        t_val= rng.uniform(0,1,size=(X_val.shape[0],1))
        z_val= rng.normal(size=y_val.shape)
        m_val, s_val= sde_dict["get_mean_std"](y_val, t_val)
        pert_y_val= m_val+ s_val*z_val
        feats_val= np.hstack([ pert_y_val, X_val, t_val])
        labels_val= -z_val
    return feats_tr, labels_tr, feats_val, labels_val

def train_lightgbm_model(X, y, X_val=None, y_val=None, seed=None, **lgbm_args):
    model = lgb.LGBMRegressor(random_state=seed, verbose=-1, **lgbm_args)
    if X_val is not None and y_val is not None:
        model.fit(X, y, eval_set=[(X_val,y_val)])
    else:
        model.fit(X, y)
    return model

def predict_score_lightgbm(models, sde_dict, y, X, t):

    if y.ndim == 1:
        y = y.reshape(-1, 1)
    
    if X.shape[0] != y.shape[0]:
        factor = y.shape[0] // X.shape[0]
        X = np.tile(X, (factor, 1))

    feats = np.hstack([y, X, t])
    _, s = sde_dict["get_mean_std"](y, t)
    s = np.maximum(s, 1e-6)
    
    scores = []
    for i, mdl in enumerate(models):
        pred_i = mdl.predict(feats)
        scores.append(pred_i / s[:, i])
    return np.column_stack(scores)

def predict_noise_lightgbm_batched(models, sde_dict, y, X, t, batch_size=1000):
   
    if y.ndim > 2:
        y = y.reshape(-1, y.shape[-1])
    if t.ndim == 1:
        t = t.reshape(-1, 1)
        
    if isinstance(X, pd.DataFrame):
        X = X.to_numpy()
    
    if X.shape[0] != y.shape[0]:
        repeat_factor = y.shape[0] // X.shape[0]
        if y.shape[0] % X.shape[0] != 0:
            repeat_factor += 1
        X = np.tile(X, (repeat_factor, 1))[:y.shape[0]]
    
    if t.shape[0] != y.shape[0]:
        t = np.broadcast_to(t, (y.shape[0], t.shape[1]))

    n_samples = y.shape[0]
    n_outputs = len(models)
    noises = np.zeros((n_samples, n_outputs))
    
    for i in range(0, n_samples, batch_size):
        batch_end = min(i + batch_size, n_samples)
        batch_y = y[i:batch_end]
        batch_X = X[i:batch_end]
        batch_t = t[i:batch_end]
        
        feats = np.hstack([batch_y, batch_X, batch_t])
        
        for j, mdl in enumerate(models):
            pred_j = mdl.predict(feats)
            noises[i:batch_end, j] = -pred_j
            
    return noises

def fit_treeffuser(X, y, sde_dict, n_repeats=10, eval_percent=0.1, seed=None, **lgbm_args):
    feats_tr, labels_tr, feats_val, labels_val= make_training_data(
        X,y, sde_dict, n_repeats, eval_percent, seed
    )
    y_dim= y.shape[1]
    models=[]
    for i in range(y_dim):
        y_val_i= labels_val[:,i] if labels_val is not None else None
        model_i= train_lightgbm_model(
            feats_tr, labels_tr[:,i],
            feats_val, y_val_i,
            seed=seed,
            **lgbm_args
        )
        models.append(model_i)
    return models


###############################################################################
# 15) The 8 sampling functions (now time-based instead of n_samples)
###############################################################################

def sample_treeffuser(
    X,
    sde_dict,
    models,
    x_scaler,
    y_scaler,
    time_budget=10.0, 
    n_steps=50,
    seed=None
):
    start_t = time.time()
    rng = np.random.default_rng(seed)
    all_samples = []

    def score_fn(y_local, t_local):
        return predict_score_lightgbm(models, sde_dict, y_local, X, t_local)

    sample_idx = 0
    while True:
        elapsed = time.time() - start_t
        if elapsed >= time_budget:
            break

        shape = (1, X.shape[0], 1)
        y0 = sde_dict["sample_prior"](shape, seed=None)  # next random draw
        y0_flat = y0.reshape(shape[0]*shape[1], shape[2])

        single_sample = sdeint(
            sde_dict,
            y0_flat,
            t0=1.0,
            t1=0.0,
            method="euler",
            n_steps=n_steps,
            score_fn=score_fn,
            seed=rng.integers(1e9)  
        )
        single_sample_3d = single_sample.reshape((1, X.shape[0], 1))

        if y_scaler:
            single_sample_3d[0] = y_scaler.inverse_transform(single_sample_3d[0])

        all_samples.append(single_sample_3d)
        sample_idx += 1

    if len(all_samples)==0:
        return np.empty((0, X.shape[0], 1))

    return np.concatenate(all_samples, axis=0)

def sample_treeffuser_unipc(
    X,
    sde_dict,
    models,
    x_scaler,
    y_scaler,
    time_budget=10.0,
    n_steps=20,
    seed=None,
    p=2
):
   
    start_t = time.time()
    rng = np.random.default_rng(seed)
    all_samples = []

    def score_fn(y_local, t_local):
        return predict_noise_lightgbm_batched(models, sde_dict, y_local, X, t_local)
    
    sample_idx = 0
    while True:
        elapsed = time.time() - start_t
        if elapsed >= time_budget:
            break

        shape = (1, X.shape[0], 1)
        y0 = sde_dict["sample_prior"](shape, seed=None)
        y0_flat = y0.reshape(shape[0]*shape[1], shape[2])

        single_sample = unipc_vpsde_integration(
            sde_dict,
            y0_flat,
            n_steps=n_steps,
            score_fn=score_fn,
            rng=rng,
            p=p
        )
        
        single_sample_3d = single_sample.reshape(1, X.shape[0], 1)
        if y_scaler:
            single_sample_3d[0] = y_scaler.inverse_transform(single_sample_3d[0])
        all_samples.append(single_sample_3d)
        sample_idx += 1

    if len(all_samples) == 0:
        return np.empty((0, X.shape[0], 1))
    
    return np.concatenate(all_samples, axis=0)


def sample_treeffuser_dpm(
    X,
    sde_dict,
    models,
    x_scaler,
    y_scaler,
    time_budget=10.0,
    n_steps=20,
    seed=None
):
    
    start_t = time.time()
    rng = np.random.default_rng(seed)
    all_samples = []

    def score_fn(y_local, t_local):
        return predict_noise_lightgbm_batched(models, sde_dict, y_local, X, t_local)

    sample_idx = 0
    while True:
        elapsed = time.time() - start_t
        if elapsed >= time_budget:
            break

        shape = (1, X.shape[0], 1)
        y0 = sde_dict["sample_prior"](shape, seed=None)
        y0_flat = y0.reshape(shape[0]*shape[1], shape[2])

        single_sample = sdeint(
            sde_dict,
            y0_flat,
            t0=1.0,
            t1=0.0,
            method="dpm",
            n_steps=n_steps,
            score_fn=score_fn,
            seed=rng.integers(1e9)
        )
        single_sample_3d = single_sample.reshape((1, X.shape[0], 1))
        if y_scaler:
            single_sample_3d[0] = y_scaler.inverse_transform(single_sample_3d[0])
        all_samples.append(single_sample_3d)
        sample_idx += 1

    if len(all_samples)==0:
        return np.empty((0, X.shape[0], 1))

    return np.concatenate(all_samples, axis=0)

def sample_treeffuser_dpm_stoch(
    X,
    sde_dict,
    models,
    x_scaler,
    y_scaler,
    time_budget=10.0,
    n_steps=20,
    seed=None,
    S_churn=10.0,
    S_min=0.05,
    S_max=50.0,
    S_noise=1.0,
    schedule="exponential"
):
   
    start_t = time.time()
    rng = np.random.default_rng(seed)
    all_samples = []

    def score_fn(y_local, t_local):
        return predict_noise_lightgbm_batched(models, sde_dict, y_local, X, t_local)

    sample_idx = 0
    while True:
        elapsed = time.time() - start_t
        if elapsed >= time_budget:
            break

        shape = (1, X.shape[0], 1)
        y0 = sde_dict["sample_prior"](shape, seed=rng.integers(1e9))
        y0_flat = y0.reshape(shape[0]*shape[1], shape[2])

        single_sample = sdeint(
            sde_dict,
            y0_flat,
            t0=1.0,
            t1=0.0,
            method="dpm_stoch",  
            n_steps=n_steps,
            score_fn=score_fn,
            seed=rng.integers(1e9),
            S_churn=S_churn,
            S_min=S_min,
            S_max=S_max,
            S_noise=S_noise,
            schedule=schedule
        )
        single_sample_3d = single_sample.reshape((1, X.shape[0], 1))

        if y_scaler:
            single_sample_3d[0] = y_scaler.inverse_transform(single_sample_3d[0])

        all_samples.append(single_sample_3d)
        sample_idx += 1

    if len(all_samples) == 0:
        return np.empty((0, X.shape[0], 1))

    return np.concatenate(all_samples, axis=0)

def sample_treeffuser_dpm_solver2_stoch(
    X,
    sde_dict,
    models,
    x_scaler,
    y_scaler,
    time_budget=10.0,
    n_steps=5,
    seed=None,
    S_churn=10.0, S_min=0.05, S_max=50.0, S_noise=1.0,
    schedule = 'exponential'
):
    start_t = time.time()
    rng = np.random.default_rng(seed)
    all_samples = []

    def score_fn(y_local, sigma_local):
        return predict_noise_lightgbm_batched(models, sde_dict, y_local, X, sigma_local)

    while True:
        if time.time() - start_t >= time_budget:
            break

        shape = (1, X.shape[0], 1)
        y0 = sde_dict["sample_prior"](shape, seed=rng.integers(1e9)) 
        y0_flat = y0.reshape(-1, 1)

        single_sample = sdeint(
            sde_dict,
            y0_flat,
            t0 = 1.0,
            t1 = 0.0,
            method = "dpm_solver_2_stoch",
            n_steps=n_steps,
            score_fn=score_fn,
            seed=rng.integers(1e9),
            S_churn=S_churn,
            S_min=S_min,
            S_max=S_max,
            S_noise=S_noise,
            schedule=schedule
        )
        single_sample_3d = single_sample.reshape((1, X.shape[0], 1))

        if y_scaler:
            single_sample_3d[0] = y_scaler.inverse_transform(single_sample_3d[0])

        all_samples.append(single_sample_3d)

    if len(all_samples)==0:
        return np.empty((0, X.shape[0], 1))
    return np.concatenate(all_samples, axis=0)


def sample_treeffuser_dpm_solver3_stoch(
    X,
    sde_dict,
    models,
    x_scaler,
    y_scaler,
    time_budget=10.0,
    n_steps=5,
    seed=None,
    S_churn=10.0, 
    S_min=0.05, 
    S_max=50.0, 
    S_noise=1.0,
    schedule="karras_linear"
):
  
    start_t = time.time()
    rng = np.random.default_rng(seed)
    all_samples = []

    def noise_fn(y_local, sigma_local):
        return predict_noise_lightgbm_batched(models, sde_dict, y_local, X, sigma_local)

    while True:
        if time.time() - start_t >= time_budget:
            break

        shape = (1, X.shape[0], 1)
        y0 = sde_dict["sample_prior"](shape, seed=rng.integers(1e9))
        y0_flat = y0.reshape(-1, 1)

        single_sample = vesde_reverse_finalDPM_solver3_stoch(
            sde_dict,
            y0_flat,
            original_t0=1.0,
            rev_total=1.0,
            n_steps=n_steps,
            rng=rng,
            noise_fn=noise_fn,
            S_churn=S_churn,
            S_min=S_min,
            S_max=S_max,
            S_noise=S_noise,
            schedule=schedule
        )

        single_sample_3d = single_sample.reshape((1, X.shape[0], 1))
        if y_scaler:
            single_sample_3d[0] = y_scaler.inverse_transform(single_sample_3d[0])

        all_samples.append(single_sample_3d)

    if len(all_samples) == 0:
        return np.empty((0, X.shape[0], 1))
    return np.concatenate(all_samples, axis=0)


def sample_treeffuser_expint(
    X,
    sde_dict,
    models,
    x_scaler,
    y_scaler,
    time_budget=10.0,
    n_steps=20,
    seed=None
):
   
    start_t = time.time()
    rng = np.random.default_rng(seed)
    all_samples = []

    def noise_fn(y_local, t_local):
        return predict_noise_lightgbm_batched(models, sde_dict, y_local, X, t_local)

    while True:
        if time.time() - start_t >= time_budget:
            break

        shape = (1, X.shape[0], 1)
        y0 = sde_dict["sample_prior"](shape, seed=None)
        y0_flat = y0.reshape(shape[0]*shape[1], shape[2])

        single_sample = vesde_reverse_exponential_integration_polynomial_noise(
            sde_dict,
            y0_flat,
            original_t0=1.0,
            rev_total=1.0,
            n_steps=n_steps,
            noise_fn=noise_fn,
            poly_order=3,
            seed=rng.integers(1e9)
        )
        single_sample_3d = single_sample.reshape((1, X.shape[0], 1))
        if y_scaler:
            single_sample_3d[0] = y_scaler.inverse_transform(single_sample_3d[0])
        all_samples.append(single_sample_3d)

    if len(all_samples)==0:
        return np.empty((0, X.shape[0], 1))

    return np.concatenate(all_samples, axis=0)


def sample_treeffuser_stochsampler(
    X,
    sde_dict,
    models,
    x_scaler,
    y_scaler,
    time_budget=10.0,
    n_steps=20,
    seed=None,
    S_churn=10.0,
    S_min=0.05,
    S_max=50.0,
    S_noise=1.0
):
    
    start_t = time.time()
    rng = np.random.default_rng(seed)
    all_samples = []

    def noise_fn(y_local, t_local):
        return predict_noise_lightgbm_batched(models, sde_dict, y_local, X, t_local)

    while True:
        if time.time() - start_t >= time_budget:
            break

        shape = (1, X.shape[0], 1)
        y0 = sde_dict["sample_prior"](shape, seed=None)
        y0_flat = y0.reshape(shape[0]*shape[1], shape[2])

        single_sample = vesde_reverse_stochsampler_once(
            sde_dict,
            y0,
            n_steps=n_steps,
            noise_fn=noise_fn,
            S_churn=S_churn,
            S_min=S_min,
            S_max=S_max,
            S_noise=S_noise
        )
        single_sample_3d = single_sample.reshape((1, X.shape[0], 1))
        if y_scaler:
            single_sample_3d[0] = y_scaler.inverse_transform(single_sample_3d[0])
        all_samples.append(single_sample_3d)

    if len(all_samples)==0:
        return np.empty((0, X.shape[0], 1))

    return np.concatenate(all_samples, axis=0)

import time

def pndm_vpsde_integration(
    sde_dict,
    y0,
    n_steps,
    noise_fn,  
    rng,
    rk4_warmup=4
):
    x = y0.reshape(-1, 1).copy()
    hyper_min = sde_dict["hyperparam_min"]
    hyper_max = sde_dict["hyperparam_max"]

    dt = 1.0
    t_now = float(n_steps)
    
    e_buffer = {}
    
    def do_step(x_in, noise_in, t_in, t_out):
        return _pndm_step_vpsde(
            x_in, noise_in, t_in / n_steps, t_out / n_steps, hyper_min, hyper_max
        )
    
    warmup_steps = min(rk4_warmup, n_steps)

    for _ in range(warmup_steps):
        if t_now <= 0:
            break
        
        e1 = noise_fn(x, t_now / n_steps)
        half_t = t_now - 0.5 * dt
        x1 = do_step(x, e1, t_now, half_t)

        e2 = noise_fn(x1, half_t / n_steps)
        x2 = do_step(x, e2, t_now, half_t)

        e3 = noise_fn(x2, half_t / n_steps)
        x3 = do_step(x, e3, t_now, t_now - dt)

        e4 = noise_fn(x3, (t_now - dt) / n_steps)

        e_prime = (e1 + 2 * e2 + 2 * e3 + e4) / 6.0

        x_new = do_step(x, e_prime, t_now, t_now - dt)
        x = x_new

        t_now -= dt
        e_buffer[t_now] = e_prime
    
    remain = int(t_now)
    while remain > 0:
        e_t = noise_fn(x, t_now / n_steps)
        e_tm1 = e_buffer.get(t_now - 1, e_t)
        e_tm2 = e_buffer.get(t_now - 2, e_t)
        e_tm3 = e_buffer.get(t_now - 3, e_t)

        e_prime = (1.0 / 24.0) * (
            55.0 * e_t - 59.0 * e_tm1 + 37.0 * e_tm2 - 9.0 * e_tm3
        )

        x_new = do_step(x, e_prime, t_now, t_now - dt)
        x = x_new
        
        t_now -= dt
        remain -= 1
        e_buffer[t_now] = e_prime
    
    return x


def sample_treeffuser_pndm(
    X,
    sde_dict,
    models,
    x_scaler,
    y_scaler,
    time_budget=10.0,
    n_steps=20,
    seed=None,
    rk4_warmup=4
):
   
    start_t = time.time()
    rng = np.random.default_rng(seed)
    all_samples = []

    def noise_fn(y_local, t_local):
        if np.isscalar(t_local):  
            t_local = np.array([[t_local]])
        elif t_local.ndim == 1:  
            t_local = t_local.reshape(-1, 1)
        
        return predict_noise_lightgbm_batched(models, sde_dict, y_local, X, t_local)

    while True:
        if time.time() - start_t >= time_budget:
            break

        shape = (1, X.shape[0], 1)
        y0 = sde_dict["sample_prior"](shape, seed=None)
        y0_flat = y0.reshape(-1, 1)

        single_sample = pndm_vpsde_integration(
            sde_dict,
            y0_flat,
            n_steps=n_steps,
            noise_fn=noise_fn,
            rng=rng,
            rk4_warmup=rk4_warmup
        )
        single_sample_3d = single_sample.reshape((1, X.shape[0], 1))
        
        if y_scaler is not None:
            single_sample_3d[0] = y_scaler.inverse_transform(single_sample_3d[0])
            
        all_samples.append(single_sample_3d)

    if len(all_samples) == 0:
        return np.empty((0, X.shape[0], 1))

    return np.concatenate(all_samples, axis=0)



def sample_treeffuser_ou_exponential(
    X,
    sde_dict,
    models,
    x_scaler,
    y_scaler,
    time_budget=10.0,
    n_steps=20,
    gamma=0.1,
    corrector_steps=1,
    dt_corr=0.001,
    seed=None
):
    
    start_t = time.time()
    rng = np.random.default_rng(seed)
    all_samples = []

    def score_fn(y_local, t_local):
        return predict_score_lightgbm(models, sde_dict, y_local, X, t_local)

    while True:
        if time.time() - start_t >= time_budget:
            break

        shape = (1, X.shape[0], 1)
        y0 = sde_dict["sample_prior"](shape, seed=None)
        y0_flat = y0.reshape(-1, 1)

        single_sample = ou_reverse_integration_once(
            sde_dict,
            y0_flat,
            original_t0=1.0,
            rev_total=1.0,
            n_steps=n_steps,
            rng=rng,
            score_fn=score_fn,
            gamma=gamma,
            corrector_steps=corrector_steps,
            dt_corr=dt_corr,
            predictor_method="exponential"
        )
        single_sample_3d = single_sample.reshape((1, X.shape[0], 1))
        if y_scaler:
            single_sample_3d[0] = y_scaler.inverse_transform(single_sample_3d[0])
        all_samples.append(single_sample_3d)

    if len(all_samples)==0:
        return np.empty((0, X.shape[0], 1))

    return np.concatenate(all_samples, axis=0)

def sample_treeffuser_ou_midpoint(
    X,
    sde_dict,
    models,
    x_scaler,
    y_scaler,
    time_budget=10.0,
    n_steps=20,
    gamma=0.1,
    corrector_steps=1,
    dt_corr=0.001,
    seed=None
):
    
    start_t = time.time()
    rng = np.random.default_rng(seed)
    all_samples = []

    def score_fn(y_local, t_local):
        return predict_score_lightgbm(models, sde_dict, y_local, X, t_local)

    while True:
        if time.time() - start_t >= time_budget:
            break

        shape = (1, X.shape[0], 1)
        y0 = sde_dict["sample_prior"](shape, seed=None)
        y0_flat = y0.reshape(-1, 1)

        single_sample = ou_reverse_integration_once(
            sde_dict,
            y0_flat,
            original_t0=1.0,
            rev_total=1.0,
            n_steps=n_steps,
            rng=rng,
            score_fn=score_fn,
            gamma=gamma,
            corrector_steps=corrector_steps,
            dt_corr=dt_corr,
            predictor_method="random_midpoint"
        )
        single_sample_3d = single_sample.reshape((1, X.shape[0], 1))
        if y_scaler:
            single_sample_3d[0] = y_scaler.inverse_transform(single_sample_3d[0])
        all_samples.append(single_sample_3d)

    if len(all_samples)==0:
        return np.empty((0, X.shape[0], 1))

    return np.concatenate(all_samples, axis=0)

In [None]:
import numpy as np
import pandas as pd
import time
import matplotlib.pyplot as plt

################################################################################
# 1) Import the data and basic transformations
################################################################################
calendar_df = pd.read_csv("calendar.csv")
sales_train_df = pd.read_csv("sales_train_validation.csv")
sell_prices_df = pd.read_csv("sell_prices.csv")

calendar_df["date"] = pd.to_datetime(calendar_df["date"])
calendar_df["day"] = calendar_df["date"].dt.day
calendar_df["month"] = calendar_df["date"].dt.month
calendar_df["year"] = calendar_df["date"].dt.year

def convert_sales_data_from_wide_to_long(sales_df_wide):
    index_vars = ["item_id", "dept_id", "cat_id", "store_id", "state_id"]
    sales_df_long = pd.wide_to_long(
        sales_df_wide.iloc[:3000, 1:], 
        i=index_vars,
        j="day",
        stubnames=["d"],
        sep="_",
    ).reset_index()
    sales_df_long = sales_df_long.rename(columns={"d": "sales", "day": "d"})
    sales_df_long["d"] = "d_" + sales_df_long["d"].astype("str")
    return sales_df_long

sales_train_df_long = convert_sales_data_from_wide_to_long(sales_train_df)
print("Rows after converting to long format:", sales_train_df_long.shape[0])

sales_train_df_long["day_number"] = sales_train_df_long["d"].str.extract("(\d+)").astype(int)
data = sales_train_df_long[sales_train_df_long["day_number"] <= 365].copy()
print("Rows after filtering for day_number <= 365:", data.shape[0])

data_index_vars = ["item_id", "dept_id", "cat_id", "store_id", "state_id"]
data.sort_values(data_index_vars + ["day_number"], inplace=True)

n_lags = 30
for lag in range(1, n_lags + 1):
    data[f"sales_lag_{lag}"] = data.groupby(data_index_vars)["sales"].shift(lag)

print("Rows after creating lag features (before dropping NaNs):", data.shape[0])
data = data.dropna()
print("Rows after dropping NaNs:", data.shape[0])

data = data.merge(calendar_df).merge(sell_prices_df)
print("Rows after merging with calendar and prices:", data.shape[0])

categorical_columns = [
    "item_id", "dept_id", "cat_id", "store_id", "state_id",
    "d", "wm_yr_wk", "weekday", "event_name_1", "event_type_1",
    "event_name_2", "event_type_2", "snap_CA", "snap_TX", "snap_WI"
]
for c in categorical_columns:
    if c in data.columns:
        data[c] = data[c].astype("category")

train_cutoff = 300
is_train = data["day_number"] <= train_cutoff
train_data = data[is_train].copy()
test_data = data[~is_train].copy()

print("Rows in train_data:", train_data.shape[0])
print("Rows in test_data:", test_data.shape[0])

y_name = "sales"
x_names = [
    col for col in data.columns
    if col not in [y_name, "day_number", "date"] 
]

X_train = train_data[x_names]
y_train = train_data[y_name].values.reshape(-1, 1)
X_test = test_data[x_names]
y_test = test_data[y_name].values.reshape(-1, 1)

print("Final rows in X_train:", X_train.shape[0])
print("Final rows in X_test:", X_test.shape[0])

from sklearn.preprocessing import StandardScaler
x_scaler = StandardScaler()
y_scaler = StandardScaler()

X_train_sc = x_scaler.fit_transform(X_train.select_dtypes(exclude=["category"]))
X_test_sc = x_scaler.transform(X_test.select_dtypes(exclude=["category"]))

Y_train_sc = y_scaler.fit_transform(y_train)
Y_test_sc = y_scaler.transform(y_test)

# %%
sde_dict_vesde = create_vesde(0.1, 1.0) 
models_vesde = fit_treeffuser(
    X_train_sc, 
    Y_train_sc, 
    sde_dict_vesde,
    n_repeats=10,
    eval_percent=0.1,
    seed=42,
    n_estimators=100,
    learning_rate=0.1,
    max_depth=8
)

sde_dict_vpsde = create_vpsde(hyperparam_min=0.1, hyperparam_max=1.0)
models_vpsde = fit_treeffuser(
    X_train_sc,
    Y_train_sc,
    sde_dict_vpsde,
    n_repeats=10,
    eval_percent=0.1,
    seed=42,
    n_estimators=100,
    learning_rate=0.1,
    max_depth=8
)

sde_dict_ou = create_ou_sde()
models_ou = fit_treeffuser(
    X_train_sc,
    Y_train_sc,
    sde_dict_ou,
    n_repeats=10,
    eval_percent=0.1,
    seed=42,
    n_estimators=100,
    learning_rate=0.1,
    max_depth=8
)


# %%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import r2_score, mean_squared_error
import properscoring as ps
from typing import List, Dict, Callable
import time
import matplotlib.cm as cm

def newsvendor_utility(y_true, quantity_ordered, prices, stocking_cost):
    
    return prices * np.minimum(y_true, quantity_ordered) - stocking_cost * quantity_ordered

def newsvendor_optimal_quantity(y_samples, prices, stocking_cost):
  
    alpha = (prices - stocking_cost) / prices
    alpha = np.maximum(alpha, 0.0)

    res = []
    for i in range(y_samples.shape[1]):
        q_i = np.quantile(y_samples[:, i], alpha[i])
        res.append(q_i)
    return np.array(res)

def calculate_crps(y_true: np.ndarray, samples: np.ndarray) -> float:
    
    y_true = y_true.ravel()
    
    if samples.shape[0] < samples.shape[1]:  
        samples = samples.T  
    
    crps_values = []
    for i in range(len(y_true)):
        crps_values.append(ps.crps_ensemble(y_true[i], samples[i]))
    
    return np.mean(crps_values)

def evaluate_samples(y_true: np.ndarray, samples: np.ndarray, prices: np.ndarray, stocking_cost: np.ndarray) -> Dict[str, float]:
   
    y_true = y_true.ravel()
    samples_2d = samples.squeeze(axis=2) if samples.ndim == 3 else samples
    
    print("In evaluate_samples:")
    print(f"y_true shape after ravel: {y_true.shape}")
    print(f"samples_2d shape: {samples_2d.shape}")
    
    y_pred_mean = np.mean(samples_2d, axis=0)
    optimal_q = newsvendor_optimal_quantity(samples_2d, prices, stocking_cost)
    profit_arr = newsvendor_utility(y_true, optimal_q, prices, stocking_cost)
    total_profit = profit_arr.sum()
    
    return {
        'R2': r2_score(y_true, y_pred_mean),
        'RMSE': np.sqrt(mean_squared_error(y_true, y_pred_mean)),
        'CRPS': calculate_crps(y_true, samples_2d),
        'Profit': total_profit
    }

sampling_methods = {
    'DPM': (sample_treeffuser_dpm, 'vesde'),
    'Euler': (sample_treeffuser, 'vesde'),
    'PNDM': (sample_treeffuser_pndm, 'vpsde'),
    'ExpInt': (sample_treeffuser_expint, 'vesde'),
    'UniPC': (sample_treeffuser_unipc, 'vpsde'),
    'StochSampler': (sample_treeffuser_stochsampler, 'vesde'),
    'DPM_Stoch': (sample_treeffuser_dpm_stoch, 'vesde'),  
    'OU-Exp': (sample_treeffuser_ou_exponential, 'ou'),
    'OU-Mid': (sample_treeffuser_ou_midpoint, 'ou'),
    'DPM_2_stoch': (sample_treeffuser_dpm_solver2_stoch, 'vesde'),
    'DPM_3_stoch': (sample_treeffuser_dpm_solver3_stoch, 'vesde')
}

time_budgets = [5,10,15,20,25,30]

results = []

saved_samples = {}  

prices = X_test["sell_price"].fillna(1.0).values
profit_margin = 0.5
stocking_cost = prices / (1 + profit_margin)

for budget in time_budgets:
    print(f"\nRunning experiments for {budget} seconds budget...")
    
    for method_name, (method_fn, sde_type) in sampling_methods.items():
        print(f"  Running {method_name}...")
        
        if sde_type == 'vpsde':
            sde_dict = sde_dict_vpsde
            models = models_vpsde
        elif sde_type == 'vesde':
            sde_dict = sde_dict_vesde
            models = models_vesde
        else:  
            sde_dict = sde_dict_ou
            models = models_ou
            
        if method_name == 'Euler':
            n_steps = 15
        elif method_name == 'OU-Mid':
            n_steps = 5
        else:
            n_steps = 3
                
        
        samples = method_fn(
            X=X_test_sc,
            sde_dict=sde_dict,
            models=models,
            x_scaler=x_scaler,
            y_scaler=y_scaler,
            time_budget=budget,
            n_steps=n_steps,
            seed=42
        )
            
    
        saved_samples[(method_name, budget)] = samples

        metrics = evaluate_samples(y_test, samples, prices, stocking_cost)
    
        results.append({
            'Method': method_name,
            'Time_Budget': budget,
            **metrics
        })


results_df = pd.DataFrame(results)

print("\nResults Summary:")
summary_table = results_df.pivot_table(
    index='Method',
    columns='Time_Budget',
    values=['R2', 'RMSE', 'CRPS', 'Profit']
)
print(summary_table)