In [1]:
import os
import sqlite3
import wget
from omegaconf import OmegaConf
import json
from nemo.collections.asr.parts.utils.decoder_timestamps_utils import (
    ASRDecoderTimeStamps,
)
from nemo.collections.asr.parts.utils.diarization_utils import OfflineDiarWithASR
from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels

import pprint

pp = pprint.PrettyPrinter(indent=4)


def getAllFiles():
    conn = sqlite3.connect("roderick.db")
    c = conn.cursor()

    c.execute("""SELECT * FROM files WHERE status='waiting'""")

    results = c.fetchall()

    conn.commit()
    conn.close()

    return results


def findOne(id):
    conn = sqlite3.connect("roderick.db")
    c = conn.cursor()

    c.execute("""SELECT * FROM files WHERE id=?""", (id,))

    results = c.fetchall()

    conn.commit()
    conn.close()

    return results


def updateStatus(id, status):
    conn = sqlite3.connect("roderick.db")
    c = conn.cursor()

    c.execute(
        """UPDATE files SET status=? WHERE id=?""",
        (
            status,
            id,
        ),
    )

    conn.commit()
    conn.close()


def read_file(path_to_file):
    with open(path_to_file) as f:
        contents = f.read().splitlines()
    return contents


def diarize(file):
    (id, path, filename, showname, episode, title, duration, status) = file

    ROOT = os.getcwd()
    data_dir = os.path.join(ROOT, "data", showname, episode)
    model_dir = os.path.join(ROOT, "model")

    os.makedirs(data_dir, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)

    EPISODE_TITLE = filename.replace(".wav", "")
    AUDIO_FILENAME = path

    DOMAIN_TYPE = (
        "meeting"  # Can be meeting or telephonic based on domain type of the audio file
    )
    CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml"

    CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}"

    if not os.path.exists(os.path.join(model_dir, CONFIG_FILE_NAME)):
        CONFIG = wget.download(CONFIG_URL, model_dir)
    else:
        CONFIG = os.path.join(model_dir, CONFIG_FILE_NAME)

    cfg = OmegaConf.load(CONFIG)

    meta = {
        "audio_filepath": AUDIO_FILENAME,
        "offset": 0,
        "duration": None,
        "label": "infer",
        "text": "-",
        "num_speakers": None,
        "rttm_filepath": None,
        "uem_filepath": None,
    }
    with open(os.path.join(data_dir, "input_manifest.json"), "w") as fp:
        json.dump(meta, fp)
        fp.write("\n")

    cfg.diarizer.manifest_filepath = os.path.join(data_dir, "input_manifest.json")

    pretrained_speaker_model = "titanet_large"
    cfg.diarizer.manifest_filepath = cfg.diarizer.manifest_filepath
    cfg.diarizer.out_dir = (
        data_dir  # Directory to store intermediate files and prediction outputs
    )
    cfg.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
    cfg.diarizer.clustering.parameters.oracle_num_speakers = False

    # Using Neural VAD and Conformer ASR
    cfg.diarizer.vad.model_path = "vad_multilingual_marblenet"
    cfg.diarizer.asr.model_path = "stt_en_conformer_ctc_large"
    cfg.diarizer.oracle_vad = False  # ----> Not using oracle VAD
    cfg.diarizer.asr.parameters.asr_based_vad = False

    asr_decoder_ts = ASRDecoderTimeStamps(cfg.diarizer)
    asr_model = asr_decoder_ts.set_asr_model()
    word_hyp, word_ts_hyp = asr_decoder_ts.run_ASR(asr_model)

    asr_diar_offline = OfflineDiarWithASR(cfg.diarizer)
    asr_diar_offline.word_ts_anchor_offset = asr_decoder_ts.word_ts_anchor_offset

    diar_hyp, diar_score = asr_diar_offline.run_diarization(cfg, word_ts_hyp)
    # segment timestamps and speaker labels
    # diar_hyp[EPISODE_TITLE]

    predicted_speaker_label_rttm_path = f"{data_dir}/pred_rttms/{EPISODE_TITLE}.rttm"
    pred_rttm = read_file(predicted_speaker_label_rttm_path)

    pred_labels = rttm_to_labels(predicted_speaker_label_rttm_path)

    trans_info_dict = asr_diar_offline.get_transcript_with_speaker_labels(
        diar_hyp, word_hyp, word_ts_hyp
    )

In [2]:
# results = getAllFiles()
results = findOne(6)

len(results)

1

In [None]:
file = results[0]
diarize(file)

In [None]:
for file in results:
    (id, path, filename, showname, episode, title, duration, status) = file
    try:
        diarize(file)
        updateStatus(id, "done")
    except:
        updateStatus(id, "failed")