In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pywt
from matplotlib import cm
from watpy.coredb.coredb import *
import plotly.express as px
import plotly.graph_objects as go
cdb = CoRe_db(r".\CoRe_DB")
from CoRe_Dataloader_ECSG import get_dataset
ds = get_dataset()
scale_min = 1
scale_max = 201
dscale = 0.1
from scipy.signal import argrelextrema
import math

In [None]:
def planck_window(j:int,N:int):
    window = np.linspace(0,j-1,j-1)
    window[0] = 1
    window =  1./(1. + np.exp(j/window - j/(j-window)))
    window[0] = 0
    window = np.concatenate((window,np.ones((N-(j*2))),np.flip(window,)))
    return window

def cut_at_lowest_envelope(hplus, hcross):
    # Cutting inspiral off
    oenv = np.sqrt(hplus**2 + hcross**2)
    cut_point = np.argmax(hplus)
    mhplus = hplus[cut_point:]
    env = oenv[cut_point:]
    envcut = argrelextrema(env,np.less)
    if len(envcut[0])==0:
        return mhplus
    return mhplus[envcut[0][0]:]
def wt(postmerger,sam_p):
    sam_f = 1/sam_p    
    scales = np.arange(scale_min, scale_max, dscale)

    #CWT on the gwf using the Morlet wavelet
    coefs, freqs = pywt.cwt(postmerger, scales, 'morl', sampling_period = sam_p)

    #Normalising the coefficient matrix using the Frobenius norm
    Z = (np.abs(coefs))/(np.linalg.norm(coefs))
    Z = Z[::5,::45][:,:400]
    return Z

def pad_width(Z,l = 400 ):
    cwidth = Z.shape[1]
    padb = np.zeros((l,int((l-cwidth)/2)))
    pada = np.zeros((l,int((l-cwidth)/2)))
    Z = np.concatenate((padb,Z,pada),axis = 1)
    cwidth = Z.shape[1]
    fpad = np.zeros((l,int((l-cwidth))))
    return np.concatenate((Z,fpad),axis = 1)
    
get_rampup = lambda ts: math.floor(math.log(len(ts)/2)*6)

In [None]:
time_con_f = 4.975e-6
gwf = cdb.sim["THC:0087"].run["R01"].data.read("rh_22")
strain = gwf[:,1]
sustrain = gwf[:,2]
env = gwf[:,6]     
pm_time = gwf[:,8]*time_con_f                         #converting to milliseconds
sam_p = (pm_time[-1] - pm_time[0])/len(pm_time)
o = cut_at_lowest_envelope(strain,sustrain)
o1 = wt(o,sam_p)
o2 = pad_width(o1)
print(sam_p,1/sam_p)
pomtime = pm_time[-o.shape[0]:]


In [None]:
rampup = get_rampup(o)
pw = planck_window(rampup,o.shape[0]+2)


In [None]:
o.shape, pw.shape[0]+2

In [None]:
o1dot5 = np.multiply(o,pw)

In [None]:
from plotly import io as io

In [None]:
from scipy import signal
M = 201
s = 4.0
w = .55
wavelet = signal.morlet(M, s, w)
# plt.plot(wavelet)
# plt.show()
wavelet = np.array(wavelet).real


In [None]:
layout = go.Layout(
    paper_bgcolor='rgba(30,31,36,255)',
    plot_bgcolor='rgba(30,31,36,255)',
    font=dict(
        family='Overpass, monospace',
        size=16,
        color="rgb(241, 194, 50)"
    ),
    legend=dict(orientation="h")
)
fig = go.Figure(data=[
                    # go.Scatter(x=pm_time,y=strain,line=dict(color="#ffcc52"),name="Complete Strain"),
                    # go.Scatter(x=pomtime,y=o,name="Postmerger Strain"),
                    # go.Scatter(x=pomtime,y=o1dot5,line=dict(color="#cc8841"),name = "Postmerger Strain After Plank Window")
                    # go.Scatter(x=pomtime,y=pw,line=dict(color="#cc8841"),name = "Postmerger Strain After Plank Window")
                    go.Scatter(y=wavelet)
                      ],layout=layout)
fig.update_layout(
autosize=False,
width=400,
height=400,)
fig.update_layout(title='Morlet Wavelet',
                   xaxis_title='Scale',
                   yaxis_title='Multiplier')
