In [3]:
import os
import numpy as np
import pickle
import json
import matplotlib.pyplot as plt
import pandas as pd
from itertools import chain 
from itertools import groupby
from glob import glob
from tqdm import tqdm_notebook as tqdm

#### Setup 

In [4]:
contribution_path = 'Contribution/'
layers = ['conv','rnn_0','rnn_1','rnn_2','rnn_3','rnn_4']

In [5]:
with open('my_data/final-file-info.json', 'r') as j:
	file_meta = json.load(j)

In [6]:
df_trans = pd.read_csv('my_data/probes/test_1750.csv', header = None, names = ['file', 'accent', 'duration'])
#df_trans['file'] = df_trans['file'].map(lambda x: x.split('.')[0])
print(df_trans.head(20))
files_list = df_trans.file.values.tolist()

                        file   accent  duration
0      common_voice_en_22042  england     2.880
1    common_voice_en_1177263  england     4.464
2   common_voice_en_17287554  england     1.584
3     common_voice_en_118123  england     4.344
4     common_voice_en_214810  england     2.784
5     common_voice_en_205613  england     2.784
6   common_voice_en_18130438  england     2.184
7     common_voice_en_216610  england     7.536
8     common_voice_en_485543  england     3.024
9     common_voice_en_606891  england     5.184
10     common_voice_en_54962  england     2.832
11  common_voice_en_17456660  england     5.760
12    common_voice_en_553961  england     3.456
13    common_voice_en_247127  england     3.096
14    common_voice_en_311828  england     3.144
15    common_voice_en_125156  england     5.232
16    common_voice_en_517178  england     4.344
17  common_voice_en_17254685  england     3.696
18      common_voice_en_3403  england     6.816
19     common_voice_en_83350  england   

In [7]:
print(len(files_list))

1750


##### Map time axis of representation to the input frames as per the convolutional layers used

In [8]:
def get_input_frame(current_frame):
    return (current_frame - 1)*2 + 11 - 2*5


In [9]:
vowel_phones = ['iy','ih','ix','ey','eh','er','ae','aa','ao','ay','aw','ah','ax','axr','ow','oy','uh','uw','ux']

##### Generate frame level allignments and other related data

In [10]:
def get_frame_allignment(file, input_size):
    alligned = []
    
    spec_stride = 0.01
    window_size = 0.02
    times = file_meta[file]['end_times']
    last_idx = 0
    
#     print(times[55:58])
    for i in range(input_size):
        frame_idx = i
        window_start = frame_idx*spec_stride
        window_mid = window_start + (window_size/2)
        alligned_phone = 'na'
        
        
        for j in range(len(times)):
            
            if (window_mid < times[j]):
                #alligned_phone = file_meta[file]['phones'][j]
                #print(j)
                alligned_phone = j
                if(j == 0 and file_meta[file]['phones'][j] =='pause'):
                    alligned_phone = -1 # marker for start pause
                if(j == len(times)-1 and file_meta[file]['phones'][j] == 'pause'):
                    alligned_phone = -2 # marker for end pause
                break
                
        #assert alligned_phone != 'na', "Failed to fetch allignment"
        if(alligned_phone != 'na'):
            alligned.append(alligned_phone)
            last_idx = i
#     pause_start = 0
#     pause_end = len(alligned)
#     for i in range(len(alligned)):
#         if(alligned[i] != 'pause'):
#             break
#         pause_start = i
    
#     for i in range(len(alligned)-1,-1,-1):
#         if(alligned[i] != 'pause'):
#             break
#         pause_end = i
        
    #print(last_idx)
    #print(pause_start, pause_end)
#     print(alligned)
    allign_grouped = [x[0] for x in groupby(alligned)]
    allign_labels = [list(x[1]) for x in groupby(alligned)]
    #print(allign_labels)
    allign_indices = [0]
    for j in allign_labels:
        allign_indices.append(allign_indices[-1] + len(j))
    #print(allign_indices)
    
    return allign_grouped, allign_indices
    

