In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
from random import seed, shuffle
import os

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
sns.set_theme(style="white")

from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from sklearn.cluster import KMeans, SpectralClustering

from scipy.spatial.distance import directed_hausdorff, euclidean, cosine, pdist
from scipy import stats
from sklearn.preprocessing import StandardScaler


from src.downsample import downsamp_audio
import src.dimension_reducer as dr
import src.distance_metrics as dm
import src.plotting as plotting
import src.emb_manipulator as em


from IPython.display import clear_output

import warnings
warnings.filterwarnings('ignore')

#Directory where already-generated embeddings stored
embedding_dir = '/Users/rahulbrito/Documents/projects/infantvoice/data/embeddings'

file = '03016_diarized_pyv2.csv'
emb = pd.read_csv(os.path.join(embedding_dir,file), index_col=0)
emb_down_unscaled = em.resample_data(emb,1)

#average emebdding for each participant
emb_a = em.embedding_averager(emb_down_unscaled)

#egemap (88 features) for each participant
gemap_file = '040122_segmented_postpartum_moms_gemaps_2sec.csv'
gemap_unscaled = pd.read_csv(os.path.join(embedding_dir,gemap_file), index_col=0)
gemap_unscaled = gemap_unscaled.drop(columns=['start','end'])


scaler = StandardScaler()
emb_down = pd.DataFrame(scaler.fit_transform(emb_down_unscaled.drop(columns='part_id')))
emb_down['part_id'] = emb_down_unscaled.part_id.to_numpy()

gemap = pd.DataFrame(scaler.fit_transform(gemap_unscaled.drop(columns='part_id')))
gemap['part_id'] = gemap_unscaled.part_id.to_numpy()

OMP: Info #271: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [2]:
#average emebdding for each participant
emb_a = em.embedding_averager(emb_down)

In [18]:
cos_dist_high_dim = dm.cos_distance(emb_a)

Processing row 19, col 19


In [29]:
#construct dictonary of leaving each participant
part_list = emb_a.part_id.unique().astype('int32') 

#loo is leave one out. Creates a dictionary of the downsampled embeddings with the format: {participant_removed:embeddings} for each participant
loo = {p:
       emb_a[emb_a.part_id!=part_list[idx]].reset_index(drop=True) 
       for idx,p in enumerate(part_list)
      }

loo_cos_dist = {rm:dm.cos_distance(e_a) for rm, e_a in loo.items()}

Processing row 18, col 18


In [23]:
for p in loo_cos_dist.keys():  
    loo_cos_dist[p].loc[p] = np.nan
    loo_cos_dist[p].sort_index(inplace=True)
    loo_cos_dist[p][p] = np.nan
    loo_cos_dist[p] = loo_cos_dist[p].reindex(sorted(loo_cos_dist[p].columns), axis=1)


In [30]:
loo_cos_dist[1]

