In [1]:
from data_processing import DataProcessor

In [4]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from scipy.stats import skew, kurtosis, sigmaclip
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler, Normalizer
import pandas as pd
import torch as tc
import torch.utils.data as tcud

class DataProcessor:
    def __init__(self, data_path):
        self.data_path = data_path
        self.data = None
        self.train_data = None
        self.val_data = None
        
    def load_data(self):
        """Load dataset from a CSV file."""
        try:
            if self.data_path is None:
                raise ValueError("Data path is not provided.")
            #self.data = pd.read_csv(self.data_path).drop(['Unnamed: 0'], axis=1)
            self.data = pd.read_feather(self.data_path)
            print("Data loaded successfully.")
        except Exception as e:
            print(f"Error loading data: {e}")
        
    def data_info(self):
        if self.data is not None:
            return self.data.info()
        else:
            return 'Data is not loaded'
    
    def basic_statistics(self):
        """Compute basic statistics for numeric columns."""
        if self.data is not None:
            stats = self.data.describe().T
            stats["Skewness"] = self.data.skew()
            stats["Kurtosis"] = self.data.kurt()
            print("\nData statistics")
            print(stats)
            return stats
        else:
            raise ValueError
    
    def non_category_features(self):
        if self.data is not None:
            self.non_categorical_features = self.data.select_dtypes(exclude=['object']).columns
            return self
        else:
            return 'Data is not loaded'
    
    def missing_values(self):
        """Check for missing values in the dataset."""
        if self.data is not None:
            return self.data.isnull().sum()
        else:
            print("Data is not loaded.")
            return None
    
    def duplicated_values(self):
        if self.data is not None:
            return self.data.duplicated().sum()
        else:
            return 'Data is not loaded'
    
    def data_columns(self):
        if self.data is not None:
            return self.data.columns
        else:
            return 'Data is not loaded'
    
    def outlier_analysis(self):
        self.Q1 = self.data[self.data.columns].quantile(0.25)
        self.Q3 = self.data[self.data.columns].quantile(0.75)
        self.IQR = self.Q3 - self.Q1

        # define the boundary
        self.lower_bound = self.Q1 - 1.5 * self.IQR
        self.upper_bound = self.Q3 + 1.5 * self.IQR

        # checking outliers
        self.outliers = self.data[(self.data[self.data.columns] < self.lower_bound) | (self.data[self.data.columns] > self.upper_bound)]

        # number of outliers
        self.n_outliers = self.outliers.shape[0]

        # percentage outlier
        self.pct_outliers = (self.n_outliers / self.data.shape[0]) * 100

        return self.n_outliers, self.pct_outliers, self.Q1, self.Q3, self.IQR, self.lower_bound, self.upper_bound

    def analyze_outliers(self):
        outlier_summary = []
        for column in self.non_categorical_features:
            n_outliers, pct_outliers, Q1, Q3, IQR, lower, upper_bound = self.outlier_analysis()
            if n_outliers > 0:
                outlier_summary.append({
                    'Feature': column,
                    'Number of Outliers': n_outliers,
                    'Percentage (%)': pct_outliers,
                    'Q1': Q1,
                    'Q3': Q3,
                    'IQR': IQR,
                    'Lower Bound': lower,
                    'Upper Bound': upper_bound,
                })
        return pd.DataFrame(outlier_summary)

    def plot_distribution(self, features, output_dir, title_prefix=""):
        """
        Plot and save distribution histograms for the given features.
        Args:
            features (list): List of feature names to plot.
            output_dir (str): Directory to save the plots.
            title_prefix (str): Prefix for the plot titles (optional).
        """
        os.makedirs(output_dir, exist_ok=True)
        
        for feature in features:
            try:
                plt.figure(figsize=(10, 6), dpi=300)
                sns.histplot(self.data[feature], bins=50, kde=True, color='skyblue', edgecolor='black')
                plt.title(f"{title_prefix}Histogram of {feature}", fontsize=14)
                plt.xlabel(feature, fontsize=12)
                plt.ylabel("Frequency", fontsize=12)
                plt.xticks(fontsize=10)
                plt.yticks(fontsize=10)
                plt.tight_layout()
                
                # Save the plot
                file_path = os.path.join(output_dir, f"{feature}.png")
                plt.savefig(file_path, dpi=600.00)
                plt.close()
            except Exception as e:
                print(f"Error plotting feature {feature}: {e}")
    
    def data_distribution(self):
        """
        Generate and save histograms for the non-categorical features.
        """
        if self.data is None:
            print("Data is not loaded.")
            return
        
        output_dir = "../output/plots/distribution"
        non_categorical_features = self.data.select_dtypes(exclude=["object"]).columns.tolist()
        self.plot_distribution(non_categorical_features, output_dir, title_prefix="Original Data: ")
        print("\nHistograms for original data saved successfully.")

    
    def correlation_heatmap(self):
        """Create a correlation matrix heatmap."""
        corr_path = '../output/plots/correlation'
        os.makedirs(corr_path, exist_ok=True)
        
        corr_matrix = self.data.drop(['Image_ID'], axis=1)
        corr_matrix = corr_matrix.corr(numeric_only=True)
        plt.figure(figsize=(16,10), dpi=600.00)
        sns.heatmap(corr_matrix, annot=True, cmap='inferno', square=True)
        plt.tight_layout()
        plt.savefig(fname=f'{corr_path}/corr_matrix.png', dpi=600.00)
        plt.close()
        print('\nCorrelation matrix saved successfully')
        
    
    def split_data(self, test_size=0.2, random_state=42):
        """Split the data into training and validation sets."""
        try:
            self.train_set, self.val_set = train_test_split(self.data, test_size=test_size, random_state=random_state)
            print("\nData split into training and validation sets.")
        except Exception as e:
            print(f"Error splitting data: {e}")
    
    def handle_missing_values(self, strategy="mean"):
        """Handle missing values using the specified strategy."""
        if self.data is not None:
            imputer = SimpleImputer(strategy=strategy)
            numeric_cols = self.data.select_dtypes(include=["float64", "int64"]).columns
            self.data[numeric_cols] = imputer.fit_transform(self.data[numeric_cols])
            print("\nMissing values handled.")
        else:
            print("Data is not loaded.")

    def detect_outliers(self, columns=None, iqr_factor=1.5):
        """Detect outliers using the IQR method."""
        if columns is None:
            columns = self.data.select_dtypes(include=["float64", "int64"]).columns
        
        outliers = {}
        Q1 = self.data[columns].quantile(0.25)
        Q3 = self.data[columns].quantile(0.75)
        IQR = Q3 - Q1
        lower_bound = Q1 - iqr_factor * IQR
        upper_bound = Q3 + iqr_factor * IQR
        
        for column in columns:
            outliers[column] = {
                "lower_bound": lower_bound[column],
                "upper_bound": upper_bound[column],
                "outliers": self.data[column][(self.data[column] < lower_bound[column]) | (self.data[column] > upper_bound[column])].count()
            }
        
        return outliers

    def detect_skewed_cols(self, columns=None, skew_threshold=0.5):
        """Detect skewed columns based on the skew threshold."""
        if columns is None:
            columns = self.data.select_dtypes(include=["float64", "int64"]).columns
        skewness = self.data[columns].skew()
        return skewness[skewness > skew_threshold]

    def detect_clip_cols(self, columns, lower_percentile=0.01, upper_percentile=0.99):
        """Detect columns needing clipping based on percentiles."""
        lower_limits = self.data[columns].quantile(lower_percentile)
        upper_limits = self.data[columns].quantile(upper_percentile)

        clip_cols = {}
        for col in columns:
            outliers = self.data[col][(self.data[col] < lower_limits[col]) | (self.data[col] > upper_limits[col])].count()
            clip_cols[col] = {
                "lower_limit": lower_limits[col],
                "upper_limit": upper_limits[col],
                "outliers": outliers
            }
        return clip_cols

    def auto_handle_outliers(self, numerical_columns):
        """Automatically handle outliers in the dataset."""
        outliers = self.detect_outliers(columns=numerical_columns)
        cols_to_impute = [col for col, stats in outliers.items() if stats["outliers"] > 0]

        skewed_cols = self.detect_skewed_cols(columns=numerical_columns)
        cols_to_transform = skewed_cols.index.tolist()

        clip_cols = self.detect_clip_cols(numerical_columns)
        cols_to_clip = [col for col, stats in clip_cols.items() if stats["outliers"] > 0]

        imputer = SimpleImputer(strategy="median")
        self.data[cols_to_impute] = imputer.fit_transform(self.data[cols_to_impute])

        for col in cols_to_transform:
            self.data[col] = np.log1p(self.data[col])

        for col in cols_to_clip:
            lower_limit = self.data[col].quantile(0.01)
            upper_limit = self.data[col].quantile(0.99)
            self.data[col] = self.data[col].clip(lower=lower_limit, upper=upper_limit)

        return cols_to_impute, cols_to_transform, cols_to_clip

    def separate_obs_data(self, data):
        """Separate observation and feature data."""
        self.obs_data_columns = [
            "GroupID", "SnapNum", "Redshift", "RandomNumber", "GalaxyID", "DescendantID",
            "LastProgID", "TopLeafID", "GroupNumber", "SubGroupNumber", "nodeIndex",
            "PosInFile", "FileNum", "Image_ID"
        ]
        self.obs_data = data[self.obs_data_columns]
        self.features = data.drop(columns=self.obs_data_columns)
        return self.obs_data, self.features

    def scale_data(self, scaler_type="standard"):
        """Scale numeric features using the specified scaler type."""
        scaler = {
            "standard": StandardScaler(),
            "robust": RobustScaler(),
            "minmax": MinMaxScaler()
        }.get(scaler_type)
        if scaler is None:
            raise ValueError(f"Invalid scaler type: {scaler_type}. Choose from 'standard', 'robust', or 'minmax'.")
        
        self.train_data_sc, self.train_features = self.separate_obs_data(self.train_set)
        self.val_data_sc, self.val_features = self.separate_obs_data(self.val_set)
        
        self.train_data_sc = pd.DataFrame(scaler.fit_transform(self.train_features), columns=self.train_features.columns)
        self.val_data_sc = pd.DataFrame(scaler.transform(self.val_features), columns=self.val_features.columns)

        print(f"\nData scaled using {scaler_type} scaler.")
        return self.train_data_sc, self.val_data_sc

    def normalize_data(self):
        normalizer =  Normalizer()
        self.train_data_sc_dn = pd.DataFrame(normalizer.fit_transform(self.train_data_sc), columns=self.train_features.columns)
        self.val_data_sc_dn = pd.DataFrame(normalizer.transform(self.val_data_sc), columns=self.val_features.columns)
        print("\nData have been successfully normalized.")
        return self.train_data_sc_dn, self.val_data_sc_dn
    
    def integrate_data(self):
        """Integrate observation data and processed features."""
        # Extract observation data for training
        try:
            train_obs_data = self.train_set[self.obs_data_columns]
        except KeyError as e:
            raise KeyError(f"Missing columns in train_data: {e}")

        # Combine observation data with scaled and normalized features for training
        self.train_data_integrate = pd.concat(
            [
                train_obs_data.reset_index(drop=True),
                pd.DataFrame(self.train_data_sc_dn[self.train_features.columns].values, columns=self.train_features.columns)
            ],
            axis=1
        )

        # Extract observation data for validation
        try:
            val_obs_data = self.val_set[self.obs_data_columns]
        except KeyError as e:
            raise KeyError(f"Missing columns in val_data: {e}")

        # Combine observation data with scaled and normalized features for validation
        self.val_data_integrate = pd.concat(
            [
                val_obs_data.reset_index(drop=True),
                pd.DataFrame(self.val_data_sc_dn[self.val_features.columns].values, columns=self.val_features.columns)
            ],
            axis=1
        )
        
        self.train_data = self.train_data_integrate
        self.val_data = self.val_data_integrate

        return self.train_data, self.val_data

    '''def final_clip(self, coverage=0.85):
        """Clip data to remove extreme outliers using sigma clipping."""
        # Compute tail probabilities for clipping
        tail_prob = (1 - coverage) / 2
        lower_percentile = tail_prob
        upper_percentile = 1 - tail_prob

        # Iterate over each feature to clip
        for col in self.train_features.columns:
            # Perform sigma clipping for train and validation data
            clipped_train, _, _ = sigmaclip(self.train_data_sc_dn[col], lower_percentile, upper_percentile)
            clipped_val, _, _ = sigmaclip(self.val_data_sc_dn[col], lower_percentile, upper_percentile)

            # Ensure the clipped data is not None
            if clipped_train is None or clipped_val is None:
                raise ValueError(f"Sigma clipping failed for column: {col}")

            # Update the columns in the train and validation data
            self.train_data_sc_dn[col] = pd.Series(
                clipped_train, index=self.train_data_sc_dn.index[:len(clipped_train)]
            )
            self.val_data_sc_dn[col] = pd.Series(
                clipped_val, index=self.val_data_sc_dn.index[:len(clipped_val)]
            )

        self.train_data = self.train_data_sc_dn
        self.val_data = self.val_data_sc_dn
        
        print(f"Data have been successfully clipped by {coverage*100}%.")
        return self.train_data, self.val_data'''


    def preprocess_data(self, scaler_type="standard", coverage=0.85):
        """Run the full preprocessing pipeline."""
        self.handle_missing_values()
        non_categorical_features = self.data.select_dtypes(exclude=["object"]).columns.tolist()
        self.auto_handle_outliers(non_categorical_features)
        self.split_data()
        self.scale_data(scaler_type)
        self.normalize_data()
        self.integrate_data()
        #self.final_clip(coverage)
        print('\nData have been preprocessed')
        return self.train_data, self.val_data
    
    def preprocessed_distribution(self):
        """
        Generate and save histograms for preprocessed training data.
        """
        if self.train_data is None:
            print("Training data is not available.")
            return
        
        output_dir = "../output/plots/preprocessed"
        numeric_features = self.train_data.select_dtypes(include=["float64", "int64"]).columns.tolist()
        self.plot_distribution(numeric_features, output_dir, title_prefix="Preprocessed Data: ")
        print("\nHistograms for preprocessed data saved successfully.")   