In [11]:
labels, indices = get_frame_allignment(files_list[0],1000)
print(labels)
print(indices)

[-1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, -2]
[0, 86, 94, 100, 111, 112, 119, 125, 130, 134, 138, 143, 148, 157, 170, 180, 189, 192, 200, 209, 219, 228, 287]


##### Generate allignment of representaions with phones 

In [12]:
def get_rep_labels(file,idx):
    spec_stride = 0.01
    window_size = 0.02
    times = file_meta[file]['end_times']
    frame_idx = get_input_frame(idx)
    window_start = frame_idx*spec_stride
    window_mid = window_start + (window_size/2)
    alligned_phone = 'na'
    for j in range(len(times)):
        if (window_mid < times[j]):
            #alligned_phone = file_meta[file]['phones'][j]
            alligned_phone = j
            if(j == 0 and file_meta[file]['phones'][j] =='pause'):
                alligned_phone = -1 # marker for start pause
            if(j == len(times)-1 and file_meta[file]['phones'][j] == 'pause'):
                alligned_phone = -2 # marker for end pause
            break
    assert alligned_phone!= 'na', 'found na allignments'
    if_vowel = False
    if(alligned_phone >= 0):
        if_vowel = file_meta[file]['phones'][alligned_phone] in vowel_phones
    return alligned_phone, if_vowel
    

### Phone focus calculations and Neighbour Analysis

In [13]:
mixing_accents = {'us':([],[]),'indian':([],[]),'scotland':([],[]),'england':([],[]),'australia':([],[]),'canada':([],[]),'african':([],[])}
vowel_accents = {'us':([],[]),'indian':([],[]),'scotland':([],[]),'england':([],[]),'australia':([],[]),'canada':([],[]),'african':([],[])}
neighbours = {'us':([],[],[],[],[],[],[]),'indian':([],[],[],[],[],[],[]),'scotland':([],[],[],[],[],[],[]),'england':([],[],[],[],[],[],[]),'australia':([],[],[],[],[],[],[]),'canada':([],[],[],[],[],[],[]),'african':([],[],[],[],[],[],[])}

In [27]:
neighbours = {'us':([],[],[],[],[],[],[]),'indian':([],[],[],[],[],[],[]),'scotland':([],[],[],[],[],[],[]),'england':([],[],[],[],[],[],[]),'australia':([],[],[],[],[],[],[]),'canada':([],[],[],[],[],[],[]),'african':([],[],[],[],[],[],[])}

In [36]:
target_layer = 'rnn_4'
target_path = os.path.join(contribution_path,target_layer)
print(target_path)


Contribution/rnn_4


In [464]:
for file in files_list:
    #file = 'common_voice_en_55029'
    #print(file)

#     print(file_meta[file]['end_times'])
#     print(file_meta[file]['phones'])
    labels, indices = get_frame_allignment(file,1500)
#     print(labels)
#     print(len(indices))
    desired_files = glob(target_path +'/{}*.npy'.format(file))
    #print(desired_files)
    for f in desired_files:
        try:
        #print(f)
#             print(f)
            index = f.split('_')[-2]
            if('rnn' in target_layer):
                index = f.split('_')[-3]
                
            #print(index)

            #print(index)
            out, if_vowel = get_rep_labels(file,int(index))
            if(out<0):
                continue
            out_id = labels.index(out)
    #         print(out)
    #         print(labels)
    #         print(indices)
            arr = np.load(f)
            if(np.sum(arr) == 0):
                #print('encountered all 0')
                continue
            #plt.plot(arr)
            #print(arr[:50])
            #print(np.sum(arr))
    #         print(len(arr))
    #         print(arr[indices[out]-20:indices[out + 1]+20])
            sliced = arr[indices[out_id]:indices[out_id + 1]]
            contr = np.sum(sliced)*100
            if(np.isnan(contr)):
                continue
            
            #print(contr)

            all_contr = []
            for i in range(len(labels)):

    #             if(labels[i] <0):
    #                 continue
    # #             print(i)
                all_contr.append(np.sum(arr[indices[i]:indices[i + 1]]))
            max_idx = np.argmax(np.asarray(all_contr))

            cond = (max_idx == out_id)
