## Use original parametric

In [None]:
import numpy as np
from sklearn.linear_model import Lasso
import utils
from CoRT_builder import CoRT
import importlib
import matplotlib.pyplot as plt
import oc
import parametric_optim
importlib.reload(utils)
importlib.reload(oc)

n_target = 30
n_source = 10
p = 10
K = 3
Ka = 1
h = 30
lamda = 0.1
s_vector = [1,1,1]
T = 3
s = len(s_vector)
CoRT_model = CoRT(alpha=lamda)
oc_results_storage = []
para_results_storage = []
alpha = 0.05
iteration = 200

for i in range(iteration):
    if i % 50 == 0:
        print(f"Processing iter: {i}")

    target_data, source_data = CoRT_model.gen_data(n_target, n_source, p, K, Ka, h, s_vector, s, "AR")
    similar_source_index = CoRT_model.find_similar_source(n_target, K, target_data, source_data, T=T, verbose=False)
    X_combined, y_combined = CoRT_model.prepare_CoRT_data(similar_source_index, source_data, target_data)

    model = Lasso(alpha=lamda, fit_intercept=False, tol=1e-10, max_iter=10000000)
    model.fit(X_combined, y_combined.ravel())
    beta_hat_target = model.coef_[-p:]

    active_indices = np.array([i for i, b in enumerate(beta_hat_target) if b != 0])
    initial_active_indices = active_indices

    if len(active_indices) == 0:
        print(f"Iteration {iter}: Lasso selected no features. Skipping.")
        continue

    j = np.random.choice(len(active_indices))

    selected_feature_index = active_indices[j]

    X_target = target_data["X"]
    y_target = target_data["y"]
    X_active, X_inactive = utils.get_active_X(beta_hat_target, X_target)

    etaj, etajTy = utils.construct_test_statistic(y_target, j, X_active)

    Sigma = np.eye(n_target)
    b_global = Sigma @ etaj @ np.linalg.pinv(etaj.T @ Sigma @ etaj)
    a_global = (Sigma - b_global @ etaj.T) @ y_target

    folds = utils.split_target(T, X_target, y_target, n_target)

    # OVER-CONDITIONING
    L_train, R_train = oc.get_Z_train(etajTy, folds, source_data, a_global, b_global, lamda, K, T)
    L_val, R_val = oc.get_Z_val(folds, T, K, a_global, b_global, etajTy, lamda, source_data)
    L_CoRT, R_CoRT, Az = oc.get_Z_CoRT(X_combined, similar_source_index, lamda, a_global, b_global, source_data, etajTy)

    L_final, R_final = oc.combine_Z(L_train, R_train, L_val, R_val, L_CoRT, R_CoRT)

    etaT_sigma_eta = (etaj.T @ Sigma @ etaj).item()
    sigma_z = np.sqrt(etaT_sigma_eta)
    truncated_cdf = utils.computed_truncated_cdf(L_final, R_final, etajTy, 0, sigma_z)
    oc_p_value = 2 * min(truncated_cdf, 1 - truncated_cdf)

    is_signal = (selected_feature_index < s) 
    oc_results_storage.append({
            "p_value": oc_p_value,
            "is_signal": is_signal,
            "feature_idx": selected_feature_index
    })

    # PARAMETRIC
    z_k = -20
    z_max = 20

    Z_train_list = parametric_optim.get_Z_train(z_k, folds, source_data, a_global, b_global, lamda, K, T)
    Z_val_list = parametric_optim.get_Z_val(z_k, folds, T, K, a_global, b_global, lamda, source_data)

    target_data_current = {"X": X_target, "y": a_global + z_k * b_global}
    similar_source_current = parametric_optim.find_similar_source(z_k, a_global, b_global, lamda,  n_target, K, target_data_current, source_data, T=T, verbose=False)
    X_combined_new, y_combined_new = CoRT_model.prepare_CoRT_data(similar_source_current, source_data, target_data_current)
    L_CoRT, R_CoRT, Az = parametric_optim.get_Z_CoRT(X_combined_new, similar_source_current, lamda, a_global, b_global, source_data, z_k)

    offset = p * len(similar_source_index)
    Az_target_only = np.array([idx - offset for idx in Az if idx >= offset])

    z_list = [z_k]
    Az_list = []

    step_count = 0
    matched_active_set = None

    stopper = "empty"
    while z_k < z_max:
        current_num_sources = len(similar_source_current)
        offset = p * current_num_sources
        Az_target_current = np.array([idx - offset for idx in Az if idx >= offset])
        Az_list.append(Az_target_current)
        
        mn = z_max
        stopper = "MAX"

        for val in Z_train_list:
            if mn > val[4]:
                mn = val[4]
                stopper = "TRAIN"

        for val in Z_val_list:
            if mn > val[3]:
                mn = val[3]
                stopper = "VAL"

        if mn > R_CoRT:
            mn = R_CoRT
            stopper = "CORT"

        R_final = mn

        if R_final - z_k < -1e-9:
            print("[WARNING] R_final is before zk")
            z_k += 0.001

        z_k = max(R_final, z_k) + 1e-5

        if (z_k >= z_max):
            z_list.append(z_max)
            break
        else:
            z_list.append(z_k)

        update_train_needed = False
        update_val_needed = False
        update_cort_needed = False
        
        if stopper == "TRAIN":
            update_train_needed = True
            update_val_needed = True   
            update_cort_needed = True

        elif stopper == "VAL":
            update_val_needed = True
            update_cort_needed = True

        elif stopper == "CORT":
            update_cort_needed = True

        if update_train_needed:
            for val in Z_train_list:
                if val[4] <= z_k + 1e-9:
                    l, r = parametric_optim.update_Z_train(val, z_k, folds, source_data, a_global, b_global, lamda, K, T)
                    val[3] = l
                    val[4] = r

        if update_val_needed:
            for val in Z_val_list:
                l, r = parametric_optim.update_Z_val(val, z_k, folds, T, K, a_global, b_global, lamda, source_data)
                val[2] = l
                val[3] = r

        if update_cort_needed:
            target_data_current = {"X": X_target, "y": a_global + z_k * b_global}
            similar_source_current = parametric_optim.find_similar_source(z_k, a_global, b_global, lamda, n_target, K, target_data_current, source_data, T=T, verbose=False)
            X_combined_new, y_combined_new = CoRT_model.prepare_CoRT_data(similar_source_current, source_data, target_data_current)
            L_CoRT, R_CoRT, Az = parametric_optim.get_Z_CoRT(X_combined_new, similar_source_current, lamda, a_global, b_global, source_data, z_k)

    para_p_value = parametric_optim.pivot(active_indices, Az_list, z_list, etaj, etajTy, 0, Sigma)

    if para_p_value == 0:
        print(" WARNING: p-value is 0")
    if para_p_value is None:
        print(" WARNING: p-value is None")

    is_signal = (selected_feature_index < s) 
    para_results_storage.append({
        "p_value": para_p_value,
        "is_signal": is_signal,
        "feature_idx": selected_feature_index
    })

