<h4> Code partially from Choi et al. (2024): https://github.com/juice500ml/phonetic_semantic_probing/tree/24b85b648c6512d9fe4df4139c546482080fef4c

Copyright (c) 2022, Puyuan Peng All rights reserved.</h4>

In [None]:
import random
import pickle
from textgrids import TextGrid
from functools import partial
import torch
from collections import defaultdict
import json
from pathlib import Path
from itertools import product
import matplotlib.pyplot as plt

from utils_Phonetic_Semantic_Choi import *

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
LOAD_DISTS = False

In [2]:
SEEDS = [42, 12, 25, 31, 69]
N_SAMPLES = 10000

In [3]:
def set_seed(seed):
    random.seed(seed)

    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 

In [4]:
w2v2_path_audio = 'embeddings/wav2vec2/LibriSpeech_audioslicing/w2v2_LibriSpeech_audioslicing.pkl' 

fast_vgs_plus_path_audio = 'embeddings/fast_vgs/LibriSpeech_audioslicing/fast_vgs_plus_librispeech_audioslicing.pkl'

w2v2_path_feat = 'embeddings/wav2vec2/LibriSpeech_featureslicing/w2v2_LibriSpeech_featureslicing.pkl'

fast_vgs_plus_path_feat = 'embeddings/fast_vgs/LibriSpeech_featureslicing/fast_vgs_plus_LibriSpeech_featureslicing.pkl'

data_path = 'exp2/' # TODO: set path

figure_path = 'exp2/figures/'

In [5]:
# load dataframes

# audio slicing
with open(w2v2_path_audio, "rb") as f:
    df_w2v2_audio = pickle.load(f)
with open(fast_vgs_plus_path_audio, "rb") as f:
    df_fast_vgs_audio = pickle.load(f)

#feature slicing
with open(w2v2_path_feat, "rb") as f:
    df_w2v2_feat = pickle.load(f)
with open(fast_vgs_plus_path_feat, "rb") as f:
    df_fast_vgs_feat = pickle.load(f)

  w2v2 = pickle.load(f)
  fast_vgs_plus = pickle.load(f)


<h1> Create the wordmaps </h1>

In [6]:
# code from Choi et al. (2024): https://github.com/juice500ml/phonetic_semantic_probing/tree/24b85b648c6512d9fe4df4139c546482080fef4c 
for seed in SEEDS:
    set_seed(seed)
    print(f"Seed: {seed}")

    # sample indices (does not matter which of the four dataframes you use here, as 
    # they all contain the same words. Words matter here, embeddings will come into play later)
    indices = random.sample(range(len(df_w2v2_audio)), N_SAMPLES)

    # filter dataframes
    df_filtered_w2v2 = df_w2v2_audio.iloc[indices]

    # get words
    words = set(df_filtered_w2v2.text.unique())

    # get phones for words
    text2phones = {row.text: tuple(row.phones) for row in df_filtered_w2v2.itertuples()}

    # get synonym maps
    synonym_map = get_synonym_map(words, df_filtered_w2v2, text2phones)
    not_filtered_synonym_map = get_synonym_map(words, df_filtered_w2v2, text2phones, threshold=-1)

    # get homophone maps
    homophone_map = get_homophone_map(words, not_filtered_synonym_map, df_filtered_w2v2, text2phones)

    with open(f'{data_path}wordmap_{seed}.pkl', "wb") as f:
        pickle.dump({
            "indices": indices,
            "synonym_map": synonym_map,
            "homophone_map": homophone_map,
        }, f)

Seed: 42


100%|██████████| 3326/3326 [00:01<00:00, 1764.15it/s]
100%|██████████| 3326/3326 [00:00<00:00, 10477.64it/s]
Finding homophones: 100%|██████████| 3326/3326 [29:30<00:00,  1.88it/s]


Seed: 12


100%|██████████| 3352/3352 [00:00<00:00, 3765.73it/s]
100%|██████████| 3352/3352 [00:00<00:00, 10916.07it/s]
Finding homophones: 100%|██████████| 3352/3352 [30:06<00:00,  1.86it/s]


Seed: 25