In [5]:
prefix = 'merge'
suffix =  ['L0100N1504']
for i in range(len(suffix)):
    filename = f'../../data/{prefix}_{suffix[i]}.feather'
    dp = DataProcessor(data_path=filename)
    dp.load_data()
    dp.missing_values()
    dp.basic_statistics()
    dp.data_distribution()
    dp.correlation_heatmap()
    train_data, val_data = dp.preprocess_data(scaler_type="robust", coverage=0.85)
    dp.preprocessed_distribution()

Data loaded successfully.

Data statistics
                               count          mean           std  \
GroupID                   22109109.0  1.109590e+13  4.515858e+12   
SnapNum                   22109109.0  1.109115e+01  4.516415e+00   
Redshift                  22109109.0  3.864350e+00  2.210902e+00   
RandomNumber              22109109.0  5.002389e-01  2.886437e-01   
GroupMass                 22109109.0  5.599906e+11  4.779501e+12   
GroupCentreOfPotential_x  22109109.0  5.138570e+01  2.903008e+01   
GroupCentreOfPotential_y  22109109.0  5.234821e+01  2.848542e+01   
GroupCentreOfPotential_z  22109109.0  4.884773e+01  2.709691e+01   
NumOfSubhalos             22109109.0  2.183267e+01  1.572136e+02   
Group_M_Crit200           22109109.0  4.501973e+11  3.809841e+12   
Group_R_Crit200           22109109.0  2.144501e+01  5.143914e+01   
Group_M_Mean200           22109109.0  4.774301e+11  4.058466e+12   
Group_R_Mean200           22109109.0  2.268119e+01  5.591853e+01   
Group

