In [None]:
import os
import numpy as np
import random
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score, mean_squared_error
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.svm import SVC
from sklearn.model_selection import KFold, train_test_split, GridSearchCV
from sklearn.decomposition import PCA


from openpyxl import load_workbook
from openpyxl.styles import PatternFill
from collections import Counter
from pprint import pprint


pd.set_option("display.width", 1000)     

In [None]:
# 1. Load 16000 Data * 16 boxes

In [None]:
base_dir = "/home/sujin/CosmoGasPeruser/data/Spectra_for_Sujin"
physics_values = [1, 2, 3, 4]
redshift_values = [0.1, 0.3, 2.2, 2.4]

# Initialize the dictionary to store data by redshift and physics values
data_by_redshift = {str(redshift): {str(physics): [] for physics in physics_values} for redshift in redshift_values}
# data_by_redshift = {
# 0.1 : {1: [[spectrum1],[spectrum2]...], 2:[[]], 3:[[]], 4:[[]]},
# 0.3 : {1: [[]], 2:[[]], 3:[[]], 4:[[]]},
# 2.2 : {1: [[]], 2:[[]], 3:[[]], 4:[[]]},
# 2.4 : {1: [[]], 2:[[]], 3:[[]], 4:[[]]},
# }

# Iterate through each folder
for redshift in redshift_values:
    redshift = str(redshift)
    for physics in physics_values:
        physics = str(physics)
        folder_name = f"{physics}_{redshift}"
        folder = os.path.join(base_dir, folder_name)
        print(f"Processing folder: {folder_name}")

        try:
            # Load flux data
            flux = np.load(os.path.join(folder, "flux.npy"))
            data_by_redshift[redshift][physics] = flux  # Store flux data under appropriate redshift and physics

            print(f"Flux data shape for redshift {redshift}, physics {physics}: {flux.shape}")

        except FileNotFoundError as e:
            print(f"Files not found in folder {folder}: {e}")
            del data_by_redshift[redshift][physics]
            continue
        except Exception as e:
            print(f"An error occurred while processing folder {folder}: {e}")
            continue

# Now data_by_redshift contains flux data for each redshift and physics combination


In [None]:
print(f'Example of a data point: spectrum number 100 of physics no.4 in redshift=0.3 \n{data_by_redshift['0.3']['4'][100]}')
print(f'Shape: {data_by_redshift['0.3']['4'][100].shape}')

In [None]:
flux = data_by_redshift['2.4']['4'][100]  
wavelength = np.arange(len(flux)) 

# Plot the spectrum
plt.figure(figsize=(20, 4))
plt.plot(wavelength, flux, label='Spectrum (Physics 4, Redshift 0.3)', color='black', lw=1.5)

plt.xlabel('Wavelength index', fontsize=12)  # Use 'Wavelength' if real wavelength data is available
plt.ylabel('Flux', fontsize=12)
plt.title('Spectrum Plot: Physics 4 at Redshift 0.3', fontsize=14)
plt.grid(alpha=0.3)
plt.show()


In [None]:
#2. ERD

In [None]:
""" DATA SHAPE AND SIZE AS WELL AS GLOBAL MIN/MAX VALUE """
# Compute the global minimum and maximum values across all spectra in data_by_redshift
all_flux_values = np.concatenate([np.concatenate(list(physics_dict.values())) 
                                  for physics_dict in data_by_redshift.values()])
print(f"Total {all_flux_values.shape[0]} spectra, of {all_flux_values.shape[1]} features")

MIN_FLUX = np.min(all_flux_values)
MAX_FLUX = np.max(all_flux_values)
print(f"Global min flux value, max flux value: {MIN_FLUX}, {MAX_FLUX}")

In [None]:
# Normalize each spectrum in the range [0, 255]
def normalize_spectrum(spectrum):
    normalized = 255 * (spectrum - MIN_FLUX) / (MAX_FLUX - MIN_FLUX)
    return normalized.astype(np.uint8)
sample_spectra = [normalize_spectrum(data_by_redshift['2.4']['4'][100])]

# Plot the stacked samples
plt.figure(figsize=(20, 0.2))
plt.imshow(sample_spectra, cmap="gray", aspect="auto", interpolation='none')
plt.axis("off")  # Hide axes
plt.show()

In [None]:
""" GRAYSCALE REPRESENTATION """
def into_grayscale(redshift, data_dict, num_samples):
    for physics, spectra in data_dict.items():
        if spectra.size > 0:  # Check if spectra is not empty
            indices = random.sample(range(len(spectra)), min(num_samples, len(spectra)))  # Sample indices
            sample_spectra = [normalize_spectrum(spectra[idx]) for idx in indices]
            stacked_samples = np.vstack(sample_spectra)
            
            # Plot the stacked samples
            plt.figure(figsize=(20, 4))
            plt.imshow(stacked_samples, cmap="gray", aspect="auto", interpolation='none')
            plt.axis("off")  # Hide axes
            plt.title(f"Redshift {redshift}, Physics {physics} : 20 Arrays Concatenated Vertically")
            plt.show()

# Sample and visualize 20 spectra per redshift and physics class
num_samples = 20
for redshift, physics_dict in data_by_redshift.items():
    print("-" * 60 + f"IN REDSHIFT {redshift}" + "-"*60)
    into_grayscale(redshift, physics_dict, num_samples)

In [None]:
""" LOCAL MINIMA IDENTIFICATION PER EACH CLASS """

def local_minima(flux):
    # Find indices of local minima
    local_minima_indices = np.where((flux[1:-1] < flux[:-2]) & (flux[1:-1] < flux[2:]))[0] + 1
    local_minima_values = flux[local_minima_indices]
    return local_minima_indices, local_minima_values