100%|██████████| 3335/3335 [00:00<00:00, 3719.11it/s]
100%|██████████| 3335/3335 [00:00<00:00, 10663.12it/s]
Finding homophones: 100%|██████████| 3335/3335 [29:32<00:00,  1.88it/s]


Seed: 31


100%|██████████| 3354/3354 [00:00<00:00, 3882.21it/s]
100%|██████████| 3354/3354 [00:00<00:00, 10832.01it/s]
Finding homophones: 100%|██████████| 3354/3354 [30:07<00:00,  1.86it/s]


Seed: 69


100%|██████████| 3305/3305 [00:00<00:00, 3784.82it/s]
100%|██████████| 3305/3305 [00:00<00:00, 10722.07it/s]
Finding homophones: 100%|██████████| 3305/3305 [28:47<00:00,  1.91it/s]


In [None]:
with open('exp1/semantic_categories.json', 'r') as f:
    sem_categories = json.load(f)

In [None]:
for slicing in ['audioslice', 'featslice']:
    if slicing == 'audioslice':
        df_w2v2 = df_w2v2_audio
        df_fast_vgs_plus = df_fast_vgs_audio
    else:
        df_w2v2 = df_w2v2_feat
        df_fast_vgs_plus = df_fast_vgs_feat
    cos_sims_w2v2 = []
    cos_sims_fast_vgs_plus = []
    words = []
    for layer_idx in range(12):
        word_pairs = set()
        print(f'########### Layer {layer_idx} ###########')
        layer = f'layer_{layer_idx}'
        layer_dists_w2v2 = []
        layer_dists_fast_vgs_plus = []
        for cat, data in sem_categories.items():
            speakers_counter = defaultdict(int)
            print(f'########### Category {cat} ###########')
            for word_i in data['words']:
                for word_j in data['words']:
                    if word_i == word_j or tuple(sorted((word_i, word_j))) in word_pairs:
                        continue
                    word_pairs.add(tuple(sorted((word_i, word_j))))
                    sub_df_w2v2_i = df_w2v2[df_w2v2['text'] == word_i]
                    sub_df_fast_vgs_plus_i = df_fast_vgs_plus[df_fast_vgs_plus['text'] == word_i]

                    sub_df_w2v2_j = df_w2v2[df_w2v2['text'] == word_j]
                    sub_df_fast_vgs_plus_j = df_fast_vgs_plus[df_fast_vgs_plus['text'] == word_j]

                    if len(sub_df_w2v2_i) == 0 or len(sub_df_w2v2_j) == 0:
                        continue
                    if word_i not in words:
                        words.append(word_i)
                    for idx_i, row_i in sub_df_w2v2_i.iterrows():
                        for idx_j, row_j in sub_df_w2v2_j.iterrows():
                            layer_dists_w2v2.append(cos_sim(row_i['w2v2_embeddings'][layer], row_j['w2v2_embeddings'][layer]))

                    for idx_i, row_i in sub_df_fast_vgs_plus_i.iterrows():
                        for idx_j, row_j in sub_df_fast_vgs_plus_j.iterrows():
                            layer_dists_fast_vgs_plus.append(cos_sim(row_i['fast_vgs_plus_embeddings'][layer], row_j['fast_vgs_plus_embeddings'][layer]))
        
        cos_sims_w2v2.append(np.mean(layer_dists_w2v2))
        cos_sims_fast_vgs_plus.append(np.mean(layer_dists_fast_vgs_plus))
    with open(f'{data_path}sem_cats_dists_{slicing}.pkl', "wb") as f:
        pickle.dump({
            "w2v2": cos_sims_w2v2,
            "fast_vgs_plus": cos_sims_fast_vgs_plus,
        }, f)
    

<h1>Analysis</h1>

In [None]:
# code based on Choi et al. (2024): https://github.com/juice500ml/phonetic_semantic_probing/tree/24b85b648c6512d9fe4df4139c546482080fef4c 