: 

In [None]:
train_data

Unnamed: 0,GroupID,SnapNum,Redshift,RandomNumber,GalaxyID,DescendantID,LastProgID,TopLeafID,GroupNumber,SubGroupNumber,...,Velocity_z,KineticEnergy,MechanicalEnergy,TotalEnergy,Vmax,VmaxRadius,Mass,MassType_DM,HalfMassProjRad_DM,HalfMassRad_DM
0,1.600450e+13,16.0,1.006850,0.286341,521691.0,521690.0,521692.0,521692.0,26958.0,0.000000,...,0.248569,-0.023669,0.017343,0.017343,0.035503,-0.022474,-0.085861,-0.085861,-0.233358,-0.242843
1,2.100980e+13,21.0,0.551370,0.961817,769339.0,769338.0,769358.0,769358.0,19.0,3.367296,...,-0.024592,0.160463,-0.301631,-0.301631,0.167656,-0.259053,0.238621,0.238621,0.181598,0.176915
2,5.021100e+12,5.0,2.085618,0.745820,55630.0,44014.0,55637.0,55637.0,3.0,4.653960,...,0.161347,-0.086250,0.024818,0.024818,-0.113568,0.140645,-0.069746,-0.069746,0.082617,0.095744
3,1.602170e+13,16.0,1.006850,0.178628,554995.0,554994.0,554997.0,554997.0,27995.0,0.000000,...,0.191987,0.021152,0.062398,0.062398,-0.015111,-0.046603,-0.033633,-0.033633,-0.138668,-0.155800
4,2.700060e+13,27.0,0.095891,0.898629,1007655.0,1007654.0,1007657.0,1007657.0,31522.0,0.000000,...,0.107055,-0.006099,0.016522,0.016522,-0.010767,-0.069218,-0.016149,-0.016149,-0.045225,-0.039401
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
989292,7.011800e+12,7.0,1.869864,0.846827,110269.0,110268.0,110281.0,110281.0,4.0,4.584967,...,0.149101,-0.049523,0.010524,0.010524,-0.021734,-0.017178,-0.025604,-0.025604,-0.052263,-0.089542
989293,1.002420e+13,10.0,1.606165,0.834165,259179.0,259178.0,259187.0,259187.0,27103.0,0.000000,...,0.065746,-0.081042,0.022439,0.022439,-0.118938,0.045532,-0.038059,-0.038059,0.079832,0.076985
989294,8.001100e+12,8.0,1.797946,0.229255,131933.0,131932.0,131933.0,131933.0,20247.0,0.000000,...,-0.011174,0.009173,0.011692,0.011692,0.104640,-0.298039,-0.096150,-0.096150,-0.309351,-0.300594
989295,1.900620e+13,19.0,0.695206,0.395208,671156.0,671155.0,671156.0,671156.0,23249.0,0.000000,...,-0.054240,0.028138,0.030698,0.030698,0.113055,-0.148162,-0.077641,-0.077641,-0.299303,-0.292034


In [None]:
val_data

Unnamed: 0,GroupID,SnapNum,Redshift,RandomNumber,GalaxyID,DescendantID,LastProgID,TopLeafID,GroupNumber,SubGroupNumber,...,Velocity_z,KineticEnergy,MechanicalEnergy,TotalEnergy,Vmax,VmaxRadius,Mass,MassType_DM,HalfMassProjRad_DM,HalfMassRad_DM
0,1.300420e+13,13.000000,1.246576,0.889542,369409.00,369408.00,369410.00,369410.00,35569.0,0.0,...,-0.236024,-0.204067,0.092412,0.092412,-0.227511,0.049183,-0.211853,-0.211853,-0.093204,-0.116673
1,1.620111e+13,16.188377,1.213387,0.500460,1224255.79,1220855.79,1224259.74,1224259.74,37285.0,0.0,...,-0.096425,-0.236849,0.059817,0.059817,-0.311868,0.333410,-0.105520,-0.105520,0.260575,0.242065
2,1.400820e+13,14.000000,1.174658,0.338938,428336.00,428335.00,428348.00,428345.00,7287.0,0.0,...,-0.174425,0.195013,-0.582414,-0.582414,0.172523,0.001628,0.204440,0.204440,0.023075,0.034475
3,1.601750e+13,16.000000,1.006850,0.820254,546966.00,546965.00,546971.00,546971.00,11901.0,0.0,...,-0.193842,0.053084,-0.049739,-0.049739,-0.054049,0.467709,0.111081,0.111081,0.063972,0.115098
4,1.620111e+13,16.188377,1.213387,0.500460,1130235.00,1130234.00,1130246.00,1130246.00,9304.0,0.0,...,-0.504944,0.140580,-0.211823,-0.211823,0.076661,0.177330,0.180916,0.180916,0.091372,0.096988
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
247320,1.620111e+13,16.188377,1.213387,0.500460,1203642.00,1203641.00,1203644.00,1203644.00,36952.0,0.0,...,-0.215147,-0.198709,0.059952,0.059952,-0.311075,0.352225,-0.095216,-0.095216,0.251889,0.260286
247321,5.006000e+12,5.000000,2.085618,0.011711,41297.00,-1.00,41299.00,41299.00,34230.0,0.0,...,-0.071611,-0.082373,0.022851,0.022851,-0.119199,0.121107,-0.038142,-0.038142,0.103302,0.094600
247322,6.002800e+12,6.000000,1.941782,0.584804,63534.00,63533.00,63535.00,63535.00,17201.0,0.0,...,-0.010384,0.053763,-0.027674,-0.027674,0.080003,-0.381622,-0.008413,-0.008413,-0.162882,-0.198907
247323,2.102370e+13,21.000000,0.551370,0.763274,793866.00,793865.00,793895.00,793887.00,2467.0,0.0,...,0.046211,0.064466,-0.691228,-0.691228,0.051547,0.027859,0.079149,0.079149,0.045699,0.044586


In [None]:

def sigma_clipping(df, column, coverage_percentage):
    """
    Applies sigma clipping to a specified column in a Pandas DataFrame.

    Args:
        df (pd.DataFrame): The DataFrame containing the data.
        column (str): The column to apply sigma clipping on.
        coverage_percentage (float): The desired coverage percentage (e.g., 95 for 95%).

    Returns:
        pd.DataFrame: A DataFrame with outliers removed based on sigma clipping.
    """
    # Ensure coverage_percentage is valid
    if not (0 < coverage_percentage < 100):
        raise ValueError("Coverage percentage must be between 0 and 100.")
    
    # Calculate the z-score threshold based on the coverage percentage
    alpha = (100 - coverage_percentage) / 100
    z_threshold = abs(np.percentile(np.random.Generate(0, 1, 1000000), [alpha * 50, 100 - alpha * 50]))[1]
    
    # Calculate mean and standard deviation
    mean = df[column].mean()
    std = df[column].std()

    # Apply sigma clipping
    df_clipped = df[(df[column] >= mean - z_threshold * std) & (df[column] <= mean + z_threshold * std)]

    return df_clipped

