In [None]:
!pip install jax jaxlib
!pip install --quiet --upgrade scipy
!pip install --quiet jax jaxlib optax

In [None]:
import jax
from jax.scipy.stats import norm
import jax.numpy as jnp
from scipy.stats import norm
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.optimize import minimize_scalar, brentq, minimize
from scipy.special import gamma
from numpy.polynomial.legendre import leggauss
import warnings
warnings.filterwarnings('ignore')
from math import log
from numpy.random import default_rng, SeedSequence
from scipy.stats import kstwobign, cramervonmises, uniform
from joblib import Parallel, delayed
from itertools import zip_longest
from collections import OrderedDict
from scipy.stats import uniform, cramervonmises, kstwobign
from joblib import Parallel, delayed
from itertools import zip_longest
import matplotlib.patches as mpatches


# ==============================================================
# ETELL Simulation
# ==============================================================


In [None]:
# ==============================================================
# ETLL Simulation - DESIGN B
# ==============================================================


class ETLLSimulation_DesignB:
    """
    Design B with enhanced robustness for problematic J-pairs:
      - Adaptive sign handling for Δ_w
      - Multiple solving strategies
      - Better bracketing for Ψ(β)
      - Fallback mechanisms
    """

    def __init__(self, theta=1.0, alpha=2.0, beta=0.5, n_quad=250,
                 use_det_re=True, use_numeric_info=False, rng=None):
        self.theta = float(theta)
        self.alpha_true = float(alpha)
        self.beta_true = float(beta)
        self.use_det_re = bool(use_det_re)
        self.use_numeric_info = bool(use_numeric_info)

        self.nodes, self.weights = leggauss(n_quad)
        self.u = 0.5 * (self.nodes + 1.0)
        self.w = 0.5 * self.weights
        self._K = None
        self.rng = np.random.default_rng(rng)

    @staticmethod
    def kumaraswamy_weight(u, a, b):
        """Kumaraswamy pdf with numerical stability"""
        with np.errstate(all='ignore'):
            if a > 20 or b > 20:
                result = a * b * np.exp((a-1)*np.log(u) + (b-1)*np.log(1 - u**a))
            else:
                result = a * b * (u ** (a - 1.0)) * ((1.0 - u**a) ** (b - 1.0))
            return np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)

    def _stable_terms(self, beta, u):
        """Stable computation of base, t, ell, g"""
        two_pow_beta = np.exp(beta * np.log(2.0))
        base = two_pow_beta - (two_pow_beta - 1.0) * u

        if abs(beta) < 1e-8:
            x = (1.0 - u) * np.log(2.0)
        else:
            x = np.log(np.maximum(base, 1e-15)) / beta

        t = np.expm1(x)
        ell = np.where(t > 0.0, np.log(t), -np.inf)

        ratio = np.empty_like(x)
        small = np.abs(x) < 1e-7
        ratio[~small] = 1.0 / (1.0 - np.exp(-x[~small]))
        ratio[small] = 1.0 / (x[small] + 1e-30) + 0.5 + x[small] / 12.0

        g = ratio / base
        return base, t, ell, g

    def _kernel_matrix(self):
        if self._K is None:
            u = self.u
            self._K = np.minimum(u[:, None], u[None, :]) - (u[:, None] * u[None, :])
        return self._K

    def generate_etll_sample(self, n):
        """Inverse-transform sampling"""
        u = self.rng.uniform(0.0, 1.0, int(n))
        if abs(self.beta_true) < 1e-8:
            t = np.expm1((1.0 - u) * np.log(2.0))
        else:
            two_pow_beta = np.exp(self.beta_true * np.log(2.0))
            base = two_pow_beta - (two_pow_beta - 1.0) * u
            x = np.log(base) / self.beta_true
            t = np.expm1(x)

        x = self.theta * np.power(t, -1.0 / self.alpha_true)
        return x

    def compute_cw_k(self, beta, k, a, b):
        """c_{w,k}(β) with enhanced stability"""
        u, w = self.u, self.w
        J = self.kumaraswamy_weight(u, a, b)
        _, t, ell, _ = self._stable_terms(beta, u)

        mask = t > 1e-15
        if not np.any(mask):
            return np.nan

        integrand = J * np.where(mask, ell, 0.0)
        result = np.sum(w * integrand)

        return result if np.isfinite(result) else np.nan

    def tau_w(self, beta, a1, b1, a2, b2):
        """τ_w(β) = c_{w,2}(β) - c_{w,1}(β)"""
        c1 = self.compute_cw_k(beta, 1, a1, b1)
        c2 = self.compute_cw_k(beta, 2, a2, b2)

        if not (np.isfinite(c1) and np.isfinite(c2)):
            return np.nan

        return c2 - c1

    def solve_beta_designB_robust(self, mu1, mu2, a1, b1, a2, b2, verbose=False):
        """
        Ultra-robust β solver with multiple strategies:
        1. Standard Ψ equation
        2. Swapped J₁↔J₂ if sign violation
        3. Multiple bracket attempts
        4. Optimization fallback
        """
        Delta_w = mu2 - mu1

        if abs(Delta_w) < 1e-12:
            if verbose:
                print(f"    Δ_w too small: {Delta_w:.2e}")
            return np.nan, False

        def Psi(beta, swap=False):
            """
            Ψ(β) = c_{w,1}(β)/τ_w(β) + (log θ - μ₁)/Δ_w
            If swap=True, use c_{w,2} instead (effectively swapping J₁↔J₂)
            """
            try:
                if swap:
                    c_num = self.compute_cw_k(beta, 2, a2, b2)
                    tau = -self.tau_w(beta, a1, b1, a2, b2)  # Note sign flip
                else:
                    c_num = self.compute_cw_k(beta, 1, a1, b1)
                    tau = self.tau_w(beta, a1, b1, a2, b2)

                if not (np.isfinite(c_num) and np.isfinite(tau)):
                    return np.nan

                if abs(tau) < 1e-14:
                    return 1e10 * np.sign(tau + 1e-15)

                result = (c_num / tau) + (np.log(self.theta) - mu1) / Delta_w
                return result if np.isfinite(result) else np.nan
            except:
                return np.nan

        # Strategy 1: Try standard formulation
        strategies = [
            ("standard", False, (-3.5, 3.5, 150)),
            ("standard", False, (-4.5, 4.5, 200)),
            ("swapped", True, (-3.5, 3.5, 150)),
            ("swapped", True, (-4.5, 4.5, 200)),
        ]

        for strategy_name, swap, (beta_min, beta_max, n_grid) in strategies:
            try:
                grid = np.linspace(beta_min, beta_max, n_grid)
                psi_vals = np.array([Psi(b, swap=swap) for b in grid])

                valid = np.isfinite(psi_vals)
                if np.sum(valid) < 3:
                    continue

                psi_valid = psi_vals[valid]
                beta_valid = grid[valid]

                signs = np.sign(psi_valid)
                sign_changes = np.where(np.diff(signs) != 0)[0]

                if len(sign_changes) == 0:
                    continue

                # Try each sign change
                for idx in sign_changes:
                    a, b = beta_valid[idx], beta_valid[idx + 1]

                    fa, fb = Psi(a, swap=swap), Psi(b, swap=swap)
                    if not (np.isfinite(fa) and np.isfinite(fb)):
                        continue

                    if fa * fb < 0:
                        try:
                            beta_hat = brentq(lambda bb: Psi(bb, swap=swap),
                                            a, b, xtol=1e-10, rtol=1e-9, maxiter=500)

                            # Validate result
                            tau = self.tau_w(beta_hat, a1, b1, a2, b2)
                            if swap:
                                tau = -tau

                            alpha_hat = -tau / Delta_w

                            # Check validity
                            if (alpha_hat > 0.01 and alpha_hat < 100 and
                                np.isfinite(alpha_hat) and abs(beta_hat) < 10):

                                if verbose and swap:
                                    print(f"    Success with {strategy_name} strategy")

                                return beta_hat, swap
                        except:
                            continue
            except:
                continue

        # Strategy 2: Optimization fallback
        for swap in [False, True]:
            try:
                def objective(beta):
                    val = Psi(beta, swap=swap)
                    return abs(val) if np.isfinite(val) else 1e10

                result = minimize_scalar(objective, bounds=(-5, 5), method='bounded',
                                       options={'maxiter': 500})

                if result.success and result.fun < 0.01:
                    beta_hat = result.x
                    tau = self.tau_w(beta_hat, a1, b1, a2, b2)
                    if swap:
                        tau = -tau

                    alpha_hat = -tau / Delta_w

                    if alpha_hat > 0.01 and np.isfinite(alpha_hat):
                        if verbose and swap:
                            print(f"    Success with optimization {['standard','swapped'][swap]}")
                        return beta_hat, swap
            except:
                continue

        if verbose:
            print(f"    All strategies failed")

        return np.nan, False

    def kumaraswamy_l_estimator_designB_robust(self, x, a1, b1, a2, b2, verbose=False):

        x = np.asarray(x)
        x = x[x >= self.theta]
        n = x.size
        if n < 3:
            return np.nan, np.nan

        xs = np.sort(x)
        i = np.arange(1, n + 1)
        uo = i / (n + 1.0)

        J1 = self.kumaraswamy_weight(uo, a1, b1)
        J2 = self.kumaraswamy_weight(uo, a2, b2)

        lx = np.log(xs)
        mu1_orig = np.mean(J1 * lx)
        mu2_orig = np.mean(J2 * lx)

        Delta_w_orig = mu2_orig - mu1_orig

        if abs(Delta_w_orig) < 1e-12:
            if verbose:
                print(f"  J₁({a1},{b1})×J₂({a2},{b2}): Δ_w ≈ 0")
            return np.nan, np.nan

        # Try to solve
        beta_hat, swapped = self.solve_beta_designB_robust(
            mu1_orig, mu2_orig, a1, b1, a2, b2, verbose=verbose
        )

        if np.isnan(beta_hat):
            if verbose:
                print(f"  J₁({a1},{b1})×J₂({a2},{b2}): β solve failed")
            return np.nan, np.nan

        # Compute tau with correct sign
        tau = self.tau_w(beta_hat, a1, b1, a2, b2)
        if swapped:
            tau = -tau

        if not np.isfinite(tau) or abs(tau) < 1e-12:
            if verbose:
                print(f"  J₁({a1},{b1})×J₂({a2},{b2}): τ_w invalid")
            return np.nan, np.nan

        alpha_hat = -tau / Delta_w_orig

        # Final validation
        if alpha_hat <= 0:
            if verbose:
                print(f"  J₁({a1},{b1})×J₂({a2},{b2}): α̂={alpha_hat:.3f} ≤ 0")
            return np.nan, np.nan

        if not np.isfinite(alpha_hat):
            if verbose:
                print(f"  J₁({a1},{b1})×J₂({a2},{b2}): α̂ not finite")
            return np.nan, np.nan

        # Sanity check: estimates shouldn't be too far from truth
        if (abs(alpha_hat - self.alpha_true) > 5 * self.alpha_true or
            abs(beta_hat - self.beta_true) > 10):
            if verbose:
                print(f"  J₁({a1},{b1})×J₂({a2},{b2}): estimates too far from truth")
            return np.nan, np.nan

        return alpha_hat, beta_hat

    def _lambda_w_pair(self, alpha, beta, ai, bi, aj, bj):
        """Compute Λ_{w,ij}"""
        u, w = self.u, self.w
        Ji = self.kumaraswamy_weight(u, ai, bi)
        Jj = self.kumaraswamy_weight(u, aj, bj)
        _, t, _, g = self._stable_terms(beta, u)

        mask = t > 0
        g = np.where(mask, g, 0.0)

        W = (w * Ji)[:, None] * (w * Jj)[None, :]
        K = self._kernel_matrix()
        G = g[:, None] * g[None, :]

        return np.sum(W * K * G)

    def Sigma_mu_designB(self, alpha, beta, a1, b1, a2, b2):
        """Asymptotic covariance for Design B"""
        delta_sq = ((np.exp(beta * np.log(2.0)) - 1.0) / (alpha * beta)) ** 2

        L11 = self._lambda_w_pair(alpha, beta, a1, b1, a1, b1)
        L12 = self._lambda_w_pair(alpha, beta, a1, b1, a2, b2)
        L22 = self._lambda_w_pair(alpha, beta, a2, b2, a2, b2)

        S = delta_sq * np.array([[L11, L12], [L12, L22]])
        return 0.5 * (S + S.T)

    def mle_etll(self, x):
        """MLE for ETELL (same as before)"""
        xv = np.asarray(x)
        xv = xv[xv >= self.theta]
        n = xv.size
        if n < 5:
            return np.nan, np.nan

        def nll(params):
            alpha, beta = params
            if alpha <= 0 or np.abs(beta) > 5:
                return 1e10
            try:
                if abs(beta) < 1e-8:
                    const = -np.log(np.log(2.0))
                else:
                    two_b = np.exp(beta * np.log(2.0))
                    const = np.log(np.abs(beta)) - np.log(np.abs(two_b - 1.0))

                ll = n * np.log(alpha) + n * const
                ratio = (self.theta / xv) ** alpha
                ll += (beta - 1.0) * np.sum(np.log1p(ratio))
                ll -= (1.0 + alpha) * np.sum(np.log(xv / self.theta))
                return -ll
            except:
                return 1e10

        lx = np.log(xv)
        m1 = lx.mean()
        m2 = (lx ** 2).mean()
        alpha0 = 1.0 / np.sqrt(max(m2 - m1**2, 1e-4))
        beta0 = np.clip(self.beta_true, -3.0, 3.0)

        try:
            res = minimize(nll, x0=[alpha0, beta0],
                          bounds=[(0.05, 10.0), (-3.0, 3.0)],
                          method="L-BFGS-B")
            if res.success and res.fun < 1e9:
                return res.x[0], res.x[1]

            res = minimize(nll, x0=[self.alpha_true, self.beta_true],
                          bounds=[(0.05, 10.0), (-3.0, 3.0)],
                          method="L-BFGS-B")
            return (res.x[0], res.x[1]) if res.success else (np.nan, np.nan)
        except:
            return np.nan, np.nan

    def fisher_information(self, alpha, beta, n):
        """Fisher Information Matrix"""
        t = np.linspace(1e-6, 1 - 1e-6, 75)
        I_bm1 = np.trapezoid(np.log(t) * (1 + t) ** (beta - 1.0), t)
        I_bm3_2 = np.trapezoid((np.log(t) ** 2) * (1 + t) ** (beta - 3.0), t)

        two_b = np.exp(beta * np.log(2.0))
        denom = two_b - 1.0
        ln2 = np.log(2.0)

        eps = 1e-8
        denom = denom if np.abs(denom) > eps else np.sign(denom) * eps
        b1 = (beta - 1.0) if np.abs(beta - 1.0) > eps else np.sign(beta - 1.0) * eps
        b2 = (beta - 2.0) if np.abs(beta - 2.0) > eps else np.sign(beta - 2.0) * eps

        Iaa = (1 / alpha**2) - ((beta - 1.0) / b2) * (1 + beta + (beta / denom) * (2 * beta * I_bm1 + (beta - 1.0) * I_bm3_2))
        Ibb = (1 / beta**2) * (1 - (two_b * (beta**2) * (ln2**2)) / (denom**2))
        Iab = (1 / (alpha * b1)) * (1 + (beta * I_bm1) / denom)

        I = n * np.array([[Iaa, Iab], [Iab, Ibb]])
        return I

    def compute_theoretical_are_designB(self, a1, b1, a2, b2):
        """Compute theoretical ARE for Design B"""
        u, w = self.u, self.w

        if abs(self.beta_true) < 1e-8:
            t = np.expm1((1.0 - u) * np.log(2.0))
        else:
            two_b = np.exp(self.beta_true * np.log(2.0))
            base = two_b - (two_b - 1.0) * u
            t = np.expm1(np.log(base) / self.beta_true)

        q = self.theta * np.power(t, -1.0 / self.alpha_true)

        J1 = self.kumaraswamy_weight(u, a1, b1)
        J2 = self.kumaraswamy_weight(u, a2, b2)

        mu1 = np.sum(w * J1 * np.log(q))
        mu2 = np.sum(w * J2 * np.log(q))

        # Solve
        beta_hat, swapped = self.solve_beta_designB_robust(mu1, mu2, a1, b1, a2, b2)

        if np.isnan(beta_hat):
            return np.nan

        tau = self.tau_w(beta_hat, a1, b1, a2, b2)
        if swapped:
            tau = -tau

        alpha_hat = -tau / (mu2 - mu1)

        if not (np.isfinite(alpha_hat) and np.isfinite(beta_hat) and alpha_hat > 0):
            return np.nan

        try:
            n_large = 5000
            Sigma_mu = self.Sigma_mu_designB(alpha_hat, beta_hat, a1, b1, a2, b2) / n_large

            eps = 1e-6

            def solve_pair(m1, m2):
                try:
                    b, sw = self.solve_beta_designB_robust(m1, m2, a1, b1, a2, b2)
                    tau = self.tau_w(b, a1, b1, a2, b2)
                    if sw:
                        tau = -tau
                    a = -tau / (m2 - m1)
                    return a, b
                except:
                    return alpha_hat, beta_hat

            a_p1, b_p1 = solve_pair(mu1 + eps, mu2)
            a_m1, b_m1 = solve_pair(mu1 - eps, mu2)
            a_p2, b_p2 = solve_pair(mu1, mu2 + eps)
            a_m2, b_m2 = solve_pair(mu1, mu2 - eps)

            D = np.array([
                [(a_p1 - a_m1) / (2 * eps), (a_p2 - a_m2) / (2 * eps)],
                [(b_p1 - b_m1) / (2 * eps), (b_p2 - b_m2) / (2 * eps)]
            ])

            if abs(np.linalg.det(D)) < 1e-12:
                return np.nan

            S_L = D @ Sigma_mu @ D.T
            I = self.fisher_information(alpha_hat, beta_hat, n_large)
            S_MLE = np.linalg.inv(I)

            det_S_L = np.linalg.det(S_L)
            det_S_MLE = np.linalg.det(S_MLE)

            if det_S_L <= 0 or det_S_MLE <= 0:
                return np.nan

            ARE = np.sqrt(det_S_MLE / det_S_L)

            # Sanity check
            if not np.isfinite(ARE) or ARE < 1e-6 or ARE > 2:
                return np.nan

            return ARE
        except:
            return np.nan

    def run_simulation_with_re_se_designB(self, sample_sizes, j_pairs,
                                          n_batches=10, sims_per_batch=100,
                                          verbose=True, ref_at="true"):
        """Run simulation for Design B with robust estimator"""
        all_results = {}

        for n in sample_sizes:
            if verbose:
                print(f"\nRunning n={n} with {n_batches} batches...")

            are_inf = {}
            for (a1, b1), (a2, b2) in j_pairs:
                are = self.compute_theoretical_are_designB(a1, b1, a2, b2)
                are_inf[((a1, b1), (a2, b2))] = are


            batch_stats = []

            for bidx in range(n_batches):
                if verbose and bidx % 5 == 0:
                    print(f"  Batch {bidx + 1}/{n_batches}")

                est = {"MLE": {"alpha": [], "beta": []}}
                for (a1, b1), (a2, b2) in j_pairs:
                    key = f"J1({a1},{b1})×J2({a2},{b2})"
                    est[key] = {"alpha": [], "beta": []}

                for _ in range(sims_per_batch):
                    x = self.generate_etll_sample(n)

                    a_mle, b_mle = self.mle_etll(x)
                    if np.isfinite(a_mle) and np.isfinite(b_mle):
                        est["MLE"]["alpha"].append(a_mle)
                        est["MLE"]["beta"].append(b_mle)

                    for (a1, b1), (a2, b2) in j_pairs:
                        key = f"J1({a1},{b1})×J2({a2},{b2})"
                        ak, bk = self.kumaraswamy_l_estimator_designB_robust(
                            x, a1, b1, a2, b2, verbose=False
                        )
                        if np.isfinite(ak) and np.isfinite(bk):
                            est[key]["alpha"].append(ak)
                            est[key]["beta"].append(bk)

                batch = {}

                def _winsorize_pair(a_vals, b_vals, p=0.00):
                    ax = np.asarray(a_vals, float)
                    bx = np.asarray(b_vals, float)
                    if ax.size < 3:
                        return ax, bx
                    lo = int(np.floor(p * ax.size))
                    hi = int(np.ceil((1 - p) * ax.size))
                    axs = np.sort(ax)
                    bxs = np.sort(bx)
                    a_lo, a_hi = axs[lo], axs[min(hi, ax.size - 1)]
                    b_lo, b_hi = bxs[lo], bxs[min(hi, bx.size - 1)]
                    ax_cl = np.clip(ax, a_lo, a_hi)
                    bx_cl = np.clip(bx, b_lo, b_hi)
                    return ax_cl, bx_cl

                def det_re(a_list, b_list, S_asymp_mle_ref):
                    vals = np.c_[a_list, b_list]
                    if vals.shape[0] < 3:
                        return np.nan
                    a_vals, b_vals = _winsorize_pair(vals[:, 0], vals[:, 1])
                    S = np.cov(np.c_[a_vals, b_vals], rowvar=False, ddof=1)
                    S = 0.5 * (S + S.T) + 1e-9 * np.eye(2)
                    den = np.linalg.det(S)
                    num = np.linalg.det(S_asymp_mle_ref)
                    return np.sqrt(num / den) if den > 0 else np.nan

                if len(est["MLE"]["alpha"]) > 0:
                    avals = np.array(est["MLE"]["alpha"])
                    bvals = np.array(est["MLE"]["beta"])

                    if ref_at == "batch":
                        alpha_ref = float(np.mean(avals))
                        beta_ref = float(np.mean(bvals))
                    else:
                        alpha_ref = self.alpha_true
                        beta_ref = self.beta_true

                    I_ref = self.fisher_information(alpha_ref, beta_ref, n)
                    S_asymp_mle_ref = np.linalg.inv(I_ref)

                    batch["MLE"] = {
                        "alpha_mean": np.mean(avals) / self.alpha_true,
                        "alpha_se": np.std(avals, ddof=1) / (self.alpha_true * np.sqrt(len(avals))),
                        "beta_mean": np.mean(bvals) / self.beta_true,
                        "beta_se": np.std(bvals, ddof=1) / (self.beta_true * np.sqrt(len(bvals))),
                        "re": det_re(avals, bvals, S_asymp_mle_ref),
                        "re_asymptotic": 1.0
                    }

                    for (a1, b1), (a2, b2) in j_pairs:
                        key = f"J1({a1},{b1})×J2({a2},{b2})"
                        if len(est[key]["alpha"]) > 0:
                            avals_k = np.array(est[key]["alpha"])
                            bvals_k = np.array(est[key]["beta"])
                            batch[key] = {
                                "alpha_mean": np.mean(avals_k) / self.alpha_true,
                                "alpha_se": np.std(avals_k, ddof=1) / (self.alpha_true * np.sqrt(len(avals_k))),
                                "beta_mean": np.mean(bvals_k) / self.beta_true,
                                "beta_se": np.std(bvals_k, ddof=1) / (self.beta_true * np.sqrt(len(bvals_k))),
                                "re": det_re(avals_k, bvals_k, S_asymp_mle_ref),
                                "re_asymptotic": are_inf[((a1, b1), (a2, b2))]
                            }

                batch_stats.append(batch)

            final = {}
            keys = set().union(*[b.keys() for b in batch_stats])

            for key in keys:
                def collect(field):
                    vals = [b[key][field] for b in batch_stats
                           if key in b and field in b[key] and np.isfinite(b[key][field])]
                    return np.array(vals)

                a_mean = collect("alpha_mean")
                a_se = collect("alpha_se")
                b_mean = collect("beta_mean")
                b_se = collect("beta_se")
                re_vals = collect("re")
                re_inf = collect("re_asymptotic")

                if a_mean.size > 0:
                    final[key] = {
                        "alpha_mean": a_mean.mean(),
                        "alpha_se": a_se.mean() if a_se.size > 0 else np.nan,
                        "beta_mean": b_mean.mean() if b_mean.size > 0 else np.nan,
                        "beta_se": b_se.mean() if b_se.size > 0 else np.nan,
                        "re": re_vals.mean() if re_vals.size > 0 else np.nan,
                        "re_se": re_vals.std(ddof=1) / np.sqrt(re_vals.size) if re_vals.size > 1 else np.nan,
                        "re_asymptotic": (1.0 if key == "MLE" else (re_inf.mean() if re_inf.size > 0 else np.nan))
                    }

            all_results[n] = final

        return all_results

    def print_results_table_designB(self, results, sample_sizes, j_pairs):
        """Print results table for Design B"""
        print("\n" + "=" * 140)
        print(f"Design B: Standardized MEAN and RE from ETELL(α={self.alpha_true}, β={self.beta_true}, θ={self.theta})")
        print("Different J (J₁ ≠ J₂), Same h = log(x)")
        print("=" * 140)

        col_w, last_w = 14, 10
        header = "Weight Config".ljust(30)
        for n in sample_sizes:
            header += f"{f'n={n}':^{col_w * 2}}"
        header += f"{'n→∞':^{last_w * 2}}"
        print(header)

        sub = "J₁(a₁,b₁)×J₂(a₂,b₂)".ljust(30)
        for _ in sample_sizes:
            sub += f"{'α̂/α':>{col_w}}{'β̂/β':>{col_w}}"
        sub += f"{'α̂/α':>{last_w}}{'β̂/β':>{last_w}}"
        print(sub)

        print("\nMEAN VALUES:")

        def row_for(key, label=None):
            lab = (label or key).ljust(30)
            out = lab
            for n in sample_sizes:
                if n in results and key in results[n]:
                    s = results[n][key]
                    out += f"{s['alpha_mean']:5.2f}({(s['alpha_se'] if np.isfinite(s['alpha_se']) else np.nan):.3f})".rjust(col_w)
                    out += f"{s['beta_mean']:5.2f}({(s['beta_se'] if np.isfinite(s['beta_se']) else np.nan):.3f})".rjust(col_w)
                else:
                    out += f"{'---':>{col_w * 2}}"
            out += f"{'1.00':>{last_w}}{'1.00':>{last_w}}"
            print(out)

        row_for("MLE", "MLE")
        for (a1, b1), (a2, b2) in j_pairs:
            key = f"J1({a1},{b1})×J2({a2},{b2})"
            row_for(key, f"J₁({a1},{b1})×J₂({a2},{b2})")

        print("\n" + "-" * 140)
        print("RELATIVE EFFICIENCY:")

        def re_row(key, label=None):
            lab = (label or key).ljust(30)
            out = lab
            for n in sample_sizes:
                if n in results and key in results[n]:
                    s = results[n][key]
                    re = s.get("re", np.nan)
                    se = s.get("re_se", np.nan)
                    out += f"{(re if np.isfinite(re) else np.nan):5.3f}({(se if np.isfinite(se) else np.nan):.3f})".rjust(col_w)
                else:
                    out += f"{'---':>{col_w}}"

            n0 = sample_sizes[0]
            if n0 in results and key in results[n0]:
                out += f"{results[n0][key]['re_asymptotic']:5.3f}".rjust(last_w)
            else:
                out += f"{'---':>{last_w}}"
            print(out)

        re_row("MLE", "MLE")
        for (a1, b1), (a2, b2) in j_pairs:
            key = f"J1({a1},{b1})×J2({a2},{b2})"
            re_row(key, f"J₁({a1},{b1})×J₂({a2},{b2})")