for slicing in ['audioslice', 'featslice']:
    seedwise_dists_w2v2 = []   # list of dicts (key is sampler, values are the 12 layers (mean distance))
    seedwise_dists_fast_vgs_plus = []
    for seed in SEEDS:
        dists_path_w2v2 = Path(f'{data_path}/dists/w2v2_{slicing}_seed-{seed}.pkl')
        dists_path_fast_vgs_plus = Path(f'{data_path}dists/fast_vgs_plus_{slicing}_seed-{seed}.pkl')
        if dists_path_w2v2.exists() and dists_path_fast_vgs_plus.exists() and LOAD_DISTS:
            dists_w2v2 = pickle.load(open(dists_path_w2v2, 'rb'))
            dists_fast_vgs_plus = pickle.load(open(dists_path_fast_vgs_plus, 'rb'))
        else:

            wordmap = pickle.load(open(f'exp2/wordmap_{seed}.pkl', 'rb'))

            if slicing == 'audioslice':
                df_w2v2 = df_w2v2_audio
                df_fast_vgs_plus = df_fast_vgs_audio
            else:
                df_w2v2 = df_w2v2_feat
                df_fast_vgs_plus = df_fast_vgs_feat

            df_w2v2_filtered = df_w2v2.iloc[wordmap['indices']]
            df_fast_vgs_plus_filtered = df_fast_vgs_plus.iloc[wordmap['indices']]

            dists_w2v2 = defaultdict(list)
            dists_fast_vgs_plus = defaultdict(list)
            for layer_idx in tqdm(range(12)):
                layer = f'layer_{layer_idx}'
                for name, sampler in samplers.items():
                    accumulator_w2v2 = defaultdict(list)
                    accumulator_fast_vgs_plus = defaultdict(list)
                    for l, r in tqdm(sampler(df_w2v2_filtered, wordmap), desc=f'{name} {layer}'):
                            accumulator_w2v2[(df_w2v2_filtered.loc[l].text, df_w2v2_filtered.loc[r].text)].append(cos_sim(df_w2v2_filtered.loc[l].w2v2_embeddings[layer], df_w2v2_filtered.loc[r].w2v2_embeddings[layer]))
                            accumulator_fast_vgs_plus[(df_fast_vgs_plus_filtered.loc[l].text, df_fast_vgs_plus_filtered.loc[r].text)].append(cos_sim(df_fast_vgs_plus_filtered.loc[l].fast_vgs_plus_embeddings[layer], df_fast_vgs_plus_filtered.loc[r].fast_vgs_plus_embeddings[layer]))
                    dists_w2v2[name].append(np.array([np.array(v).mean() for v in accumulator_w2v2.values()]))
                    dists_fast_vgs_plus[name].append(np.array([np.array(v).mean() for v in accumulator_fast_vgs_plus.values()]))
            pickle.dump(dists_w2v2, open(dists_path_w2v2, 'wb'))
            pickle.dump(dists_fast_vgs_plus, open(dists_path_fast_vgs_plus, 'wb'))


<h1>Plots</h1>

<h4>Audio slicing</h4>

In [None]:
pairs = {
    "w2v2_fast_vgs_plus": {
        "left": "w2v2_seed-x.pkl",
        "right": "fast_vgs_plus_seed-x.pkl",
        "seeds": [42, 12, 25, 31, 69],
        "speakers": ("everyone", ),
        "normalizer": ("subtract", "subtract"),
        "legend": {"loc": "center", "bbox_to_anchor": (-0.2, 1.6), "ncols": 3},
        "title": ("wav2vec 2.0-base (Norm.)", "FaST-VGS+ (Norm.)"),
    }
}

keys_with_sem = {
    "random": ("C0", "", "Random"),
    "synonym": ("C1", "x", "Synonym"),
    "homophone": ("C2", "o", "Near homophone"),
    "speaker": ("C3", "|", "Same speaker"),
    "same_word": ("C4", "^", "Same word"),
    "semantic_categories": ("C5", "s", "Semantic categories"),
}

In [None]:
# code based on Choi et al. (2024): https://github.com/juice500ml/phonetic_semantic_probing/tree/24b85b648c6512d9fe4df4139c546482080fef4c 