# Example of local minima identification of a spectrum 
# Example array (flux values)
sample_spectrum = data_by_redshift["0.3"]["4"][10000]  
sample_local_minima_indices, sample_local_minima_values = local_minima(sample_spectrum)

# print example local minima 
print(sample_local_minima_indices)
print(sample_local_minima_values)
print(f"This spectrum has {len(sample_local_minima_values)} local minimas")

# Plot local minima
plt.figure(figsize=(14, 4))
plt.plot(sample_spectrum, label="Flux", color="blue", linewidth=1)
plt.scatter(sample_local_minima_indices, sample_local_minima_values, color="red", label="Local Minima", zorder=5)
plt.title("Flux Array with Local Minima")
plt.xlabel("Wavelength Index")
plt.ylabel("Flux Value")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
""" DISTRIBUTION OF LOCAL MINIMA COUNTS """

""" INDIVIDUAL BAR GRAPH FOR EACH REDSHIFT,PHYSICS """
def local_minima_occurances(spectra, redshift, physics):
    local_minima_counts = []
    
    for spectrum in spectra:
        indices, values = local_minima(spectrum)

        local_minima_counts.append(len(indices))    # the number of local minima that this spectrum has

    
    # Analyze the distribution of the counts
    minima_count_mean = np.mean(local_minima_counts)
    minima_count_std = np.std(local_minima_counts)
    minima_count_median = np.median(local_minima_counts)
    minima_count_min = np.min(local_minima_counts)
    minima_count_max = np.max(local_minima_counts)
    
    print(f"Mean count of local minima: {minima_count_mean:.2f}")
    print(f"Standard deviation of local minima counts: {minima_count_std:.2f}")
    print(f"Median count of local minima: {minima_count_median:.2f}")
    print(f"Min count of local minima: {minima_count_min}")
    print(f"Max count of local minima: {minima_count_max}")
    
    # Visualize the histogram of local minima counts
    plt.figure(figsize=((minima_count_max-minima_count_min)/20, 3))
    plt.hist(local_minima_counts, bins=minima_count_max-minima_count_min, color='skyblue', edgecolor='black')
    plt.title(f"Redshift {redshift}, Physics {physics}")
    plt.xlabel("Number of Local Minima")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.show()

    return minima_count_mean, minima_count_std, minima_count_median, minima_count_min, minima_count_max

results = {str(redshift):pd.DataFrame() for redshift, _ in data_by_redshift.items()}
for redshift, physics_dict in data_by_redshift.items():
    print("\n" + "-"*40 + f"local minima count distribution for redshift {redshift}" + "-"*40)
    # Loop through each physics category and process data
    curr_result = []
    for physics, data in physics_dict.items():
        # Get the statistics
        mean_count, std_dev, median_count, min_count, max_count= local_minima_occurances(data, redshift, physics)
        
        curr_result.append({
            "Physics": physics,
            "Mean Count": mean_count,
            "Standard Deviation": std_dev,
            "Median Count": median_count,
            "Min Count": min_count,
            "Max Count": max_count,
        })
    results[redshift] = pd.DataFrame(curr_result)
    print(results[redshift])

In [None]:
""" Local Minima Distribution """

""" Since the bar plot above is too long and the length varies, cut the graph into half """
def local_minima_occurances_below_75(spectra, redshift, physics):
    local_minima_counts = []
    
    for spectrum in spectra:
        indices, values = local_minima(spectrum)
        if len(indices) <= 75:
            local_minima_counts.append(len(indices))  # the number of local minima this spectrum has
    
    if len(local_minima_counts) == 0:
        return 0, 0, 0, 0, 0, [], 0
    
    minima_count_mean = np.mean(local_minima_counts)
    minima_count_std = np.std(local_minima_counts)
    minima_count_median = np.median(local_minima_counts)
    minima_count_min = np.min(local_minima_counts)
    minima_count_max = np.max(local_minima_counts)
    
    unique_counts, counts_frequency = np.unique(local_minima_counts, return_counts=True)
    
    return minima_count_mean, minima_count_std, minima_count_median, minima_count_min, minima_count_max, (unique_counts, counts_frequency), len(local_minima_counts)


results = {str(redshift): pd.DataFrame() for redshift, _ in data_by_redshift.items()}

for redshift, physics_dict in data_by_redshift.items():
    print("\n" + "-" * 40 + f" Local minima count distribution for redshift {redshift} " + "-" * 40)
    curr_result = []
    
    plt.figure(figsize=(12, 6))
    plt.title(f"Redshift {redshift}: Distribution of Local Minima Counts Below 75")
    plt.xlabel("Number of Local Minima")
    plt.ylabel("Frequency")
    plt.grid(True, linestyle='--', alpha=0.7)
    
    for physics, data in physics_dict.items():
        mean_count, std_dev, median_count, min_count, max_count, freq_data, size = local_minima_occurances_below_75(
            data, redshift, physics
        )
        
        curr_result.append({
            "Physics": physics,
            "Mean Count": mean_count,
            "Standard Deviation": std_dev,
            "Median Count": median_count,
            "Min Count": min_count,
            "Max Count": max_count,
            "Size": size
        })
        
        if len(freq_data) > 0:
            unique_counts, counts_frequency = freq_data
            plt.plot(unique_counts, counts_frequency, label=f"Physics: {physics}", marker='o')

    plt.legend()
    plt.show()
    
    results[redshift] = pd.DataFrame(curr_result)
    print(results[redshift])