# Show Over-conditioning result
oc_is_signal_cases = [r for r in oc_results_storage if r['is_signal']]
oc_not_signal_cases = [r for r in oc_results_storage if not r['is_signal']]

oc_false_positives = sum(1 for c in oc_not_signal_cases if c['p_value'] <= alpha)
oc_fpr = oc_false_positives / len(oc_not_signal_cases)
print(f"Over-conditioning FPR: {oc_fpr:.4f} (Target: {alpha})")

oc_true_positives = sum(1 for r in oc_is_signal_cases if r['p_value'] <= alpha)
oc_tpr = oc_true_positives / len(oc_is_signal_cases)
print(f"Over-conditioning TPR: {oc_tpr:.4f}")
print("="*50)

# Show parametric result 
para_is_signal_cases = [r for r in para_results_storage if r['is_signal']]
para_not_signal_cases = [r for r in para_results_storage if not r['is_signal']]

para_false_positives = sum(1 for c in para_not_signal_cases if c['p_value'] <= alpha)
para_fpr = para_false_positives / len(para_not_signal_cases)
print(f"Parametric FPR: {para_fpr:.4f} (Target: {alpha})")

para_true_positives = sum(1 for r in para_is_signal_cases if r['p_value'] <= alpha)
para_tpr = para_true_positives / len(para_is_signal_cases)
print(f"Parametric TPR: {para_tpr:.4f}")

Processing iter: 0


KeyboardInterrupt: 

## Use optimized parametric

In [None]:
import numpy as np
from sklearn.linear_model import Lasso
import utils
from CoRT_builder import CoRT
import parametric_optim
import importlib
import matplotlib.pyplot as plt
import oc
importlib.reload(utils)
importlib.reload(oc)

n_target = 30
n_source = 10
p = 10
K = 3
Ka = 1
h = 30
lamda = 0.1
s_vector = [1] * 3
T = 3
s = len(s_vector)
CoRT_model = CoRT(alpha=lamda)
alpha = 0.05
iteration = 1000

