In [1]:
import numpy as np
import cupy as cp
import cupyx.scipy.signal as signal
import scipy
import pandas as pd

# import griddata
from scipy.interpolate import griddata
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import matplotlib.cm as cm

from astropy.modeling import models, fitting
import astropy.units as u
from astropy import constants as const
from astropy.stats import gaussian_sigma_to_fwhm,gaussian_fwhm_to_sigma

import datetime
import pickle
import importlib
import json
import time
import glob

import do_wavelet
# import the module to reimport the module
import importlib
from interpolate_to_uniform_grid import interpolate_to_uniform_grid

# the wavelet code always give a warning about matplotlib tight_layout




with open('data/selected_band.pkl','rb') as f:
    selected_band = pickle.load(f)
selected_band=selected_band[selected_band['Identified']==True]
selected_band = selected_band[selected_band['Channel']=='MEGS-A']
selected_band

wavelength_full = np.load('data/wavelength_full.npz')['wavelength_full']


In [2]:
with open('data/wavelet_df/wavelet_amplitude_df.pkl','rb') as f:
    wavelet_amplitude_df = pickle.load(f)

with open('data/wavelet_df/wavelet_mean_df.pkl','rb') as f:
    wavelet_mean_df = pickle.load(f)

with open('data/wavelet_df/wavelet_stddev_df.pkl','rb') as f:
    wavelet_stddev_df = pickle.load(f)

## by_wavelength

In [3]:
# https://plotly.com/python/colorscales/



# Define a color scale: suggested: Viridis, inferno, magma, plasma, cividis,spectral
color_scale = 'thermal'

df=wavelet_amplitude_df.copy()


# Create figure
fig = go.Figure()

# Add scatter plot
fig.add_trace(
    go.Scatter(
        x=df['amplitude periods'],
        y=df['wavelength'],
        mode='markers',
        marker=dict(
            color=df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T) of the spectrum'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=df['line name'],  # Hover text shows 'size'
                )
            )

# Update layout with title
fig.update_layout(title_text="Amplitude Wavelet Peaks",title_x=0.5)
fig.update_xaxes(title_text="Amplitude Periods (d)",range=[0,300])
fig.update_yaxes(title_text="Wavelength (Angstrom)")


# Show figure
fig.show()
fig.write_html('output\wavelet_explore\peak_scatter\\by_wavelength\\amplitude.html')


In [4]:

# Define a color scale: suggested: Viridis, inferno, magma, plasma, cividis,spectral
color_scale = 'thermal'

df=wavelet_mean_df.copy()


# Create figure
fig = go.Figure()

# Add scatter plot
fig.add_trace(
    go.Scatter(
        x=df['mean periods'],
        y=df['wavelength'],
        mode='markers',
        marker=dict(
            color=df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T) of the spectrum'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=df['line name'],  # Hover text shows 'size'
                )
            )

# Update layout with title
fig.update_layout(title_text="Mean Wavelet Peaks",title_x=0.5)
fig.update_xaxes(title_text="Mean Periods (d)",range=[0,300])
fig.update_yaxes(title_text="Wavelength (Angstrom)")


# Show figure
fig.show()
fig.write_html('output\wavelet_explore\peak_scatter\\by_wavelength\\mean.html')


In [5]:

# Define a color scale: suggested: Viridis, inferno, magma, plasma, cividis,spectral
color_scale = 'thermal'

df=wavelet_stddev_df.copy()


# Create figure
fig = go.Figure()

# Add scatter plot
fig.add_trace(
    go.Scatter(
        x=df['stddev periods'],
        y=df['wavelength'],
        mode='markers',
        marker=dict(
            color=df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T) of the spectrum'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=df['line name'],  # Hover text shows 'size'
                )
            )

# Update layout with title
fig.update_layout(title_text="Stddev Wavelet Peaks",title_x=0.5)
fig.update_xaxes(title_text="Stddev Periods (d)",range=[0,300])
fig.update_yaxes(title_text="Wavelength (Angstrom)")


# Show figure
fig.show()
fig.write_html('output\wavelet_explore\peak_scatter\\by_wavelength\\stddev.html')


## by_temperature

In [6]:
# https://plotly.com/python/colorscales/



# Define a color scale: suggested: Viridis, inferno, magma, plasma, cividis,spectral
color_scale = 'thermal'

df=wavelet_amplitude_df.copy()