In [None]:
def local_minima_occurances_above_75(spectra, redshift, physics):
    local_minima_counts = []
    
    for spectrum in spectra:
        indices, values = local_minima(spectrum)
        if len(indices) > 75:
            local_minima_counts.append(len(indices)) 
    
    if len(local_minima_counts) == 0:
        return 0, 0, 0, 0, 0, ([], []), 0
    
    minima_count_mean = np.mean(local_minima_counts)
    minima_count_std = np.std(local_minima_counts)
    minima_count_median = np.median(local_minima_counts)
    minima_count_min = np.min(local_minima_counts)
    minima_count_max = np.max(local_minima_counts)


    unique_counts, counts_frequency = np.unique(local_minima_counts, return_counts=True)

    return minima_count_mean, minima_count_std, minima_count_median, minima_count_min, minima_count_max, (unique_counts, counts_frequency), len(local_minima_counts)


results = {str(redshift): pd.DataFrame() for redshift, _ in data_by_redshift.items()}

for redshift, physics_dict in data_by_redshift.items():
    print("\n" + "-" * 40 + f" Local minima count distribution for redshift {redshift} " + "-" * 40)
    curr_result = []
    
    plt.figure(figsize=(12, 6))
    plt.title(f"Redshift {redshift}: Local Minima Counts Above 75")
    plt.xlabel("Number of Local Minima")
    plt.ylabel("Frequency")
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    for physics, data in physics_dict.items():
        mean_count, std_dev, median_count, min_count, max_count, freq_data, size = local_minima_occurances_above_75(
            data, redshift, physics
        )
        
        curr_result.append({
            "Physics": physics,
            "Mean Count": mean_count,
            "Standard Deviation": std_dev,
            "Median Count": median_count,
            "Min Count": min_count,
            "Max Count": max_count,
            "Size": size
        })
        
        if len(freq_data[0]) > 0:  
            unique_counts, counts_frequency = freq_data
            plt.plot(unique_counts, counts_frequency, label=f"Physics: {physics}", marker='o')
        
    results[redshift] = pd.DataFrame(curr_result)

    plt.legend()
    plt.xlim(75, max(curr_result, key=lambda x: x["Max Count"])["Max Count"] + 5)

    plt.show()

    print(results[redshift])



In [None]:
""" Local Minima Occurrence Index Distribution """
## At which indices do local minima happen mostly?

def get_all_minima_indices(spectra):
    all_minima_indices = []
    
    for flux in spectra:
        indices, values = local_minima(flux)
        all_minima_indices.extend(indices)
    return all_minima_indices
    
def get_top_minima_indices(all_minima_indices):
    index_frequencies = Counter(all_minima_indices)
    sorted_indices = sorted(index_frequencies.items(), key=lambda x: x[1], reverse=True)
    
    top_indices = sorted_indices[:10]  
    print("Top 10 indices with most local minima:")
    for index, freq in top_indices:
        print(f"Index {index}: {freq} occurrences")
    return top_indices
    
def local_minima_counter_histogram(all_minima_indices):
    index_frequencies = Counter(all_minima_indices)
    indices, frequencies = zip(*index_frequencies.items())
    
    plt.figure(figsize=(12, 3))
    plt.hist(indices, bins=int(len(indices) / 10), weights=frequencies, color='skyblue', edgecolor='black')
    plt.xlabel("Index")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.title(f"Redshift {redshift}, Physics {physics} Histogram: Frequency of Local Minima Across Indices")
    plt.show()
    

def local_minima_counter_colorbar(all_minima_indices):
    index_frequencies = Counter(all_minima_indices)
    indices = np.arange(2048)  
    frequencies = np.zeros(2048)  
    for index, freq in index_frequencies.items():
        frequencies[index] = freq
    norm = plt.Normalize(vmin=frequencies.min(), vmax=frequencies.max())
    
    plt.figure(figsize=(20, 2)) 
    plt.scatter(indices, np.zeros_like(indices), c=frequencies, cmap='viridis', norm=norm, s=10)
    
    cbar = plt.colorbar()
    cbar.set_label("Frequency of Local Minima", rotation=270, labelpad=15)
    
    plt.title(f"Redshift {redshift}, Physics {physics} Colorbar: Frequency of Local Minima Across Indices")
    plt.xlabel("Index")
    plt.yticks([])  
    plt.gca().axes.get_yaxis().set_visible(False)  
    plt.grid(axis='x', linestyle='--', alpha=0.5)  
    
    plt.show()



results = {str(redshift):pd.DataFrame() for redshift, _ in data_by_redshift.items()}
for redshift, physics_dict in data_by_redshift.items():
    curr_result = []
    for physics, data in physics_dict.items():
        print("-"*60 + f"Redshift {redshift}, Physics {physics}", "-"*60)
        minima_indices = get_all_minima_indices(data)
        local_minima_counter_histogram(minima_indices)
        local_minima_counter_colorbar(minima_indices)
        curr_result.append([index for index,_ in get_top_minima_indices(minima_indices)])
    print(f"Redshift {redshift}: Top 10 popular wavelengh indices with the most local minima occurances")
    results[redshift] = pd.DataFrame(curr_result)
    results[redshift].columns = [i for i in range(1, 11)]
    results[redshift].index = [f"Physics{i}" for i in range(1, len(data_by_redshift[redshift])+1)]
    print(results[redshift])

## Conclusion: they happen everywhere


In [None]:
""" FURTHER EXPLORATION: IDENTIFY LOCAL MINIMA CENTER, WIDTH, DEPTH """
""" The method to fit the wells needs to be advised """
# Example flux array
flux = data_by_redshift["0.3"]["4"][10000][:500]

# Find local minima indices
local_minima_indices = np.where((flux[1:-1] < flux[:-2]) & (flux[1:-1] < flux[2:]))[0] + 1

# Initialize widths list
minima_widths = []

