In [None]:
from pathlib import Path
import torch
import nemo.collections.asr as nemo_asr
from datetime import datetime
from tqdm.auto import tqdm

In [None]:
!pip install pysrt

In [None]:
import pysrt
from pydub import AudioSegment

In [None]:
result_srt = Path("/opt/whisper-diarization/speech/001-output.srt")
origin_wav = Path("/opt/whisper-diarization/speech/001-output.wav")

speaker_response = Path("/opt/whisper-diarization/speech/001-output.json")

output_path = Path("/opt/whisper-diarization/speech/parts/")
output_path.mkdir(parents=True, exist_ok=True)

assert result_srt.exists(), f"Must be exists: {result_srt}"
assert origin_wav.exists(), f"Must be exists: {origin_wav}"

In [None]:
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")

In [None]:
speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name="titanet_large")
speaker_model

In [None]:
subs = pysrt.open(result_srt)

In [None]:
def convert_to_ms(time_srt) -> int:
    pure_time = time_srt.to_time()
    time_ms = 0
    time_ms += pure_time.hour * 60 * 60 * 1000
    time_ms += pure_time.minute * 60 * 1000
    time_ms += pure_time.second * 1000
    time_ms += pure_time.microsecond // 1000
    return time_ms


assert origin_wav.exists(), f"Original wav: {origin_wav}"
input_data = AudioSegment.from_wav(origin_wav)


for index, sub in tqdm(enumerate(subs)):
    print(sub.text)
    print(sub.start)
    
    start_time = convert_to_ms(sub.start)
    print(f"start time: {start_time}")
    print(f"start time: {sub.start}")

    end_time = convert_to_ms(sub.end)
    print(f"end time: {end_time}")    
    print(f"end time: {sub.end}")


    split = input_data[start_time:end_time]
    split.export(Path(output_path, f'{index}.wav'), format ='wav')    

In [None]:
def get_speaker_id(text: str) -> int:
    parts = text.split(":")
    speaker_part = parts[0]
    speaker_parts = speaker_part.split(" ")
    return int(speaker_parts[1])


meta_dict = {}
for index, sub in tqdm(enumerate(subs)):
    speaker_id = get_speaker_id(sub.text)
    meta_dict.update(
        {
            index: {
               "text": sub.text,
                "start": sub.start.to_time().isoformat(),
                "stop": sub.end.to_time().isoformat(),
                "speaker_id": speaker_id,
                "file_path": str(Path(output_path, f'{index}.wav').absolute()),
                "main_speaker": False
            }
        }
    )

In [None]:
# meta_dict

In [None]:
import IPython
# import matplotlib.pyplot as plt
# import numpy as np
# import librosa

# sr = 16000
# signal, sr = librosa.load(an4_audio,sr=sr)

IPython.display.Audio(Path(output_path, f'996.wav'))

In [None]:
main_speaker_key = 12
main_speaker = meta_dict[main_speaker_key]
main_speaker

In [None]:
# speaker_model.verify_speakers?

In [None]:
# decision = speaker_model.verify_speakers(main_speaker["file_path"], meta_dict[0]["file_path"])
# print(decision)

In [None]:
# speaker_model.verify_speakers?

In [None]:
for index, sub in tqdm(enumerate(subs)):
    if index == main_speaker_key:
        meta_dict[index]["main_speaker"] = True
        continue
    
    decision = speaker_model.verify_speakers(main_speaker["file_path"], meta_dict[index]["file_path"], threshold=0.65)
    meta_dict[index]["main_speaker"] = decision

In [None]:
# meta_dict

In [None]:
import json

json.dump(
    meta_dict,
    speaker_response.open("w")
)

In [None]:
meta_dict[966]

In [None]:
max(meta_dict.keys())