In [None]:
import sys
import numpy as np
import pandas as pd

#sys.path.append("../")
from helpers.trainer import train_mse, train_HSIC_IV
from models.kernel import CategoryKernel, RBFKernel
from models.hsicx import NNHSICX, LinearHSICX, PolyHSICX
from helpers.utils import med_sigma, to_torch

from joblib import Parallel, delayed
import matplotlib.pyplot as plt


In [None]:
# Define the function to process a single alpha
def process_alpha(alpha, runs, config_mse, config_hsic, instrument):
    train_data_path = f'data/sec5.4/train_Zcont_g_fct_f_softplusalpha{alpha}.csv'
    test_data_path = f'data/sec5.4/test_Zcont_g_fct_f_softplusalpha{alpha}.csv'

    train_data = np.genfromtxt(train_data_path, delimiter=',', skip_header=1)
    test_data = np.genfromtxt(test_data_path, delimiter=',', skip_header=1)

    Z, X, Y = train_data[:, 0], train_data[:, 1], train_data[:, 2]
    X_test = test_data[:, 0].astype(np.float32).T
    # X_test_grid = test_data[:, 2].astype(np.float32).T

    df_mse = pd.DataFrame()
    # df_plot = pd.DataFrame()

    for run in range(runs):
        # Train the Pure Predictive model
        mse_net = NNHSICX(input_dim=1, lr=config_mse['lr'], lmd=-99)
        mse_net = train_mse(mse_net, config_mse, X, Y, Z)

        # HSIC IV model setup
        s_z = med_sigma(Z)
        kernel_e = RBFKernel(sigma=1)

        if instrument == 'Binary':
            kernel_z = CategoryKernel()
        else:
            kernel_z = RBFKernel(sigma=s_z)

        hsic_net = NNHSICX(input_dim=1, 
                           lr=config_hsic['lr'], 
                           kernel_e=kernel_e, 
                           kernel_z=kernel_z, 
                           lmd=0)

        hsic_net.load_state_dict(mse_net)
        hsic_net = train_HSIC_IV(hsic_net, config_hsic, X, Y, Z, verbose=True)

        # Make predictions
        intercept_adjust = Y.mean() - hsic_net(to_torch(X)).mean()

        y_hat_hsic = intercept_adjust + hsic_net(to_torch(X_test))
        y_hat_hsic = y_hat_hsic.detach().numpy().copy()

        # y_hat_hsic_grid = intercept_adjust + hsic_net(to_torch(X_test_grid))
        # y_hat_hsic_grid = y_hat_hsic_grid.detach().numpy().copy()

        df_mse[f'Run_{run+1}'] = y_hat_hsic
        # df_plot[f'Run_{run+1}'] = y_hat_hsic_grid

    # Save results for the current alpha
    df_mse.to_csv(f'results/sec5.4/hsic_result_mse_alpha{alpha}.csv', index=False)
    # df_plot.to_csv(f'results/hsic_result_plot_alpha{alpha}.csv', index=False)

    # Optional plot
    # plt.scatter(X_test_grid, df_plot.iloc[:, -1], s=1, color='red', label=f'HSIC-X Alpha {alpha}')
    # plt.scatter(X_test_grid, test_data[:, 3], s=1)
    # plt.legend()
    # plt.show()

# Parallel execution
if __name__ == "__main__":
    from joblib import Parallel, delayed

    # Input values
    alpha_values = [0, 1, 5]
    runs = 10
    config_hsic = {'batch_size': 256, 'lr': 1e-3,
               'max_epoch': 700, 'num_restart': 4}

    config_mse = {'batch_size': 256, 'lr': 1e-3,
              'max_epoch': 300}
    instrument = 'Continuous'  # or 'Binary'

    # Parallelize the loop
    Parallel(n_jobs=-1)(
        delayed(process_alpha)(alpha, runs, config_mse, config_hsic, instrument) 
        for alpha in alpha_values
    )