for pair_key, meta in pairs.items():

    with open( f'{data_path}sem_cats_dists_audioslice.pkl', "rb") as f:
        sem_cats_dists = pickle.load(f)

    fig, axes = plt.subplots(1, 2, figsize=(8, 3), sharey=True)
    for loc, ax, normalizer, title in zip(("left", "right"), axes, meta["normalizer"], meta["title"]):
        seedwise_dists = []
        for seed_idx, speaker in product(range(len(meta["seeds"])), meta["speakers"]):
            seed = meta['seeds'][seed_idx]
            dists_path = Path(f'{data_path}dists') / meta[loc].replace("seed-x", f"seed-{seed}")
            dists = pickle.load(open(dists_path, "rb"))
            seedwise_dists.append({
                k: [mean_confidence_interval(v)[0] for v in vs]
                for k, vs in dists.items()
            })
        agg_dists = {}
        for key in seedwise_dists[0].keys():
            agg_dists[key] = []
            for layer in range(len(seedwise_dists[0]["random"])):
                vs = [seedwise_dists[i][key][layer] for i in range(len(seedwise_dists))]
                agg_dists[key].append(mean_confidence_interval(vs))
        if loc == 'left':
            agg_dists['semantic_categories'] = sem_cats_dists['w2v2']
        else:
            agg_dists['semantic_categories'] = sem_cats_dists['fast_vgs_plus']
        
        for key, tuples in agg_dists.items():
            if key == 'semantic_categories':
                value = tuples
                if normalizer == "subtract":
                    value -= np.array([t[0] for t in agg_dists["random"]])
                color, marker, label = keys_with_sem[key]
                style = "dotted" if (key == "random" and normalizer == "subtract") else "solid"
                ax.plot(value, label=label, marker=marker, color=color, linestyle=style)
            else:
                value = np.array([t[0] for t in tuples])
                bound = np.array([t[1] for t in tuples])
                if normalizer == "subtract":
                    value -= np.array([t[0] for t in agg_dists["random"]])
                color, marker, label = keys_with_sem[key]
                style = "dotted" if (key == "random" and normalizer == "subtract") else "solid"
                ax.plot(value, label=label, marker=marker, color=color, linestyle=style)
                ax.fill_between(np.arange(len(value)), value-bound, value+bound, alpha=0.2)

        ax.set_yticks([0.0, 0.2, 0.4])
        ax.yaxis.set_tick_params(labelbottom=True)
        ax.set_xticks([0, 5, 10])
        ax.set_title(title)
        ax.set_ylabel("Norm. Cos. Sim." if normalizer == "subtract" else "Cos. sim.")
        ax.set_xlabel("Layer Index")
        if title in ("HuBERT-large (Norm.)", "Audio slicing (Norm.)", "Center pooling (Norm.)", "Centroid pooling (Norm.)"):
            ax.set_ylim(-0.02, 0.3)
    plt.tight_layout()
    handles, labels = ax.get_legend_handles_labels()

    label_order = ['Random', 'Near homophone', 'Same word',
                     'Synonym', 'Same speaker', 'Semantic categories']

    label_to_handle = {label: handle for label, handle in zip(labels, handles)}

    sorted_handles = [label_to_handle[label] for label in label_order if label in label_to_handle]
    sorted_labels = [label for label in label_order if label in label_to_handle]

    plt.legend(sorted_handles, sorted_labels, **meta["legend"])

    plt.suptitle(f'Audio slicing', y=1.025, x=0.55, fontsize=18, fontstyle='italic')
    plt.savefig(f"{figure_path}{pair_key}.pdf", bbox_inches="tight")
    plt.show()

<h4>Feature slicing</h4>

In [None]:
pairs = {
    "w2v2_fast_vgs_plus_feature_slice": {
        "left": "w2v2_featslice_seed-x.pkl",
        "right": "fast_vgs_plus_featslice_seed-x.pkl",
        "seeds": [42, 12, 25, 31, 69],
        "speakers": ("everyone", ),
        "normalizer": ("subtract", "subtract"),
        "legend": {"loc": "center", "bbox_to_anchor": (-0.2, 1.4), "ncols": 3},
        "title": ("wav2vec 2.0-base (Norm.)", "FaST-VGS+ (Norm.)"),
    }
}