In [None]:
train_data_clip = train_data
train_data_clip = sigma_clipping(train_data_clip, train_data.columns, 85)
train_data_clip.dropna().reset_index(drop=True)

Unnamed: 0,GroupID,SnapNum,Redshift,RandomNumber,GalaxyID,DescendantID,LastProgID,TopLeafID,GroupNumber,SubGroupNumber,...,Velocity_z,KineticEnergy,MechanicalEnergy,TotalEnergy,Vmax,VmaxRadius,Mass,MassType_DM,HalfMassProjRad_DM,HalfMassRad_DM
0,1.900270e+13,19.0,0.695206,0.232667,664610.0,664609.0,664625.0,664625.0,9196.0,0.0,...,0.096851,0.136929,-0.351693,-0.351693,0.135758,-0.193651,0.165770,0.165770,0.048835,0.065318
1,1.900160e+13,19.0,0.695206,0.887910,662597.0,662596.0,662606.0,662606.0,10048.0,0.0,...,0.132286,0.157683,-0.379617,-0.379617,0.151565,-0.074694,0.151655,0.151655,-0.019520,-0.019656
2,1.700200e+13,17.0,0.910959,0.673209,566495.0,566494.0,566505.0,566505.0,10873.0,0.0,...,-0.022154,0.127924,-0.207793,-0.207793,0.079525,0.057689,0.164824,0.164824,0.077088,0.070598
3,1.800180e+13,18.0,0.815069,0.743976,614978.0,614977.0,614982.0,614981.0,9211.0,0.0,...,-0.171808,0.136128,-0.235490,-0.235490,0.147548,0.085389,0.105995,0.105995,-0.063311,-0.079067
4,2.100190e+13,21.0,0.551370,0.206509,755492.0,755491.0,755503.0,755502.0,7231.0,0.0,...,-0.111483,0.153675,-0.430258,-0.430258,0.131762,-0.010662,0.174811,0.174811,0.079877,0.071138
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
102,1.900230e+13,19.0,0.695206,0.141115,663872.0,663871.0,663879.0,663879.0,9408.0,0.0,...,-0.101447,0.162445,-0.376567,-0.376567,0.145115,0.132346,0.163675,0.163675,0.017922,0.018945
103,1.600060e+13,16.0,1.006850,0.760488,514073.0,514072.0,514081.0,514079.0,5870.0,0.0,...,0.020623,0.156709,-0.418019,-0.418019,0.145459,0.016135,0.143580,0.143580,-0.010094,-0.018642
104,1.600200e+13,16.0,1.006850,0.353033,516755.0,516754.0,516766.0,516766.0,11752.0,0.0,...,0.104651,0.147148,-0.292155,-0.292155,0.152886,-0.038810,0.158498,0.158498,-0.023865,-0.024661
105,1.600140e+13,16.0,1.006850,0.297315,515585.0,515584.0,515594.0,515594.0,9912.0,0.0,...,-0.007567,0.167311,-0.392245,-0.392245,0.163040,-0.159221,0.158739,0.158739,-0.033080,-0.021696


In [None]:
from data_processing import DataProcessor,  DMDataset
from model_train import ModelTrainer
from model_hyperparameter_optimizer import RNNHyperParameterOptimizer
from rnn_model import RNNDMHaloMapper, LSTMDMHaloMapper, GRUDMHaloMapper
from redshift_analyzer import RedshiftAnalyzer
from model_visualization import ModelVisualizer
import torch
import torch.optim as tco
import torch.utils.data as tcud
import torch.nn as tcn