#             print(max_idx,out_id,all_contr[max_idx]*100)
            #print(contr, cond, max_idx, out_id)
            accent = file_meta[file]['accent']
            mixing_accents[accent][0].append(contr)
            mixing_accents[accent][1].append(cond)
            if(if_vowel):
                vowel_accents[accent][0].append(contr)
                vowel_accents[accent][1].append(cond)
        except:
            #print('failed for file:',file,index)
            continue
#         break
#     break
    
    
    

In [37]:
class Contr:
    def __init__(self, arr):
        self.arr = arr
        self.len = len(arr)
        #print(arr)
        #print('done')
    def fetch(self,idx):
        #print(self.arr)
        if(idx < 0 or idx >= self.len):
            #print('enc')
            return 0
        else:
            #print('nor')
            return self.arr[idx]
    def fetch_range(self,r ):
        sum = 0.0
        (start, stop) = r
        for i in range(start,stop +1):
            sum += self.fetch(i)
        return sum

In [38]:
for file in files_list:
    #file = 'common_voice_en_55029'
    #print(file)

#     print(file_meta[file]['end_times'])
#     print(file_meta[file]['phones'])
    labels, indices = get_frame_allignment(file,1500)
#     print(labels)
#     print(len(indices))
    desired_files = glob(target_path +'/{}*.npy'.format(file))
    #print(desired_files)
    for f in desired_files:
        try:
            #print(f)
    #             print(f)
            index = f.split('_')[-2]
            if('rnn' in target_layer):
                index = f.split('_')[-3]

            #print(index)

            #print(index)
            out, if_vowel = get_rep_labels(file,int(index))
            if(out<0):
                continue
            out_id = labels.index(out)
    #         print(out)
    #         print(labels)
    #         print(indices)
            arr = np.load(f)
            if(np.sum(arr) == 0):
                #print('encountered all 0')
                continue
            #plt.plot(arr)
            #print(arr[:50])
            #print(np.sum(arr))
    #         print(len(arr))
    #         print(arr[indices[out]-20:indices[out + 1]+20])
            sliced = arr[indices[out_id]:indices[out_id + 1]]
            contr = np.sum(sliced)*100
            if(np.isnan(contr)):
                continue

            #print(contr)

            all_contr = []
            for i in range(len(labels)):

    #             if(labels[i] <0):
    #                 continue
    # #             print(i)
                all_contr.append(np.sum(arr[indices[i]:indices[i + 1]]))

            max_idx = np.argmax(np.asarray(all_contr))
            neigh_contr = Contr(all_contr)

            #print(neigh_contr.fetch(0))
            cond = (max_idx == out_id)
    #             print(max_idx,out_id,all_contr[max_idx]*100)
            #print(contr, cond, max_idx, out_id)
            accent = file_meta[file]['accent']
            #accent = 'timit'
    #         mixing_timit[accent][0].append(contr)
    #         mixing_timit[accent][1].append(cond)
            neighbours[accent][0].append(neigh_contr.fetch(out_id -1) + neigh_contr.fetch(out_id +1))
            neighbours[accent][1].append(neigh_contr.fetch(out_id -2) + neigh_contr.fetch(out_id +2))
            neighbours[accent][2].append(neigh_contr.fetch(out_id -3) + neigh_contr.fetch(out_id +3))
            neighbours[accent][3].append(neigh_contr.fetch_range((out_id - 5, out_id -4)) + neigh_contr.fetch_range((out_id +4, out_id+5)))
            neighbours[accent][4].append(neigh_contr.fetch_range((out_id - 8, out_id -6)) + neigh_contr.fetch_range((out_id +6, out_id+8)))
            neighbours[accent][5].append(neigh_contr.fetch_range((out_id - 11, out_id -9)) + neigh_contr.fetch_range((out_id +9, out_id+11)))
            neighbours[accent][6].append(neigh_contr.fetch_range((out_id - 100, out_id -12)) + neigh_contr.fetch_range((out_id +12, out_id+100)))