# ==============================================================
# Runner with ALL J-pairs
# ==============================================================

def run_etll_simulation_study_designB_robust():
    """Run Design B simulation with ALL configurations"""
    alpha_true, beta_true, theta_true = 2.0, 0.5, 1.0
    sample_sizes = [100,250,500,1000]


    j_pairs = [
        ((1.0, 1.0), (1.0, 2.0)),
        ((1.0, 1.0), (0.3, 1.0)),
        ((1.0, 1.0), (1.2, 1.8)),
        ((1.0, 1.0), (0.8, 1.0)),
        ((1.0, 1.0), (4.0, 12.0)),
        ((1.0, 1.0), (2.0, 1.0)),
    ]

    print("="*70)
    print("ETLL Simulation Study - DESIGN B")
    print("Different J (J₁ ≠ J₂), Same h = log(x)")
    print(f"True params: α={alpha_true}, β={beta_true}, θ={theta_true}")
    print("="*70)

    sim = ETLLSimulation_DesignB(
        theta=theta_true,
        alpha=alpha_true,
        beta=beta_true,
        n_quad=250,
        use_det_re=True,
        use_numeric_info=False,
        rng=123
    )

    results = sim.run_simulation_with_re_se_designB(
        sample_sizes=sample_sizes,
        j_pairs=j_pairs,
        n_batches=50,
        sims_per_batch=200,
        verbose=True,
        ref_at="true"
    )

    sim.print_results_table_designB(results, sample_sizes, j_pairs)
    return results