# Create figure
fig = go.Figure()

# Add scatter plot
fig.add_trace(
    go.Scatter(
        x=df['amplitude periods'],
        y=df['temperature'],
        mode='markers',
        marker=dict(
            color=df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T) of the spectrum'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=df['line name'],  # Hover text shows 'size'
                )
            )

# Update layout with title
fig.update_layout(title_text="Amplitude Wavelet Peaks",title_x=0.5)
fig.update_xaxes(title_text="Amplitude Periods (d)",range=[0,300])
fig.update_yaxes(title_text="log(T)")


# Show figure
fig.show()
fig.write_html('output\wavelet_explore\peak_scatter\\by_temperature\\amplitude.html')


In [7]:
# https://plotly.com/python/colorscales/



# Define a color scale: suggested: Viridis, inferno, magma, plasma, cividis,spectral
color_scale = 'thermal'

df=wavelet_mean_df.copy()


# Create figure
fig = go.Figure()

# Add scatter plot
fig.add_trace(
    go.Scatter(
        x=df['mean periods'],
        y=df['temperature'],
        mode='markers',
        marker=dict(
            color=df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T) of the spectrum'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=df['line name'],  # Hover text shows 'size'
                )
            )

# Update layout with title
fig.update_layout(title_text="Mean Wavelet Peaks",title_x=0.5)
fig.update_xaxes(title_text="Mean Periods (d)",range=[0,300])
fig.update_yaxes(title_text="log(T)")


# Show figure
fig.show()
fig.write_html('output\wavelet_explore\peak_scatter\\by_temperature\\mean.html')


In [8]:
# https://plotly.com/python/colorscales/



# Define a color scale: suggested: Viridis, inferno, magma, plasma, cividis,spectral
color_scale = 'thermal'

df=wavelet_stddev_df.copy()


# Create figure
fig = go.Figure()

# Add scatter plot
fig.add_trace(
    go.Scatter(
        x=df['stddev periods'],
        y=df['temperature'],
        mode='markers',
        marker=dict(
            color=df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T) of the spectrum'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=df['line name'],  # Hover text shows 'size'
                )
            )

# Update layout with title
fig.update_layout(title_text="Stddev Wavelet Peaks",title_x=0.5)
fig.update_xaxes(title_text="Stddev Periods (d)",range=[0,300])
fig.update_yaxes(title_text="log(T)")


# Show figure
fig.show()
fig.write_html('output\wavelet_explore\peak_scatter\\by_temperature\\stddev.html')


## combined

In [9]:
with open('data/wavelet_df/wavelet_amplitude_df.pkl','rb') as f:
    wavelet_amplitude_df = pickle.load(f)

with open('data/wavelet_df/wavelet_mean_df.pkl','rb') as f:
    wavelet_mean_df = pickle.load(f)

with open('data/wavelet_df/wavelet_stddev_df.pkl','rb') as f:
    wavelet_stddev_df = pickle.load(f)

In [10]:
# set the whole figure to be large
fig = make_subplots(rows=3, cols=2)
color_scale = 'thermal'




# Add scatter plot   Row 1
fig.add_trace(
    go.Scatter(
        x=wavelet_amplitude_df['amplitude periods'],
        y=wavelet_amplitude_df['wavelength'],
        mode='markers',
        marker=dict(
            color=wavelet_amplitude_df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T)'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=wavelet_amplitude_df['line name'],  
        name='Amplitude',
                ),
    row=1,
    col=1,
    )
fig.update_xaxes(title_text="Amplitude Periods (d)", row=1, col=1)
fig.update_yaxes(title_text="Wavelength (Angstrom)", row=1, col=1)


fig.add_trace(
    go.Scatter(
        x=wavelet_amplitude_df['amplitude periods'],
        y=wavelet_amplitude_df['temperature'],
        mode='markers',
        marker=dict(
            color=wavelet_amplitude_df['temperature'],  # Use normalized 'size' for color
            # colorscale=color_scale,  # Define color scale
            # colorbar=dict(title='Temperature'),  # Show color scale bar
            showscale=False,
            
            ),
        showlegend=False,
        name='Amplitude',
        text=wavelet_amplitude_df['line name'], 
        # f['temperature'],  # Hover text shows 'size'
    ),
    row=1, 
    col=2,
    )
