In [None]:
import pandas as pd, numpy as np, xarray as xr
from pathlib import Path
import re, yaml, copy, json
import helper, config_adapter
from helper import RenderJSON
import scipy.io

In [None]:
import plotly
plotly.offline.init_notebook_mode()
plotly_config = {'scrollZoom': True, 'displaylogo': False, 'toImageButtonOptions': {
    'format': 'svg', # one of png, svg, jpeg, webp
    'filename': 'custom_image',
    'height': None,
    'width': None,
    'scale': 1 # Multiply title/legend/axis/canvas sizes by this factor
  },
  'modeBarButtonsToAdd': 
    ['drawline',
    'drawopenpath',
    'drawclosedpath',
    'drawcircle',
    'drawrect',
    'eraseshape'
    ]
  
  }

In [None]:
import itables
itables.init_notebook_mode(all_interactive=True )
itables.options.maxBytes = "1MB"
itables.options.lengthMenu = [25, 10, 50, 100, 200]
itables.options.buttons = ["copyHtml5", "csvHtml5", "excelHtml5"]
itables.options.layout={"topEnd": "pageLength", "top1": "searchBuilder"}

In [None]:
params = yaml.safe_load(Path("params.yaml").open("r"))
RenderJSON(params)

In [None]:
config_path = Path(params["config_path"])
config = config_adapter.load(config_path)
RenderJSON(config)

In [None]:
annotations = pd.read_csv(params["annotation_path"]).rename(columns={"name": "label", "start_seconds": "start", "stop_seconds": "end"}).sort_values("start")
annotations

In [None]:
fs, data = scipy.io.wavfile.read(params["audio_path"])
song = xr.Dataset()
song["data"] = xr.DataArray(data, dims="t")
song["t"] = np.arange(data.size)/fs
song["t"].attrs["fs"] = fs
song

In [None]:
volume_params = config["processing"]["volume"]
win_size = int(np.round(volume_params["window_duration"]*fs))
stride = int(np.round(fs/volume_params["approx_out_fs"]))
if volume_params["window_type"] == "hanning":
    window = xr.DataArray(np.hanning(win_size), dims="window")
else:
    raise Exception(f'Unhandled windowtype {volume_params["window_type"]}')
tmp = song["data"].rolling(t=win_size, center=True).construct("window", stride = stride).dropna(dim="t", how="any")
tmp

In [None]:

volume = xr.Dataset()
volume["volume"] =  np.abs(tmp * window).mean("window")
if volume_params["scale"] == 'log':
    volume["volume"] = np.log10(volume["volume"])
elif volume_params["scale"] == 'linear': pass
else: raise Exception("Unknwon scale...")
volume_fs = fs/stride
volume["t"].attrs["fs"] = volume_fs
# volume["incr"] = volume["volume"] - volume["volume"].shift(t=1)
volume

In [None]:
syb = annotations.to_xarray().rename_dims(dict(index="syb"))
syb

In [None]:
window = xr.DataArray(np.arange(int(np.round((config["processing"]["correction_limits"][1] - config["processing"]["correction_limits"][0])*volume_fs)))/volume_fs + config["processing"]["correction_limits"][0], dims="window_t")
window["window_t"] = window

def get_volume_window(syb_t, prev_t, next_t):
    return volume["volume"].sel(t=window + syb_t, method="nearest").where(
            (window + syb_t <  next_t) & (window + syb_t >  prev_t)
            )

syb_window_volume_start = get_volume_window(syb["start"], syb["end"].shift(syb=1, fill_value=-np.inf), syb["end"])
syb_window_volume_end = get_volume_window(syb["end"], syb["start"], syb["start"].shift(syb=-1, fill_value=np.inf))

display(syb_window_volume_start)
display(syb_window_volume_end)



In [None]:
if config["processing"]["method"] == "derivative":
    def compute_new_start(swv: xr.DataArray):
        return (swv - swv.shift(window_t=1)).idxmax("window_t")
    def compute_new_end(swv: xr.DataArray):
        return (swv - swv.shift(window_t=1)).idxmin("window_t")
