In [None]:
import pandas as pd, numpy as np, xarray as xr
from pathlib import Path
import re, yaml, copy, json
import helper, config_adapter
from helper import RenderJSON
import scipy.io

In [None]:
import plotly
plotly.offline.init_notebook_mode()
plotly_config = {'scrollZoom': True, 'displaylogo': False, 'toImageButtonOptions': {
    'format': 'svg', # one of png, svg, jpeg, webp
    'filename': 'custom_image',
    'height': None,
    'width': None,
    'scale': 1 # Multiply title/legend/axis/canvas sizes by this factor
  },
  'modeBarButtonsToAdd': 
    ['drawline',
    'drawopenpath',
    'drawclosedpath',
    'drawcircle',
    'drawrect',
    'eraseshape'
    ]
  
  }

In [None]:
import itables
itables.init_notebook_mode(all_interactive=True )
itables.options.maxBytes = "1MB"
itables.options.lengthMenu = [25, 10, 50, 100, 200]
itables.options.buttons = ["copyHtml5", "csvHtml5", "excelHtml5"]
itables.options.layout={"topEnd": "pageLength", "top1": "searchBuilder"}

In [None]:
params = yaml.safe_load(Path("params.yaml").open("r"))
RenderJSON(params)

In [None]:
config_path = Path(params["config_path"])
config = config_adapter.load(config_path)
RenderJSON(config)

In [None]:
annotations = pd.read_csv(params["annotation_path"]).rename(columns={"name": "label", "start_seconds": "start", "stop_seconds": "end"}).sort_values("start")
annotations

In [None]:
fs, data = scipy.io.wavfile.read(params["audio_path"])
song = xr.Dataset()
song["data"] = xr.DataArray(data, dims="t")
song["t"] = np.arange(data.size)/fs
song["t"].attrs["fs"] = fs
song

In [None]:
volume_params = config["processing"]["volume"]
win_size = int(np.round(volume_params["window_duration"]*fs))
stride = int(np.round(fs/volume_params["approx_out_fs"]))
if volume_params["window_type"] == "hanning":
    window = xr.DataArray(np.hanning(win_size), dims="window")
else:
    raise Exception(f'Unhandled windowtype {volume_params["window_type"]}')
tmp = song["data"].rolling(t=win_size, center=True).construct("window", stride = stride).dropna(dim="t", how="any")
tmp

In [None]:

volume = xr.Dataset()
volume["volume"] =  np.log10(np.abs(tmp * window).mean("window"))
volume_fs = fs/stride
volume["t"].attrs["fs"] = volume_fs
volume["incr"] = volume["volume"] - volume["volume"].shift(t=1)
volume

In [None]:
syb = annotations.to_xarray().rename_dims(dict(index="syb"))
syb

In [None]:
window = xr.DataArray(np.arange(int(np.round((config["processing"]["correction_limits"][1] - config["processing"]["correction_limits"][0])*volume_fs)))/volume_fs + config["processing"]["correction_limits"][0], dims="window_t")
window["window_t"] = window
syb["new_start"] = volume["incr"].sel(t=window + syb["start"], method="nearest").where(
    (window + syb["start"] <  syb["end"]) & (window + syb["start"] >  syb["end"].shift(syb=1, fill_value=-np.inf))
    ).idxmax("window_t") + syb["start"]
syb["new_end"] = volume["incr"].sel(t=window + syb["end"], method="nearest").where(
    (window + syb["end"] > syb["start"]) & (window + syb["end"] < syb["start"].shift(syb=-1, fill_value=np.inf))
    ).idxmin("window_t") + syb["end"]
syb

In [None]:
noverlap = (syb["new_end"]>syb["new_start"].shift(syb=-1)).sum().item()
syb["new_end"] = xr.where(syb["new_end"]>syb["new_start"].shift(syb=-1), volume["incr"].sel(t=window + syb["end"], method="nearest").where(window + syb["end"] < syb["new_start"].shift(syb=-1, fill_value=np.inf)).idxmin("window_t") + syb["end"], syb["new_end"])
syb["overlaps"] = syb["new_end"]>syb["new_start"].shift(syb=-1)
noverlap, syb["overlaps"].sum().item()

In [None]:
out_annotations = syb.to_dataframe().drop(columns="index").rename(columns=dict(start="uncorrected_start", end="uncorrected_end", new_start="start_seconds", new_end="stop_seconds", label="name"))
out_annotations.to_csv(params["out_annotations"], index=False)
out_annotations

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(go.Scatter(
        x=song["t"].to_numpy(),
        y=song["data"].to_numpy(),
        showlegend=True,
        name="song",
))
fig.add_trace(go.Scatter(
        x=volume["t"].to_numpy(),
        y=volume["volume"].to_numpy(),
        showlegend=True,
        name="volume",
), secondary_y=True)

for _, row in out_annotations.to_dict(orient="index").items():
    fig.add_vrect(x0=row["start_seconds"], x1=row["stop_seconds"], 
                label = dict(
                    text=row["name"],
                    textposition="top center",
                    font=dict(size=20, family="Times New Roman", color="MediumPurple"),
                ),
                line=dict(color="MediumPurple"))
    fig.add_vrect(x0=row["uncorrected_start"], x1=row["uncorrected_end"], 
                line=dict(color="yellow", dash="dot"))

fig.update_layout(hovermode='x unified')
# fig.update_layout(yaxis=dict(side="left"), yaxis2=dict(side="right"))
fig.show(config = plotly_config)