class DMHaloMapper:
    def __init__(self, data=None, device=None):
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.dp = DataProcessor(data_path=data)
        self.mt = ModelTrainer()
        self.dm = DMDataset()
        self.hp = None
        self.mv = ModelVisualizer()
        self.rnn = RNNDMHaloMapper()
        self.lstm = LSTMDMHaloMapper()
        self.gru = GRUDMHaloMapper()
        self.train_data = None
        self.val_data = None
        self.X = None
        self.y = None
        self.features = None
        self.targets = None
        self.redshift_values = None
        self.unique_redshifts = None
        self.best_hyperparameter = None


    def data_process(self):
        dp.load_data()
        dp.missing_values()
        dp.basic_statistics()
        dp.data_distribution()
        dp.correlation_heatmap()
        self.train_data, self.val_data = dp.preprocess_data(scaler_type="robust", coverage=0.85)
        dp.preprocessed_distribution()
    
    def data_to_tensor(self):
        features = [
            'Redshift',
            'NumOfSubhalos',
            'Velocity_x',
            'Velocity_y',
            'Velocity_z',
            'MassType_DM',
            'CentreOfPotential_x',
            'CentreOfPotential_y',
            'CentreOfPotential_z',
            'GroupCentreOfPotential_x',
            'GroupCentreOfPotential_y',
            'GroupCentreOfPotential_z',
            'Vmax',
            'VmaxRadius',
            'Group_R_Crit200'
        ]
        
        targets = ['Group_M_Crit200']
        
        self.X = self.train_data[features]
        self.y = self.train_data[targets].values.reshape(-1, 1)
        
        self.features, self.targets = self.dm(self.X, self.y)
    
    def process_redshift(self):
        self.ra = RedshiftAnalyzer(self.train_data, self.features)
        self.redshift_values, self.unique_redshifts = self.ra.process_redshift_data()
    
    def discover_hyperparameter(self):
        print(f"Starting discover_hyperparameter for {len(self.models_to_optimize)} models...")
        self.models_to_optimize = [self.rnn, self.lstm, self.gru]
        
        optimization_results = []
        
        self.hp = RNNHyperParameterOptimizer(
            X=self.X,
            y=self.y,
            base_model_class=RNNDMHaloMapper,  # You can change this to LSTMDMHaloMapper or GRUDMHaloMapper
            device=self.device,
            num_features=self.X.shape[1],
            pop_size=20,
            maxiter=100,
            cross_val_folds=13
        )
        
        # Run optimization
        optimization_result = self.hp.optimize()
        
        if optimization_result:
            self.best_hyperparameters = optimization_result['best_params']
            print("Best hyperparameters found:")
            for param, value in self.best_hyperparameters.items():
                print(f"{param}: {value}")
        else:
            print("Hyperparameter optimization failed. Using default hyperparameters.")
            self.best_hyperparameters = {
                'hidden_size': 256,
                'num_layers': 2,
                'dropout_rate': 0.2,
                'learning_rate': 0.001,
                'batch_size': 64,
                'weight_decay': 0.0001,
                'num_epochs': 10
            }
        
    def train_rnn_model(self):
        """
        Train the RNN, LSTM, and GRU models using the best hyperparameters.
        """
        if not self.best_hyperparameters:
            print("No hyperparameters found. Running discover_hyperparameter first.")
            self.discover_hyperparameter()

        input_size = self.X.shape[1]
        output_size = 1

        # Define models with the best hyperparameters
        models = {
            'RNN': RNNDMHaloMapper(
                input_size=input_size,
                hidden_size=self.best_hyperparameters['hidden_size'],
                output_size=output_size,
                num_layers=self.best_hyperparameters['num_layers'],
                dropout_rate=self.best_hyperparameters['dropout_rate'],
                learning_rate=self.best_hyperparameters['learning_rate'],
                batch_size=self.best_hyperparameters['batch_size'],
                weight_decay=self.best_hyperparameters['weight_decay'],
                num_epochs=self.best_hyperparameters['num_epochs']
            ),
            'LSTM': LSTMDMHaloMapper(
                input_size=input_size,
                hidden_size=self.best_hyperparameters['hidden_size'],
                output_size=output_size,
                num_layers=self.best_hyperparameters['num_layers'],
                dropout_rate=self.best_hyperparameters['dropout_rate'],
                learning_rate=self.best_hyperparameters['learning_rate'],
                batch_size=self.best_hyperparameters['batch_size'],
                weight_decay=self.best_hyperparameters['weight_decay'],
                num_epochs=self.best_hyperparameters['num_epochs']
            ),
            'GRU': GRUDMHaloMapper(
                input_size=input_size,
                hidden_size=self.best_hyperparameters['hidden_size'],
                output_size=output_size,
                num_layers=self.best_hyperparameters['num_layers'],
                dropout_rate=self.best_hyperparameters['dropout_rate'],
                learning_rate=self.best_hyperparameters['learning_rate'],
                batch_size=self.best_hyperparameters['batch_size'],
                weight_decay=self.best_hyperparameters['weight_decay'],
                num_epochs=self.best_hyperparameters['num_epochs']
            )
        }

        X_train, X_val, y_train, y_val = train_test_split(self.X, self.y, test_size=0.25, random_state=42)

        self.X_val = X_val
        self.y_val = y_val
        
        train_dataset = DMDataset(X_train, y_train)
        val_dataset = DMDataset(X_val, y_val)

        train_loader = DMDataset.data_loader(train_dataset, batch_size=self.best_hyperparameters['batch_size'])
        val_loader = DMDataset.data_loader(val_dataset, batch_size=self.best_hyperparameters['batch_size'])
        '''train_loader = tcud.DataLoader(
            train_dataset,
            batch_size=self.best_hyperparameters['batch_size'],
            shuffle=True,
            num_workers=0,
            pin_memory=True if self.device.type == 'cuda' else False
        )

        val_loader = tcud.DataLoader(
            val_dataset,
            batch_size=self.best_hyperparameters['batch_size'],
            shuffle=False,
            num_workers=0,
            pin_memory=True if self.device.type == 'cuda' else False
        )'''

        model_results = {}

        for name, model in models.items():
            print(f"\nTraining {name} model with optimized hyperparameters...")
            optimizer = tco.Adam(model.parameters(), lr=self.best_hyperparameters['learning_rate'], weight_decay=self.best_hyperparameters['weight_decay'])
            criterion = tcn.MSELoss()

            history, model_dir = self.mt.train_model(
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                criterion=criterion,
                optimizer=optimizer,
                num_epochs=self.best_hyperparameters['num_epochs'],
                device=self.device,
                model_name=name,
                save_dir='./best_model'
            )

            model_results[name] = {
                'history': history,
                'model_dir': model_dir
            }

        return model_results
    
    def process_model(self):
        """
        Process the validation dataset, make predictions, and store results for visualization.
        """
        if not hasattr(self, 'models'):
            print("No trained models found. Run train_rnn_model first.")
            return

        # Load validation data
        X_val_tensor = tc.tensor(self.X_val.values, dtype=tc.float32)
        self.y_val = self.y[self.X_val.index]

        # Make predictions on the validation set
        self.val_predictions = {}
        for name, model in self.models.items():
            model.eval()
            with tc.no_grad():
                outputs = model(X_val_tensor.to(self.device))
                if isinstance(outputs, tuple):
                    outputs = outputs[0]
                self.val_predictions[name] = outputs.cpu().numpy()

        print("Validation dataset processed and predictions stored.")
    
    def model_visualization(self):
        """
        Visualize the model's predictions, including dark matter distribution, triaxial model, and redshift evolution.
        """
        if not hasattr(self, 'val_predictions'):
            print("No validation predictions found. Run process_model first.")
            return

        # Extract spatial coordinates for visualization
        X_val_tensor = tc.tensor(self.X_val.values, dtype=tc.float32)
        column_indices = [self.X.columns.get_loc(f) for f in 
            [
                'GroupCentreOfPotential_x',
                'GroupCentreOfPotential_y',
                'GroupCentreOfPotential_z'
            ]
        ]
        spatial_coords = X_val_tensor[:, column_indices].numpy()

        # Initialize ModelVisualizer
        self.mv = ModelVisualizer(
            predictions=self.val_predictions['RNN'],  # Use RNN predictions as an example
            true_values=self.y_val,
            spatial_coords=spatial_coords,
            redshifts=self.redshift_values[self.X_val.index],
            save_dir='./visualizations'
        )

        # Generate and save visualizations
        self.mv.visualize_dm_distribution(a=1.0, b=0.8, c=0.6)

    
    def run_process(self):
        self.data_process()
        self.data_to_tensor()
        self.process_redshift()
        self.discover_hyperparameter()
        self.train_rnn_model()
        self.model_visualization()
        #self.process_model()