fig.update_xaxes(title_text="Amplitude Periods (d)", row=1, col=2)
fig.update_yaxes(title_text="log(T)", row=1, col=2)
fig.update_layout(
    autosize=False,
    width=1200,
    height=1000,
)


# Add scatter plot    Row   2
fig.add_trace(
    go.Scatter(
        x=wavelet_mean_df['mean periods'],
        y=wavelet_mean_df['wavelength'],
        mode='markers',
        marker=dict(
            color=wavelet_mean_df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T)'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=wavelet_mean_df['line name'],  # Hover text shows 'size'
        name='Mean',
                ),
    row=2,
    col=1,
    )
fig.update_xaxes(title_text="mean Periods (d)", row=2, col=1)
fig.update_yaxes(title_text="Wavelength (Angstrom)", row=2, col=1)

fig.add_trace(
    go.Scatter(
        x=wavelet_mean_df['mean periods'],
        y=wavelet_mean_df['temperature'],
        mode='markers',
        marker=dict(
            color=wavelet_mean_df['temperature'],  # Use normalized 'size' for color
            # colorscale=color_scale,  # Define color scale
            # colorbar=dict(title='Temperature'),  # Show color scale bar
            showscale=False,
            
            ),
        showlegend=False,
        text=wavelet_mean_df['line name'], 
        name='Mean',
    ),
    row=2, 
    col=2,
    )
fig.update_xaxes(title_text="mean Periods (d)", row=2, col=2)
fig.update_yaxes(title_text="log(T)", row=2, col=2)



# Add scatter plot  Row 3
fig.add_trace(
    go.Scatter(
        x=wavelet_stddev_df['stddev periods'],
        y=wavelet_stddev_df['wavelength'],
        mode='markers',
        marker=dict(
            color=wavelet_stddev_df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T)'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=wavelet_stddev_df['line name'],  # Hover text shows 'size'
        name='Stddev',
                ),
    row=3, 
    col=1,
    )
fig.update_xaxes(title_text="stddev Periods (d)", row=3, col=1)
fig.update_yaxes(title_text="Wavelength (Angstrom)", row=3, col=1)

fig.add_trace(
    go.Scatter(
        x=wavelet_stddev_df['stddev periods'],
        y=wavelet_stddev_df['temperature'],
        mode='markers',
        marker=dict(
            color=wavelet_stddev_df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T)'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=wavelet_stddev_df['line name'],  # Hover text shows 'size'
        name='Stddev',
                ),
    row=3, 
    col=2
    )
fig.update_xaxes(title_text="stddev Periods (d)", row=3, col=2)
fig.update_yaxes(title_text="log(T)", row=3, col=2)

fig.update_layout(
    title_text="Wavelet Peak",
    title_x=0.5,
    # autosize=False,
    width=900,
    height=900,
)



# Update to match x-axes and optionally y-axes if needed
for row in range(1, 4):  # Rows 1 through 3
    for col in range(1, 3):  # Columns 1 and 2
        fig.update_xaxes(matches='x', row=row, col=col)  # Match all x-axes to the first one




# set x axis lim
for row in [1,2,3]:
    for col in [1,2]:
        fig.update_xaxes(range=[0, 300], row=row, col=col)


fig.write_html('output\wavelet_explore\peak_scatter\combine\combined_wavelet.html')
fig.show()

In [11]:
# set the whole figure to be large
fig = make_subplots(rows=3, cols=1)
color_scale = 'thermal'




# Add scatter plot   Row 1
fig.add_trace(
    go.Scatter(
        x=wavelet_amplitude_df['amplitude periods'],
        y=wavelet_amplitude_df['wavelength'],
        mode='markers',
        marker=dict(
            color=wavelet_amplitude_df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T)'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=wavelet_amplitude_df['line name'],  
        name='Amplitude',
                ),
    row=1,
    col=1,
    )
fig.update_xaxes(title_text="Amplitude Periods (d)", row=1, col=1)
fig.update_yaxes(title_text="Wavelength (Angstrom)", row=1, col=1)




# Add scatter plot    Row   2
fig.add_trace(
    go.Scatter(
        x=wavelet_mean_df['mean periods'],
        y=wavelet_mean_df['wavelength'],
        mode='markers',
        marker=dict(
            color=wavelet_mean_df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T)'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=wavelet_mean_df['line name'],  # Hover text shows 'size'
        name='Mean',
                ),
    row=2,
    col=1,
    )