fig.update_layout(legend=dict(yanchor="middle", y=1.1, xanchor="center", x=0.5))
io.write_image(fig,"./plots/wides/morlet.png",format="png")

fig.show()

In [80]:
layout = go.Layout(
    paper_bgcolor='rgba(30,31,36,255)',
    plot_bgcolor='rgba(30,31,36,255)',
    font=dict(
        family='Overpass, monospace',
        size=32,
        color="rgb(241, 194, 50)"
    ),
    legend=dict(orientation="h")
)
fig = go.Figure(data=[
    # go.Scatter(x=pm_time,y=strain,line=dict(color="#ffcc52"),name="Complete Strain"),
    # go.Scatter(x=pomtime,y=o,name="Postmerger Strain"),
    # go.Scatter(x=pomtime,y=o1dot5,line=dict(color="#cc8841"),name = "Postmerger Strain After Plank Window")
    # go.Scatter(x=pomtime,y=pw,line=dict(color="#cc8841"),name = "Postmerger Strain After Plank Window")
    # go.Scatter(y=wavelet)
    go.Heatmap(z = o2dot5,colorscale="cividis")
], layout=layout)
fig.update_layout(
    autosize=False,
    width=1300,
    height=600,)
fig.update_layout(title='Morlet Wavelet Transform Output',
                  xaxis_title='Time',
                  yaxis_title='Frequency')
fig.update_layout(legend=dict(yanchor="middle",
                  y=1.1, xanchor="center", x=0.5))
io.write_image(fig, "./plots/wides/sg.png", format="png")

fig.show()


In [None]:
from CoRe_Dataloader_ECSG import load_raw_from_pth_file,dataset
a,b = load_raw_from_pth_file()

In [None]:
names,counts = dataset.ueoss,dataset.ueosscounts

In [None]:
eos_list = b[:,0]
eos_list = eos_list.cpu().numpy()

In [None]:
named_list = []
for i in eos_list:
    named_list.append(names[int(i)])
    

In [None]:
named_list


In [None]:
layout = go.Layout(
    paper_bgcolor='rgba(30,31,36,255)',
    plot_bgcolor='rgba(30,31,36,255)',
    font=dict(
        family='Overpass, monospace',
        size=32,
        color="rgb(241, 194, 50)"
    ),
    legend=dict(orientation="h")
)
fig = go.Figure(data=[
    go.Histogram(x=named_list, histnorm='probability',ybins=dict(start=1,end=19,size=1))
], layout=layout)
fig.update_layout(
    autosize=False,
    width=1800,
    height=1000, title='EOS Distribution in the CoRe Dataset',
    xaxis_title='EOS Number (19 Total)',


    yaxis_title='Proportion of Total', legend=dict(yanchor="middle", y=1.1, xanchor="center", x=0.5))
fig.show()


In [None]:
data = np.load("data.npy")

In [None]:
layout = go.Layout(
    paper_bgcolor='rgba(30,31,36,255)',
    plot_bgcolor='rgba(30,31,36,255)',
    font=dict(
        family='Overpass, monospace',
        size=32,
        color="rgb(241, 194, 50)"
    ),
    legend=dict(orientation="h")
)
fig = go.Figure(data=[
    go.Bar(y=data*100,
           x=list(set(named_list))
           )
], layout=layout)
fig.update_layout(
    autosize=False,
    width=1800,
    height=1000, title='Accuracy versus EOS',
    xaxis_title='Equation of State',


    yaxis_title='Accuracy in %', legend=dict(yanchor="middle", y=1.1, xanchor="center", x=0.5))
fig.show()


In [None]:
a,b = np.unique(named_list,return_counts=True)
c = len(named_list)
b = b/c

In [None]:
layout = go.Layout(
    paper_bgcolor='rgba(30,31,36,255)',
    plot_bgcolor='rgba(30,31,36,255)',
    font=dict(
        family='Overpass, monospace',
        size=32,
        color="rgb(241, 194, 50)"
    ),
    legend=dict(orientation="h")
)
fig = go.Figure(data=[
    go.Bar(y=data*100,           x=list(set(named_list))),
    go.Bar(y=b*100,           x=a)
], layout=layout)
fig.update_layout(
    autosize=False,
    width=1800,
    height=1000, title='Accuracy versus EOS',
    xaxis_title='Equation of State',
    yaxis_title='Accuracy in %', legend=dict(yanchor="middle", y=1.1, xanchor="center", x=0.5))
fig.show()