elif config["processing"]["method"] == "threshold":
    # volume_df = volume["volume"].to_dataframe().reset_index()
    next_start_index = xr.DataArray(np.searchsorted(syb["start"], volume["t"]) , dims="t")
    prev_start = xr.where(next_start_index ==0, np.nan , syb["start"].isel(syb=next_start_index-1))
    next_start = xr.where(next_start_index == syb.sizes["syb"], np.nan, syb["start"].isel(syb=xr.where(next_start_index == syb.sizes["syb"], 0, next_start_index)))
    next_end_index = xr.DataArray(np.searchsorted(syb["end"], volume["t"]) , dims="t")
    prev_end = xr.where(next_end_index ==0, np.nan, syb["end"].isel(syb=next_end_index-1))
    next_end = xr.where(next_end_index == syb.sizes["syb"], np.nan, syb["end"].isel(syb=xr.where(next_end_index == syb.sizes["syb"], 0, next_end_index)))
    volume_tmp = volume.copy()
    volume_tmp["prev_start"] = prev_start
    volume_tmp["prev_end"] = prev_end
    volume_tmp["next_start"] = next_start
    volume_tmp["next_end"] = next_end
    is_in_syb = prev_start.fillna(-np.inf) > prev_end.fillna(0)
    volume_tmp["d_before"] = xr.where(is_in_syb, prev_start - volume["t"], volume["t"] - prev_end)
    volume_tmp["d_after"] = xr.where(is_in_syb, volume["t"] - next_end, next_start - volume["t"])
    volume_tmp["is_in_syb"] = is_in_syb
    volume_df = volume_tmp.to_dataframe().reset_index()
    
    display(volume)
    
    # ex_index = slice(syb["start"].isel(t=0).item() - 10, syb["end"].isel(t=0).item() + 10)
    
    # for row in annotations.to_dict(orient="index").values():
    #     volume_dist = volume_dist.loc[(volume_dist["t"] > row["end"]+0.01) |  (volume_dist["t"] < row["start"]-0.01)]
    th_value = volume_df.eval(config["processing"]["method_params"]["threshold_expr"])
    display(RenderJSON(dict(threshhold_value=th_value)))
    display(volume_df.loc[(volume_df["t"] >= syb["start"].isel(syb=0).item() - 0.01) & (volume_df["t"] <= syb["end"].isel(syb=1).item() + 0.01)])
    def compute_new_start(swv: xr.DataArray):
        return swv["window_t"].where(swv < th_value).max("window_t").fillna(swv["window_t"].max())
    def compute_new_end(swv: xr.DataArray):
        return swv["window_t"].where(swv > th_value).min("window_t").fillna(swv["window_t"].min())
else:
    raise Exception(f'Unknown method {config["processing"]["method"]}')


In [None]:

syb["new_start"] = compute_new_start(syb_window_volume_start) + syb["start"]
syb["new_end"] = compute_new_end(syb_window_volume_end) + syb["end"]
overlaps = syb["new_end"]>syb["new_start"].shift(syb=-1)
noverlap = (syb["new_end"]>syb["new_start"].shift(syb=-1)).sum().item()
display(f'noverlap={noverlap}')
display(syb.to_dataframe())


In [None]:

syb["new_end"] = xr.where(overlaps, compute_new_end(get_volume_window(syb["end"], syb["new_start"], syb["new_start"].shift(syb=-1, fill_value=np.inf))) + syb["end"], syb["new_end"])
overlaps = syb["new_end"]>syb["new_start"].shift(syb=-1, fill_value=np.inf)
noverlap = overlaps.sum().item()
if noverlap!=0:
    overlapped = syb["new_end"].shift(syb=1, fill_value=-np.inf)> syb["new_start"]
    display(syb.where(overlaps | overlapped, drop=True).to_dataframe())
    raise Exception("Overlapping in correct bounds...")
syb_df = syb.to_dataframe()
display(syb_df)

In [None]:
out_annotations =syb_df.drop(columns="index").rename(columns=dict(start="uncorrected_start", end="uncorrected_end", new_start="start_seconds", new_end="stop_seconds", label="name"))
out_annotations.to_csv(params["out_annotations"], index=False)


In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
fig = make_subplots(rows=2, cols=1, row_heights=[0.5, 0.5], shared_xaxes=True)
fig.add_trace(go.Scatter(
        x=song["t"].to_numpy(),
        y=song["data"].to_numpy(),
        showlegend=True,
        name="song",
), row=1, col=1)
fig.add_trace(go.Scatter(
        x=volume["t"].to_numpy(),
        y=volume["volume"].to_numpy(),
        showlegend=True,
        name="volume",
), row=2, col=1)

for _, row in out_annotations.to_dict(orient="index").items():
    fig.add_vrect(x0=row["start_seconds"], x1=row["stop_seconds"], 
                label = dict(
                    text=row["name"],
                    textposition="top center",
                    font=dict(size=20, family="Times New Roman", color="MediumPurple"),
                ),
                line=dict(color="MediumPurple"))
    fig.add_vrect(x0=row["uncorrected_start"], x1=row["uncorrected_end"], 
                line=dict(color="yellow", dash="dot"))

if config["processing"]["method"] == "threshold":
    fig.add_hline(y=th_value, row=2, label = dict(
                    text="threshold",
                    textposition="end",
                    font=dict(size=20, family="Times New Roman", color="black"),
                    yanchor="bottom",
                ), line=dict(color="black"))

fig.update_traces(xaxis='x')
fig.update_shapes(selector=dict(type="rect"), xref="x")
fig.update_shapes(selector=dict(type="line"), xref="x domain")

fig.update_layout(hovermode='x unified', hoversubplots="axis", xaxis_showticklabels=True)
fig.show(config = plotly_config)