In [None]:
# Standard library imports
import warnings
import time
from pprint import pprint

# Third-party library imports
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import cupy as cp
import tqdm
from scipy.interpolate import interp1d
import torchaudio
import torchaudio.functional as AF
import torchaudio.transforms as AT

# Jupyter notebook specific
%matplotlib inline

# Suppress warnings
warnings.filterwarnings("ignore", "Wswiglal-redir-stdio")

# Local imports - core functionality
from coredldev.dataset import CoReDataset
from coredldev.utilites.pipeline import pipeline

# Local imports - finders and sources
from coredldev.finders.distance_scaling.h5 import h5Finder
from coredldev.sources.distance_scaling.h5 import h5Source

# Local imports - preprocessing steps
from coredldev.preprocessing.raw_postmerger.detector_angle_mixing import DetectorAngleMixing
from coredldev.preprocessing.raw_postmerger.fast_detector_angle_mixing import detector_angle_mixing as fastdam
from coredldev.preprocessing.raw_postmerger.distance_scale import distance_scale
from coredldev.preprocessing.raw_postmerger.time_shift import time_shift
from coredldev.preprocessing.to_tensor import to_tensor_clean
from coredldev.preprocessing.ligo_noise.inject_noise import NoiseInjection1D as noise_injection_1d
from coredldev.preprocessing.wavelet_transforms.morlet import MorletWaveletTransform
from coredldev.preprocessing.whiten import TimeSeriesWhitener
from pycbc.types import TimeSeries

# Import PyCBC libraries
import pycbc.noise
import pycbc.psd
import pycbc.filter

import scipy
from scipy.signal import welch

from copy import deepcopy as dp

# Load frequency values for plotting
freqs = np.genfromtxt("freqs.npy")


In [None]:
# datapoints, eosmap, remaining = h5Finder().get_datapoints()
datapoints, eosmap, remaining = h5Finder(shiftpercents=[0],angles=[(0,0,0)],distances = [20]).get_datapoints()
source = h5Source(eos_to_index_map=eosmap)
dataset = CoReDataset(source, datapoints,lambda x: x)
transformed_dataset = CoReDataset(source, datapoints,pipeline([DetectorAngleMixing(),distance_scale(),time_shift(),noise_injection_1d(),TimeSeriesWhitener(4,2),to_tensor_clean()]))
morl = MorletWaveletTransform(freqs=np.loadtxt("freqs.npy"))

In [None]:
fdam = DetectorAngleMixing()
ds = distance_scale()
ts = time_shift()
ni = noise_injection_1d(psd_file = "CE40-asd.txt")
ttc = to_tensor_clean()
morl = MorletWaveletTransform(freqs=np.loadtxt("freqs.npy"))
# pycbcwhitener = whiten_signal(30)
tswhitener = TimeSeriesWhitener(4,2)

In [None]:
data = dataset[400]
print(f"0 {data["hplus"].shape = }")
data1 = fdam(dp(data))
print(f"1 {data1["signal"].shape = }")
data2 = ds(dp(data1))
print(f"2 {data2["signal"].shape = }")
data3 = ts(dp(data2))
print(f"3 {data3["signal"].shape = }")
data4 = ni(dp(data3))
print(f"4 {data4["signal"].shape = }")
data5 = tswhitener(dp(data4))
print(f"5 {data5["signal"].shape = }")
print()

In [None]:
signal = data4["signal"] 
sam_p = data4["params"]["sam_p"]
frequencies, asd = welch(signal, fs=1/sam_p, nperseg=20000)
print(freqs.shape, asd.shape)
fft = np.fft.rfft(signal)
fft_freqs = np.fft.rfftfreq(len(signal), d=sam_p)
print(fft.shape, fft_freqs.shape)
asd_interpolated = np.interp(fft_freqs, frequencies, asd)
print(asd_interpolated.shape)
whitened_freqs = fft / np.sqrt(asd_interpolated) /(2*np.pi*(20_000)/sam_p+1)
print(whitened_freqs.shape)
whitened_ts = np.fft.irfft(whitened_freqs)
print(whitened_ts.shape)
plt.plot(whitened_ts)
plt.show()

