In [None]:
import numpy as np
import pandas as pd
from joblib import Parallel, delayed

#sys.path.append("../")
from helpers.trainer import train_mse, train_HSIC_IV
from models.kernel import CategoryKernel, RBFKernel
from models.hsicx import NNHSICX
from helpers.utils import med_sigma, to_torch

# Configuration for training
config_hsic = {'batch_size': 256, 'lr': 1e-3, 'max_epoch': 700, 'num_restart': 4}
config_mse = {'batch_size': 256, 'lr': 5e-4, 'max_epoch': 300}

# DGP paths (train and test)
train_files = [
    'data/sec5.1/train_Zbin_g_lin_f_lin.csv',
    'data/sec5.1/train_Zcont_g_lin_f_lin.csv',
    'data/sec5.1/train_Zbin_g_lin_f_log_case.csv',
    'data/sec5.1/train_Zcont_g_lin_f_log_case.csv',
    'data/sec5.1/train_Zbin_g_lin_f_sin_lin.csv',
    'data/sec5.1/train_Zcont_g_lin_f_sin_lin.csv',
]

test_files = [
    'data/sec5.1/test_Zbin_g_lin_f_lin.csv',
    'data/sec5.1/test_Zcont_g_lin_f_lin.csv',
    'data/sec5.1/test_Zbin_g_lin_f_log_case.csv',
    'data/sec5.1/test_Zcont_g_lin_f_log_case.csv',
    'data/sec5.1/test_Zbin_g_lin_f_sin_lin.csv',
    'data/sec5.1/test_Zcont_g_lin_f_sin_lin.csv',
]

# Function to run HSIC and MSE models for a single DGP
def run_for_dgp(train_file, test_file, dgp_idx):
    # Load data
    train_data = np.genfromtxt(train_file, delimiter=',', skip_header=1)
    test_data = np.genfromtxt(test_file, delimiter=',', skip_header=1)
    
    Z, X, Y = train_data[:, 1], train_data[:, 2], train_data[:, 3]
    X_test = test_data[:, 0].astype(np.float32).T
    X_test_grid = test_data[:, 4].astype(np.float32).T

    df_mse = pd.DataFrame()
    df_plot = pd.DataFrame()

    # Determine instrument type based on the DGP index (even -> Binary, odd -> Continuous)
    if dgp_idx % 2 == 0:
        instrument = 'Binary'
    else:
        instrument = 'Continuous'

    # Function to run a single repetition
    def rep_function(i):
        # Pure predictive (MSE)
        mse_net = NNHSICX(input_dim=1, lr=config_mse['lr'], lmd=-99)
        mse_net = train_mse(mse_net, config_mse, X, Y, Z)
        
        # Detach and convert predictions to NumPy
        y_hat_mse = mse_net(to_torch(X_test)).detach().numpy()

        # HSIC IV setup
        s_z = med_sigma(Z)
        kernel_e = RBFKernel(sigma=1)

        if instrument == 'Binary':
            kernel_z = CategoryKernel()
        else:
            kernel_z = RBFKernel(sigma=s_z)

        # Non-regularized HSIC IV
        hsic_net = NNHSICX(input_dim=1, lr=config_hsic['lr'], kernel_e=kernel_e, kernel_z=kernel_z, lmd=0)
        hsic_net.load_state_dict(mse_net)
        hsic_net = train_HSIC_IV(hsic_net, config_hsic, X, Y, Z, verbose=True)

        intercept_adjust = Y.mean() - hsic_net(to_torch(X)).mean().detach()

        y_hat_hsic = intercept_adjust + hsic_net(to_torch(X_test)).detach()
        y_hat_hsic = y_hat_hsic.numpy().copy()

        y_hat_hsic_grid = intercept_adjust + hsic_net(to_torch(X_test_grid)).detach()
        y_hat_hsic_grid = y_hat_hsic_grid.numpy().copy()

        # Add the predictions to DataFrames
        df_mse[f'Run_{i+1}'] = y_hat_hsic
        df_plot[f'Run_{i+1}'] = y_hat_hsic_grid

        return df_mse, df_plot

    # Run 10 repetitions in parallel
    results = Parallel(n_jobs=10)(delayed(rep_function)(i=i) for i in range(10))

    # Combine the results
    df_mse_combined = pd.concat([result[0] for result in results], axis=1)
    df_plot_combined = pd.concat([result[1] for result in results], axis=1)

    # Save results for the current DGP
    df_mse_combined.to_csv(f'results/sec5.1/hsic_result_mse_dgp{dgp_idx}.csv', index=False)
    df_plot_combined.to_csv(f'results/sec5.1/hsic_result_plot_dgp{dgp_idx}.csv', index=False)

# Run the function for each DGP
for idx, (train_file, test_file) in enumerate(zip(train_files, test_files), start=1):
    run_for_dgp(train_file, test_file, idx)
