In [None]:
import json
import pandas as pd

from tqdm import tqdm
import sys
sys.path.append('.')
from diarization_utils.utils import transcript_preserving_speaker_transfer
from utils.data import true_labels, aws_labels, create_diarized_text
from utils.metrics import calculate_swer

In [None]:
# Load the file paths - this is a dataframe with the following columns:
# utt_id, file_path_original, file_path_asr
# utt_id is the file name
# file_path_original is the path to the original transcription file
# file_path_asr is the path to the ASR transcribed file
file_paths_df = pd.read_csv('./data/file_paths.csv')

We run a for-loop across all the files in the dataframe, where for each file we:

1. Extract words, speaker labels and timings from the reference transcript.
2. Use these timings to deduce the part of the audio file that corresponds to the reference transcript.
3. Extract the words, speaker labels and timings from the ASR output that correspond to the audio segment.
4. Skip the file if the ASR output has only one speaker.
5. Deduce if the speaker labels in the ASR output match the speaker labels in the reference transcript or should be swapped.
6. Transfer the speaker labels from the reference transcript to the ASR output using TPST.
7. Store the results in a dictionary.


In [None]:
res = {
   'utterances': []
}

for i in tqdm(range(len(file_paths_df))):
    utterance_id = file_paths_df.iloc[i]['utt_id']
    original_file = file_paths_df.iloc[i]['file_path_original']
    aws_file = file_paths_df.iloc[i]['file_path_asr']

    ref_labels, ref_words, ref_times = true_labels(original_file)
    start_time, end_time = ref_times[0][0], ref_times[-1][1]

    hyp_words, hyp_labels1 = aws_labels(aws_file, start_time, end_time)

    # check for single speaker
    if len(set(hyp_labels1)) == 1:
        continue

    hyp_labels2 = ['1' if x == '2' else '2' for x in hyp_labels1]

    # Calculate the SWER to determine which speaker diarization is better
    # i.e. which speaker labelling matches the reference speaker labelling
    swer1 = calculate_swer(ref_words, ref_labels, hyp_words, hyp_labels1)
    swer2 = calculate_swer(ref_words, ref_labels, hyp_words, hyp_labels2)
    if swer1 < swer2:
        hyp_labels = hyp_labels1
    else:
        hyp_labels = hyp_labels2

    # Transfer the speaker labels from the ASR to the original transcription
    hyp_spk_oracle = transcript_preserving_speaker_transfer(
        " ".join(ref_words), " ".join(ref_labels), " ".join(hyp_words), " ".join(hyp_labels)
    )

    res['utterances'].append({
        'utterance_id': utterance_id,
        'ref_text': " ".join(ref_words),
        'ref_spk': " ".join(ref_labels),
        'ref_diarized_text': create_diarized_text(ref_words, ref_labels),
        'hyp_text': " ".join(hyp_words),
        'hyp_spk': " ".join(hyp_labels),
        'hyp_diarized_text': create_diarized_text(hyp_words, hyp_labels),
        'hyp_spk_oracle': hyp_spk_oracle,
        'hyp_diarized_text_oracle': create_diarized_text(hyp_words, hyp_spk_oracle.split(' ')),
    })


In [None]:
with open('./data/processed_data.json', 'w') as file:
        json.dump(res, file)