def calculate_minimum_width(flux, minima_index):
    min_value = flux[minima_index]  # Flux at the minimum
    left_index = minima_index - 1
    right_index = minima_index + 1
   
    while left_index > 0 and flux[left_index-1] < flux[left_index]:
        left_index -= 1

    while right_index < len(flux) - 1 and flux[right_index] < flux[right_index+1]:
        right_index += 1

    width = right_index - left_index
    return width

for index in local_minima_indices:
    width = calculate_minimum_width(flux, index)
    minima_widths.append(width)

print(local_minima_indices)
print(minima_widths)

plt.figure(figsize=(14, 6))
plt.plot(flux, label="Flux", color="blue", linewidth=1)
plt.scatter(local_minima_indices, flux[local_minima_indices], color="red", label="Local Minima", zorder=3)

for i, index in enumerate(local_minima_indices):
    plt.text(index, flux[index], f"{minima_widths[i]:.0f}", color="green", fontsize=8)

plt.title("Flux Array with Local Minima and Widths")
plt.xlabel("Wavelength Index")
plt.ylabel("Flux Value")
plt.legend()
plt.grid(True)

plt.show()

for i, index in enumerate(local_minima_indices):
    print(f"Local Minimum at Index {index}, Depth: {flux[index]:.4f}, Width: {minima_widths[i]}")

# Not really clear how subtle the absorption well should be to be disregarded

In [None]:
# 3. Dataset

In [None]:
# 3-a. Original dataset by redshift: spectra = {'0.1': [[]], '0.3': [[]], ... }
spectra = {str(redshift): [] for redshift in redshift_values}
spectra_labels = {str(redshift): [] for redshift in redshift_values}

# Iterate over each redshift
for redshift, physics_dict in data_by_redshift.items():
    # Iterate over each physics (class)
    for physics, fluxes in physics_dict.items():
        for flux in fluxes:  # fluxes is a list of spectra
            spectra[redshift].append(flux)
            spectra_labels[redshift].append(physics)  # The label is the physics value
    spectra[redshift] = np.array(spectra[redshift])
    spectra_labels[redshift] = np.array(spectra_labels[redshift])


In [None]:
# Concatenate all spectra arrays across redshifts
all_spectra = np.concatenate([spectra_by_redshift for spectra_by_redshift in spectra.values()])
print(all_spectra.shape)

In [None]:
# Reduced data for each redshifts: "pca_data_by_redshift"
from sklearn.decomposition import PCA

var_95_n_components = {}
var_90_n_components = {}
var_85_n_components = {}
var_80_n_components = {}
var_75_n_components = {}
var_70_n_components = {}
var_65_n_components = {}
var_60_n_components = {}

for redshift, curr_spectra in spectra.items():
    print(f"\nPCA analysis for redshift {redshift}")
    # spectra = np.concatenate([data for data in data_by_redshift[redshift].values()])
    print(f"Total {curr_spectra.shape[0]} spectra, each of {curr_spectra.shape[1]} long")
    pca = PCA().fit(curr_spectra)  # Fit PCA without specifying n_components
    plt.plot(np.cumsum(pca.explained_variance_ratio_))
    plt.xlabel('Number of Components')
    plt.ylabel('Cumulative Explained Variance')
    plt.title('Explained Variance vs. Number of Components')
    %matplotlib inline
    plt.figure(figsize=(4, 2.5))
    plt.show()
    
    cumulative_variance = np.cumsum(pca.explained_variance_ratio_)

    n_components_var95 = np.where(cumulative_variance >= 0.95)[0][0] + 1
    n_components_var90 = np.where(cumulative_variance >= 0.90)[0][0] + 1
    n_components_var85 = np.where(cumulative_variance >= 0.85)[0][0] + 1
    n_components_var80 = np.where(cumulative_variance >= 0.80)[0][0] + 1
    n_components_var75 = np.where(cumulative_variance >= 0.75)[0][0] + 1
    n_components_var70 = np.where(cumulative_variance >= 0.70)[0][0] + 1
    n_components_var65 = np.where(cumulative_variance >= 0.65)[0][0] + 1
    n_components_var60 = np.where(cumulative_variance >= 0.60)[0][0] + 1

    var_95_n_components[redshift] = n_components_var95
    var_90_n_components[redshift] = n_components_var90
    var_85_n_components[redshift] = n_components_var85
    var_80_n_components[redshift] = n_components_var80
    var_75_n_components[redshift] = n_components_var75
    var_70_n_components[redshift] = n_components_var70
    var_65_n_components[redshift] = n_components_var65
    var_60_n_components[redshift] = n_components_var60

    print(f"Number of components to retain 95% variance for redshift {redshift} : {n_components_var95}")
    print(f"Number of components to retain 90% variance for redshift {redshift} : {n_components_var90}")
    print(f"Number of components to retain 85% variance for redshift {redshift} : {n_components_var85}")
    print(f"Number of components to retain 80% variance for redshift {redshift} : {n_components_var80}")
    print(f"Number of components to retain 75% variance for redshift {redshift} : {n_components_var75}")
    print(f"Number of components to retain 70% variance for redshift {redshift} : {n_components_var70}")
    print(f"Number of components to retain 65% variance for redshift {redshift} : {n_components_var65}")
    print(f"Number of components to retain 60% variance for redshift {redshift} : {n_components_var60}")

