In [None]:
import sys
import numpy as np
import pandas as pd
from sklearn.preprocessing import OneHotEncoder

sys.path.append("../")
from helpers.trainer import train_mse, train_HSIC_IV
from models.kernel import CategoryKernel, RBFKernel
from models.hsicx import NNHSICX
from helpers.utils import med_sigma, to_torch
from joblib import Parallel, delayed

import matplotlib.pyplot as plt

In [None]:
config_hsic = {'batch_size': 256, 'lr': 1e-3,
               'max_epoch': 1000, 'num_restart': 4}

config_mse = {'batch_size': 256, 'lr': 1e-3,
              'max_epoch': 500}

In [None]:
dataset_rpe1 = pd.read_csv("data/sec6.2/dataset_rpe1.csv")

interv_genes = dataset_rpe1.columns[:9].tolist()

train_data = dataset_rpe1[dataset_rpe1['interventions'].isin(interv_genes + ["non-targeting"])].copy()

train_data['interventions'] = train_data['interventions'].astype('category')
train_data['Ztr'] = train_data.iloc[:, 10].astype('category')

Xtr = train_data.iloc[:, :9].values  # First 9 columns
Ytr = train_data.iloc[:, 9].values   # 10th column (target)

# Convert `Ztr` to One-Hot Encoding
encoder = OneHotEncoder(sparse_output=False)
Ztr_encoded = encoder.fit_transform(train_data[['Ztr']])

# Get unique interventions excluding "non-targeting"
unique_interventions = [g for g in train_data['interventions'].unique() if g != "non-targeting"]
num_iterations = len(unique_interventions)  # Excluding one per iteration

df_mse = pd.DataFrame()

In [None]:
# test data (from 50 test environments)
test_data_path = 'data/sec6.2/test_single_cell.csv'
test_data = pd.read_csv(test_data_path)
Xtest = test_data.iloc[:, 0:9].values.astype(np.float32)

In [None]:
# Define model parameters
instrument = 'Binary'  # Binary or Continuous
kernel_e = RBFKernel(sigma=1)
kernel_z = CategoryKernel(one_hot=True) if instrument == 'Binary' else RBFKernel(sigma=med_sigma(Ztr_encoded))

In [None]:
# Define function to process a single gene removal
def process_gene_removal(i, gene_to_remove, train_data, Xtr, Ytr, Ztr_encoded, interv_genes, config_mse, config_hsic, instrument):
    print(f"\nIteration {i+1}: Removing intervention {gene_to_remove}")

    # Correct filtering
    valid_rows = train_data['interventions'].isin(["non-targeting"] + list(set(interv_genes) - {gene_to_remove}))

    # Debug: Check row count
    print(f"Valid rows count: {valid_rows.sum()} out of {len(train_data)}")

    # Apply filtering
    Xtr_excl, Ytr_excl, Ztr_excl = Xtr[valid_rows], Ytr[valid_rows], Ztr_encoded[valid_rows]

    # Train the Pure Predictive model
    mse_net = NNHSICX(input_dim=9, lr=config_mse['lr'], lmd=-99)
    mse_net = train_mse(mse_net, config_mse, Xtr_excl, Ytr_excl, Ztr_excl)

    # HSIC IV model setup
    s_z = med_sigma(Ztr_excl)
    kernel_e = RBFKernel(sigma=1)

    if instrument == 'Binary':
        kernel_z = CategoryKernel(one_hot=True)
    else:
        kernel_z = RBFKernel(sigma=s_z)

    # Train HSIC IV model
    hsic_net = NNHSICX(input_dim=9, 
                        lr=config_hsic['lr'], 
                        kernel_e=kernel_e, 
                        kernel_z=kernel_z, 
                        lmd=0)

    hsic_net.load_state_dict(mse_net)
    hsic_net = train_HSIC_IV(hsic_net, config_hsic, Xtr_excl, Ytr_excl, Ztr_excl, verbose=True)

    # Make predictions
    intercept_adjust = Ytr_excl.mean() - hsic_net(to_torch(Xtr_excl)).mean()

    y_hat_hsic = intercept_adjust + hsic_net(to_torch(Xtest))
    y_hat_hsic = y_hat_hsic.detach().numpy().copy()

    return i, y_hat_hsic  # Return index and predictions


# Parallel execution
if __name__ == "__main__":
    num_iterations = len(unique_interventions)

    results = Parallel(n_jobs=9)(
        delayed(process_gene_removal)(i, gene_to_remove, train_data, Xtr, Ytr, Ztr_encoded, 
                                      interv_genes, config_mse, config_hsic, instrument)
        for i, gene_to_remove in enumerate(unique_interventions)
    )

    # Store predictions in DataFrame
    df_mse = pd.DataFrame({f'Run_{i+1}': y_hat_hsic for i, y_hat_hsic in results})

In [None]:
df_mse.columns = unique_interventions
df_mse.to_csv("results/sec6.2/hsic_singlecell_pairwise.csv", index=False)