#             if(if_vowel):
#                 vowels_timit[accent][0].append(contr)
#                 vowels_timit[accent][1].append(cond)
        except:
            #print('failed for file:',file,index)
            continue
#         break
#     break
    
    
    

##### Phone focus and binary phone focus (calculated one layer at a time), shown here for SPEC

In [465]:
for a in mixing_accents.keys():
    try:
        print(a)
#         
        contr_arr = np.asarray(mixing_accents[a][0])
#         print(mixing_accents[a][1])
#         print(mixing_accents[a][0])
#         break
        print(contr_arr.mean(),contr_arr.std(),100.0*sum(mixing_accents[a][1])/(len(mixing_accents[a][1])))
    except:
        continue
# for a in vowel_accents.keys():
#     try:
#         print(a)
# #         
#         contr_arr = np.asarray(vowel_accents[a][0])
# #         print(mixing_accents[a][1])
# #         print(mixing_accents[a][0])
# #         break
#         print(contr_arr.mean(),contr_arr.std(),100.0*sum(vowel_accents[a][1])/(len(vowel_accents[a][1])))
#     except:
#         continue

us
62.54442978064803 25.070968167836934 73.87542644110717
indian
63.9602528876121 25.894985414401255 74.68457723934468
scotland
62.3642934689578 25.634266641694573 73.61333169351863
england
62.21513071446245 25.27033001413601 73.41286889850953
australia
62.430332227020145 25.491467533674932 73.52960547981294
canada
61.2306828972328 25.180897913225586 72.55708840692144
african
64.16524100009383 25.735879414367638 74.89567318251702


##### Neighbour Analysis values for layer RNN_4

In [39]:
for a in mixing_accents.keys():
    try:
        print(a)
#         
        #contr_arr = np.asarray(mixing_accents[a][0])
#         print(mixing_accents[a][1])
#         print(mixing_accents[a][0])
#         break
        #print(neighbours[a][0])
#         print(100*np.asarray(neighbours[a][0]).mean())
        arr = []
        for i in range(len(neighbours[a])):
            arr.append(100*np.asarray(neighbours[a][i]).mean())
        print(arr)
        #print(contr_arr.mean(),contr_arr.std(),100.0*sum(mixing_timit[a][1])/(len(mixing_timit[a][1])))
    except:
        continue