In [None]:
plt.figure(figsize=(8, 6))
plt.plot(signal)
plt.show()
analyze_signal_frequencies(signal, sam_p, " ", True)
plt.figure(figsize=(8, 6))
plt.loglog(frequencies, np.sqrt(asd))
plt.title("Amplitude Spectral Density (ASD)")
plt.xlabel("Frequency (Hz)")
plt.ylabel("ASD (strain/√Hz)")
plt.grid(True, which="both", ls="--", alpha=0.5)
plt.show()# Compute the ASD using Welch's method
plt.loglog(frequencies, whitened_freqs)
plt.show()

In [None]:
def analyze_signal_frequencies(signal, sampling_period,name,plot = False):
    """
    Perform FFT analysis on a signal and extract frequency information
    
    Parameters:
    -----------
    signal : array-like
        Time domain signal to analyze
    sampling_period : float
        The sampling period in seconds
        
    Returns:
    --------
    tuple
        (min_nonzero_freq, max_freq, freq_bins, fft_magnitude)
        - min_nonzero_freq: minimum nonzero frequency with significant magnitude
        - max_freq: maximum frequency with significant magnitude
        - freq_bins: array of frequency bins
        - fft_magnitude: array of magnitude values
    """
    # Calculate sampling frequency
    sampling_frequency = 1.0 / sampling_period
    
    # Calculate FFT
    fft_result = np.fft.rfft(signal)
    fft_magnitude = np.abs(fft_result)
    
    # Calculate frequency bins
    n_samples = len(signal)
    freq_bins = np.fft.rfftfreq(n_samples, d=sampling_period)
    
    # Find significant peaks (frequencies)
    # Using a simple threshold approach - you might want to adjust this
    threshold = 0.1 * np.max(fft_magnitude)
    peak_indices = np.where(fft_magnitude > threshold)[0]
    peak_freqs = freq_bins[peak_indices]
    peak_magnitudes = fft_magnitude[peak_indices]
    
    # Sort peaks by magnitude
    sort_indices = np.argsort(peak_magnitudes)[::-1]
    sorted_freqs = peak_freqs[sort_indices]
    sorted_magnitudes = peak_magnitudes[sort_indices]
    
    # Find the maximum frequency (using a small threshold to avoid noise)
    noise_threshold = 1e-10 * np.max(fft_magnitude)
    nonzero_indices = np.where(fft_magnitude > noise_threshold)[0]
    
    # Find the maximum frequency
    if len(nonzero_indices) > 0:
        max_freq_index = nonzero_indices[-1]
        max_freq = freq_bins[max_freq_index]
    else:
        max_freq = 0.0
    
    # Find the minimum non-zero frequency (excluding DC component at index 0)
    nonzero_indices = nonzero_indices[nonzero_indices > 0] if len(nonzero_indices) > 0 else []
    if len(nonzero_indices) > 0:
        min_nonzero_freq_index = nonzero_indices[0]
        min_nonzero_freq = freq_bins[min_nonzero_freq_index]
    else:
        min_nonzero_freq = 0.0
    
    # Plot the results
    if plot:
        plt.figure(figsize=(8, 6))
        
        # Plot 1: Full spectrum
        plt.subplot(211)
        plt.loglog(freq_bins, fft_magnitude)
        plt.grid(True, alpha=0.3)
        plt.title(f'Frequency Spectrum: {name}')
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Magnitude')
        
        # Add vertical lines for min and max frequencies
        plt.axvline(x=min_nonzero_freq, color='g', linestyle='--', 
                    label=f'Min non-zero freq: {min_nonzero_freq:.2f} Hz')
        plt.axvline(x=max_freq, color='r', linestyle='--',
                    label=f'Max freq: {max_freq:.2f} Hz')
        plt.legend()
        
        # Plot 2: Log scale view for better visibility of smaller components
        plt.subplot(212)
        plt.loglog(freq_bins, fft_magnitude)
        plt.grid(True, alpha=0.3)
        plt.title('Frequency Spectrum (log scale)')
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Magnitude (log scale)')
        
        # Add vertical lines on log plot too
        plt.axvline(x=min_nonzero_freq, color='g', linestyle='--', 
                    label=f'Min non-zero freq: {min_nonzero_freq:.2f} Hz')
        plt.axvline(x=max_freq, color='r', linestyle='--',
                    label=f'Max freq: {max_freq:.2f} Hz')
        plt.legend()
        
        plt.tight_layout()
    
    # Print the min and max frequencies
    print(f"Signal name: {name}")
    print(f"Minimum non-zero frequency: {min_nonzero_freq:.2f} Hz")
    print(f"Maximum frequency: {max_freq:.2f} Hz")
    print()
    
    # Return min and max frequencies along with the frequency data
    return min_nonzero_freq, max_freq, freq_bins, fft_magnitude

