In [None]:
import os, sys
import json

vsongrecog_path = os.getcwd()
sys.path.append(f"{vsongrecog_path}/zsass")
sys.path.append(vsongrecog_path)
sys.path.append(f"{vsongrecog_path}/whisper")

In [None]:
import importlib
import pathlib
from typing import Dict, List, Tuple

import torch
import librosa
import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd
from py_linq import Enumerable as E

import zsass.htsat_config as htsat_config
import src.audioclassifier as audioclassifier
import whisper

In [None]:
import src.utils as utils
import src.audioset as aset
import src.chart_utils as chart_utils
import src.detector as detector
import src.transcriber as transcriber
import src.Identifier as identifier
import notebooks.myconfig as myconfig

def reload():
    importlib.reload(utils)
    importlib.reload(aset)
    importlib.reload(chart_utils)

    importlib.reload(audioclassifier)
    importlib.reload(detector)
    importlib.reload(transcriber)
    importlib.reload(identifier)
    importlib.reload(myconfig)

In [None]:
reload()

onto = utils.Ontology(f'{vsongrecog_path}/ontology/ontology.json')
audioset = aset.AudioSet(f"{vsongrecog_path}/ast/egs/audioset/data/class_labels_indices.csv")
softmax = torch.nn.Softmax(dim=1)

interests = onto[utils.reg(r"^(Singing|Music)$")]

In [None]:
Audiocls = audioclassifier.AudioClassifier(htsat_config.resume_checkpoint, htsat_config, interests, onto, audioset)

## Detect

In [None]:
reload()

media_dir = myconfig.media_dir
thres_set = detector.Threshold(
    thres = 0.60,
    human_thres = 0.5,
    music_joint_thres = 0.4,
    human_joint_thres = 0.3,

    adj_thres=5.0, #0.8,
    long_thres=50,#10,
    hs_rate_thres=0.4,

    transc_long_thres=30,
    search_len=50,
)
cut_duration = 10
cache_itv = 3
print_itv = 10

In [None]:
reload()

def iter(media_dir, medias = ["*.flac", "*.ogg"]):
    for typ in medias:
        for soundfile in sorted(media_dir.glob(typ)):
            yield soundfile

for wavfile in iter(media_dir)
    try:
        config = detector.Config(wavfile, cut_duration, thres_set, interests, cache_itv, print_itv)

        # Check existance
        infer_cache_dir = config.cache_dir / wavfile.stem
        cache_exists = infer_cache_dir.exists() and len(list(infer_cache_dir.glob("*.pkl")))

        # Instantiate
        music_detector = detector.Detector(Audiocls, onto, audioset, config)

        if cache_exists:
            print("Skip Detection", wavfile.name)
            continue
        else:
            print("Detect", wavfile.name)
            music_intervals = music_detector.main(0)
            forauda = detector.to_audacity(music_intervals)
            utils.save_text(config.input_path.parent / f"{config.input_path.stem}_mid.txt", forauda)

        # Beautify
        all_abst_tensor, all_start, all_duration = detector.Detector.concat_cached_abst(music_detector)
        all_itvs = music_detector.abst_tensor2intervals(all_abst_tensor, all_start, all_duration)

        utils.save_text(config.input_path.parent / f"{config.input_path.stem}.txt", detector.to_audacity(all_itvs))
    except FileNotFoundError as ex:
        print(wavfile, ex)

## Labeling

In [None]:
import google.generativeai as genai
import os
# https://zenn.dev/layerx/articles/e13030eb8e364a
# https://qiita.com/kccs_kai-morita/items/7cc6510b8f483c31bf6e

In [None]:
def gen_gemini():
    genai.configure(api_key=os.environ['API_KEY'])
    return genai.GenerativeModel(model_name='gemini-pro')

lang_model = gen_gemini()

In [None]:
model = whisper.load_model("small")

In [None]:
import time
import browser_cookie3
reload()


engine_type = "bing"
cj = browser_cookie3.firefox()

end_expand = 5.0