In [None]:
def pca_transform(dict_n_components):    # Use specific n_components values for each redshift
    reduced_data_by_redshift = {str(redshift): {str(physics): [] for physics in [1,2,3,4]} for redshift in data_by_redshift.keys()}
    for redshift, physics_dict in data_by_redshift.items():
        n_components = dict_n_components[redshift]  # Get n_components for this redshift
        print(f"\nApplying PCA on redshift {redshift} with n_components = {n_components}")
        for physics, data in physics_dict.items():
            pca = PCA(n_components=n_components)
            pca_transformed = pca.fit_transform(data)
            
            reduced_data_by_redshift[redshift][physics] = pca_transformed
            
            print(f"Transformed shape for physics {physics}: {reduced_data_by_redshift[redshift][physics].shape}")
    
    del reduced_data_by_redshift['0.1']['4']     
    
    return reduced_data_by_redshift

In [None]:
print("Reducing data with 95% variance")
pca_95_data_by_redshift = pca_transform(var_95_n_components)   # var_95_n_components = {'0.1':111, '0.3':98, '2.2':106, '2.4':109}

In [None]:
print("Reducing data with 90% variance")
pca_90_data_by_redshift = pca_transform(var_90_n_components)    # var_90_n_components = {'0.1':84, '0.3':73, '2.2':81, '2.4':83}

In [None]:
print("Reducing data with 85% variance")
pca_85_data_by_redshift = pca_transform(var_85_n_components)  

In [None]:
print("Reducing data with 80% variance")
pca_80_data_by_redshift = pca_transform(var_80_n_components)    

In [None]:
print("Reducing data with 75% variance")
pca_75_data_by_redshift = pca_transform(var_75_n_components)   

In [None]:
print("Reducing data with 70% variance")
pca_70_data_by_redshift = pca_transform(var_70_n_components)   

In [None]:
print("Reducing data with 65% variance")
pca_65_data_by_redshift = pca_transform(var_65_n_components)   

In [None]:
print("Reducing data with 60% variance")
pca_60_data_by_redshift = pca_transform(var_60_n_components)   

In [None]:
""" How PCA transformation looks like for spectrum data """
num_samples = 3  # Number of spectra to visualize per redshift and physics

for redshift, physics_dict in data_by_redshift.items():
    for physics, original_spectra in physics_dict.items():
        
        sample_indices = np.random.choice(len(original_spectra), num_samples, replace=False)
        original_samples = np.array([original_spectra[idx] for idx in sample_indices])
        
        pca_90_samples = pca_90_data_by_redshift[redshift][physics][sample_indices]
        pca_95_samples = pca_95_data_by_redshift[redshift][physics][sample_indices]
        
        pca_90 = PCA(n_components=var_90_n_components[redshift])
        pca_90.fit(original_spectra)  
        reconstructed_90_samples = pca_90.inverse_transform(pca_90_samples)
        
        pca_95 = PCA(n_components=var_95_n_components[redshift])
        pca_95.fit(original_spectra)  
        reconstructed_95_samples = pca_95.inverse_transform(pca_95_samples)

        fig, axes = plt.subplots(num_samples, 3, figsize=(15, num_samples * 2))
        fig.suptitle(f"Redshift {redshift}, Physics {physics} - Original vs PCA-90% vs PCA-95% Reconstructed Spectra", fontsize=14)
        
        for i in range(num_samples):
            mse_90 = mean_squared_error(original_samples[i], reconstructed_90_samples[i])
            mse_95 = mean_squared_error(original_samples[i], reconstructed_95_samples[i])
            cos_sim_90 = cosine_similarity(original_samples[i].reshape(1, -1), reconstructed_90_samples[i].reshape(1, -1))[0, 0]
            cos_sim_95 = cosine_similarity(original_samples[i].reshape(1, -1), reconstructed_95_samples[i].reshape(1, -1))[0, 0]
            
            axes[i, 0].plot(original_samples[i], color='orange')
            axes[i, 0].set_title("Original")
            
            axes[i, 1].plot(reconstructed_90_samples[i], color='green')
            axes[i, 1].set_title("PCA-90% Reconstructed")
            axes[i, 1].text(0.15, -0.0, f"MSE: {mse_90:.4f}\nCosine Sim: {cos_sim_90:.4f}", transform=axes[i, 1].transAxes,
                            fontsize=8, ha='center', va='bottom', 
                            bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.3))
            
            axes[i, 2].plot(reconstructed_95_samples[i], color='blue')
            axes[i, 2].set_title("PCA-95% Reconstructed")
            axes[i, 2].text(0.15, -0.0, f"MSE: {mse_95:.4f}\nCosine Sim: {cos_sim_95:.4f}", transform=axes[i, 2].transAxes,
                            fontsize=8, ha='center', va='bottom', 
                            bbox=dict(boxstyle="round,pad=0.3", edgecolor="green", facecolor="lightgreen", alpha=0.3))
        
        
        plt.tight_layout(rect=[0, 0, 1, 0.96])  
        plt.show()

        fig, axes = plt.subplots(num_samples, 3, figsize=(15, num_samples * 2))
        fig.suptitle(f"Redshift {redshift}, Physics {physics} - Original vs PCA-90% vs PCA-95% Reduced Spectra", fontsize=14)

        for i in range(num_samples):
            axes[i, 0].plot(original_samples[i], color='orange')
            axes[i, 0].set_title("Original")
            
            axes[i, 1].plot(pca_90_samples[i], color='green')
            axes[i, 1].set_title("PCA-90% Reduced")
            
            axes[i, 2].plot(pca_95_samples[i], color='blue')
            axes[i, 2].set_title("PCA-95% Reduced")
        
        plt.tight_layout(rect=[0, 0, 1, 0.96])  
        plt.show()

In [None]:
# 4. Train SVM with k-fold cross validation and hyperparameter tuning