In [None]:
def post_whitening_frequency_analysis():
    mins = []
    maxs = []
    bin_list = []
    for n, (data,metadata) in enumerate(transformed_dataset):
        minimum, maximum, bins, _ = analyze_signal_frequencies(data.numpy(),dataset[n]["params"]["sam_p"],f"{n} | {metadata}")
        mins.append(minimum)
        maxs.append(maximum)
        bin_list.append(bins)
        plt.show()
    print(min(mins),max(maxs))
    flattened_data = np.concatenate(bin_list)  # Replace `maxs` with your list of numpy arrays if different
    plt.hist(mins,bins=100)
    plt.title("Minimum Frequencies")
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Counts")
    plt.grid()
    plt.show()
    plt.hist(maxs,bins=100)
    plt.title("Maximum Frequencies")
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Counts")
    plt.grid()
    plt.show()
    # Flatten the list of numpy arrays

    # Plot a histogram of the flattened data
    plt.hist(flattened_data, bins=1000, alpha=0.7, color='blue')
    plt.hist(fspace, bins=1000, alpha=0.7, color='red')
    plt.title("Histogram of Flattened Data")
    plt.xlabel("Value")
    plt.ylabel("Frequency")
    plt.grid(alpha=0.3)
    plt.show()
post_whitening_frequency_analysis()

In [None]:
# Collect timing statistics for each step
def benchmark():
    fdam = DetectorAngleMixing()
    ds = distance_scale()
    ts = time_shift()
    ni = noise_injection_1d()
    ttc = to_tensor_clean()

    # Initialize dictionaries to store timing data
    timing_data = {
        "Initial data loading": [],
        "Detector angle mixing": [],
        "Distance scaling": [],
        "Time shifting": [],
        "Noise injection": [],
        "To tensor clean": [],
        "Total": []
    }

    # Number of samples to process
    num_samples = 10000

    # Collect timing data
    for i in range(num_samples):
        # Initial data loading
        start_time = time.time()
        data = dataset[i]
        data_loading_time = time.time() - start_time
        timing_data["Initial data loading"].append(data_loading_time)
        
        # Detector angle mixing
        start_time = time.time()
        data1 = fdam(data)
        dam_time = time.time() - start_time
        timing_data["Detector angle mixing"].append(dam_time)
        
        # Distance scaling
        start_time = time.time()
        data2 = ds(data1)
        ds_time = time.time() - start_time
        timing_data["Distance scaling"].append(ds_time)
        
        # Time shifting
        start_time = time.time()
        data3 = ts(data2)
        ts_time = time.time() - start_time
        timing_data["Time shifting"].append(ts_time)
        
        # Noise injection
        start_time = time.time()
        data4 = ni(data3)
        ni_time = time.time() - start_time
        timing_data["Noise injection"].append(ni_time)
        
        # To tensor clean
        start_time = time.time()
        data5 = ttc(data4)
        ttc_time = time.time() - start_time
        timing_data["To tensor clean"].append(ttc_time)
        
        # Calculate total time
        total_time = data_loading_time + dam_time + ds_time + ts_time + ni_time + ttc_time
        timing_data["Total"].append(total_time)

    # Calculate statistics
    stats = {}
    for step, times in timing_data.items():
        mean_time = np.mean(times)
        std_time = np.std(times)
        stats[step] = {"mean": mean_time, "std": std_time}

    # Display results
    print("\nTiming Statistics (in seconds):")
    print("-" * 50)
    print(f"{'Step':<25} {'Mean':<15} {'Std Dev':<15}")
    print("-" * 50)
    for step, stat in stats.items():
        print(f"{step:<25} {stat['mean']:<15.6f} {stat['std']:<15.6f}")

    # Create a bar chart for visualization
    plt.figure(figsize=(12, 6))

    steps = list(stats.keys())
    means = [stats[step]['mean'] for step in steps]
    stds = [stats[step]['std'] for step in steps]

    # Plot bar chart with error bars
    bars = plt.bar(steps, means, yerr=stds, capsize=10)

    # Add labels and title
    plt.xlabel('Processing Step')
    plt.ylabel('Time (seconds)')
    plt.title('Mean Processing Time per Step with Standard Deviation')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()

    # Show percentage contribution to total time
    total_mean = stats['Total']['mean']
    print("\nPercentage Contribution to Total Processing Time:")
    print("-" * 50)
    for step in steps[:-1]:  # Exclude the 'Total' step
        percentage = (stats[step]['mean'] / total_mean) * 100
        print(f"{step:<25} {percentage:>6.2f}%")

    # Create a pie chart showing percentage contribution
    plt.figure(figsize=(10, 8))
    labels = steps[:-1]  # Exclude the 'Total' step
    sizes = [stats[step]['mean'] for step in labels]
    percentages = [(size / total_mean) * 100 for size in sizes]

    plt.pie(sizes, labels=[f"{label}\n({perc:.1f}%)" for label, perc in zip(labels, percentages)], 
            autopct='%1.1f%%', startangle=140, shadow=True)
    plt.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
    plt.title('Percentage Contribution of Each Step to Total Processing Time')
    plt.show()
