In [None]:
import numpy as np
from sklearn.linear_model import Lasso
import matplotlib.pyplot as plt
from joblib import Parallel, delayed 
import oc
import warnings
import copy


from CoRT_builder import CoRT
import utils
import parametric_optim

def run_single_trial(seed, n_target, n_source, p, K, Ka, h, lamda, s_vector, T):
    """
    Runs a single simulation iteration.
    """
    # Set unique seed for this parallel worker
    np.random.seed(seed)
    
    s = len(s_vector)
    CoRT_model = CoRT(alpha=lamda)
    
    target_data, source_data = CoRT_model.gen_data(n_target, n_source, p, K, Ka, h, s_vector, s, "AR")
    similar_source_index = CoRT_model.find_similar_source(n_target, K, target_data, source_data, T=T, verbose=False)
    X_combined, y_combined = CoRT_model.prepare_CoRT_data(similar_source_index, source_data, target_data)

    model = Lasso(alpha=lamda, fit_intercept=False, tol=1e-10, max_iter=10000000)
    model.fit(X_combined, y_combined.ravel())
    beta_hat_target = model.coef_[-p:]

    active_indices = np.sort(np.array([i for i, b in enumerate(beta_hat_target) if b != 0]))

    if len(active_indices) == 0:
        return None 

    j = np.random.choice(len(active_indices))
    selected_feature_index = active_indices[j]
    
    X_target = target_data["X"]
    y_target = target_data["y"]
    X_active, X_inactive = utils.get_active_X(beta_hat_target, X_target)

    etaj, etajTy = utils.construct_test_statistic(y_target, j, X_active)

    Sigma = np.eye(n_target)
    b_global = Sigma @ etaj @ np.linalg.pinv(etaj.T @ Sigma @ etaj)
    a_global = (Sigma - b_global @ etaj.T) @ y_target


    folds = utils.split_target(T, X_target, y_target, n_target)
    
    # OVER-CONDITIONING
    L_train, R_train = oc.get_Z_train(etajTy, folds, source_data, a_global, b_global, lamda, K, T)
    L_val, R_val = oc.get_Z_val(folds, T, K, a_global, b_global, etajTy, lamda, source_data)
    L_CoRT, R_CoRT, Az = oc.get_Z_CoRT(X_combined, similar_source_index, lamda, a_global, b_global, source_data, etajTy)

    L_final, R_final = oc.combine_Z(L_train, R_train, L_val, R_val, L_CoRT, R_CoRT)

    etaT_sigma_eta = (etaj.T @ Sigma @ etaj).item()
    sigma_z = np.sqrt(etaT_sigma_eta)
    truncated_cdf = utils.computed_truncated_cdf(L_final, R_final, etajTy, 0, sigma_z)
    oc_p_value = 2 * min(truncated_cdf, 1 - truncated_cdf)

    is_signal = (selected_feature_index < s) 
    oc_result_dict = {
            "p_value": oc_p_value,
            "is_signal": is_signal,
            "feature_idx": selected_feature_index
    }

    # Parametric
    z_k = -20
    z_max = 20

    Z_train_list = parametric_optim.get_Z_train(z_k, folds, source_data, a_global, b_global, lamda, K, T)
    Z_val_list = parametric_optim.get_Z_val(z_k, folds, T, K, a_global, b_global, lamda, source_data)

    target_data_current = {"X": X_target, "y": a_global + z_k * b_global}
    similar_source_current = parametric_optim.find_similar_source(z_k, a_global, b_global, lamda,  n_target, K, target_data_current, source_data, T=T, verbose=False)
    X_combined_new, y_combined_new = CoRT_model.prepare_CoRT_data(similar_source_current, source_data, target_data_current)
    L_CoRT, R_CoRT, Az = parametric_optim.get_Z_CoRT(X_combined_new, similar_source_current, lamda, a_global, b_global, source_data, z_k)

    offset = p * len(similar_source_index)
    
    z_list = [z_k]
    Az_list = []

    # 5. Path Following Loop
    step_count = 0
    matched_active_set = None
    
    while z_k < z_max:
        step_count += 1
        
        current_num_sources = len(similar_source_current)
        offset = p * current_num_sources
        
        # [FIX] Sort the active set found along the path
        Az_target_current = np.sort(np.array([idx - offset for idx in Az if idx >= offset]))
        Az_list.append(Az_target_current)

        mn = z_max
        stopper = None

        # Check Train Boundaries
        for val in Z_train_list:
            if mn - val[4] > 1e-9:
                mn = val[4]
                stopper = "TRAIN"

        # Check Val Boundaries
        for val in Z_val_list:
            if mn - val[3] > 1e-9:
                mn = val[3]
                stopper = "VAL"

        # Check CoRT Boundaries
        if mn > R_CoRT:
            mn = R_CoRT
            stopper = "CORT"

        R_final = mn

        if R_final - z_k < -1e-9:
            z_k += 1e-5
        else:
            z_k = max(R_final, z_k) + 1e-5

        if (z_k >= z_max):
            z_list.append(z_max)
        else:
            z_list.append(z_k)

        update_train_needed = False
        update_val_needed = False
        update_cort_needed = False
        
        if stopper == "TRAIN":
            update_train_needed = True
            update_val_needed = True
            update_cort_needed = True

        elif stopper == "VAL":
            update_val_needed = True
            update_cort_needed = True

        elif stopper == "CORT":
            update_cort_needed = True

        if update_train_needed:
            for val in Z_train_list:
                if val[4] <= z_k + 1e-9:
                    l, r = parametric_optim.update_Z_train(val, z_k, folds, source_data, a_global, b_global, lamda, K, T)
                    val[3] = l
                    val[4] = r

        if update_val_needed:
            for val in Z_val_list:
                l, r = parametric_optim.update_Z_val(val, z_k, folds, T, K, a_global, b_global, lamda, source_data)
                val[2] = l
                val[3] = r

        if update_cort_needed:
            target_data_current = {"X": X_target, "y": a_global + z_k * b_global}
            similar_source_current = parametric_optim.find_similar_source(z_k, a_global, b_global, lamda, n_target, K, target_data_current, source_data, T=T, verbose=False)
            X_combined_new, y_combined_new = CoRT_model.prepare_CoRT_data(similar_source_current, source_data, target_data_current)
            L_CoRT, R_CoRT, Az = parametric_optim.get_Z_CoRT(X_combined_new, similar_source_current, lamda, a_global, b_global, source_data, z_k)
    
    para_p_value = parametric_optim.pivot(active_indices, Az_list, z_list, etaj, etajTy, 0, Sigma)
    is_signal = (selected_feature_index < s) 
    para_result_dict = {
        "p_value": para_p_value,
        "is_signal": is_signal,
        "feature_idx": selected_feature_index
    }
    
    return (oc_result_dict , para_result_dict)

