In [None]:
import pandas as pd, numpy as np, xarray as xr
from pathlib import Path
import re, yaml, copy, json
import helper, config_adapter
from helper import RenderJSON
import scipy.io

In [None]:
import plotly
plotly.offline.init_notebook_mode()
plotly_config = {'scrollZoom': True, 'displaylogo': False, 'toImageButtonOptions': {
    'format': 'svg', # one of png, svg, jpeg, webp
    'filename': 'custom_image',
    'height': None,
    'width': None,
    'scale': 1 # Multiply title/legend/axis/canvas sizes by this factor
  },
  'modeBarButtonsToAdd': 
    ['drawline',
    'drawopenpath',
    'drawclosedpath',
    'drawcircle',
    'drawrect',
    'eraseshape'
    ]
  
  }

In [None]:
import itables
itables.init_notebook_mode(all_interactive=True )
itables.options.maxBytes = "1MB"
itables.options.lengthMenu = [25, 10, 50, 100, 200]
itables.options.buttons = ["copyHtml5", "csvHtml5", "excelHtml5"]
itables.options.layout={"topEnd": "pageLength", "top1": "searchBuilder"}

In [None]:
params = yaml.safe_load(Path("params.yaml").open("r"))
RenderJSON(params)

In [None]:
config_path = Path(params["config_path"])
config = config_adapter.load(config_path)
RenderJSON(config)

In [None]:
annotations = pd.read_csv(params["annotation_path"]).rename(columns={"name": "label", "start_seconds": "start", "stop_seconds": "end"}).sort_values("start")
annotations

In [None]:
fs, data = scipy.io.wavfile.read(params["audio_path"])
song = xr.Dataset()
song["data"] = xr.DataArray(data, dims="t")
song["t"] = np.arange(data.size)/fs
song["t"].attrs["fs"] = fs
song

In [None]:
ds = xr.Dataset()
t_fs = config["processing"]["t_fs"]
ds["t"] = xr.DataArray(np.arange(int(song["t"].max()*t_fs)+1)/t_fs, dims="t")
ds


In [None]:
syb = annotations.to_xarray().rename_dims(dict(index="syb"))
syb

In [None]:
next_start_index = xr.DataArray(np.searchsorted(syb["start"], ds["t"]) , dims="t")
prev_start = xr.where(next_start_index ==0, np.nan , syb["start"].isel(syb=next_start_index-1))
next_start = xr.where(next_start_index == syb.sizes["syb"], np.nan, syb["start"].isel(syb=xr.where(next_start_index == syb.sizes["syb"], 0, next_start_index)))
next_end_index = xr.DataArray(np.searchsorted(syb["end"], ds["t"]) , dims="t")
prev_end = xr.where(next_end_index ==0, np.nan, syb["end"].isel(syb=next_end_index-1))
next_end = xr.where(next_end_index == syb.sizes["syb"], np.nan, syb["end"].isel(syb=xr.where(next_end_index == syb.sizes["syb"], 0, next_end_index)))
ds["prev_start"] = xr.DataArray(prev_start.to_numpy(), dims="t")
ds["prev_end"] = xr.DataArray(prev_end.to_numpy(), dims="t")
ds["next_start"] = xr.DataArray(next_start.to_numpy(), dims="t")
ds["next_end"] = xr.DataArray(next_end.to_numpy(), dims="t")
is_in_syb = prev_start.fillna(-np.inf) > prev_end.fillna(0)
ds["d_before"] = xr.where(is_in_syb, prev_start - ds["t"], ds["t"] - prev_end)
ds["d_after"] = xr.where(is_in_syb, ds["t"] - next_end, next_start - ds["t"])
ds["is_in_syb"] = is_in_syb
display(ds)
as_df = ds.to_dataframe().reset_index()
display(as_df.loc[(as_df["t"] >= syb["start"].isel(syb=0).item() - 0.01) & (as_df["t"] <= syb["end"].isel(syb=1).item() + 0.01)])


In [None]:
def compute_envellope(arr_name, window_duration, window_type):
  if arr_name == "source_signal":
    sig = song["data"]
  else:
    sig = ds[arr_name]
  sig_fs = sig["t"].attrs["fs"]
  win_size = int(np.round(window_duration*sig_fs))
  stride = int(sig_fs/t_fs)
  if window_type == "hanning":
    window = xr.DataArray(np.hanning(win_size), dims="window")
  else:
    raise Exception(f'Unhandled windowtype {window_type}')
  tmp = sig.rolling(t=win_size, center=True).construct("window", stride = stride).dropna(dim="t", how="any")
  vol = np.abs(tmp * window).mean("window")
  vol["t"].attrs["fs"] = sig_fs/stride
  return vol


