In [None]:
from joblib import Parallel, delayed
import torch
from torch import nn
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import os

from methods.toy_model_selection_method import ToyModelSelectionMethod

# Paths to the 10 training and testing datasets
train_files = [
    'data/sec5.1/train_Zbin_g_lin_f_lin.csv',
    'data/sec5.1/train_Zcont_g_lin_f_lin.csv',
    'data/sec5.1/train_Zbin_g_lin_f_log_case.csv',
    'data/sec5.1/train_Zcont_g_lin_f_log_case.csv',
    'data/sec5.1/train_Zbin_g_lin_f_sin_lin.csv',
    'data/sec5.1/train_Zcont_g_lin_f_sin_lin.csv'
]

test_files = [
    'data/sec5.1/test_Zbin_g_lin_f_lin.csv',
    'data/sec5.1/test_Zcont_g_lin_f_lin.csv',
    'data/sec5.1/test_Zbin_g_lin_f_log_case.csv',
    'data/sec5.1/test_Zcont_g_lin_f_log_case.csv',
    'data/sec5.1/test_Zbin_g_lin_f_sin_lin.csv',
    'data/sec5.1/test_Zcont_g_lin_f_sin_lin.csv'
]

# Define the dataset
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Function to be parallelized
def fit_predict_deepGMM(run, X_train, Z_train, Y_train, X_val, Z_val, Y_val, X_test, X_test_grid):
    # Initialize the deepGMM model
    deepGMM = ToyModelSelectionMethod()

    # Fit the model
    deepGMM.fit(X_train.double(), Z_train.double(), Y_train.double(),
                X_val.double(), Z_val.double(), Y_val.double(),
                g_dev=None, verbose=True)

    # Predict
    y_hat_deepGMM = deepGMM.predict(X_test.double()).flatten().detach().numpy()
    y_hat_deepGMM_grid = deepGMM.predict(X_test_grid.double()).flatten().detach().numpy()

    # Return predictions
    return y_hat_deepGMM, y_hat_deepGMM_grid

# Loop through all datasets
# for idx, (train_file, test_file) in enumerate(zip(train_files, test_files), start=1):
for idx, (train_file, test_file) in enumerate(zip(train_files[4:], test_files[4:]), start=5):
    # Load train and test data
    train_data = torch.tensor(np.genfromtxt(train_file, delimiter=',', skip_header=1), dtype=torch.float32)
    test_data = np.genfromtxt(test_file, delimiter=',', skip_header=1)

    data_train_length = train_data.shape[0]
    print(f'Train data size for DGP{idx}: {data_train_length}')

    # Separate the columns into individual tensors
    Z = train_data[:, 1].reshape(-1, 1)
    X = train_data[:, 2].reshape(-1, 1)
    Y = train_data[:, 3].reshape(-1, 1)

    # Create an instance of the dataset
    dataset = MyDataset(train_data)

    # Define the split ratio
    train_ratio = 0.9  # 90% of the data for training, 10% for validation
    train_size = int(train_ratio * len(dataset))
    val_size = len(dataset) - train_size

    # Split the dataset into training and validation sets
    train_data_split, val_data_split = torch.utils.data.random_split(dataset, [train_size, val_size])

    # Separate the columns into individual tensors for train and validation sets
    Z_train = train_data_split.dataset.data[:train_size, 1].reshape(-1, 1)
    X_train = train_data_split.dataset.data[:train_size, 2].reshape(-1, 1)
    Y_train = train_data_split.dataset.data[:train_size, 3].reshape(-1, 1)

    Z_val = val_data_split.dataset.data[train_size:, 1].reshape(-1, 1)
    X_val = val_data_split.dataset.data[train_size:, 2].reshape(-1, 1)
    Y_val = val_data_split.dataset.data[train_size:, 3].reshape(-1, 1)

    # Define X_test and X_test_grid
    X_test = torch.tensor(test_data[:, 0].astype(np.float32)).squeeze()
    X_test_grid = torch.tensor(test_data[:, 4].astype(np.float32)).squeeze()

    # Parallel execution of the training and prediction runs
    results = Parallel(n_jobs=10)(
        delayed(fit_predict_deepGMM)(run, X_train, Z_train, Y_train, X_val, Z_val, Y_val, X_test, X_test_grid)
        for run in range(10)
    )

    # Initialize DataFrames for storing results
    df_mse = pd.DataFrame()
    df_plot = pd.DataFrame()

    # Collect results from parallel execution
    for run, (y_hat_deepGMM, y_hat_deepGMM_grid) in enumerate(results, start=1):
        # Add the results as a new column to the DataFrame
        df_mse[f'Run_{run}'] = y_hat_deepGMM
        df_plot[f'Run_{run}'] = y_hat_deepGMM_grid

    # Save the results to CSV files for the current dataset
    df_mse.to_csv(f'results/sec5.1/deepgmm_result_mse_dgp{idx}.csv', index=False)
    df_plot.to_csv(f'results/sec5.1/deepgmm_result_plot_dgp{idx}.csv', index=False)