In [1]:
import os
import json
import numpy as np
import jiwer
from collections import defaultdict
import inflect
import string
# Create an inflect engine
p = inflect.engine()

# Define a custom function to replace digits with words
def digits_to_words(text):
    text = text.translate(str.maketrans('', '', string.punctuation))
    words = text.split()
    converted_words = [p.number_to_words(word) if word.isdigit() else word for word in words]
    return " ".join(converted_words)


def get_wer(gt, pred):
    gt_transformed = digits_to_words(gt)
    pred_transformed = digits_to_words(pred)

    transforms = jiwer.Compose(
        [
            jiwer.ExpandCommonEnglishContractions(),
            jiwer.RemoveEmptyStrings(),
            jiwer.ToLowerCase(),
            jiwer.RemoveMultipleSpaces(),
            jiwer.Strip(),
            jiwer.RemovePunctuation(),
            jiwer.ReduceToListOfListOfWords(),
        ]
    )
    wer = jiwer.wer(
                gt_transformed,
                pred_transformed,
                truth_transform=transforms,
                hypothesis_transform=transforms,
            )
    return wer

In [2]:
res = defaultdict(dict)
for root, dirs, files in os.walk("/scratch/zar8jw/Audio/word_time"):
    name = os.path.basename(root)
    for file in files:
        if file == "pred.json":
            with open(os.path.join(root, file), 'r') as f:
                pred = json.load(f)
            
            text = []
            for i in range(len(pred["word-level timestamp"])):
                word = pred["word-level timestamp"][i][0]
                text.append(word)
            res[name]["pred"] = " ".join(text)

In [3]:
res