oc_results_storage = []
para_results_storage = []

for i in range(iteration):
    if i % 100 == 0:
        print(f"Processing iter: {i}")

    target_data, source_data = CoRT_model.gen_data(n_target, n_source, p, K, Ka, h, s_vector, s, "AR")
    similar_source_index = CoRT_model.find_similar_source(n_target, K, target_data, source_data, T=T, verbose=False)
    X_combined, y_combined = CoRT_model.prepare_CoRT_data(similar_source_index, source_data, target_data)

    model = Lasso(alpha=lamda, fit_intercept=False, tol=1e-10, max_iter=10000000)
    model.fit(X_combined, y_combined.ravel())
    beta_hat_target = model.coef_[-p:]

    active_indices = np.array([i for i, b in enumerate(beta_hat_target) if b != 0])
    initial_active_indices = active_indices

    if len(active_indices) == 0:
        print(f"Iteration {iter}: Lasso selected no features. Skipping.")
        continue

    j = np.random.choice(len(active_indices))

    selected_feature_index = active_indices[j]

    X_target = target_data["X"]
    y_target = target_data["y"]
    X_active, X_inactive = utils.get_active_X(beta_hat_target, X_target)

    etaj, etajTy = utils.construct_test_statistic(y_target, j, X_active)

    Sigma = np.eye(n_target)
    b_global = Sigma @ etaj @ np.linalg.pinv(etaj.T @ Sigma @ etaj)
    a_global = (Sigma - b_global @ etaj.T) @ y_target

    folds = utils.split_target(T, X_target, y_target, n_target)

    # OVER-CONDITIONING
    L_train, R_train = oc.get_Z_train(etajTy, folds, source_data, a_global, b_global, lamda, K, T)
    L_val, R_val = oc.get_Z_val(folds, T, K, a_global, b_global, etajTy, lamda, source_data)
    L_CoRT, R_CoRT, Az = oc.get_Z_CoRT(X_combined, similar_source_index, lamda, a_global, b_global, source_data, etajTy)

    L_final, R_final = oc.combine_Z(L_train, R_train, L_val, R_val, L_CoRT, R_CoRT)

    etaT_sigma_eta = (etaj.T @ Sigma @ etaj).item()
    sigma_z = np.sqrt(etaT_sigma_eta)
    truncated_cdf = utils.computed_truncated_cdf(L_final, R_final, etajTy, 0, sigma_z)
    oc_p_value = 2 * min(truncated_cdf, 1 - truncated_cdf)

    is_signal = (selected_feature_index < s) 
    oc_results_storage.append({
            "p_value": oc_p_value,
            "is_signal": is_signal,
            "feature_idx": selected_feature_index
    })

    # PARAMETRIC
    z_k = -20
    z_max = 20

    Z_train_list = parametric_optim.get_Z_train(z_k, folds, source_data, a_global, b_global, lamda, K, T)
    Z_val_list = parametric_optim.get_Z_val(z_k, folds, T, K, a_global, b_global, lamda, source_data)

    target_data_current = {"X": X_target, "y": a_global + z_k * b_global}
    similar_source_current = parametric_optim.find_similar_source(z_k, a_global, b_global, lamda,  n_target, K, target_data_current, source_data, T=T, verbose=False)
    X_combined_new, y_combined_new = CoRT_model.prepare_CoRT_data(similar_source_current, source_data, target_data_current)
    L_CoRT, R_CoRT, Az = parametric_optim.get_Z_CoRT(X_combined_new, similar_source_current, lamda, a_global, b_global, source_data, z_k)

    offset = p * len(similar_source_index)
    Az_target_only = np.array([idx - offset for idx in Az if idx >= offset])

    z_list = [z_k]
    Az_list = []
    matched_active_set = None

    stopper = "empty"
    while z_k < z_max:
        current_num_sources = len(similar_source_current)
        offset = p * current_num_sources
        Az_target_current = np.array([idx - offset for idx in Az if idx >= offset])
        Az_list.append(Az_target_current)
        
        mn = z_max
        stopper = "MAX"

        for val in Z_train_list:
            if mn > val[4]:
                mn = val[4]
                stopper = "TRAIN"

        for val in Z_val_list:
            if mn > val[3]:
                mn = val[3]
                stopper = "VAL"

        if mn > R_CoRT:
            mn = R_CoRT
            stopper = "CORT"

        R_final = mn

        if R_final - z_k < -1e-9:
            print("[WARNING] R_final is before zk")
            z_k += 0.001

        z_k = max(R_final, z_k) + 1e-5

        if (z_k >= z_max):
            z_list.append(z_max)
            break
        else:
            z_list.append(z_k)

        update_train_needed = False
        update_val_needed = False
        update_cort_needed = False
        
        if stopper == "TRAIN":
            update_train_needed = True
            update_val_needed = True   
            update_cort_needed = True

        elif stopper == "VAL":
            update_val_needed = True
            update_cort_needed = True

        elif stopper == "CORT":
            update_cort_needed = True

        if update_train_needed:
            for val in Z_train_list:
                if val[4] <= z_k + 1e-9:
                    l, r = parametric_optim.update_Z_train(val, z_k, folds, source_data, a_global, b_global, lamda, K, T)
                    val[3] = l
                    val[4] = r

        if update_val_needed:
            for val in Z_val_list:
                l, r = parametric_optim.update_Z_val(val, z_k, folds, T, K, a_global, b_global, lamda, source_data)
                val[2] = l
                val[3] = r

        if update_cort_needed:
            target_data_current = {"X": X_target, "y": a_global + z_k * b_global}
            similar_source_current = parametric_optim.find_similar_source(z_k, a_global, b_global, lamda, n_target, K, target_data_current, source_data, T=T, verbose=False)
            X_combined_new, y_combined_new = CoRT_model.prepare_CoRT_data(similar_source_current, source_data, target_data_current)
            L_CoRT, R_CoRT, Az = parametric_optim.get_Z_CoRT(X_combined_new, similar_source_current, lamda, a_global, b_global, source_data, z_k)

    para_p_value = parametric_optim.pivot(active_indices, Az_list, z_list, etaj, etajTy, 0, Sigma)

    if para_p_value == 0:
        print(" WARNING: p-value is 0")
    if para_p_value is None:
        print(" WARNING: p-value is None")

    is_signal = (selected_feature_index < s) 
    para_results_storage.append({
        "p_value": para_p_value,
        "is_signal": is_signal,
        "feature_idx": selected_feature_index
    })