In [None]:

def sliding_mean(arr_name, window_duration):
  if arr_name == "source_signal":
    sig = song["data"]
  else:
    sig = ds[arr_name]
  sig_fs = sig["t"].attrs["fs"]
  win_size = int(np.round(window_duration*sig_fs))
  ret = sig.rolling(t=win_size, center=True).construct("window").dropna(dim="t", how="any").mean("window")
  ret["t"].attrs["fs"] = sig_fs
  return ret

In [None]:
for d in config["processing"]["data"]:
    name = d["name"]
    if d["method"] == "envellope":
        res = compute_envellope(**d["method_params"])
    elif d["method"] == "sliding_mean":
        res = sliding_mean(**d["method_params"])
    elif d["method"] == "pandas_eval":
        res = ds.eval(d["method_params"]["expr"]).reset_coords(drop=True)
    if "t" in res.dims and ("fs" not in res["t"].attrs or res["t"].attrs["fs"] != t_fs):
        res = res.interp(t=ds["t"])
        res["t"].attrs["fs"] = t_fs
    ds[name] = res
display(ds)

In [None]:
annotations["next_label"] = annotations["label"].shift(-1, fill_value="$")
annotations["prev_label"] = annotations["label"].shift(1, fill_value="^")
annotations["next_start"] = annotations["start"].shift(-1, fill_value=np.inf)
annotations["prev_end"] = annotations["end"].shift(1, fill_value=0)
all_transitions = pd.Series([(row["prev_label"], row["label"]) for row in annotations.to_dict(orient="index").values()]+[(annotations["label"].iat[-1], "$")]).drop_duplicates()
all_transitions

In [None]:
specific_corrections = dict(
    start={tuple(i["transition"]): {k: v for k, v in i.items() if k!="transition"} for i in config["processing"]["specific_corrections"]["start"]},
    end={tuple(i["transition"]): {k: v for k, v in i.items() if k!="transition"} for i in config["processing"]["specific_corrections"]["end"]},
)
corrections = pd.DataFrame()
corrections["transition"] = all_transitions
start_mapped = corrections["transition"].map(specific_corrections["start"])
corrections["start_params"] = np.where(start_mapped.isna(), config["processing"]["default_corrections"]["start"], start_mapped)
end_mapped = corrections["transition"].map(specific_corrections["end"])
corrections["end_params"] = np.where(end_mapped.isna(), config["processing"]["default_corrections"]["end"], end_mapped)
display(corrections)
corrections = corrections.set_index("transition").to_dict(orient="index")


In [None]:
def compute_new_bounds(correction_params, syb_t, prev_t, next_t):
    [corr_min, corr_max] = correction_params["correction_limits"]
    [min_bound, max_bound] = [max((prev_t+syb_t)/2+0.00001, syb_t+corr_min), min((next_t+syb_t)/2 -0.00001, syb_t+corr_max)]
    arr = ds.sel(t=slice(min_bound, max_bound)).eval(correction_params["expr"])
    if correction_params["method"]=="first_all_true":
        return float(arr["t"].where(~arr).max().fillna(min_bound))
    if correction_params["method"]=="min":
        return float(arr["t"].isel(t=arr.argmin("t")))

In [None]:
new_bounds = dict(start=[], end=[])
for row in annotations.to_dict(orient="index").values():
    label = row["label"]
    for which in ["start", "end"]:
        syb_t = row[which]
        prev_t = row["start"] if which=="end" else row[f"prev_end"]
        next_t = row["end"] if which=="start" else row[f"next_start"]
        transition = (row["prev_label"], row["label"]) if which=="start" else (row["label"], row["next_label"])
        correction_params = corrections[transition][f'{which}_params']
        new_bounds[which].append(compute_new_bounds(correction_params, syb_t, prev_t, next_t))
annotations["new_start"] = new_bounds["start"]
annotations["new_end"] = new_bounds["end"]
annotations

In [None]:
overlaps = annotations["new_end"]>annotations["new_start"].shift(-1)
noverlap = overlaps.sum()
if noverlap > 0:
    display(annotations)
    raise Exception("Overlap problem...")

In [None]:
out_annotations =annotations.rename(columns=dict(start="uncorrected_start", end="uncorrected_end", new_start="start_seconds", new_end="stop_seconds", label="name"))
out_annotations.to_csv(params["out_annotations"], index=False)
out_annotations


