In [None]:

import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder
import torch
from torch.utils.data import Dataset
from joblib import Parallel, delayed
from methods.toy_model_selection_method import ToyModelSelectionMethod # change input dimension! (f model(s) for X, g model for Y)


# Load dataset in the same way as HSIC-X
dataset_rpe1 = pd.read_csv("data/sec6.2/dataset_rpe1.csv")

# Select relevant columns (same as HSIC-X)
interv_genes = dataset_rpe1.columns[:9].tolist()
train_data = dataset_rpe1[dataset_rpe1['interventions'].isin(interv_genes + ["non-targeting"])].copy()

# Convert intervention column to categorical
train_data['interventions'] = train_data['interventions'].astype('category')
train_data['Ztr'] = train_data.iloc[:, 10].astype('category')

# Get list of unique interventions (excluding "non-targeting")
unique_interventions = [g for g in train_data['interventions'].unique() if g != "non-targeting"]

# Define test data (from 50 test environments)
test_data_path = 'data/sec6.2/test_single_cell.csv'
test_data = torch.tensor(np.genfromtxt(test_data_path, delimiter=',', skip_header=1), dtype=torch.float32)
Xtest = test_data[:, 0:9].reshape(-1, 9)    

In [None]:
# Define function to process a single intervention removal
def gmm_train_stability(i, gene_to_remove, train_data, interv_genes):
    print(f"Iteration {i+1}: Removing intervention {gene_to_remove}")

    # Remove data for the given intervention
    valid_rows = train_data['interventions'].isin(["non-targeting"] + list(set(interv_genes) - {gene_to_remove}))

    # Apply filtering
    filtered_data = train_data[valid_rows]

    # Convert `Ztr` to One-Hot Encoding AFTER removing one environment
    encoder = OneHotEncoder(sparse_output=False)  # Drop first category for consistency
    Ztr_encoded = encoder.fit_transform(filtered_data[['Ztr']])

    # Extract features and target
    Xtr = filtered_data.iloc[:, :9].values  # First 9 columns
    Ytr = filtered_data.iloc[:, 9].values   # 10th column

    # Convert to PyTorch tensors
    X_train = torch.tensor(Xtr, dtype=torch.float32)
    Y_train = torch.tensor(Ytr.reshape(-1, 1), dtype=torch.float32)
    Z_train = torch.tensor(Ztr_encoded, dtype=torch.float32)

    # Define PyTorch Dataset
    class MyDataset(Dataset):
        def __init__(self, X, Z, Y):
            self.X = X
            self.Z = Z
            self.Y = Y

        def __len__(self):
            return len(self.X)

        def __getitem__(self, idx):
            return self.X[idx], self.Z[idx], self.Y[idx]

    # Create dataset and split into training/validation sets
    dataset = MyDataset(X_train, Z_train, Y_train)
    train_ratio = 0.9
    train_size = int(train_ratio * len(dataset))
    val_size = len(dataset) - train_size

    train_data, val_data = torch.utils.data.random_split(dataset, [train_size, val_size])

    # Extract separate tensors for training and validation
    X_train, Z_train, Y_train = train_data.dataset.X[:train_size], train_data.dataset.Z[:train_size], train_data.dataset.Y[:train_size]
    X_val, Z_val, Y_val = val_data.dataset.X[train_size:], val_data.dataset.Z[train_size:], val_data.dataset.Y[train_size:]

    # Train DeepGMM model
    
    deepGMM = ToyModelSelectionMethod()
    deepGMM.fit(X_train.double(), Z_train.double(), Y_train.double(),
                X_val.double(), Z_val.double(), Y_val.double(),
                g_dev=None, verbose=True)

    # Make predictions on the same test data as HSIC-X
    y_hat_deepGMM = deepGMM.predict(torch.tensor(Xtest, dtype=torch.float32).double())
    y_hat_deepGMM = y_hat_deepGMM.detach().numpy().copy()

    return i, y_hat_deepGMM  # Return index and predictions


# Parallel execution for all interventions
if __name__ == "__main__":
    num_iterations = len(unique_interventions)

    results = Parallel(n_jobs=9)(
        delayed(gmm_train_stability)(i, gene_to_remove, train_data, interv_genes)
        for i, gene_to_remove in enumerate(unique_interventions)
    )

    # Store predictions in DataFrame
    df_results = pd.DataFrame({f'Run_{i+1}': y_hat_deepGMM.flatten() for i, y_hat_deepGMM in results})

In [None]:
df_results.columns = unique_interventions

df_results.to_csv("results/sec6.2/deepgmm_singlecell_pairwise.csv", index=False)