In [None]:
import os, sys, csv, time
import importlib
import json, re

#os.environ["CUDA_VISIBLE_DEVICES"] = ""
sys.path.append(os.path.join('../'))

import pickle
import librosa.display
import scipy
import numpy as np
import matplotlib.pyplot as plt
import ipympl
import ipywidgets as wdg
from ipywidgets import Layout, HBox, VBox
import pygame

import torchaudio
import torch
import IPython.display as ipd

import src.recog as recog
import src.utils as utils
import nbutils

sys.path.append(os.path.join('../ast/src'))
from models import ASTModel

os.environ['TORCH_HOME'] = '../ast/pretrained_models'

torchaudio.set_audio_backend("soundfile")

def reload():
    importlib.reload(recog)
    importlib.reload(utils)
    importlib.reload(nbutils)

## Set vars and Load models and other requirements

In [None]:
sample_freq = 16000
mel_bins = 128
target_length = 1024

interval = 10e-3 #ms
win_length = 25e-3 #ms

# hop_length is number of samples between successive frames.
hop_length=int(sample_freq * interval)

if not "ast_model" in vars(sys.modules[__name__]) or not ast_model:
    ast_model = recog.Recog(ASTModel, target_length, "../ast/egs/audioset/data/class_labels_indices.csv")

In [None]:
ontology = utils.Ontology('../ontology/ontology.json')
interests = ontology[utils.reg(r"^(Singing|Music)$")]

In [None]:
filename = "../media/2022-3-5_TA_last4.flac"
save_dir = "test"

start = 0
clip_len = 5*60
delta = .5
duration = 2
sr = int(44.1*1000)
stop_time = 10*60

## Do Inference

In [None]:
importlib.reload(recog)

infer_series = utils.if_not_defined(__name__, "infer_series", {})
infer_series, _, _, _ = recog.Recog.detect_music_main(ast_model, filename, save_dir, start, clip_len, delta, duration, ontology, interests, sr, stop_time=stop_time, infer_series=infer_series)

## Load inference cache

In [None]:
importlib.reload(recog)

metadata = recog.Recog.load_cache(save_dir, load_sr=44100, load_mono=True, load_wav=False)

### Merge music intervals and save as audacity label text

In [None]:
reload()

#nbutils.update_metadata(nbutils.state_map(recog), metadata) # only for jupyter

music_itvs, itvs = recog.Recog.merge_intervals(metadata)
print(music_itvs)
print(itvs)

In [None]:
import pathlib

aud_txt = "\n".join(map(lambda itv: "\t".join(map(lambda sec: str(sec), itv)), music_itvs.values()))
print(aud_txt)
flac_path = pathlib.Path(filename)

with open(flac_path.parent / f"{flac_path.stem}_detect.txt", "w") as f:
    f.write(aud_txt)

## Plots

In [None]:
reload()

%matplotlib widget
nbutils.plot_all(metadata, 0, ontology, pygame, recog)

In [None]:
reload()

%matplotlib widget
nbutils.plot_all(metadata, 1, ontology, pygame, recog)

## Draft code

In [None]:
test_sr = int(44.1*1000)
y, sr = librosa.load(filename, sr=test_sr, mono=True, offset=3*60+33, duration=7*60+35-(3*60+33))
y = y.reshape(1, y.shape[0])

In [None]:
delta = .5
duration = 2
start = 0#+100
cut_length = 60+30
tmp_series = utils.if_not_defined(__name__, "tmp_series", {})

wav = y[..., int(test_sr*start):int(test_sr*(cut_length+start))]
entire_abs_mean = np.abs(y).mean()

In [None]:
importlib.reload(recog)
importlib.reload(utils)
detect_series, detect_d_series, conc_series, states = recog.Recog.detect_music(ast_model, tmp_series, start, delta, duration, start, wav, test_sr, ontology, interests, entire_abs_mean=entire_abs_mean)

In [None]:
%matplotlib widget
importlib.reload(recog)
tmp={}

def onclick(event):
    text = 'event.button=%d,  event.x=%d, event.y=%d, event.xdata=%f, event.ydata=%f' % (event.button, event.x, event.y, event.xdata, event.ydata)
    print(text)
    tmp["y_cut"] = y[:, int(sr*event.xdata):-1]
    sound = pygame.sndarray.make_sound((32768*tmp["y_cut"].transpose(1,0).copy(order="C")).astype(np.int16))
    sound.play()
    tmp["sound"] = sound

def on_click_callback(clicked_button: wdg.Button) -> None:
    tmp["sound"].stop()

def plot_all():
    button = wdg.Button(description='Stop')
    qbutton = wdg.Button(description="Quit")
    button.on_click(on_click_callback)
    qbutton.on_click(lambda b: pygame.mixer.quit())

    # for stop, pygame.mixer.quit()
    pygame.mixer.pre_init(sr, size=-16, channels=2)
    pygame.mixer.init()

    fig, ax = plt.subplots(nrows=5, sharex=True)
    recog.Recog.waveplot(plt, librosa, wav, sr, ax=ax[0], offset=start, max_sr=2)
    recog.Recog.classifyshow(ontology, tmp_series, ax=ax[1])
    recog.Recog.state_show([recog.State.Music, recog.State.Talking, recog.State.Other], conc_series, ax=ax[2])
    recog.Recog.detect_show(detect_series, xstep=4, ax=ax[3])
    recog.Recog.detect_show(detect_d_series, xstep=5, ax=ax[4])

    fig.set_size_inches(16, 8)
    for a in ax:
        a.xaxis.grid()
    cid = fig.canvas.mpl_connect('button_press_event', onclick)

    #display(txt)
    return HBox([button, qbutton])

plot_all()