In [None]:
class RNNHyperparameterOptimizer:
    def __init__(self, X, y, base_model_class, device, num_features, 
                 pop_size=20, maxiter=10, cross_val_folds=3):
        self.X = X
        self.y = y
        self.base_model_class = base_model_class
        self.device = device
        self.num_features = num_features
        self.pop_size = pop_size
        self.maxiter = maxiter
        self.cross_val_folds = cross_val_folds
        
        # Define parameter bounds
        self.bounds = [
            (32, 512),      # hidden_size
            (1, 4),         # num_layers
            (0.0, 0.5),     # dropout_rate
            (0.0001, 0.1),  # learning_rate
            (16, 128),      # batch_size
            (0.0, 0.1)      # weight_decay
            (1, 100)        # number of epochs
        ]
    
    def create_model(self, params):
        """Create model with given hyperparameters"""
        hidden_size = int(params[0])
        num_layers = int(params[1])
        dropout_rate = params[2]
        learning_rate = params[3]
        batch_size = int(params[4])
        weight_decay = params[5]
        num_epochs = int(params[6])
        
        return self.base_model_class(
            input_size=self.num_features,
            hidden_size=hidden_size,
            output_size=1,
            num_layers=num_layers,
            dropout_rate=dropout_rate,
            learning_rate=learning_rate,
            batch_size=batch_size,
            weight_decay=weight_decay,
            num_epochs=num_epochs,
            nonlinearity='tanh'
        )
    
    def objective_function(self, params):
        """Objective function to minimize"""
        try:
            # Extract parameters
            hidden_size = int(params[0])
            num_layers = int(params[1])
            dropout_rate = params[2]
            learning_rate = params[3]
            batch_size = int(params[4])
            weight_decay = params[5]
            
            # Initialize cross-validation
            kf = KFold(n_splits=self.cross_val_folds, shuffle=True, random_state=42)
            cv_scores = []
            
            for train_idx, val_idx in kf.split(self.X):
                # Split data
                X_train, X_val = self.X.iloc[train_idx], self.X.iloc[val_idx]
                y_train, y_val = self.y[train_idx], self.y[val_idx]
                
                # Create datasets and dataloaders
                train_dataset = DMDataset(X_train, y_train)
                val_dataset = DMDataset(X_val, y_val)
                
                train_loader = tcud.DataLoader(
                    train_dataset,
                    batch_size=batch_size,
                    shuffle=True,
                    num_workers=0,
                    pin_memory=True if self.device.type == 'cuda' else False
                )
                
                val_loader = tcud.DataLoader(
                    val_dataset,
                    batch_size=batch_size,
                    shuffle=False,
                    num_workers=0,
                    pin_memory=True if self.device.type == 'cuda' else False
                )
                
                # Create and train model
                model = self.create_model(params).to(self.device)
                optimizer = tco.Adam(model.parameters(), 
                                    lr=learning_rate, 
                                    weight_decay=weight_decay)
                criterion = tcn.MSELoss()
                
                # Quick training loop
                model.train()
                for epoch in range(5):  # Reduced epochs for optimization
                    for batch_features, batch_targets in train_loader:
                        batch_features = batch_features.to(self.device)
                        batch_targets = batch_targets.to(self.device)
                        
                        optimizer.zero_grad()
                        outputs = model(batch_features)
                        loss = criterion(outputs, batch_targets)
                        loss.backward()
                        optimizer.step()
                
                # Evaluate
                model.eval()
                val_loss = 0
                with tc.no_grad():
                    for batch_features, batch_targets in val_loader:
                        batch_features = batch_features.to(self.device)
                        batch_targets = batch_targets.to(self.device)
                        outputs = model(batch_features)
                        val_loss += criterion(outputs, batch_targets).item()
                
                cv_scores.append(val_loss / len(val_loader))
            
            mean_score = np.mean(cv_scores)
            print(f"Trial completed - Score: {mean_score:.4f}, Params: {params}")
            return mean_score
            
        except Exception as e:
            print(f"Error in objective function: {e}")
            return float('inf')
    
    def optimize(self):
        """Run differential evolution"""
        try:
            result = differential_evolution(
                func=self.objective_function,
                bounds=self.bounds,
                maxiter=self.maxiter,
                popsize=self.pop_size,
                mutation=(0.5, 1.0),
                recombination=0.7,
                seed=42,
                disp=True,
                workers=-1
            )
            
            # Process results
            best_params = {
                'hidden_size': int(result.x[0]),
                'num_layers': int(result.x[1]),
                'dropout_rate': result.x[2],
                'learning_rate': result.x[3],
                'batch_size': int(result.x[4]),
                'weight_decay': result.x[5],
                'num_epochs': int(result.x[6])
            }
            
            return {
                'best_params': best_params,
                'best_score': result.fun,
                'convergence': result.success,
                'iterations': result.nit,
                'optimization_result': result
            }
            
        except Exception as e:
            print(f"Error in optimization: {e}")
            return None
        
        
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, model_name, save_dir):
    model = model.to(device)
    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0
    metrics_calculator = CustomMetrics()
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_metrics': [],
        'val_metrics': []
    }
    
    # Create directories for saving results
    model_dir = os.path.join(save_dir, f"{model_name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
    os.makedirs(model_dir, exist_ok=True)
    
    best_model_path = os.path.join(model_dir, "best_model.pth")
    train_history_path = os.path.join(model_dir, "train_history.npy")
    model_comparison_path = os.path.join(model_dir, "model_comparison_results.npy")
    plot_path = os.path.join(model_dir, "dmhalo_mapper.html")

    # Wrap the epochs loop in tqdm
    for epoch in tqdm(range(num_epochs), desc="Epoch Progress"):
        model.train()
        train_loss = 0
        
        epoch_train_metrics =   {
                                    'mse': 0, 
                                    'rmse': 0, 
                                    'mae': 0, 
                                    'r2': 0,
                                    'gaussian_nll': 0,
                                    'poisson_nll': 0
                                }

        # Wrap the training batches loop in tqdm
        for batch_features, batch_targets in train_loader:
            try: 
                batch_features= batch_features.to(device)
                batch_targets = batch_targets.to(device)

                optimizer.zero_grad()
                outputs = model(batch_features)
                if isinstance(outputs, tuple):
                    outputs = outputs[0]
                loss = criterion(outputs, batch_targets)
                loss.backward()
                
                tcn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()

                train_loss += loss.item()
            
                # Calculate training metrics for this batch
                batch_metrics = metrics_calculator.calculate_all_metrics(batch_targets, outputs)
                for metric in batch_metrics:
                    epoch_train_metrics[metric] += batch_metrics[metric]
            
            except RuntimeError as e:
                print(f"Error in training batch: {e}")
                continue        

        # Average the training metrics
        num_batches = len(train_loader)
        avg_train_loss = train_loss / num_batches
        avg_train_metrics = {
            metric: value / num_batches 
            for metric, value in epoch_train_metrics.items()
        }

        # Validation phase
        model.eval()
        val_loss = 0
        
        epoch_val_metrics =   {
                                    'mse': 0, 
                                    'rmse': 0, 
                                    'mae': 0, 
                                    'r2': 0,
                                    'gaussian_nll': 0, 
                                    'poisson_nll': 0
                                }
        
        with tc.no_grad():
            for batch_features, batch_targets in val_loader:
                try:
                    batch_features= batch_features.to(device)
                    batch_targets = batch_targets.to(device)
                    
                    optimizer.zero_grad()

                    outputs = model(batch_features)
                    if isinstance(outputs, tuple):
                        outputs = outputs[0]
                    loss = criterion(outputs, batch_targets)

                    val_loss += loss.item()
                
                    # Calculate validation metrics for this batch
                    batch_metrics = metrics_calculator.calculate_all_metrics(batch_targets, outputs)
                    for metric in batch_metrics:
                        epoch_val_metrics[metric] += batch_metrics[metric]

                except RuntimeError as e:
                    print(f"Error in validation batch: {e}")
                    continue


        # Average the validation metrics
        num_val_batches = len(val_loader)
        avg_val_loss = val_loss / num_val_batches
        avg_val_metrics = {
            metric: value / num_val_batches 
            for metric, value in epoch_val_metrics.items()
        }

        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            tc.save(model.state_dict(), best_model_path)
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch + 1}")
            break

        # Update history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_metrics'].append(avg_train_metrics)
        history['val_metrics'].append(avg_val_metrics)
        
        # Print progress every 5 epochs
        if (epoch + 1) % 5 == 0:
            print(f"\nEpoch [{epoch+1}/{num_epochs}]")
            print(f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
            print("\nTraining Metrics:")
            for metric, value in avg_train_metrics.items():
                print(f"{metric}: {value:.4f}", end="  ")
            print("\nValidation Metrics:")
            for metric, value in avg_val_metrics.items():
                print(f"{metric}: {value:.4f}", end="  ")
            print("\n")
            
    # Save training history
    np.save(train_history_path, history)

    print(f"Model {model} training complete. Best model saved to {best_model_path}")
    return history, model_dir

def visualize_dm_distribution(predictions, true_values, spatial_coords, redshifts=None, save_path=None, 
                            a=1.0, b=0.8, c=0.6):  # Add axis ratios for triaxiality
    try:
        # Ensure all inputs are numpy arrays and properly shaped
        predictions = np.array(predictions).reshape(-1)
        true_values = np.array(true_values).reshape(-1)
        spatial_coords = np.array(spatial_coords)
        
        if redshifts is not None:
            redshifts = np.array(redshifts)
            unique_redshifts = np.sort(np.unique(redshifts))[::-1]  # Sort in descending order
            
        if spatial_coords.shape[1] != 3:
            raise ValueError("spatial_coords must have shape (n_samples, 3)")

        # Create subplots: 3D scatter, triaxial view, and animation controls
        fig = sp.make_subplots(
            rows=1, cols=3,
            specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}, {'type': 'scatter3d'}]],
            subplot_titles=[
                "Dark Matter Distribution",
                "Triaxial Model",
                "Evolution View"
            ]
        )

        # Create frames for animation if redshifts are provided
        frames = []
        if redshifts is not None:
            for z in unique_redshifts:
                mask = (redshifts == z)
                
                frame_data = [
                    # 3D scatter for DM distribution
                    go.Scatter3d(
                        x=spatial_coords[mask, 0],
                        y=spatial_coords[mask, 1],
                        z=spatial_coords[mask, 2],
                        mode='markers',
                        marker=dict(
                            size=3,
                            color=predictions[mask],
                            colorscale='Viridis',
                            opacity=0.8
                        ),
                        name=f'z={z:.2f}'
                    ),
                    # Triaxial model stays constant
                    *create_triaxial_surfaces(a, b, c),
                    # Evolution view
                    go.Scatter3d(
                        x=spatial_coords[mask, 0],
                        y=spatial_coords[mask, 1],
                        z=spatial_coords[mask, 2],
                        mode='markers',
                        marker=dict(
                            size=3,
                            color=predictions[mask],
                            colorscale='Inferno',
                            opacity=0.8
                        ),
                        name=f'Evolution z={z:.2f}'
                    )
                ]
                
                frames.append(go.Frame(
                    data=frame_data,
                    name=f'z={z:.2f}'
                ))

        # Initial state
        scatter3d = go.Scatter3d(
            x=spatial_coords[:, 0],
            y=spatial_coords[:, 1],
            z=spatial_coords[:, 2],
            mode='markers',
            marker=dict(
                size=3,
                color=predictions,
                colorscale='Viridis',
                opacity=0.8,
                colorbar=dict(title="Predicted Mass")
            ),
            name='Dark Matter Halos'
        )
        fig.add_trace(scatter3d, row=1, col=1)

        # Add triaxial model
        triaxial_surfaces = create_triaxial_surfaces(a, b, c)
        for surface in triaxial_surfaces:
            fig.add_trace(surface, row=1, col=2)

        # Add evolution view
        evolution_scatter = go.Scatter3d(
            x=spatial_coords[:, 0],
            y=spatial_coords[:, 1],
            z=spatial_coords[:, 2],
            mode='markers',
            marker=dict(
                size=3,
                color=predictions,
                colorscale='Inferno',
                opacity=0.8
            ),
            name='Evolution View'
        )
        fig.add_trace(evolution_scatter, row=1, col=3)

        # Update layout
        fig.update_layout(
            title='Dark Matter Halo Distribution with Triaxial Model and Evolution',
            width=1800,
            height=600,
            showlegend=True,
            updatemenus=[{
                'buttons': [
                    {
                        'args': [None, {'frame': {'duration': 50, 'redraw': True},
                                      'fromcurrent': True}],
                        'label': 'Play',
                        'method': 'animate'
                    },
                    {
                        'args': [[None], {'frame': {'duration': 0, 'redraw': True},
                                        'mode': 'immediate',
                                        'transition': {'duration': 0}}],
                        'label': 'Pause',
                        'method': 'animate'
                    }
                ],
                'type': 'buttons',
                'showactive': False,
                'x': 0.1,
                'y': 1.1,
                'xanchor': 'right',
                'yanchor': 'top'
            }],
            sliders=[{
                'currentvalue': {'prefix': 'Redshift: '},
                'steps': [
                    {
                        'args': [[f'z={z:.2f}'], {'frame': {'duration': 50, 'redraw': True},
                                                'mode': 'immediate'}],
                        'label': f'{z:.2f}',
                        'method': 'animate'
                    }
                    for z in unique_redshifts
                ]
            }] if redshifts is not None else None
        )

        # Add frames to the figure
        fig.frames = frames

        # Save to file if specified
        if save_path:
            fig.write_html(save_path)
            print(f"Visualization saved to {save_path}")

            '''# Create and save GIF
            if redshifts is not None:
                gif_path = save_path.replace('.html', '.gif')
                create_evolution_gif(spatial_coords, predictions, redshifts, gif_path)
'''
        return fig

    except Exception as e:
        print(f"Error in visualization: {e}")
        return None

