In [18]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import interpolate
from astropy.io import fits
from astropy.io import ascii as astropy_ascii
from astropy.table import Table
from tqdm.auto import tqdm

In [13]:
class SpectrumProcessor:
    def __init__(self, metadata_path, spectra_dir, output_dir=None):
        """
        Initialize the processor with paths to metadata and spectra.

        Parameters:
        -----------
        metadata_path : str
            Path to the metadata CSV file
        spectra_dir : str
            Directory containing the spectra files
        output_dir : str, optional
            Directory to save processed spectra
        """
        self.metadata_path = metadata_path
        self.spectra_dir = spectra_dir
        self.output_dir = output_dir or os.path.join(spectra_dir, "processed")

        # Create output directory if it doesn't exist
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Load metadata
        self.metadata = pd.read_csv(metadata_path)

        # Define standard wavelength grid for resampling
        # This range covers optical to near-IR (common for supernovae)
        self.wavelength_grid = np.linspace(3500, 10000, 1000)

    def read_spectrum(self, filepath):
        """
        Read a spectrum file in any format.

        Parameters:
        -----------
        filepath : str
            Path to the spectrum file

        Returns:
        --------
        wavelength : numpy.ndarray
            Wavelength array
        flux : numpy.ndarray
            Flux array
        """
        filename = os.path.basename(filepath)
        extension = os.path.splitext(filename)[1].lower()

        try:
            # ECSV format (DESI spectra)
            if extension == ".ecsv":
                data = Table.read(filepath, format="ascii.ecsv")
                wavelength = np.array(data["WAVE"])
                flux = np.array(data["FLUX"])

            # ASCII format for ePESSTO data
            elif extension == ".asci":
                data = np.loadtxt(filepath)
                wavelength = data[:, 0]
                flux = data[:, 1]

            # CAT format
            elif extension == ".cat":
                data = np.loadtxt(filepath)
                wavelength = data[:, 0]
                flux = data[:, 1]

            # DAT format (model spectra)
            elif extension == ".dat":
                data = np.loadtxt(filepath)
                wavelength = data[:, 0]
                flux = data[:, 1]

            # FITS format
            elif extension in [".fits", ".fit", ".fts"]:
                with fits.open(filepath) as hdul:
                    # Try to figure out the correct extension and column names
                    # This might need adjustment based on your specific FITS files
                    for hdu in hdul:
                        if isinstance(hdu, fits.BinTableHDU):
                            table = Table(hdu.data)
                            # Look for common column names for wavelength and flux
                            for wave_col in [
                                "WAVE",
                                "WAVELENGTH",
                                "LAMBDA",
                                "LOGLAM",
                                "CRVAL1",
                            ]:
                                if wave_col in table.colnames:
                                    wavelength = np.array(table[wave_col])
                                    break
                            for flux_col in ["FLUX", "FLAM", "SPEC", "DATA"]:
                                if flux_col in table.colnames:
                                    flux = np.array(table[flux_col])
                                    break
                            break

            # Cal_galsub format
            elif "_galsub" in filename or ".cal_galsub" in filename:
                data = np.loadtxt(filepath, skiprows=2)
                wavelength = data[:, 0]
                flux = data[:, 1]

            # Standard ASCII format with two columns (wavelength, flux)
            else:
                # Try different reading approaches
                try:
                    # Try standard loadtxt first
                    data = np.loadtxt(filepath)
                    wavelength = data[:, 0]
                    flux = data[:, 1]
                except:
                    # If that fails, try astropy.io.ascii which can handle more formats
                    data = astropy_ascii.read(filepath)
                    # Try to identify the wavelength and flux columns
                    if len(data.colnames) >= 2:
                        wavelength = np.array(data[data.colnames[0]])
                        flux = np.array(data[data.colnames[1]])
                    else:
                        raise ValueError(
                            f"Could not identify wavelength and flux columns in {filepath}"
                        )

            return wavelength, flux

        except Exception as e:
            print(f"Error reading {filepath}: {str(e)}")
            return None, None

    def preprocess_spectrum(self, wavelength, flux):
        """
        Preprocess a spectrum by:
        1. Removing NaN or infinite values
        2. Removing negative wavelengths
        3. Sorting by wavelength

        Parameters:
        -----------
        wavelength : numpy.ndarray
            Wavelength array
        flux : numpy.ndarray
            Flux array

        Returns:
        --------
        wavelength : numpy.ndarray
            Cleaned wavelength array
        flux : numpy.ndarray
            Cleaned flux array
        """
        if wavelength is None or flux is None:
            return None, None

        # Make sure arrays are the same length
        min_len = min(len(wavelength), len(flux))
        wavelength = wavelength[:min_len]
        flux = flux[:min_len]

        # Remove NaN and inf values
        mask = np.isfinite(wavelength) & np.isfinite(flux)
        wavelength = wavelength[mask]
        flux = flux[mask]

        # Remove negative wavelengths
        mask = wavelength > 0
        wavelength = wavelength[mask]
        flux = flux[mask]

        # Sort by wavelength
        sort_idx = np.argsort(wavelength)
        wavelength = wavelength[sort_idx]
        flux = flux[sort_idx]

        return wavelength, flux

    def normalize_spectrum(self, flux):
        """
        Normalize a spectrum by:
        1. Subtracting the mean
        2. Dividing by the standard deviation

        Parameters:
        -----------
        flux : numpy.ndarray
            Flux array

        Returns:
        --------
        norm_flux : numpy.ndarray
            Normalized flux array
        """
        if flux is None or len(flux) == 0:
            return None

        # Remove outliers (optional)
        # This can help with extreme values that might skew normalization
        q1, q3 = np.percentile(flux, [1, 99])
        iqr = q3 - q1
        mask = (flux >= q1 - 1.5 * iqr) & (flux <= q3 + 1.5 * iqr)

        if np.sum(mask) > len(flux) * 0.5:  # If we still have enough data
            mean = np.mean(flux[mask])
            std = np.std(flux[mask])
        else:
            mean = np.mean(flux)
            std = np.std(flux)

        # Avoid division by zero
        if std == 0:
            std = 1.0

        norm_flux = (flux - mean) / std

        return norm_flux

    def resample_spectrum(self, wavelength, flux):
        """
        Resample a spectrum to a standard wavelength grid using interpolation.

        Parameters:
        -----------
        wavelength : numpy.ndarray
            Original wavelength array
        flux : numpy.ndarray
            Original flux array

        Returns:
        --------
        resampled_flux : numpy.ndarray
            Resampled flux array on the standard wavelength grid
        """
        if wavelength is None or flux is None or len(wavelength) < 2:
            return None

        # Check wavelength coverage against the standard grid
        if (
            wavelength.min() > self.wavelength_grid.min()
            or wavelength.max() < self.wavelength_grid.max()
        ):
            # Spectrum doesn't cover the entire standard grid
            # We'll interpolate within range and set values outside to NaN
            valid_range = (self.wavelength_grid >= wavelength.min()) & (
                self.wavelength_grid <= wavelength.max()
            )

            # Create interpolation function for the valid range
            f = interpolate.interp1d(
                wavelength, flux, bounds_error=False, fill_value=np.nan
            )
            resampled_flux = np.full_like(self.wavelength_grid, np.nan)
            resampled_flux[valid_range] = f(self.wavelength_grid[valid_range])
        else:
            # Spectrum covers the entire standard grid
            f = interpolate.interp1d(
                wavelength, flux, bounds_error=False, fill_value=np.nan
            )
            resampled_flux = f(self.wavelength_grid)

        return resampled_flux

    def process_all_spectra(self):
        """
        Process all spectra in the directory and create a tensor for machine learning.

        Returns:
        --------
        spectra_tensor : numpy.ndarray
            3D tensor containing all processed spectra
            Shape: (n_spectra, n_wavelength_points, 1)
        metadata_subset : pandas.DataFrame
            Metadata for the successfully processed spectra
        """
        processed_spectra = []
        successful_indices = []

        for i, row in tqdm(
            self.metadata.iterrows(),
            total=len(self.metadata),
            desc="Processing spectra",
        ):
            try:
                # Get filename from metadata
                filename = row.get("filename", "")
                if not filename:
                    # If filename not in metadata, try to match based on other fields
                    # This depends on your metadata structure
                    continue

                filepath = os.path.join(self.spectra_dir, filename)
                if not os.path.exists(filepath):
                    # Try to find the file by fuzzy matching
                    possible_files = glob.glob(
                        os.path.join(
                            self.spectra_dir, f"*{os.path.splitext(filename)[0]}*"
                        )
                    )
                    if possible_files:
                        filepath = possible_files[0]
                    else:
                        print(f"Could not find file for {filename}")
                        continue

                # Read and process the spectrum
                wavelength, flux = self.read_spectrum(filepath)
                wavelength, flux = self.preprocess_spectrum(wavelength, flux)

                # Check if we have valid data
                if wavelength is None or flux is None or len(wavelength) < 10:
                    print(f"Insufficient data in {filename}, skipping")
                    continue

                # Normalize and resample
                norm_flux = self.normalize_spectrum(flux)
                resampled_flux = self.resample_spectrum(wavelength, norm_flux)

                # Check if resampling was successful
                if resampled_flux is None or np.all(np.isnan(resampled_flux)):
                    print(f"Resampling failed for {filename}, skipping")
                    continue

                # Fill NaN values with zeros or interpolate
                # This is a simple approach; you might want something more sophisticated
                mask = np.isnan(resampled_flux)
                if np.any(mask):
                    # If less than 30% NaN, interpolate
                    if np.sum(mask) < 0.3 * len(resampled_flux):
                        valid_indices = np.where(~mask)[0]
                        valid_values = resampled_flux[valid_indices]
                        nan_indices = np.where(mask)[0]

                        # Simple linear interpolation
                        interp_values = np.interp(
                            nan_indices,
                            valid_indices,
                            valid_values,
                            left=0,  # Or another strategy for extrapolation
                            right=0,
                        )
                        resampled_flux[nan_indices] = interp_values
                    else:
                        # Too many NaNs, skip this spectrum
                        print(f"Too many NaN values in {filename}, skipping")
                        continue

                # Reshape for the tensor
                reshaped_flux = resampled_flux.reshape(1, -1, 1)
                processed_spectra.append(reshaped_flux)
                successful_indices.append(i)

                # Optional: save the processed spectrum
                if self.output_dir:
                    output_file = os.path.join(
                        self.output_dir, f"processed_{os.path.basename(filename)}.npy"
                    )
                    np.save(output_file, reshaped_flux)

            except Exception as e:
                print(f"Error processing {row.get('filename', 'unknown')}: {str(e)}")
                continue

        # Combine all processed spectra into a tensor
        if not processed_spectra:
            print("No spectra were successfully processed")
            return None, None

        # Convert list of arrays to a single 3D tensor
        # Each spectrum is a 1 x n_wavelength_points x 1 array
        # We stack along the first dimension to get n_spectra x n_wavelength_points x 1
        spectra_tensor = np.vstack(processed_spectra)

        # Get the subset of metadata for successfully processed spectra
        metadata_subset = self.metadata.iloc[successful_indices].reset_index(drop=True)

        return spectra_tensor, metadata_subset

    def visualize_spectra(self, spectra_tensor, n_samples=5):
        """
        Visualize a random sample of processed spectra.

        Parameters:
        -----------
        spectra_tensor : numpy.ndarray
            3D tensor containing all processed spectra
        n_samples : int, optional
            Number of spectra to visualize
        """
        if spectra_tensor is None or spectra_tensor.shape[0] == 0:
            print("No spectra to visualize")
            return

        # Select random indices
        n_spectra = spectra_tensor.shape[0]
        indices = np.random.choice(n_spectra, min(n_samples, n_spectra), replace=False)

        # Create plot
        plt.figure(figsize=(12, 8))
        for i, idx in enumerate(indices):
            plt.subplot(n_samples, 1, i + 1)
            plt.plot(self.wavelength_grid, spectra_tensor[idx, :, 0])
            plt.title(f"Spectrum {idx}")
            plt.xlabel("Wavelength (Å)")
            plt.ylabel("Normalized Flux")

        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, "sample_spectra.png"))
        plt.close()

In [14]:
metadata_path = "../1. download ALL wise data/wiserep_spectra_combined.csv"
spectra_dir = "../1. download ALL wise data/wiserep_data/spectra/"
output_dir = "output/"

In [15]:
# Create processor
processor = SpectrumProcessor(
    metadata_path=metadata_path,
    spectra_dir=spectra_dir,
    output_dir=output_dir,
)

  self.metadata = pd.read_csv(metadata_path)


In [16]:
print("Processing spectra...")

Processing spectra...


In [17]:
spectra_tensor, metadata_subset = processor.process_all_spectra()

Processing spectra:   0%|          | 0/54005 [00:00<?, ?it/s]

No spectra were successfully processed