defaultdict(dict,
            {'GX010336_encoded_trimmed': {'pred': 'Hey are you okay Checking for pulse for breathing Neither Call my one one get an AED Checking for pulse No pulse Applying pads Adult patient If the patient is a child press the child button Do not touch the patient Analyzing heart rhythm Do not touch the patient Analyzing heart rhythm Shock advised Do not touch the patient Everyone clear Press the flashing shock button Shock delivered Begin CPR Ten Twenty Thirty Give two breaths Ten Twenty Thirty Give two breaths Ten Twenty Thirty Give two breaths'},
             'GX010321_encoded_trimmed': {'pred': 'Assessing patient . No pulse , no breaths . Doing CPR . Applying defib Adult patient . If the patient is a child , press the child button . Do not touch the patient . Analyzing heart rhythm . Do not touch the patient . Analyzing heart rhythm . Shock advised . Do not touch the patient . Everyone clear . Press the flashing shock button . Shock delivered . Begin CPR . 10 , 2

In [4]:
file_names = [each.replace("_trimmed", "_human.json") for each in list(res.keys())]
path = "/scratch/zar8jw/Audio/manual_check_transcripts"
for root, dirs, files in os.walk(path):
    for file in files:
        if file in file_names:
            with open(os.path.join(root, file), 'r') as f:
                gt_file = json.load(f)
            
            gt_text = ""
            for i in range(len(gt_file)):
                gt_text += gt_file[i]["Utterance"] + " "
            
            res[file.replace("_human.json", "_trimmed")]["gt"] = gt_text

In [5]:
res

defaultdict(dict,
            {'GX010336_encoded_trimmed': {'pred': 'Hey are you okay Checking for pulse for breathing Neither Call my one one get an AED Checking for pulse No pulse Applying pads Adult patient If the patient is a child press the child button Do not touch the patient Analyzing heart rhythm Do not touch the patient Analyzing heart rhythm Shock advised Do not touch the patient Everyone clear Press the flashing shock button Shock delivered Begin CPR Ten Twenty Thirty Give two breaths Ten Twenty Thirty Give two breaths Ten Twenty Thirty Give two breaths',
              'gt': 'Yeah. Okay. Ready? Hey, are you okay? Checking for pulse. For Breathing. Call 911 and get an AED. Checking for pulse. No pulse. Applying pads. Adult patient, If the patient is a child, press the child button. Do not touch the patient, analyzing heart rhythm. Do not touch the patient, analyzing heart rhythm. Shock advised. Do not touch the patient. Everyone clear. Press the flashing shock button. Shock 

In [6]:
all_wer = []
for k, metadata in res.items():
    wer = get_wer(metadata["gt"], metadata["pred"])
    all_wer.append(wer)
    print(k, wer)
print(f"avg: {np.average(all_wer)}")

GX010336_encoded_trimmed 0.11702127659574468
GX010321_encoded_trimmed 0.0
GX010325_encoded_trimmed 0.1111111111111111
GX010318_encoded_trimmed 0.25
GX010332_encoded_trimmed 0.6488888888888888
GX010322_encoded_trimmed 0.10144927536231885
GX010324_encoded_trimmed 0.0
GX010364_encoded_trimmed 0.0970873786407767
GX010323_encoded_trimmed 0.02857142857142857
GX010319_encoded_trimmed 0.29473684210526313
avg: 0.16488662012755317


### timestamp ###

In [7]:
import string
def rm_punc(text):
    no_punctuation = ''.join([char for char in text if char not in string.punctuation])
    return no_punctuation.strip()

In [8]:
import re
import vertexai
from vertexai.generative_models import GenerativeModel, Part, SafetySetting, HarmBlockThreshold
from sklearn.metrics import mean_absolute_error
safety_settings = [
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
        threshold=SafetySetting.HarmBlockThreshold.OFF
    ),
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
        threshold=SafetySetting.HarmBlockThreshold.OFF
    ),
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
        threshold=SafetySetting.HarmBlockThreshold.OFF
    ),
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
        threshold=SafetySetting.HarmBlockThreshold.OFF
    ),
]

def gemini_generate(prompt, audio, max_token=8192, temperature=1, top_p=0.95):

    # print(audio)
    generation_config = {
        "max_output_tokens": max_token,
        "temperature": temperature,
        "top_p": top_p,
    }

    vertexai.init(project="coherent-code-440919-v3", location="us-central1")
    model = GenerativeModel(
        "gemini-1.5-flash-002",
    )

    if audio != None:
        responses = model.generate_content(
            [prompt, audio],
            generation_config=generation_config,
            safety_settings=safety_settings,
            stream=False,
        )
    else:
        responses = model.generate_content(
            prompt,
            generation_config=generation_config,
            safety_settings=safety_settings,
            stream=False,
        )

    # print(responses.text)
    return responses.text

def extract_json(response, pattern=None):

    if type(response) == json:
        return None, response

    if "Please provide the audio recording." in response or "no audio" in response:
        return None, []

    # Regular expression pattern to match all JSON objects
    # pattern = r'\[.*\]'
    if not pattern:
        pattern = r'\[\s*{\s*"role":\s*".+?",\s*"utterance":\s*".+?"\s*}(?:,\s*{\s*"role":\s*".+?",\s*"utterance":\s*".+?"\s*})*\s*\]'


    # Find all matches in the text
    matches = re.findall(pattern, response, re.DOTALL)
    # print('+++++'*100)
    # print(matches)
    # print(response)
    # print('+++++'*100)

    if not matches:
        print('====='*100)
        print("No JSON object found in the text.")
        print(response)
        print('====='*100)
        return "No JSON object found in the text.", None
    
    # Select the JSON object based on the number of matches
    json_data = matches[0] if len(matches) == 1 else matches[-1]
    
    try:
        # Load the JSON data
        data = json.loads(json_data)
        return None, data
    except json.JSONDecodeError as e:
        print('*****'*100)
        print(f"Error decoding JSON: {e}")
        print(response)
        print(json_data)
        print('*****'*100)
        return e, json_data

def handleError(prompt, audio, next_response, model="gemini"):

    if model == "gpt4o":
        generate = gpt4o_generate
    elif model == "gemini":
        generate = gemini_generate
    else:
        generate = None

    error, next_response_dict = extract_json(next_response)
    # ################################ no json file, regenerate ################################
    # cnt = 1
    # while error == "No JSON object found in the text." and next_response_dict == None and next_response:
    #     print(f"No json, repeat generating for the {cnt} time")
    #     next_response = generate(prompt, audio)
    #     error, next_response_dict = extract_json(next_response)
    #     cnt += 1

    ################################ json file incorrect ################################
    cnt = 1
    while error and cnt < 10:
        prompt = f"""There is an Error decoding JSON: {error} in the following json data
        {next_response_dict}, Can you fix the error and return the correct json format. Directly return the json without explanation.
        """
  
        new_response = generate(prompt, audio=None, temperature=0.3)
        error, next_response_dict = extract_json(new_response)
        cnt += 1
    
    if error:
        prompt = f"""There is an Error decoding JSON: {error} in the following json data
        {next_response_dict}, Can you fix the error and return the correct json format. Make sure it can be loaded using python (json.loads()). Directly return the json without explanation.
        """

        new_response = generate(prompt, audio=None, temperature=0.3)
        next_response_dict = json.loads(new_response)
    return next_response_dict



In [10]:

path = "/scratch/zar8jw/Audio/word_time"
res = []
for root, dirs, files in os.walk(path):
    name = os.path.basename(root)

    if "gt_" + name + '.json' in files:     
        gt_file_path = os.path.join(root, "gt_" + name + '.json')
        pred_file_path = os.path.join(root, "pred.json")

        with open(gt_file_path, 'r') as f:
            gt_dct = json.load(f)

        gt = [[rm_punc(each[0].lower()), each[1]] for each in gt_dct]

        with open(pred_file_path, 'r') as f:
            pred_dct = json.load(f)

        pred = [[rm_punc(each[0].lower()), each[1]] for each in pred_dct["word-level timestamp"]]

        prompt = f"""Match two events based on the event type, do not give me the python code, \
        directly return the result (the time for prediction and groundtruth), ignore unmatched events, return results defined as follows
        [
            [word, pred time, gt time],
        ]

        prediction: {pred}
        ground-truth: {gt}
        """

        output = gemini_generate(prompt, audio=None)
        error, jsondata = extract_json(output, pattern=r'\[.*\]')
        if error:
            jsondata = handleError(prompt, audio=None, next_response=output, model="gemini")
        print(jsondata)

        refined_jsondata = []
        for m in jsondata:
            if m[1] in [None, 'none', "None", ""] or m[2] in [None, 'none', "None", ""]:
                continue
            refined_jsondata.append(m)

        gt_time = [each[2] for each in refined_jsondata]
        pred_time = [each[1] for each in refined_jsondata]
        mae = mean_absolute_error(gt_time, pred_time)
        res.append(mae)
        print(name, mae)

[['hey', 7.71999979019165, 7.771], ['are', 8.039999961853027, 7.952], ['you', 8.15999984741211, 8.132], ['okay', 8.34000015258789, 8.313], ['checking', 9.84000015258789, 9.759], ['for', 10.020000457763672, 9.94], ['pulse', 10.260000228881836, 10.12], ['for', 10.640000343322754, 12.831], ['breathing', 10.8, 13.192], ['call', 15.100000381469727, 15.542], ['checking', 61.52000045776367, 62.167], ['for', 62.5, 62.528], ['pulse', 63.119998931884766, 62.709], ['no', 63.959999084472656, 63.879], ['pulse', 64.55999755859375, 64.242], ['applying', 65.0199966430664, 64.605], ['pads', 65.0199966430664, 65.149], ['adult', 72.08000183105469, 72.045], ['patient', 72.63999938964844, 72.771], ['if', 74.31999969482422, 74.041], ['the', 74.44000244140625, 74.404], ['patient', 74.87999725341797, 74.586], ['is', 74.87999725341797, 74.949], ['a', 75.16000366210938, 75.13], ['child', 75.72000122070312, 75.311], ['press', 76.26000213623047, 75.856], ['the', 76.30000305175781, 76.4], ['child', 76.839996337890

In [11]:
np.mean(res)

22.091660558444715