if __name__ == "__main__":
    warnings.filterwarnings("ignore", category=RuntimeWarning)
    t0 = time.perf_counter()
    results = run_etll_simulation_study_designB_robust()
    print(f"\n⏱️ Total runtime: {time.perf_counter() - t0:.2f} s")

ETLL Simulation Study - DESIGN B
Different J (J₁ ≠ J₂), Same h = log(x)
True params: α=2.0, β=0.5, θ=1.0

Running n=100 with 50 batches...
  Batch 1/50
  Batch 6/50
  Batch 11/50
  Batch 16/50
  Batch 21/50
  Batch 26/50
  Batch 31/50
  Batch 36/50
  Batch 41/50
  Batch 46/50

Running n=250 with 50 batches...
  Batch 1/50
  Batch 6/50
  Batch 11/50
  Batch 16/50
  Batch 21/50
  Batch 26/50
  Batch 31/50
  Batch 36/50
  Batch 41/50
  Batch 46/50

Running n=500 with 50 batches...
  Batch 1/50
  Batch 6/50
  Batch 11/50
  Batch 16/50
  Batch 21/50
  Batch 26/50
  Batch 31/50
  Batch 36/50
  Batch 41/50
  Batch 46/50

Running n=1000 with 50 batches...
  Batch 1/50
  Batch 6/50
  Batch 11/50
  Batch 16/50
  Batch 21/50
  Batch 26/50
  Batch 31/50
  Batch 36/50
  Batch 41/50
  Batch 46/50

Design B: Standardized MEAN and RE from ETELL(α=2.0, β=0.5, θ=1.0)
Different J (J₁ ≠ J₂), Same h = log(x)
Weight Config                            n=100                       n=250                       n=