for wavfile in iter(media_dir)
    skipped = False
    try:
        config = detector.Config(wavfile, cut_duration, thres_set, interests, cache_itv, print_itv)

        music_transcriber = transcriber.Transcriber(model, config)
        trans_cache_file = config.transcript_cache_dir / wavfile.stem / f"{wavfile.stem}.pkl"
        output_file = config.output_dir / wavfile.stem / f"{wavfile.stem}.pkl"

        if trans_cache_file.exists():
            transcriptions = music_transcriber.get_cache()
            if not len(transcriptions):
                print("re Transcribe", wavfile.stem)
                transcriptions = music_transcriber.main()
            else:
                print("Skip transcription", wavfile.name)
        else:
            print("Transcribe", wavfile.stem)
            # do cache also
            transcriptions = music_transcriber.main()

        music_identifier = identifier.Identifier(engine_type, transcriptions, lang_model, config, cookie_jar=cj)
        
        if output_file.exists():
            identified = music_identifier.get_cache()
            if not len(identified):
                print("re Identify", wavfile.stem)
                identified  = music_identifier.main()
            elif len(identified) != len(transcriptions):
                print("Resume identify", wavfile.stem)
                identified  = music_identifier.main(identified)
            else:
                print("Skip identification", wavfile.name)
                skipped = True
        else:
            print("Identify", wavfile.stem)
            identified  = music_identifier.main()
        
        print("  ", len(identified), "items")
        music_identifier.save_csv(media_dir, wavfile, identified)

        if skipped:
            continue
    #except grpc.RpcError as e:
    except Exception as e:
        print("ERR", e)
        if "Too Many" in str(e):
            raise e
        lang_model = gen_gemini()

    print("Sleep...")
    time.sleep(2*60) # For avoid heavy access
    print("End sleep")

In [None]:
reload()

cj = browser_cookie3.firefox()
result = identifier.google("今日", cj)

In [None]:
reload()
result = identifier.search("天気", "bing", cj)

In [None]:
reload()

print(not not [])
identifier.test()(result)

In [None]:
import src.utils as utils

print(len(transcriptions))
idx = 3

transcription = transcriptions[idx]
utils.sec2time(transcription.segment[0]), transcription.lang, transcription.text

In [None]:
import csv
from dateutil import parser

def name_filter(file_path: pathlib.Path):
    return True
    #return "7-20" in file_path.stem

def iter():
    for typ in ["*.flac", "*.ogg"]:
        for soundfile in media_dir.glob(typ):
            yield soundfile

end_expand = 5.0

for soundfile in iter():
    if not name_filter(soundfile):
        print("Skip", soundfile.name)
        continue
    else:
        print("Write", soundfile.name)


    pkl_file = media_dir / "identified" / soundfile.stem / f"{soundfile.stem}.pkl"
    csv_file = media_dir / "identified" / soundfile.stem / f"{soundfile.stem}.csv"
    csv_data = []

    if not pkl_file.exists():
        print(pkl_file, "not found")
        continue

    transcriptions: List[transcriber.Transcription] = utils.load_pickle(pkl_file)

    #parsed_date = parser.parse(soundfile.name, fuzzy=True)
    #dir_name = parsed_date.strftime('%Y-%m-%d')

    for idx, transcription in enumerate(transcriptions):
        title = transcription.title
        #print(title, transcription.search_word)
        if not title:
            title = "NoName" + str(idx)
            #continue

        command = 'ffmpeg'
        start, end = transcription.segment
        end += end_expand

        csv_data.append([start, end, title, transcription.artist])


    with open(csv_file, "w", newline="") as file:
        mywriter = csv.writer(file, delimiter=",")
        mywriter.writerows(np.array(csv_data))

## Idetify draft codes

In [None]:
# reset for tmp use
for idx, tmp in enumerate(transcriptions):
    if idx in [4]:#[2, 7]:
        tmp.search_result = []
        tmp.llm_result = ""

In [None]:
reload()
transcriptions = [transcriber.Transcription.Renew(t) for t in transcriptions]
music_identifier = identifier.Identifier(transcriptions, lang_model, config)

In [None]:
transcriptions = music_identifier.main()

In [None]:
idx = 4

transcription = transcriptions[idx]
utils.sec2time(transcription.segment[0]), transcription.lang, transcription.text,transcriptions[idx].title

In [None]:
reload()

tmp = transcriptions[7]
#tmp.search_result = None
#tmp.llm_result = ""
search_word = identifier.Identifier.get_search_word(music_identifier, tmp)
print(tmp.search_word,"\n", search_word)

tmp.search_result

In [None]:
tmp = music_identifier.guess_song(tmp)
tmp.artist, tmp.title

In [None]:
list = str.join("\n", map(lambda line: f"- {line}", tmp.search_result))
prompt = f"""
Guess the artist and the song name in human understandable Japanese as possible as you can from the list and return it in json format include "artist" and "title" as keys.
If you can't identify them uniquely, the guess from top item of list that include the word "歌詞".

{list}
"""
print(prompt)

response = lang_model.generate_content(prompt)
#response = music_identifier.model.generate_content(prompt)
response.parts[0].text