In [None]:
if config["display"]:
  import plotly.graph_objects as go
  from plotly.subplots import make_subplots
  if "t_limits" in config["display"] and config["display"]["t_limits"]:
    t_min = config["display"]["t_limits"]["min"]
    t_max = config["display"]["t_limits"]["max"]
  else:
    t_min, t_max = (-np.inf, np.inf)
  sigs= config["display"]["signals"]
  if config["display"]["show_spectrogram"]:
    sigs = dict(song_display_spectrogram=dict(dest="_spec"), **sigs)
    win_size = 512
    stride = 128
    spectro_window = song["data"].sel(t=slice(t_min-0.2, t_max+0.2)).rolling(t=win_size, min_periods=win_size, center=True).construct("window_t", stride=stride)
    spectro_window = spectro_window * xr.DataArray(np.hanning(spectro_window.sizes["window_t"]), dims="window_t")
    fft = xr.apply_ufunc(np.fft.rfft, spectro_window, input_core_dims=[["window_t"]], output_core_dims=[["f"]])
    fft["f"] = np.fft.rfftfreq(spectro_window.sizes["window_t"], 1/fs)
    fft = fft.sel(f=slice(2000, 8000))
    psd = np.abs(fft)**2
    display_psd = np.log10(psd)
    nds=xr.Dataset()
    nds["song_display_spectrogram"] = np.maximum(display_psd, display_psd.max()/2)
    for sig in ds:
      if "t" in ds[sig].dims:
        try:
          nds[sig] = ds[sig].interp(t=nds["t"])
        except:
          nds[sig] = ds[sig].sel(t=nds["t"], method="nearest")
      else:
        nds[sig] = ds[sig]
  else:
    nds = ds.sel(t=slice(t_min, t_max))
  
  sigs_info = pd.DataFrame()
  sigs_info["sig_name"] = list(sigs.keys())
  sigs_info["subplot_name"] = [i["dest"] for i in sigs.values()]
  sigs_info["sig_ndim"] = sigs_info["sig_name"].apply(lambda n: nds[n].ndim)
  sigs_info["subplot_num"] = sigs_info.groupby("subplot_name").ngroup()

  plots_info = sigs_info.groupby(["subplot_name", "subplot_num"])["sig_ndim"].max().reset_index()
  plots_info["subplot_height"] = plots_info["sig_ndim"]/plots_info["sig_ndim"].sum()
  plots_info=plots_info.sort_values("subplot_num")


  n_subplots = len(plots_info.index)
  fig = make_subplots(rows=n_subplots, cols=1, row_heights=plots_info["subplot_height"].to_list(), shared_xaxes=True)

  for row in sigs_info.to_dict(orient="index").values():
    arr = nds[row["sig_name"]]
    if arr.dims == ("t",):
      fig.add_trace(go.Scatter(x=arr["t"].to_numpy(),y=arr.to_numpy(),showlegend=True,name=row["sig_name"]), row=row["subplot_num"]+1, col=1)
    elif arr.ndim == 0:
      fig.add_hline(y=float(arr), row=row["subplot_num"]+1, col=1, label = dict(
        text=row["sig_name"],
        ), line=dict(color="black"), showlegend=True)
    elif arr.ndim == 2:
      other_dim = [d for d in arr.dims if not d=="t"][0]
      fig.add_trace(go.Heatmap(x=arr["t"].values, y=arr[other_dim].values, z= arr.transpose(other_dim, "t").values,
                     name=row["sig_name"],showlegend=False, showscale=False)
        ,row=row["subplot_num"]+1, col=1
      )
    else:
      raise Exception("Not handled")


  fig.add_vrect(x0=0, x1=0,label = dict(text="corrected_labels"),line=dict(color="MediumPurple"), showlegend=True, row=1, col=1)
  fig.add_vrect(x0=0, x1=0,label = dict(text="uncorrected_labels"),line=dict(color="yellow", dash="dot"), showlegend=True, row=1, col=1)
  
  for i, row in enumerate(out_annotations.to_dict(orient="index").values()):
    fig.add_vrect(x0=row["start_seconds"], x1=row["stop_seconds"],
      label = dict(
      text=row["name"],
      textposition="top center",
      font=dict(size=20, family="Times New Roman", color="MediumPurple"),
      ),
      line=dict(color="MediumPurple"))
    fig.add_vrect(x0=row["uncorrected_start"], x1=row["uncorrected_end"], line=dict(color="yellow", dash="dot"))

  fig.update_layout(hovermode='x unified', hoversubplots="axis", xaxis_showticklabels=True)
  fig.show(config = plotly_config)