keys_with_sem = {
    "random": ("C0", "", "Random"),
    "synonym": ("C1", "x", "Synonym"),
    "homophone": ("C2", "o", "Near homophone"),
    "speaker": ("C3", "|", "Same speaker"),
    "same_word": ("C4", "^", "Same word"),
    "semantic_categories": ("C5", "s", "Semantic categories"),
}

In [None]:
# code based on Choi et al. (2024): https://github.com/juice500ml/phonetic_semantic_probing/tree/24b85b648c6512d9fe4df4139c546482080fef4c 

for pair_key, meta in pairs.items():

    with open( f'{data_path}sem_cats_dists_featslice.pkl', "rb") as f:
        sem_cats_dists = pickle.load(f)

    fig, axes = plt.subplots(1, 2, figsize=(8, 3), sharey=True)
    for loc, ax, normalizer, title in zip(("left", "right"), axes, meta["normalizer"], meta["title"]):
        seedwise_dists = []
        for seed_idx, speaker in product(range(len(meta["seeds"])), meta["speakers"]):
            seed = meta['seeds'][seed_idx]
            dists_path = Path('exp2/dists') / meta[loc].replace("seed-x", f"seed-{seed}")
            dists = pickle.load(open(dists_path, "rb"))
            seedwise_dists.append({
                k: [np.nanmean(v) for v in vs]
                for k, vs in dists.items()
            })
        agg_dists = {}
        for key in seedwise_dists[0].keys():
            agg_dists[key] = []
            for layer in range(len(seedwise_dists[0]["random"])):
                vs = [seedwise_dists[i][key][layer] for i in range(len(seedwise_dists))]
                agg_dists[key].append(mean_confidence_interval(vs))
        #print(agg_dists)
        if loc == 'left':
            agg_dists['semantic_categories'] = sem_cats_dists['w2v2']

        else:
            agg_dists['semantic_categories'] = sem_cats_dists['fast_vgs_plus']
        
        for key, tuples in agg_dists.items():
            if key == 'semantic_categories':
                value = tuples
                if normalizer == "subtract":
                    value -= np.array([t[0] for t in agg_dists["random"]])
                color, marker, label = keys_with_sem[key]
                style = "dotted" if (key == "random" and normalizer == "subtract") else "solid"
                ax.plot(value, label=label, marker=marker, color=color, linestyle=style)
            else:
                value = np.array([t[0] for t in tuples])
                bound = np.array([t[1] for t in tuples])
                if normalizer == "subtract":
                    value -= np.array([t[0] for t in agg_dists["random"]])
                color, marker, label = keys_with_sem[key]
                style = "dotted" if (key == "random" and normalizer == "subtract") else "solid"
                ax.plot(value, label=label, marker=marker, color=color, linestyle=style)
                ax.fill_between(np.arange(len(value)), value-bound, value+bound, alpha=0.2)

        ax.set_yticks([0.0, 0.015, 0.03, 0.045])
        ax.yaxis.set_tick_params(labelbottom=True)
        ax.set_xticks([0, 5, 10])
        ax.set_title(title)
        ax.set_ylabel("Norm. Cos. Sim." if normalizer == "subtract" else "Cos. sim.")
        ax.set_xlabel("Layer Index")
        if title in ("HuBERT-large (Norm.)", "Audio slicing (Norm.)", "Center pooling (Norm.)", "Centroid pooling (Norm.)"):
            ax.set_ylim(-0.02, 0.3)
    plt.tight_layout()

    ###
    handles, labels = ax.get_legend_handles_labels()

    label_order = ['Random', 'Near homophone', 'Same word',
                     'Synonym', 'Same speaker', 'Semantic categories']

    label_to_handle = {label: handle for label, handle in zip(labels, handles)}

    sorted_handles = [label_to_handle[label] for label in label_order if label in label_to_handle]
    sorted_labels = [label for label in label_order if label in label_to_handle]

    plt.suptitle(f'Feature slicing', y=1.025, x=0.55, fontsize=18, fontstyle='italic')
    plt.savefig(f"{figure_path}{pair_key}.pdf", bbox_inches="tight")
    plt.show()