In [None]:
import sys
import numpy as np
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
from joblib import Parallel, delayed

sys.path.append("../")
from helpers.trainer import train_mse, train_HSIC_IV
from models.kernel import CategoryKernel, RBFKernel
from models.hsicx import NNHSICX
from helpers.utils import med_sigma, to_torch

import matplotlib.pyplot as plt


In [None]:
config_hsic = {'batch_size': 256, 'lr': 1e-3,
               'max_epoch': 1000, 'num_restart': 4}

config_mse = {'batch_size': 256, 'lr': 1e-3,
              'max_epoch': 500}

In [None]:
dataset_rpe1 = pd.read_csv("R/sec6.2/dataset_rpe1.csv")

interv_genes = dataset_rpe1.columns[:9].tolist()

train_data = dataset_rpe1[dataset_rpe1['interventions'].isin(interv_genes + ["non-targeting"])].copy()

train_data['interventions'] = train_data['interventions'].astype('category')
train_data['Ztr'] = train_data.iloc[:, 10].astype('category')

Xtr = train_data.iloc[:, :9].values  # First 9 columns
Ytr = train_data.iloc[:, 9].values   # 10th column (target)

# Convert `Ztr` to One-Hot Encoding
encoder = OneHotEncoder(sparse_output=False)
Ztr_encoded = encoder.fit_transform(train_data[['Ztr']])

df_mse = pd.DataFrame()

In [None]:
# test data 
test_data_path = 'R/sec6.2/test_single_cell.csv'
test_data = pd.read_csv(test_data_path)
Xtest = test_data.iloc[:, 0:9].values.astype(np.float32)

In [None]:
# Define model parameters
instrument = 'Binary'  # Binary or Continuous
kernel_e = RBFKernel(sigma=1)
kernel_z = CategoryKernel(one_hot=True) if instrument == 'Binary' else RBFKernel(sigma=med_sigma(Ztr_encoded))

In [None]:
# Function to train and predict HSIC-X for a single run
def hsic_train(i, Xtr, Ytr, Ztr_encoded, config_mse, config_hsic, instrument):
    print(f"Starting HSIC-X Run {i + 1}/10...")

    # Train the Pure Predictive model
    mse_net = NNHSICX(input_dim=9, lr=config_mse['lr'], lmd=-99)
    mse_net = train_mse(mse_net, config_mse, Xtr, Ytr, Ztr_encoded)

    # HSIC IV model setup
    s_z = med_sigma(Ztr_encoded)
    kernel_e = RBFKernel(sigma=1)

    if instrument == 'Binary':
        kernel_z = CategoryKernel(one_hot=True)
    else:
        kernel_z = RBFKernel(sigma=s_z)

    # Train HSIC IV model
    hsic_net = NNHSICX(input_dim=9, 
                        lr=config_hsic['lr'], 
                        kernel_e=kernel_e, 
                        kernel_z=kernel_z, 
                        lmd=0)

    hsic_net.load_state_dict(mse_net)
    hsic_net = train_HSIC_IV(hsic_net, config_hsic, Xtr, Ytr, Ztr_encoded, verbose=True)

    # Make predictions
    intercept_adjust = Ytr.mean() - hsic_net(to_torch(Xtr)).mean()
    y_hat_hsic = intercept_adjust + hsic_net(to_torch(Xtest))
    y_hat_hsic = y_hat_hsic.detach().numpy().copy()

    return y_hat_hsic  # Return predictions

In [None]:
# Run the 10 iterations in parallel
num_repeats = 10
results = Parallel(n_jobs=5)(
    delayed(hsic_train)(i, Xtr, Ytr, Ztr_encoded, config_mse, config_hsic, instrument)
    for i in range(num_repeats)
)

In [None]:
df_results = pd.DataFrame(results).T  # Transpose to have runs as columns
df_results.columns = [f'Run_{i+1}' for i in range(num_repeats)]

output_filename = 'results/hsicx_singlecell_10runs.csv'
df_results.to_csv(output_filename, index=False)