# ==========================================
# Main Execution Block
# ==========================================
n_target = 30
n_source = 10
p = 10
K = 3
Ka = 1
h = 30
lamda = 0.1
alpha = 0.05
s_vector = [1] * 1
T = 3
iteration = 1000

print(f"Starting {iteration} iterations in parallel...")

# Run in parallel using all available cores (n_jobs=-1)
results = Parallel(n_jobs=-1, verbose=10)(
    delayed(run_single_trial)(i, n_target, n_source, p, K, Ka, h, lamda, s_vector, T) 
    for i in range(iteration)
)

# print("\n\n")
# print("-" * 50 + "OVER-CONDITIONING" +"-" * 50)

# oc_results_storage = [res[0] for res in results if res is not None]
# para_results_storage = [res[1] for res in results if res is not None]

# oc_is_signal_cases = [r for r in oc_results_storage if r['is_signal']]
# oc_not_signal_cases = [r for r in oc_results_storage if not r['is_signal']]

# oc_false_positives = sum(1 for c in oc_not_signal_cases if c['p_value'] <= alpha)
# oc_fpr = oc_false_positives / len(oc_not_signal_cases)
# print(f"Over-conditioning FPR: {oc_fpr:.4f} (Target: {alpha})")

# oc_true_positives = sum(1 for r in oc_is_signal_cases if r['p_value'] <= alpha)
# oc_tpr = oc_true_positives / len(oc_is_signal_cases)
# print(f"Over-conditioning TPR: {oc_tpr:.4f}")
# print("\n\n")

# # Show parametric result 
# print("-" * 50 + "PARAMETRIC" + "-" * 50)
# para_is_signal_cases = [r for r in para_results_storage if r['is_signal']]
# para_not_signal_cases = [r for r in para_results_storage if not r['is_signal']]

# para_false_positives = sum(1 for c in para_not_signal_cases if c['p_value'] <= alpha)
# para_fpr = para_false_positives / len(para_not_signal_cases)
# print(f"Parametric FPR: {para_fpr:.4f} (Target: {alpha})")

# para_true_positives = sum(1 for r in para_is_signal_cases if r['p_value'] <= alpha)
# para_tpr = para_true_positives / len(para_is_signal_cases)
# print(f"Parametric TPR: {para_tpr:.4f}")

Starting 1000 iterations in parallel...


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 24 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:    5.9s
[Parallel(n_jobs=-1)]: Done  13 tasks      | elapsed:    6.9s
[Parallel(n_jobs=-1)]: Done  24 tasks      | elapsed:    7.9s
[Parallel(n_jobs=-1)]: Done  37 tasks      | elapsed:   11.8s
[Parallel(n_jobs=-1)]: Done  50 tasks      | elapsed:   15.0s
[Parallel(n_jobs=-1)]: Done  65 tasks      | elapsed:   17.0s
[Parallel(n_jobs=-1)]: Done  80 tasks      | elapsed:   20.5s
[Parallel(n_jobs=-1)]: Done  97 tasks      | elapsed:   22.7s
[Parallel(n_jobs=-1)]: Done 114 tasks      | elapsed:   26.2s
[Parallel(n_jobs=-1)]: Done 133 tasks      | elapsed:   30.3s
[Parallel(n_jobs=-1)]: Done 152 tasks      | elapsed:   33.9s
[Parallel(n_jobs=-1)]: Done 173 tasks      | elapsed:   38.4s
[Parallel(n_jobs=-1)]: Done 194 tasks      | elapsed:   42.5s
[Parallel(n_jobs=-1)]: Done 217 tasks      | elapsed:   47.1s
[Parallel(n_jobs=-1)]: Done 240 tasks      | elapsed:  




--------------------------------------------------OVER-CONDITIONING--------------------------------------------------
Over-conditioning FPR: 0.0460 (Target: 0.05)
Over-conditioning TPR: 0.2308



--------------------------------------------------PARAMETRIC--------------------------------------------------
Parametric FPR: 0.0509 (Target: 0.05)
Parametric TPR: 0.7692


[Parallel(n_jobs=-1)]: Done 1000 out of 1000 | elapsed:  3.5min finished