Unnamed: 0,2,3,4,5,6,7,8,9,10,11,...,13,14,15,16,17,18,19,20,0,1
2,,0.217101,0.391459,-0.410343,0.273236,0.210096,-0.254224,0.388528,0.856732,0.166042,...,-0.069807,0.265675,0.382556,0.22465,0.335398,0.891615,,,-0.299673,-0.47848
3,0.288701,,-0.279639,0.173309,0.439523,0.064131,-0.441179,0.439674,0.26347,0.149715,...,0.706096,0.622583,0.403319,0.794409,0.56926,0.111406,,,0.180469,0.958438
4,0.52233,-0.233488,,0.633006,0.368594,0.318787,0.156515,0.109289,-0.239514,0.604197,...,0.843642,0.527146,0.434609,0.469878,0.847789,0.319425,,,-0.242367,0.143077
5,-0.402278,0.134319,0.549329,,0.337285,0.050836,-0.5127,0.891487,1.172273,-0.122692,...,0.514062,0.663307,0.300013,0.536071,0.19807,0.491717,,,0.422731,-1.128859
6,0.301171,0.399475,0.271294,0.328617,,0.30422,0.362478,0.026757,0.442282,1.053588,...,-0.155213,0.118187,0.875787,-0.081445,0.432825,0.358287,,,-0.984138,-1.409435
7,0.294916,0.075241,0.277721,0.103158,0.360538,,-0.035475,0.947879,0.176788,0.258427,...,0.412324,0.640725,0.069206,1.198772,0.666098,-0.300411,,,-0.014556,0.731339
8,-0.163338,-0.431144,0.15191,-0.447026,0.464408,-0.00906,,0.940104,0.810127,0.24876,...,0.398232,0.891953,0.205239,0.664842,-0.014077,0.814337,,,1.478499,-0.456201
9,0.403982,0.386462,-0.018385,0.885343,-0.006233,0.908196,0.830014,,-0.301551,0.1366,...,0.36393,-0.819816,0.715307,-0.220143,-0.053988,0.577855,,,-1.530605,0.927408
10,0.949409,0.257544,-0.310313,1.204689,0.482297,0.155144,0.738191,-0.227817,,0.399346,...,0.51955,-0.204535,-0.184037,0.165872,0.37736,-0.454661,,,0.351442,1.023711
11,0.279932,0.191407,0.602725,-0.045429,1.155675,0.290167,0.239978,0.244998,0.449848,,...,0.783413,0.311869,-0.682666,0.601359,0.186146,0.396099,,,0.708604,1.680991


In [31]:
cos_dist_high_dim

Unnamed: 0,1,2,3,4,5,6,7,8,9,10,...,12,13,14,15,16,17,18,19,20,0
1,,-0.524982,-0.104664,0.177554,-0.245412,0.345779,-0.622762,-0.056831,0.96065,-0.911788,...,0.40893,-0.383922,-0.285506,-0.856318,1.186898,-0.0942,0.173075,0.727196,,-1.034519
2,-0.42278,,-0.020543,0.780773,0.174813,-0.459622,-0.599794,0.543622,-0.166353,0.737046,...,1.024189,-0.072137,-0.619513,0.699957,0.596164,-0.412874,0.196211,0.155536,,-2.106496
3,0.002475,-0.015769,,0.129161,0.322977,-0.720796,0.180236,0.045899,-0.51274,0.306998,...,-0.017029,0.310751,-0.289094,0.141856,0.264641,0.126744,0.217754,0.930589,,-0.983303
4,0.322767,0.785812,0.204929,,-0.568331,0.026798,0.394426,-0.144359,-0.752457,0.37064,...,-0.037935,0.484452,0.720661,0.582243,0.290754,0.857312,0.530967,-0.072093,,-1.501521
5,0.040703,0.330965,0.509647,-0.458609,,0.615618,0.303065,0.187571,0.013915,-0.04046,...,0.54401,0.852153,0.899662,0.464483,0.330108,0.441186,0.904002,0.195243,,-0.15866
6,0.484374,-0.378581,-0.696305,0.021176,0.532648,,0.262736,-0.161687,-0.844161,0.932833,...,-0.386742,-0.397854,0.47075,0.632493,0.160826,0.526061,0.03383,0.416662,,0.119358
7,-0.454116,-0.5351,0.221193,0.367058,0.163382,0.225731,,0.168584,0.278004,-0.143156,...,1.119436,0.335288,-0.40024,-0.040129,0.884978,-0.265743,0.348239,0.245186,,-0.755833
8,0.19267,0.659125,0.213035,-0.055888,0.171917,-0.063057,0.292688,,-0.232257,1.003003,...,0.101266,-0.059645,0.338348,0.604629,-0.129461,1.375804,0.660663,-0.601337,,0.928553
9,1.188653,-0.003341,-0.384659,-0.716441,0.004824,-0.767782,0.42648,-0.239759,,0.993329,...,0.088887,0.991945,0.320009,0.914618,0.041629,0.691177,-0.2503,0.831276,,0.018371
10,-0.818651,0.768502,0.355287,0.350083,-0.221352,0.938835,-0.179742,0.955832,0.877487,,...,-0.054729,0.069401,0.275369,-1.197528,0.683143,-0.443587,-0.303753,0.527363,,-0.110387
