# Setup

In [None]:
import json
import numpy as np
from tqdm import tqdm
import meeteval

import sys
sys.path.append('../')
from diarization_utils.utils import transcript_preserving_speaker_transfer
from utils.metrics import calculate_cpwer, calculate_swer, preprocess_str
from utils.data import extract_text_and_spk

# Data Preparation

In [None]:
with open('./data/full_test.json', 'r') as file:
    data = [json.loads(x) for x in file][0]['utterances']
    data2 = {x['utterance_id']: x for x in data}

In [None]:
# remove utterances with 10 consecutive words/phrases (max 3 words)
consecutive_words = []
for i in range(len(data)):
    words = data[i]['hyp_text'].split()
    for j in range(len(words)-10):
        if len(set(words[j:j+10])) == 1:
            consecutive_words.append(data[i]['utterance_id'])
            break
    for j in range(len(words)-20):
        if len(set(words[j:j+20])) == 2:
            consecutive_words.append(data[i]['utterance_id'])
            break
    for j in range(len(words)-30):
        if len(set(words[j:j+30])) == 3:
            consecutive_words.append(data[i]['utterance_id'])
            break
data = [x for x in data if x['utterance_id'] not in consecutive_words]
data2 = {x['utterance_id']: x for x in data}


# Baseline Results

In [None]:
baseline_wer = []

baseline_cpwer = []
baseline_swer = []

for i in tqdm(range(len(data2.keys()))):
    idx = list(data2.keys())[i]

    words = data2[idx]['hyp_text'].split()

    speakers_ref = data2[idx]['hyp_spk_oracle']
    speakers_input = data2[idx]['hyp_spk']

    base_cpwer = calculate_cpwer(
        data2[idx]['hyp_text'], data2[idx]['hyp_spk'], data2[idx]['ref_text'], data2[idx]['ref_spk']
    )
    baseline_cpwer.append(base_cpwer)

    base_swer = calculate_swer(
        data2[idx]['hyp_text'], data2[idx]['hyp_spk'], data2[idx]['ref_text'], data2[idx]['ref_spk']
    )
    baseline_swer.append(base_swer)

    # calculate WER
    wer = meeteval.wer.wer.siso.siso_word_error_rate(
        reference=data2[idx]['ref_text'],
        hypothesis=data2[idx]['hyp_text']
    ).error_rate * 100

    baseline_wer.append(wer)


In [None]:
np.mean(baseline_wer), np.std(baseline_wer)
np.mean(baseline_cpwer), np.std(baseline_cpwer)
np.mean(baseline_swer), np.std(baseline_swer)


# Fine-tuned Results

In [None]:
filepath = './results/model_predictions.json'

with open(filepath, 'r') as file:
    finetuned = json.load(file)

keys = sorted([k for k in finetuned.keys()])
for i in sorted(keys):
    utt_id = i.split('_seg')[0]
    data2[utt_id]['completions_llm'].append(finetuned[i].split('### Answer\n\n')[1].strip())


In [None]:
# flatten completions_llm
unfinalized_outputs = []
for key in data2.keys():
    if len(data2[key]['completions_llm']) < len(data2[key]['prompts_unprocessed']):
        unfinalized_outputs.append(key)


In [None]:
# data3 is the finalized data
data3 = {x: data2[x] for x in data2.keys() if x not in unfinalized_outputs}



In [None]:
results_cpwer = []
results_swer = []

for i in tqdm(range(len(data3))):
    key = list(data3.keys())[i]

    speakers_ref = data3[key]['hyp_spk_oracle']
    speakers_input = data3[key]['hyp_spk']
    words_input = data3[key]['hyp_text']
    words_trans = data3[key]['ref_text']
    speakers_trans = data3[key]['ref_spk']

    speakers_pred = ""

    for j in range(len(data3[key]['completions_llm'])):
        input = data3[key]['prompts_unprocessed'][j]
        output = data3[key]['completions_llm'][j]
                
        input = preprocess_str(input)
        output = preprocess_str(output)

        # extract text and speaker
        words_in, speakers_in = extract_text_and_spk(input)
        words_out, speakers_out = extract_text_and_spk(output)

        # transfer speakers from out to ref
        speakers_out2 = transcript_preserving_speaker_transfer(words_out, speakers_out, words_in, speakers_in)
        assert len(speakers_out2.split()) == len(speakers_in.split())

        speakers_pred += " " + speakers_out2

    speakers_pred = speakers_pred[1:]
    assert len(speakers_pred.split()) == len(speakers_input.split())
    assert len(words_input.split()) == len(speakers_input.split())
    assert len(words_input.split()) == len(speakers_ref.split())

    result_cpwer = calculate_cpwer(
        words_input, speakers_pred, words_trans, speakers_trans
    )
    result_swer = calculate_swer(
        words_input, speakers_pred, words_trans, speakers_trans
    )

    results_cpwer.append(result_cpwer)
    results_swer.append(result_swer)



In [None]:
np.mean(results_cpwer), np.std(results_cpwer)
np.mean(baseline_cpwer), np.std(baseline_cpwer)

np.mean(results_swer), np.std(results_swer)
np.mean(baseline_swer), np.std(baseline_swer)