#benchmark()

In [None]:
def plot_steps():
    # Create a proper symlog plot for signals from data1 to data5
    plt.figure(figsize=(12, 8))

    # Plot each signal with a different color and label
    plt.plot(np.arange(len(data1["signal"])), data1["signal"], label="data1 - After Detector Angle Mixing")
    plt.plot(np.arange(len(data2["signal"])), data2["signal"], label="data2 - After Distance Scaling")
    plt.plot(np.arange(len(data3["signal"])), data3["signal"], label="data3 - After Time Shifting")
    plt.plot(np.arange(len(data4["signal"])), data4["signal"], label="data4 - After Noise Injection")
    plt.plot(np.arange(len(data5["signal"])), data5["signal"], label="data5 - After Whitening")

    # Set symlog scale for y-axis with appropriate parameters
    plt.yscale('symlog', linthresh=1e-25)  # Set linear threshold appropriate for the data

    # Add grid, legend, and labels
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.legend(loc="best")
    plt.title("Signal Comparison Across Processing Steps")
    plt.xlabel("Sample Index")
    plt.ylabel("Signal Amplitude (symlog scale)")

    # Add a zoomed inset for data5 which has much larger values
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes
    axins = inset_axes(plt.gca(), width="40%", height="30%", loc="upper right")
    axins.plot(np.arange(len(data5["signal"])), data5["signal"], 'r-')
    axins.set_title("Zoomed view of whitened signal")
    axins.grid(True, alpha=0.2)

    plt.tight_layout()
    plt.show()

    # Create a second figure to compare the signals on separate subplots
    fig, axes = plt.subplots(6, 1, figsize=(12, 15), sharex=True)

    # Plot each signal on its own subplot
    axes[0].plot(np.arange(len(data1["signal"])), data1["signal"], 'b-')
    axes[0].set_title("After Detector Angle Mixing")
    # axes[0].set_yscale('symlog', linthresh=1e-25)
    axes[0].grid(True, which="both", ls="-", alpha=0.2)

    axes[1].plot(np.arange(len(data2["signal"])), data2["signal"], 'g-')
    axes[1].set_title("After Distance Scaling")
    # axes[1].set_yscale('symlog', linthresh=1e-25)
    axes[1].grid(True, which="both", ls="-", alpha=0.2)

    axes[2].plot(np.arange(len(data3["signal"])), data3["signal"], 'r-')
    axes[2].set_title("After Time Shifting")
    # axes[2].set_yscale('symlog', linthresh=1e-25)
    axes[2].grid(True, which="both", ls="-", alpha=0.2)

    axes[3].plot(np.arange(len(data4["signal"])), data4["signal"], 'c-')
    axes[3].set_title("After Noise Injection")
    # axes[3].set_yscale('symlog', linthresh=1e-25)
    axes[3].grid(True, which="both", ls="-", alpha=0.2)

    axes[4].plot(np.arange(len(data5["signal"])), data5["signal"], 'm-')
    axes[4].set_title("After Whitening")
    axes[4].grid(True, which="both", ls="-", alpha=0.2)

    axes[5].plot(np.arange(len(data5["signal"])), data5["signal"], 'm-')
    axes[5].set_title("After Whitening (Symlog)")
    axes[4].set_yscale('symlog')
    axes[5].grid(True, which="both", ls="-", alpha=0.2)

    # Add common labels
    fig.add_subplot(111, frameon=False)
    plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
    plt.xlabel("Sample Index")
    plt.ylabel("Signal Amplitude")
    plt.title("Signal Processing Steps Comparison", fontsize=16, pad=20)

    plt.tight_layout()
    plt.show()
plot_steps()