In [None]:
def dump_gs_results_analysis(gs_results, excel_file):    
    gs_df = {}
    for redshift, grid_search in gs_results.items():
        df = pd.DataFrame.from_dict(grid_search.cv_results_)
        
        df['overfit'] = df['mean_train_score'] - df['mean_test_score']
    
        condition_1 = df['mean_train_score'] == 1
        condition_2 = abs(df['mean_test_score'] - df['mean_train_score']) == \
                      abs(df['mean_test_score'] - df['mean_train_score']).min()
    
        df['highlight'] = ''  
        df.loc[condition_1, 'highlight'] = 'red'
        df.loc[condition_2, 'highlight'] = 'green'
    
        gs_df[redshift] = df
    
    with pd.ExcelWriter(excel_file, engine='openpyxl') as writer:
        for redshift, df in gs_df.items():
            df.to_excel(writer, sheet_name=f'redshift_{redshift}', index=False)
    
    wb = load_workbook(excel_file)
    
    for redshift, df in gs_df.items():
        sheet = wb[f'redshift_{redshift}']
        
        for row_idx, highlight in enumerate(df['highlight'], start=2):  
            fill = None
            if highlight == 'red':
                fill = PatternFill(start_color="E6B8B7", end_color="E6B8B7", fill_type="solid")
            elif highlight == 'green':
                fill = PatternFill(start_color="B7DEE8", end_color="B7DEE8", fill_type="solid")
            
            if fill:
                for col_idx in range(1, len(df.columns) + 1):  
                    sheet.cell(row=row_idx, column=col_idx).fill = fill
    
    wb.save(excel_file)


def svm_tune_hyperparam(param_grid, size_per_class, pca_ed_data_by_redshift, n_comp):
    grid_search_results = {}
    
    for redshift, physics_dict in pca_ed_data_by_redshift.items():
        
        print(f"\nTraining SVM for redshift {redshift}")
        X = []
        y = []
        print(f"redshift: {redshift}")
        for physics, data in physics_dict.items():
            indices = np.random.choice(data.shape[0], size=size_per_class, replace=False)
            X.append(data[indices])  
            y.extend([physics] * size_per_class)  

        X = np.vstack(X)
        y = np.array(y).astype(int) 
        
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=40)
        print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
        
        grid_search = GridSearchCV(SVC(), param_grid, cv=KFold(n_splits=5, shuffle=True), scoring='accuracy', verbose=1, return_train_score=True)
        
        print("Performing grid search...")
        grid_search.fit(X_train, y_train)

        best_params = grid_search.best_params_

        print(f"Best parameters for redshift {redshift}: {best_params}")
        print(f"Best cross-validation accuracy: {grid_search.best_score_:.4f}")
        
        grid_search_results[redshift] = grid_search
        dump_gs_results_analysis(grid_search_results, f"gs_results_{n_comp}_{size_per_class}.xlsx")    
        
    return grid_search_results


In [None]:
# Define hyperparameter grid
param_grid_1 = {
    'C': [1, 10, 100],  
    'gamma': [0.1, 1, 10],  
    'kernel': ['rbf', 'linear', 'poly']  
}

pca_85_300_gs_results = svm_tune_hyperparam(param_grid, 300, pca_85_data_by_redshift, 85)
pca_80_300_gs_results = svm_tune_hyperparam(param_grid, 300, pca_80_data_by_redshift, 80)


In [None]:
param_grid_2 = {
    'C': [1, 0.1, 0.01],  
    'gamma': [0.1, 1, 10],  
    'kernel': ['poly'] 
}

pca_85_500_gs_results = svm_tune_hyperparam(param_grid_2, 500, pca_85_data_by_redshift, 85)
pca_80_500_gs_results = svm_tune_hyperparam(param_grid_2, 500, pca_80_data_by_redshift, 80)


In [None]:
param_grid_3 = {
    'C': [0.1, 0.01, 0.001],  
    'gamma': [0.1, 1, 10],  
    'kernel': ['poly'] 
}

pca_85_800_gs_results = svm_tune_hyperparam(param_grid_3, 800, pca_85_data_by_redshift, 85)
pca_80_800_gs_results = svm_tune_hyperparam(param_grid_3, 800, pca_80_data_by_redshift, 80)


In [None]:
param_grid_4 = {
    'C': [0.01, 0.001, 0.0001],  
    'gamma': [0.1, 1, 10],  
    'kernel': ['poly'] 
}

pca_85_1000_gs_results = svm_tune_hyperparam(param_grid_4, 1000, pca_85_data_by_redshift, 85)
pca_80_1000_gs_results = svm_tune_hyperparam(param_grid_4, 1000, pca_80_data_by_redshift, 80)


In [None]:
param_grid_5 = {
    'C': [0.005, 0.001, 0.0005],  
    'gamma': [0.1, 1, 10],  
    'kernel': ['poly'] 
}

pca_85_500_gs_results = svm_tune_hyperparam(param_grid_5, 2000, pca_85_data_by_redshift, 85)
pca_80_500_gs_results = svm_tune_hyperparam(param_grid_5, 2000, pca_80_data_by_redshift, 80)


In [None]:
param_grid_5 = {
    'C': [0.005, 0.001, 0.0005],  
    'gamma': [0.1, 1, 10],  
    'kernel': ['poly'] 
}

pca_75_2000_gs_results = svm_tune_hyperparam(param_grid_5, 2000, pca_75_data_by_redshift, 75)
pca_70_2000_gs_results = svm_tune_hyperparam(param_grid_5, 2000, pca_70_data_by_redshift, 70)


In [None]:
param_grid_5 = {
    'C': [0.005, 0.001, 0.0005],
    'gamma': [0.1, 1, 10],  
    'kernel': ['poly'] 
}

pca_65_2000_gs_results = svm_tune_hyperparam(param_grid_5, 2000, pca_65_data_by_redshift, 65)
pca_60_2000_gs_results = svm_tune_hyperparam(param_grid_5, 2000, pca_60_data_by_redshift, 60)


