In [66]:
import json, csv, string, os
import jiwer
import numpy as np
from scipy.stats import describe

trans_path = './oyez/'
ibm_path = './ibm_trans/'
google_path = './google_trans/'
verbose = True

In [67]:
with open('./set_dict.json') as f:
    set_dict = json.load(f)

In [68]:
def load_ibm(case, pth=ibm_path, sec=1800):
    segs = [p for p in os.listdir(pth) if case+'_' in p]
    ibm_list = []
    wd = []
    for seg in segs:
        split = int(seg.split('_')[1])
        add = split*sec
        with open(pth+seg) as f:
            api = json.load(f)
            
        for i in api['results']:
            tsl = []
            for j in i['alternatives'][0]['timestamps']:
                tsl.append( [j[0], (j[1]+add), (j[2]+add)] )
            i['alternatives'][0]['timestamps'] = tsl
            wd.append(tsl)
        ibm_list.append(api['results'])

    ibm = [j for i in ibm_list for j in i]
    
    trans_times = []
    for i in ibm:
        trans = i['alternatives'][0]['transcript']
        start = i['alternatives'][0]['timestamps'][0][1]
        end = i['alternatives'][0]['timestamps'][-1][-1]
        trans_times.append((start, end, trans))

    ibm_transcript = []
    for v in trans_times:
        ibm_transcript.append(v[-1])
    return ibm_transcript, [j for i in wd for j in i]


def load_google(case, pth=google_path):
    with open(google_path+case+'.json') as f:
        file = json.load(f)
    g = []
    wd = []
    for k in file.keys():
        g.append(file[k][0])
        wd.append(file[k][2])
    return g, [j for i in wd for j in i]

def load_oyez(case, pth=trans_path):    
    f = open(pth+case+'.txt','r')
    k = f.readlines()
    f.close()
    oyez = []
    wd = []
    for u in k:
        if len(u)>=3: # check for filler lines
            t0, t1, spkr = u.split(' ')[0:3]
            text = u.split(spkr)[-1]
            oyez.append(text)
            temp = text.split(' ')
            stp = ((float(t1)-float(t0))/len(temp))
            seq = [float(t0)+(stp*i) for i in range(len(temp))]
            for idx, time in enumerate(seq):
                if temp[idx]!='':
                    wd.append([temp[idx], time, time+(stp-.0001)])       
    return oyez, wd

def clean_trans(lst):
    clean  = ''.join(lst)
    c = clean.translate(str.maketrans('', '', string.punctuation))
    c = c.replace('\n', '')
    c = c.lower()
    return c
   

In [69]:
# Segment Pulling Functions
def word_time_list(lst):
    start = []
    end = []
    for item in lst:
        start.append(item[1])
        end.append(item[2])
    return start, end

def find_nearest(array, value):
    return (np.abs(np.asarray(array) - value)).argmin()

def get_seg(st, e, name, start, end, trans=None):
    st_id = find_nearest(start, float(st))
    en_id = find_nearest(end, float(e)+float(st))
    if trans is not None:
        script = trans.split(' ')[st_id:en_id+1]
        return [name, ' '.join(script), (float(st), float(st)+float(e))]
    return st_id, en_id, name  

## Evaluating Speech2Text API Transcriptions

In [70]:
trans_wer = {}
ires = []
gres = []
for wav in set_dict['t']:
    case = wav.split('.')[0]
    if verbose:
        print("Processing Case:", case)
    #load Oyez Transcript
    oyez, wl = load_oyez(case, trans_path)
    ground_truth = clean_trans(oyez)

    #load IBM transcript
    ibm, iwd = load_ibm(case, ibm_path)
    i = clean_trans(ibm)

    #load Google
    google, gwd = load_google(case, google_path)
    g = clean_trans(google)

    i_err = jiwer.wer(ground_truth, i)
    g_err = jiwer.wer(ground_truth, g)

    ires.append(i_err)
    gres.append(g_err)

    trans_wer[case] = {'IBM': i_err, 'GOOG': g_err}

    if verbose:
        print("Lengths:", len(ground_truth.split(' ')), len(g.split(' ')), len(i.split(' ')))
        print("WER:", trans_wer[case])   