fig.update_xaxes(title_text="mean Periods (d)", row=2, col=1)
fig.update_yaxes(title_text="Wavelength (Angstrom)", row=2, col=1)

# Add scatter plot  Row 3
fig.add_trace(
    go.Scatter(
        x=wavelet_stddev_df['stddev periods'],
        y=wavelet_stddev_df['wavelength'],
        mode='markers',
        marker=dict(
            color=wavelet_stddev_df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T)'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=wavelet_stddev_df['line name'],  # Hover text shows 'size'
        name='Stddev',
                ),
    row=3, 
    col=1,
    )
fig.update_xaxes(title_text="stddev Periods (d)", row=3, col=1)
fig.update_yaxes(title_text="Wavelength (Angstrom)", row=3, col=1)





col=1
# Update to match x-axes and optionally y-axes if needed
for row in range(1, 4):  # Rows 1 through 3
    fig.update_xaxes(matches='x', row=row, col=col)  # Match all x-axes to the first one
    fig.update_yaxes(matches='y', row=row, col=col)
        

for row in [1,2,3]:
    for col in [1,2]:
        fig.update_xaxes(range=[0, 300], row=row, col=col)
        fig.update_yaxes(range=[130, 380], row=row, col=col)

fig.update_layout(
    autosize=False,
    width=800,
    height=1000,
    title_text="Wavelet Peak",
    title_x=0.5,
)

fig.write_html('output\wavelet_explore\peak_scatter\combine\wavelength.html')
fig.show()

In [12]:
# set the whole figure to be large
fig = make_subplots(rows=3, cols=1)
color_scale = 'thermal'




# Add scatter plot   Row 1
fig.add_trace(
    go.Scatter(
        x=wavelet_amplitude_df['amplitude periods'],
        y=wavelet_amplitude_df['temperature'],
        mode='markers',
        marker=dict(
            color=wavelet_amplitude_df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T)'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=wavelet_amplitude_df['line name'],  
        name='Amplitude',
                ),
    row=1,
    col=1,
    )
fig.update_xaxes(title_text="Amplitude Periods (d)", row=1, col=1)
fig.update_yaxes(title_text="log(T)", row=1, col=1)




# Add scatter plot    Row   2
fig.add_trace(
    go.Scatter(
        x=wavelet_mean_df['mean periods'],
        y=wavelet_mean_df['temperature'],
        mode='markers',
        marker=dict(
            color=wavelet_mean_df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T)'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=wavelet_mean_df['line name'],  # Hover text shows 'size'
        name='Mean',
                ),
    row=2,
    col=1,
    )
fig.update_xaxes(title_text="mean Periods (d)", row=2, col=1)
fig.update_yaxes(title_text="log(T)", row=2, col=1)

# Add scatter plot  Row 3
fig.add_trace(
    go.Scatter(
        x=wavelet_stddev_df['stddev periods'],
        y=wavelet_stddev_df['temperature'],
        mode='markers',
        marker=dict(
            color=wavelet_stddev_df['temperature'],  # Use normalized 'size' for color
            colorscale=color_scale,  # Define color scale
            colorbar=dict(title=r'log(T)'),  # Show color scale bar
            showscale=True,# show the color bar (color scale)
        ),
        showlegend=False, # otherwise a legend will appear on the right,
        text=wavelet_stddev_df['line name'],  # Hover text shows 'size'
        name='Stddev',
                ),
    row=3, 
    col=1,
    )
fig.update_xaxes(title_text="stddev Periods (d)", row=3, col=1)
fig.update_yaxes(title_text="log(T)", row=3, col=1)





col=1
# Update to match x-axes and optionally y-axes if needed
for row in range(1, 4):  # Rows 1 through 3
    fig.update_xaxes(matches='x', row=row, col=col)  # Match all x-axes to the first one
    fig.update_yaxes(matches='y', row=row, col=col)
        

for row in [1,2,3]:
        fig.update_xaxes(range=[0, 300], row=row, col=col)
        fig.update_yaxes(range=[4.6, 7.3], row=row, col=col)

fig.update_layout(
    autosize=False,
    width=800,
    height=1000,
    title_text="Wavelet Peak",
    title_x=0.5,
)

fig.write_html('output\wavelet_explore\peak_scatter\combine\\temperature.html')
fig.show()