In [None]:
param_grid_6 = {
    'C': [0.005, 0.001, 0.0005], 
    'gamma': [0.005, 0.01, 0.05],  
    'kernel': ['linear', 'poly', 'rbf'] 
}

pca_85_2500_gs_results = svm_tune_hyperparam(param_grid_6, 2500, pca_85_data_by_redshift, 85)
pca_80_2500_gs_results = svm_tune_hyperparam(param_grid_6, 2500, pca_80_data_by_redshift, 80)


In [None]:
pca_75_2500_gs_results = svm_tune_hyperparam(param_grid_6, 2500, pca_75_data_by_redshift, 75)
pca_70_2500_gs_results = svm_tune_hyperparam(param_grid_6, 2500, pca_70_data_by_redshift, 70)


In [None]:
# 5. Evaluation and Result Analysis

In [None]:
base_dir = "./"
output_file = "consolidated_gs_results.xlsx"

consolidated_data = {
    "redshift_0.1": [],
    "redshift_0.3": [],
    "redshift_2.2": [],
    "redshift_2.4": []
}

for file_name in os.listdir(base_dir):
    print(f"Processing {file_name}")
    if file_name.endswith(".xlsx") and file_name.startswith("gs_results_"):
        file_path = os.path.join(base_dir, file_name)
        
        n_comp, data_size = file_name.replace(".xlsx", "").split("_")[2:]
        print(f"n_comp={n_comp}, data_size={data_size}")
        
        xls = pd.ExcelFile(file_path)
        
        for sheet_name in xls.sheet_names:
            print(f"On sheet {sheet_name}")
            df = pd.read_excel(file_path, sheet_name=sheet_name)
            
            df['n_comp'] = int(n_comp)
            df['data_size'] = int(data_size)
        
            consolidated_data[sheet_name].append(df)

with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
    for sheet_name, data_list in consolidated_data.items():
        combined_df = pd.concat(data_list, ignore_index=True)
        combined_df.to_excel(writer, sheet_name=sheet_name, index=False)

print(f"Consolidation complete! Results saved to {output_file}.")


In [None]:
def performance_vs_overfit():
    excel_file = "consolidated_gs_results.xlsx"
    sheets = pd.ExcelFile(excel_file).sheet_names  
    
    for sheet in sheets:
        df = pd.read_excel(excel_file, sheet_name=sheet)
        
        if 'overfit' not in df.columns or 'mean_test_score' not in df.columns:
            print(f"Skipping sheet {sheet}: required columns not found.")
            continue
        
        plt.figure(figsize=(8, 5))
        plt.scatter(df['mean_test_score'], df['overfit'], alpha=0.7, c='blue', label='Overfit vs Test Score')
        plt.axhline(0, color='red', linestyle='--', linewidth=1, label='No Overfit Line')
        
        plt.title(f"Overfit vs Mean Test Score for {sheet}")
        plt.xlabel("Mean Test Score")
        plt.ylabel("Overfit (Train - Test Score)")
        plt.legend()
        plt.grid(alpha=0.5)

In [None]:
def factors_heatmap():
    file_path = 'consolidated_gs_results.xlsx'
    sheet_names = ['redshift_0.1', 'redshift_0.3', 'redshift_2.2', 'redshift_2.4']
    
    columns_to_include = ['param_C', 'param_gamma', 'param_kernel', 'mean_test_score', 
                          'mean_train_score', 'overfit', 'n_comp', 'data_size']
    
    for sheet in sheet_names:
        data = pd.read_excel(file_path, sheet_name=sheet)
        selected_data = data[columns_to_include]
        selected_data = pd.get_dummies(selected_data, drop_first=False)  # One-hot encoding
        correlation_matrix = selected_data.corr()
        
        plt.figure(figsize=(10, 8))
        sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt=".2f", cbar=True)
        plt.title(f"Factors correlation Heatmap for {sheet}")
        plt.show()

factors_heatmap()

In [None]:
param_grid_7 = {
    'C': [0.005, 0.001, 0.01],  
    'gamma': [0.01, 0.03, 0.1],  
    'kernel': ['poly'] 
}

pca_85_3000_gs_results = svm_tune_hyperparam(param_grid_7, 3000, pca_85_data_by_redshift, 85)
pca_80_3000_gs_results = svm_tune_hyperparam(param_grid_7, 3000, pca_80_data_by_redshift, 80)
pca_75_3000_gs_results = svm_tune_hyperparam(param_grid_7, 3000, pca_75_data_by_redshift, 75)
pca_70_3000_gs_results = svm_tune_hyperparam(param_grid_7, 3000, pca_70_data_by_redshift, 70)


In [None]:
param_grid_8 = {
    'C': [0.005, 0.001, 0.01],  
    'gamma': [0.8, 1, 2],  
    'kernel': ['poly'] 
}

pca_85_3300_gs_results = svm_tune_hyperparam(param_grid_8, 3300, pca_85_data_by_redshift, 85)
pca_80_3300_gs_results = svm_tune_hyperparam(param_grid_8, 3300, pca_80_data_by_redshift, 80)
pca_75_3300_gs_results = svm_tune_hyperparam(param_grid_8, 3300, pca_75_data_by_redshift, 75)
pca_70_3300_gs_results = svm_tune_hyperparam(param_grid_8, 3300, pca_70_data_by_redshift, 70)


In [None]:
factors_heatmap()

In [None]:
# 6. Hyperparameter Tuning again

In [None]:
param_grid_9 = {
    'C': [0.005, 0.001, 0.01],  
    'gamma': [2, 3, 4],  
    'kernel': ['poly'] 
}