Processing Case: 17-1705
Lengths: 11819 10758 10936
WER: {'IBM': 0.19497915373015168, 'GOOG': 0.11647298855672847}
Processing Case: 17-530
Lengths: 9107 8753 8857
WER: {'IBM': 0.17274774774774776, 'GOOG': 0.08963963963963964}
Processing Case: 17-459
Lengths: 10888 10192 10417
WER: {'IBM': 0.16965055050263284, 'GOOG': 0.09832455720440401}
Processing Case: 17-1174
Lengths: 10885 10472 10585
WER: {'IBM': 0.1767257638626933, 'GOOG': 0.0950584685024519}
Processing Case: 17-1272
Lengths: 10756 10290 10388
WER: {'IBM': 0.16077755434262186, 'GOOG': 0.09451307095662166}
Processing Case: 17-130
Lengths: 10768 10227 10504
WER: {'IBM': 0.17370846936815826, 'GOOG': 0.0893988861148454}
Processing Case: 17-269
Lengths: 12192 11435 11575
WER: {'IBM': 0.1845917845576956, 'GOOG': 0.10533492415203681}
Processing Case: 17-1104
Lengths: 10964 10462 10623
WER: {'IBM': 0.16443236148870347, 'GOOG': 0.09093465829192837}
Processing Case: 17-571
Lengths: 10726 10312 10763
WER: {'IBM': 0.16096385542168676, 'GOOG'

In [71]:
describe(ires)

DescribeResult(nobs=25, minmax=(0.15090464547677263, 0.2337782340862423), mean=0.17301548814918197, variance=0.00031276932294871027, skewness=1.536009818660894, kurtosis=3.60327914386448)

In [72]:
describe(gres)

DescribeResult(nobs=25, minmax=(0.07966768692610407, 0.13367556468172484), mean=0.0960380072322987, variance=0.00016770994183634756, skewness=1.1043509574977306, kurtosis=0.9542885813840596)

In [75]:
with open('./S2T_API_WER.json', 'w') as outfile:  
    json.dump(trans_wer, outfile)

# Generate BERT Sequences

In [77]:
rttm_path = './rttm/' # gold standard diarization
#rttm_path = './rdsv/' # RDSV predicted diarization

In [78]:
to_BERT = {}
for wav in set_dict['t']:
    case = wav.split('.')[0]
        
    if 'rdsv' in rttm_path:
        with open(rttm_path+case+'_rdsv.rttm', newline='\n') as f:
            reader = csv.reader(f)
            case_diary = list(reader)        
    else:
        with open(rttm_path+case+'.rttm', newline='\n') as f:
            reader = csv.reader(f)
            case_diary = list(reader)

    oyez, owd = load_oyez(case, trans_path)        
    ibm, iwd = load_ibm(case, ibm_path)
    google, gwd = load_google(case, google_path)
    
    ostl, oetl = word_time_list(owd)   
    istl, ietl = word_time_list(iwd)
    gstl, getl = word_time_list(gwd)
    
    oyez_toBERT = []
    ibm_toBERT = []
    goog_toBERT = []
    for item in case_diary:
        temp = item[0].split(' ')
        if temp[7][-14:]=='scotus_justice' and float(temp[4])>3:
            oyez_toBERT.append(get_seg(temp[3], temp[4], temp[7], ostl, oetl, trans=''.join(oyez)))
            ibm_toBERT.append(get_seg(temp[3], temp[4], temp[7], istl, ietl, trans=''.join(ibm)))
            goog_toBERT.append(get_seg(temp[3], temp[4], temp[7], gstl, getl, trans=''.join(google)))
            
    to_BERT[case]= {'OYEZ': oyez_toBERT, 'IBM': ibm_toBERT, 'GOOG': goog_toBERT}

In [79]:
if rttm_path.split('/')[1]=='rttm':
    with open("./toBERT_oyez.json", "w") as outfile: 
        json.dump(to_BERT, outfile)
else:
    with open("./toBERT_"+rttm_path.split('/')[1]+".json", "w") as outfile: 
        json.dump(to_BERT, outfile)