def create_triaxial_surfaces(a, b, c, n_points=50):
    """Create triaxial model surfaces."""
    phi = np.linspace(0, 2*np.pi, n_points)
    theta = np.linspace(-np.pi/2, np.pi/2, n_points)
    phi, theta = np.meshgrid(phi, theta)

    surfaces = []
    
    # Create outer shell
    x_outer = a * np.cos(theta) * np.cos(phi)
    y_outer = b * np.cos(theta) * np.sin(phi)
    z_outer = c * np.sin(theta)
    
    surfaces.append(go.Surface(
        x=x_outer, y=y_outer, z=z_outer,
        opacity=0.3,
        colorscale='Blues',
        showscale=False,
        name='Outer Shell'
    ))
    
    # Create inner shell
    x_inner = 0.7*a * np.cos(theta) * np.cos(phi)
    y_inner = 0.7*b * np.cos(theta) * np.sin(phi)
    z_inner = 0.7*c * np.sin(theta)
    
    surfaces.append(go.Surface(
        x=x_inner, y=y_inner, z=z_inner,
        opacity=0.3,
        colorscale='Reds',
        showscale=False,
        name='Inner Shell'
    ))
    
    return surfaces

def create_evolution_gif(spatial_coords, predictions, redshifts, save_path):
    """Create a GIF showing the evolution of dark matter halos with redshift."""
    import imageio
    
    unique_redshifts = np.sort(np.unique(redshifts))[::-1]
    frames = []
    
    for z in unique_redshifts:
        mask = (redshifts == z)
        
        # Create static figure for this redshift
        fig = go.Figure(data=[
            go.Scatter3d(
                x=spatial_coords[mask, 0],
                y=spatial_coords[mask, 1],
                z=spatial_coords[mask, 2],
                mode='markers',
                marker=dict(
                    size=3,
                    color=predictions[mask],
                    colorscale='Viridis',
                    opacity=0.8
                )
            )
        ])
        
        fig.update_layout(
            title=f'Dark Matter Distribution at z={z:.2f}',
            width=800,
            height=800
        )
        
        # Convert to image
        img_bytes = fig.to_image(format="png")
        frames.append(imageio.imread(img_bytes))
    
    # Save as GIF
    imageio.mimsave(save_path, frames, duration=50)
    print(f"Evolution GIF saved to {save_path}")
    
def process_redshift_data(data, selected_features):
    """
    Process redshift data from the EAGLE simulation dataset.
    
    Parameters:
    data : DataFrame
        The input EAGLE simulation data
    selected_features : list
        List of selected features including 'Redshift'
        
    Returns:
    numpy.ndarray
        Processed redshift values
    """
    if 'Redshift' not in selected_features:
        raise ValueError("'Redshift' must be included in selected_features")
    
    # Extract redshift values
    redshift_values = data['Redshift'].values
    
    # Sort unique redshift values in descending order
    unique_redshifts = np.sort(np.unique(redshift_values))[::-1]
    
    print(f"Redshift range: {unique_redshifts.min():.2f} to {unique_redshifts.max():.2f}")
    print(f"Number of unique redshift values: {len(unique_redshifts)}")
    
    return redshift_values, unique_redshifts

def analyze_redshift_evolution(predictions, true_values, redshifts, save_dir, model_name):
    """
    Analyze how predictions vary with redshift.
    
    Parameters:
    predictions : array-like
        Model predictions
    true_values : array-like
        True values
    redshifts : array-like
        Corresponding redshift values
    save_dir : str
        Directory to save analysis results
    model_name : str
        Name of the model being analyzed
    """
    unique_z = np.sort(np.unique(redshifts))
    metrics_by_z = []
    
    for z in unique_z:
        mask = (redshifts == z)
        z_metrics = CustomMetrics.calculate_all_metrics(
            true_values[mask],
            predictions[mask]
        )
        z_metrics['redshift'] = z
        metrics_by_z.append(z_metrics)
    
    # Convert to DataFrame and save
    metrics_df = pd.DataFrame(metrics_by_z)
    metrics_df.to_csv(os.path.join(save_dir, f'{model_name}_redshift_evolution.csv'))
    
    # Create evolution plots
    fig = go.Figure()
    for metric in ['mse', 'rmse', 'mae', 'r2']:
        fig.add_trace(go.Scatter(
            x=metrics_df['redshift'],
            y=metrics_df[metric],
            mode='lines+markers',
            name=metric.upper()
        ))
    
    fig.update_layout(
        title=f'{model_name} Metrics Evolution with Redshift',
        xaxis_title='Redshift',
        yaxis_title='Metric Value',
        showlegend=True
    )
    
    fig.write_html(os.path.join(save_dir, f'{model_name}_redshift_evolution.html'))
    
    