pca_85_3500_gs_results = svm_tune_hyperparam(param_grid_9, 3500, pca_85_data_by_redshift, 85)
pca_80_3500_gs_results = svm_tune_hyperparam(param_grid_9, 3500, pca_80_data_by_redshift, 80)
pca_75_3500_gs_results = svm_tune_hyperparam(param_grid_9, 3500, pca_75_data_by_redshift, 75)
pca_70_3500_gs_results = svm_tune_hyperparam(param_grid_9, 3500, pca_70_data_by_redshift, 70)


In [None]:
param_grid_10 = {
    'C': [0.005, 0.001, 0.01],  
    'gamma': [0.01, 0.02, 0.03],  
    'kernel': ['poly'] 
}

pca_85_3700_gs_results = svm_tune_hyperparam(param_grid_10, 3700, pca_85_data_by_redshift, 85)
pca_80_3700_gs_results = svm_tune_hyperparam(param_grid_10, 3700, pca_80_data_by_redshift, 80)
pca_75_3700_gs_results = svm_tune_hyperparam(param_grid_10, 3700, pca_75_data_by_redshift, 75)
pca_70_3700_gs_results = svm_tune_hyperparam(param_grid_10, 3700, pca_70_data_by_redshift, 70)


In [None]:
def performance_vs_overfit():
    excel_file = "consolidated_gs_results.xlsx"
    sheets = pd.ExcelFile(excel_file).sheet_names  # Get all sheet names
    
    for sheet in sheets:
        # Load the data for the current sheet
        df = pd.read_excel(excel_file, sheet_name=sheet)
        
        # Ensure the required columns are present
        required_columns = ['overfit', 'mean_test_score', 'data_size', 'n_comp']    
        if not all(col in df.columns for col in required_columns):
            print(f"Skipping sheet {sheet}: required columns not found.")
            continue
        
        # Compute the ratio for coloring
        df['size_to_comp_ratio'] = df['data_size'] / df['n_comp']
        
        # Plot the graph
        plt.figure(figsize=(8, 5))
        scatter = plt.scatter(
            df['mean_test_score'], 
            df['overfit'], 
            c=df['size_to_comp_ratio'], 
            cmap='viridis', 
            alpha=0.7, 
            label='Overfit vs Test Score'
        )
        plt.axhline(0, color='red', linestyle='--', linewidth=1, label='No Overfit Line')
        
        # Adding color bar to explain ratio
        cbar = plt.colorbar(scatter)
        cbar.set_label('Data Size / n_feature Ratio')
        
        # Adding titles and labels
        plt.title(f"Overfit vs Mean Test Score for {sheet}")
        plt.xlabel("Mean Test Score")
        plt.ylabel("Overfit (Train - Test Score)")
        plt.legend()
        plt.grid(alpha=0.5)
        
        plt.show()

performance_vs_overfit()

In [None]:
factors_heatmap()

In [None]:
# 7. Train all over with the best hyperparameter and full dataset
# skipping this part due to limited computational resource

In [None]:
def svm_train_with_best_hyperparam(grid_search_results, full_data):
    for redshift, grid_search in grid_search_results.items(): 
        # best_params = grid_search.best_params_
        # print(f"Best parameters for redshift {redshift}: {best_params}")
        # print(f"Best cross-validation accuracy: {grid_search.best_score_:.4f}")
        
        best_svm = SVC(**best_params)
        best_svm.fit(X_train, y_train)
    
        model_filename = f"svm_pca_.pkl"
        joblib.dump(best_svm, model_filename)
        logging.info(f"Model for redshift {redshift} saved to {model_filename}")
        
        train_accuracy = accuracy_score(y_train, best_svm.predict(X_train))
        test_accuracy = accuracy_score(y_test, best_svm.predict(X_test))
        print(f"Train accuracy for redshift {redshift}: {train_accuracy:.4f}")
        print(f"Test accuracy for redshift {redshift}: {test_accuracy:.4f}")
        grid_search_results[redshift] = grid_search.cv_results_



In [None]:
# 8. Training SVM with Feature Extraction: 6 features
# Somehow this task takes too long even with very small amount of sample data. 
# Skip

In [None]:
def sample_local_minima(data_by_redshift, num_samples=1000, sample_size=100):
    new_dict = {
        redshift: {physics: [] for physics in physics_dict} 
        for redshift, physics_dict in data_by_redshift.items()
    }
    for redshift, physics_dict in data_by_redshift.items():
        for physics, flux_data in physics_dict.items():            
            new_dataset = []  
            
            all_local_minima = []
            for spectrum in flux_data:
                indices, _ = local_minima(spectrum)
                all_local_minima.append(indices)
            
            for _ in range(num_samples):
                sampled_indices = random.sample(range(len(all_local_minima)), sample_size)
                sampled_minima = [all_local_minima[i] for i in sampled_indices]
                
                concatenated = np.concatenate(sampled_minima)
                stats = [
                    round(np.mean(concatenated),4),
                    round(np.min(concatenated),4),
                    round(np.max(concatenated),4),
                    round(np.std(concatenated),4),
                    round(np.median(concatenated),4),
                    len(concatenated)
                ]
                new_dataset.append(stats)
            
            new_dict[redshift][physics] = np.array(new_dataset)
    
    return new_dict

local_minima_stats_by_redshift = sample_local_minima(data_by_redshift, num_samples=2000, sample_size=2000)
param_grid_10 = {
    'C': [0.005, 0.001, 0.0005],  
    'gamma': [0.1, 0.5, 1],  
    'kernel': ['linear', 'rbf', 'poly'] 
}
# Below training code takes too long
sample_local_minima_100_gs_results = svm_tune_hyperparam(param_grid_10, 100, local_minima_stats_by_redshift, 6)