us
[22.638480365276337, 12.542382574099609, 7.952615837516744, 9.385792697883067, 7.614086639646633, 4.851851712807524, 14.444850730930911]
indian
[21.280723810195923, 12.305844389243482, 8.195414496937591, 10.123701281435645, 8.182893904021277, 4.914469927645419, 13.960731343402333]
scotland
[21.29308432340622, 12.132624955503704, 7.9527269285730995, 9.562700520589063, 7.945317385160462, 4.971957288337058, 15.219619409445926]
england
[21.94535266152243, 12.496599320776916, 8.170664550136404, 9.697629917171719, 7.984599989060589, 5.00539373145459, 14.395554327181582]
australia
[22.04861491918564, 12.597264878502923, 8.086613376326612, 9.744694454159763, 8.074721421988148, 4.995766007349288, 13.901542000918287]
canada
[22.8696346282959, 12.696123461599603, 8.067566000259314, 9.395388820129348, 7.578003661423454, 4.830882189746786, 13.786903177871535]
african
[21.585925028024786, 11.85953984030524, 7.679363588716809, 9.199444571711155, 7.494925847285976, 4.779391716000584, 15.44301008758

In [None]:
# conv 
# us
# [32.7407032251358, 4.129967677019115, 0.5069611499109884, 0.07540823335304898, 0.0020524582412921005, 0.0, 0.0]
# indian
# [30.90875744819641, 4.187893504497704, 0.6535981730305066, 0.238160239311137, 0.046958024210296455, 0.003738099867418384, 0.0002974208116719145]
# scotland
# [32.13910460472107, 4.590714934856013, 0.7024039413508787, 0.17588622736358267, 0.02393076658670436, 0.0026134167411263247, 0.0]
# england
# [32.66556247854946, 4.368573959888614, 0.6079838453871551, 0.11953031163447808, 0.01943777885757124, 0.002580319954526009, 0.0004979741529129851]
# australia
# [32.34032988548279, 4.472068407601206, 0.632637141106141, 0.11507221019796406, 0.009235619216210154, 1.952351730773e-05, 0.0]
# canada
# [33.539554476737976, 4.574416147792803, 0.5713643466006877, 0.07515107721887157, 0.007279535102926467, 0.001223377025871137, 0.0]
# african
# [31.233730379140706, 4.009490048996913, 0.5022208455794831, 0.07889555733309296, 0.00871781443660488, 0.0008204423270378301, 6.012520126919877e-05]

# rnn_0
# us
# [27.49955654144287, 5.13346373603394, 2.0309174663958474, 2.6767877794440995, 3.0200369344228926, 2.362625394259192, 7.852535240260654]
# indian
# [26.19927227497101, 5.090281970997703, 2.077755862808442, 2.7060146116646786, 2.9382837094459093, 2.2580741789313943, 7.560508918494634]
# scotland
# [26.919716596603394, 5.392323765776727, 2.1341611514668424, 2.674973549457239, 2.9833333826550628, 2.3151264046490434, 8.07114260450446]
# england
# [27.37573602958441, 5.295955654012483, 2.1334719217250915, 2.7150068105729908, 3.0929771560113015, 2.4152857833745416, 7.717478458531311]
# australia
# [27.247267961502075, 5.34691052375551, 2.105283901850338, 2.673661797066982, 3.067990261786627, 2.3905664178446235, 7.4642531857985]
# canada
# [28.169283270835876, 5.448183687807504, 2.077020231947467, 2.636450567206863, 3.0098093966374746, 2.369359405710461, 7.507328188736076]
# african
# [26.15024838869291, 4.877831960924951, 1.9293778762854086, 2.5777538785687097, 2.9440899393857025, 2.303586629368907, 8.30029107582431]

# rnn_1
# us
# [23.962603509426117, 8.507628458622953, 4.465992478354428, 5.71455936704617, 6.195458850237518, 4.829749567876604, 16.16509212889644]
# indian
# [23.195047676563263, 8.404100489357473, 4.470128897364782, 5.671139707643684, 6.008702689557492, 4.620739478180099, 15.721930947546001]
# scotland
# [23.04048240184784, 8.462198564993297, 4.534059670150146, 5.696610933284059, 6.151703288604992, 4.786315666033271, 16.790108252422357]
# england
# [23.631045862128857, 8.5951688668769, 4.632277741647281, 5.784687362959637, 6.338151672549959, 4.936726220444522, 16.040685559373344]
# australia
# [23.752808570861816, 8.661556305105101, 4.572080661245361, 5.7502741717005055, 6.307751578932401, 4.898818396068181, 15.481466421210937]
# canada
# [24.332773685455322, 8.688939460844866, 4.547470026116145, 5.6507561064439, 6.15434375223156, 4.845393838018302, 15.452406634770167]
# african
# [22.747576780279, 8.000965653438755, 4.25131166990766, 5.47551236200026, 6.0265592081826975, 4.722378272694236, 17.231608823099382]

# rnn_2
# us
# [24.472643435001373, 11.416035009853484, 6.433550829242235, 7.172085362697694, 6.3121468964634735, 4.558863821094617, 14.998742030734071]
# indian
# [23.51335436105728, 11.387106717103787, 6.649511262522023, 7.51940392086301, 6.3137742982306, 4.358723318068517, 14.450793554558317]
# scotland
# [23.225121200084686, 11.172301102451147, 6.492206765380377, 7.298659402172287, 6.401033581395596, 4.550037674420537, 15.69172966370581]
# england
# [23.913722893040283, 11.459426143523336, 6.633741471145262, 7.348491074447801, 6.48838323695089, 4.668961221382211, 14.981185589184959]
# australia
# [24.096539616584778, 11.600811539104825, 6.5722590834003265, 7.360531804260481, 6.491923510091799, 4.633371833046721, 14.393180537147854]
# canada
# [24.683956801891327, 11.572113914395796, 6.535514317857894, 7.151745245651895, 6.264582652847311, 4.568181247712387, 14.349837376059055]
# african
# [23.278117922402455, 10.784819814874142, 6.167853194400013, 6.9110016395469005, 6.130243824584725, 4.448534439984776, 16.056985357788058]

# rnn_3
# us
# [23.23090434074402, 12.174146764277422, 7.504013566115788, 8.723696850925187, 7.227962272073496, 4.761639574511886, 14.509863650139298]
# indian
# [21.902601420879364, 12.02527095338879, 7.787259836802122, 9.424352349689228, 7.673399363880054, 4.747935228341197, 13.992496343298566]
# scotland
# [21.891167759895325, 11.795758316695142, 7.507693259656952, 8.898633270182273, 7.511580442325496, 4.847434087985344, 15.216503226337569]
# england
# [22.57687249452284, 12.14707995605362, 7.730428943421168, 9.010231605380561, 7.541451811973519, 4.898843390369724, 14.478454061859688]
# australia
# [22.68569767475128, 12.253803979128968, 7.636378218902687, 9.0469665974413, 7.621676064515722, 4.872514438164382, 13.939816792289179]
# canada
# [23.474161326885223, 12.297419011472826, 7.59384786153115, 8.712447004216948, 7.161918183838477, 4.73416186686994, 13.869042442307574]
# african
# [22.116470338617592, 11.47845570845182, 7.220993539734151, 8.530101589118326, 7.0719255800387915, 4.653464676917603, 15.48961048674086]

# rnn_4
# us
# [22.638480365276337, 12.542382574099609, 7.952615837516744, 9.385792697883067, 7.614086639646633, 4.851851712807524, 14.444850730930911]
# indian
# [21.280723810195923, 12.305844389243482, 8.195414496937591, 10.123701281435645, 8.182893904021277, 4.914469927645419, 13.960731343402333]
# scotland
# [21.29308432340622, 12.132624955503704, 7.9527269285730995, 9.562700520589063, 7.945317385160462, 4.971957288337058, 15.219619409445926]
# england
# [21.94535266152243, 12.496599320776916, 8.170664550136404, 9.697629917171719, 7.984599989060589, 5.00539373145459, 14.395554327181582]
# australia
# [22.04861491918564, 12.597264878502923, 8.086613376326612, 9.744694454159763, 8.074721421988148, 4.995766007349288, 13.901542000918287]
# canada
# [22.8696346282959, 12.696123461599603, 8.067566000259314, 9.395388820129348, 7.578003661423454, 4.830882189746786, 13.786903177871535]
# african
# [21.585925028024786, 11.85953984030524, 7.679363588716809, 9.199444571711155, 7.494925847285976, 4.779391716000584, 15.443010087585678]



In [None]:
failed for file: common_voice_en_117181 1 22
failed for file: common_voice_en_16666058 1 25