# Show Over-conditioning result
oc_is_signal_cases = [r for r in oc_results_storage if r['is_signal']]
oc_not_signal_cases = [r for r in oc_results_storage if not r['is_signal']]

oc_false_positives = sum(1 for c in oc_not_signal_cases if c['p_value'] <= alpha)
oc_fpr = oc_false_positives / len(oc_not_signal_cases)
print(f"Over-conditioning FPR: {oc_fpr:.4f} (Target: {alpha})")

oc_true_positives = sum(1 for r in oc_is_signal_cases if r['p_value'] <= alpha)
oc_tpr = oc_true_positives / len(oc_is_signal_cases)
print(f"Over-conditioning TPR: {oc_tpr:.4f}")
print("="*50)

# Show parametric result 
para_is_signal_cases = [r for r in para_results_storage if r['is_signal']]
para_not_signal_cases = [r for r in para_results_storage if not r['is_signal']]

para_false_positives = sum(1 for c in para_not_signal_cases if c['p_value'] <= alpha)
para_fpr = para_false_positives / len(para_not_signal_cases)
print(f"Parametric FPR: {para_fpr:.4f} (Target: {alpha})")

para_true_positives = sum(1 for r in para_is_signal_cases if r['p_value'] <= alpha)
para_tpr = para_true_positives / len(para_is_signal_cases)
print(f"Parametric TPR: {para_tpr:.4f}")

Processing iter: 0
Processing iter: 20
Processing iter: 40
Processing iter: 60
Processing iter: 80
Processing iter: 100
Processing iter: 120
Processing iter: 140
Processing iter: 160
Processing iter: 180
Processing iter: 200
Processing iter: 220
Processing iter: 240
Processing iter: 260
Processing iter: 280
Processing iter: 300
Processing iter: 320
Processing iter: 340
Processing iter: 360
Processing iter: 380
Processing iter: 400
Processing iter: 420
Processing iter: 440
Processing iter: 460
Processing iter: 480
Processing iter: 500
Processing iter: 520
Processing iter: 540
Processing iter: 560
Processing iter: 580
Processing iter: 600
Processing iter: 620
Processing iter: 640
Processing iter: 660
Processing iter: 680
Processing iter: 700
Processing iter: 720
Processing iter: 740
Processing iter: 760
Processing iter: 780
Processing iter: 800
Processing iter: 820
Processing iter: 840
Processing iter: 860
Processing iter: 880
Processing iter: 900
Processing iter: 920
Processing iter: 94