def main():
    
    selected_features = [
        'Redshift',
        'NumOfSubhalos',
        'Velocity_x',
        'Velocity_y',
        'Velocity_z',
        'MassType_DM',
        'Mass',
        'CentreOfPotential_x',
        'CentreOfPotential_y',
        'CentreOfPotential_z',
        'GroupCentreOfPotential_x',
        'GroupCentreOfPotential_y',
        'GroupCentreOfPotential_z',
        'Vmax',
        'VmaxRadius',
        'Group_R_Crit200'
    ]
    
    
    try:
        #Model hyperparameter
        input_size = len(selected_features)
        hidden_size = 256
        output_size = 1
        learning_rate = 0.001
        batch_size = 64
        num_epochs = 10
        weight_decay = 0.001
        epsilon = 1e-08
    
        device = tc.device('cuda' if tc.cuda.is_available() else 'cpu')
        print(f'Using device: {device}')
        
        save_dir = './best_model'
    
        X = train_set_w_obs[selected_features]
        y = train_set_w_obs['Group_M_Crit200'].values.reshape(-1, 1)

        # Process redshift data
        redshift_values, unique_redshifts = process_redshift_data(train_set_w_obs, selected_features)

    
        X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.25, random_state=42)

        # Get corresponding redshift values for validation set
        val_indices = X_val.index
        redshift_val = redshift_values[val_indices]

    
        train_dataset = DMDataset(X_train, y_train)
        val_dataset = DMDataset(X_val, y_val)
    
        # Create data loaders
        train_loader = tcud.DataLoader(
            train_dataset, 
            batch_size=batch_size, 
            shuffle=True,
            num_workers=0,
            pin_memory=True if device.type == 'cuda' else False
        )
        
        val_loader = tcud.DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=True if device.type == 'cuda' else False
        )
    
        # Initialize models
        models = {
            'RNN': RNNDMHaloMapper(input_size, hidden_size, output_size),
            'LSTM': LSTMDMHaloMapper(input_size, hidden_size, output_size),
            'GRU': GRUDMHaloMapper(input_size, hidden_size, output_size)
        }
        
        '''models = {
            'RNN': RNNDMHaloMapper(input_size, hidden_size, output_size)
        }'''
        
        model_results = {}
    
        for name, model in models.items():
            print(f"\nTraining {name} model:")
            optimizer = tco.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
            
            criterion = tcn.MSELoss()
        
            # Pass tqdm into train_model if it has an internal loop
            history, model_dir = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, name, save_dir)


            #sanitized_model_name = re.sub(r'[<>:"/\\|?*\n]', '_', str(model))

            # Create directories for saving results
            #model_dir = os.path.join(save_dir, f"{name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
            #os.makedirs(model_dir, exist_ok=True)
            # Paths
            best_model_path = os.path.join(model_dir, "best_model.pth")
            train_history_path = os.path.join(model_dir, f"training_history_{name}.csv")
            model_comparison_path = os.path.join(save_dir, "model_comparison_results.csv")
            plot_path = os.path.join(model_dir, f"dmhalo_distribution_{name}.html")
        
            # Load best model weights
            model.load_state_dict(tc.load(best_model_path))
            model.eval()
            
            # Make predictions on validation set
            with tc.no_grad():
                val_predictions = []
                for batch_features, _ in val_loader:
                    batch_features = batch_features.to(device)
                    outputs = model(batch_features)
                    if isinstance(outputs, tuple):
                        outputs = outputs[0]
                    val_predictions.extend(outputs.cpu().numpy())
                
                val_predictions = np.array(val_predictions)
            
            if isinstance(X_val, pd.DataFrame):
                # Convert DataFrame to tensor
                X_val_tensor = tc.tensor(X_val.values, dtype=tc.float32)
            else:
                # Directly convert to tensor if it's already NumPy or similar
                X_val_tensor = tc.tensor(X_val, dtype=tc.float32)

            # Map feature names to indices
            column_indices = [selected_features.index(f) for f in 
                [
                    'GroupCentreOfPotential_x',
                    'GroupCentreOfPotential_y',
                    'GroupCentreOfPotential_z'
                ]
            ]

            # Slice the tensor
            spatial_coords = X_val_tensor[:, column_indices]
                        
            # Create visualization
            fig = visualize_dm_distribution(
                predictions=val_predictions,
                true_values=y_val,
                spatial_coords=spatial_coords,
                redshifts=redshift_val,  # Add redshift values here
                save_path=plot_path
                #a=1.0, b=0.8, c=0.6  # Triaxial parameters
            )
            
            '''# Analyze triaxial fit
            fit_stats = analyze_triaxial_fit(
                    spatial_coords=spatial_coords,
                    predictions=val_predictions,
                    save_dir=model_dir
            )
            
            print(f"\nTriaxial Fit Statistics for {name}:")
            print(f"Axis Ratios (b/a, c/a): {fit_stats['axis_ratios'][1]:.3f}, {fit_stats['axis_ratios'][2]:.3f}")
            print(f"Percentage of halos inside model: {fit_stats['percent_inside']:.2f}%")'''
            
            # Additional analysis specific to redshift evolution
            analyze_redshift_evolution(
                predictions=val_predictions,
                true_values=y_val,
                redshifts=redshift_val,
                save_dir=model_dir,
                model_name=name
            )
            
            # Calculate final metrics
            metrics = CustomMetrics.calculate_all_metrics(
                tc.tensor(y_val),
                tc.tensor(val_predictions)
            )
            
            # Store results
            model_results[name] = {
                'history': history,
                'metrics': metrics,
                'predictions': val_predictions
            }
            
            print(f"\nFinal metrics for {name} model:")
            for metric, value in metrics.items():
                print(f"{metric}: {value:.4f}")
        
            # Compare models and save results
            results_df = pd.DataFrame({
                            name: {
                                'MSE': results.get('metrics', {}).get('mse', np.nan),
                                #'MSE': model_results['metrics']['mse']
                                
                                'RMSE': results.get('metrics', {}).get('rmse', np.nan),
                                #'RMSE': model_results['metrics']['rmse']
                                
                                'MAE': results.get('metrics', {}).get('mae', np.nan),
                                #'MAE': model_results['metrics']['mae']

                                'R2': results.get('metrics', {}).get('r2', np.nan),
                                #'R2': model_results['metrics']['r2']

                                'GaussianNLL': results.get('metrics', {}).get('gaussian_nll', np.nan),
                                #'GaussianNLL': model_results['metrics']['gaussian_nll']

                                'PoissonNLL': results.get('metrics', {}).get('poisson_nll', np.nan),
                                #'PoissonNLL': model_results['metrics']['poisson_nll']

                            }
                            for name, results in model_results.items()
                        })

            
            results_df.to_csv(model_comparison_path)
            print(f"\nModel comparison results saved to {model_comparison_path}")

            
            # Save training histories
            for name, results in model_results.items():
                history_df = pd.DataFrame({
                    'train_loss': results['history']['train_loss'],
                    'val_loss': results['history']['val_loss']
                })
                history_df.to_csv(train_history_path)
                print(f"Training history for {name} saved to {train_history_path}")
        
    except Exception as e:
        print(f"Error in main execution: {e}")
        raise
            
        '''visualize_dm_distribution(
            predictions=model(X_val_tensor).detach().cpu().numpy(),
            true_values=y_val,
            save_path=f'../plots/dmhalo_vis/{name}_halo_distribution.html'
        )'''

if __name__=='__main__':
    main()