In [2]:
import pandas as pd
from encode import get_units, sample_files
from pathlib import Path

align_dir = Path("data/alignments/dev-clean/")
align_path = align_dir / "alignments.csv"
audio_dir = Path("data/dev-clean")
audio_ext = ".flac"

align_df = pd.read_csv(align_path)
gamma = 0.1
layer = 7
save_dir = Path("features/")

align_df = pd.read_csv(align_path)

paths, sample_size = sample_files(
    audio_dir=audio_dir, audio_ext=audio_ext, sample_size=-1
)

print(f"Sample size: {sample_size}")
if not Path(save_dir / str(gamma)).exists():
    get_units(paths, align_df, audio_dir, gamma, layer, save_dir)

Sample size: 2703


In [3]:
from dist import get_features, get_batch_of_paths, cal_dist_per_pair
from tqdm import tqdm
import numpy as np

out_dir = Path(f"ladies/output/{gamma}/temp/")
out_dir.mkdir(parents=True, exist_ok=True)
preloaded = True
chunk_limit = 500000


def process_batch(batch, features):
    """Parallelized function to calculate distance for each (i, j) pair."""
    return [cal_dist_per_pair(((i, j), (features[i], features[j]))) for i, j in batch]


if not preloaded:
    paths = (p for p in Path(f"ladies/features/{gamma}").rglob("**/*.npy"))
    sorted_paths = sorted(paths, key=lambda x: int(x.stem.split("_")[-1]))
    sample_size = len(sorted_paths)

    features = get_features(sorted_paths)

    rows, cols, vals = [], [], []

    num_pairs = sample_size * (sample_size - 1) // 2
    num_batches = (num_pairs + chunk_limit - 1) // chunk_limit

    print(f"num_samples: {sample_size}")
    print(f"num_pairs: {num_pairs}")

    chunk_idx = 0
    # Parallel execution
    for batch in tqdm(
        get_batch_of_paths(sample_size, chunk_limit),
        total=num_batches,
        unit="batch",
        mininterval=10.0,
        desc="Processing Batches",
    ):
        for i, j in batch:
            i, j, dist = cal_dist_per_pair(((i, j), (features[i], features[j])))
            rows.append(i)
            cols.append(j)
            vals.append(dist)

        np.save(out_dir / f"temp_rows_{chunk_idx}.npy", rows)
        np.save(out_dir / f"temp_cols_{chunk_idx}.npy", cols)
        np.save(out_dir / f"temp_vals_{chunk_idx}.npy", vals)

        rows, cols, vals = [], [], []
        chunk_idx += 1

In [5]:
from cluster import build_graph_from_temp, adaptive_res_search
import pickle

use_preloaded_graph = False
num_clusters = 13967
temp_dir = Path(f"output/{gamma}/temp")
temp_dir.mkdir(parents=True, exist_ok=True)  # Ensure the directory exists
num_chunks = 399

graph_path = Path(f"output/{gamma}/graph.pkl")

if use_preloaded_graph and graph_path.exists():
    with open(graph_path, "rb") as f:
        g = pickle.load(f)
    print(f"Loaded precomputed graph from {graph_path}")
else:
    g = build_graph_from_temp(temp_dir, num_chunks)
    g.write_pickle(str(graph_path))
    print(f"Graph built and saved to {graph_path}")


Calculating total: 100%|██████████| 399/399 [01:01<00:00,  6.47it/s]


total_size: 1993108816, sample_size: 63137


Getting Temp Info: 100%|██████████| 399/399 [04:20<00:00,  1.53it/s]


Graph built and saved to output/0.1/graph.pkl
Best resolution found: 0.027 with cluster difference: 58


In [6]:
partition_pattern = Path(f"output/{gamma}").glob("partition_r*.csv")
partition_files = list(partition_pattern)

if not partition_files:
    # No existing partitions found, run the search
    best_res, best_partition = adaptive_res_search(g, num_clusters)

    # Convert best_partition to a DataFrame
    best_partition_df = pd.DataFrame(
        {
            "node": range(len(best_partition.membership)),  # Node IDs
            "cluster": best_partition.membership,  # Cluster assignments
        }
    )

    # Save to CSV
    best_partition_df.to_csv(
        f"output/{gamma}/partition_r{round(best_res, 3)}.csv", index=False
    )
else:
    # Load existing partitions
    res_partitions = [
        (float(p.stem.split("_r")[1]), pd.read_csv(p)) for p in partition_files
    ]

    # Find the partition with the minimum resolution
    best_res, best_partition_df = min(res_partitions, key=lambda x: x[0])

# Ensure best_partition_df is used for further processing
actual_clusters = len(set(best_partition_df["cluster"]))
diff = abs(actual_clusters - num_clusters)

print(f"Best resolution found: {best_res:.3f} with cluster difference: {diff}")

Best resolution found: 0.027 with cluster difference: 58


In [63]:
def get_phones_and_texts(gamma, align_dir):
    cache_path = Path(f"features/{gamma}/texts_and_phones.csv")

    if cache_path.exists():
        df = pd.read_csv(cache_path)
        texts = df["text"].tolist()
        phones = df["phones"].apply(lambda x: tuple(x.split(",")))
        print(f"Loaded texts from {cache_path}")
        return phones, texts

    paths = sorted(
        Path(f"features/{gamma}").rglob("**/*.npy"),
        key=lambda x: int(x.stem.split("_")[-1]),
    )
    align_df = pd.read_csv(align_dir / "alignments.csv")

    texts = []
    phones = []

    for path in tqdm(paths, desc="Appending Text and Phones"):
        filename_parts = path.stem.split("_")
        wav_df = align_df[align_df["filename"] == filename_parts[0]]
        word_df = wav_df[wav_df["word_id"] == int(filename_parts[1])]
        texts.append(str(word_df["text"].iloc[0]))
        word_phones = word_df["phones"].iloc[0].split(",")
        word_phones = " ".join(word_phones)
        phones.append(word_phones)

    df = pd.DataFrame({"text": texts, "phones": phones})
    df.to_csv(cache_path, index=False)
    print(f"Saved texts to {cache_path}")

    return df["phones"].apply(lambda x: tuple(x.split(","))), df["text"].tolist()


In [64]:
phones, texts = get_phones_and_texts(gamma, align_dir)
print(phones)

Loaded texts from features/0.1/texts_and_phones.csv
0                             (sil,)
1                             (sil,)
2                             (sil,)
3                             (sil,)
4                             (sil,)
                    ...             
63132                     (K AH1 M,)
63133    (S P IH1 R IH0 CH W AH0 L,)
63134                 (AH0 G EH1 N,)
63135                          (sp,)
63136     (S UW0 P IH1 R IY0 ER0 Z,)
Name: phones, Length: 63137, dtype: object


In [65]:
import re

tuple_phones = []
for id, word_phones in enumerate(phones):
    word_phones_tuple = tuple(word_phones[0].split(" "))
    word_phones_tuple = tuple(
        re.sub(r"[012]", "", phn)
        for phn in word_phones_tuple
        if phn != "sil" and phn != "sp"
    )
    text = texts[id]
    tuple_phones.append((id, word_phones_tuple, text))

del phones, texts

In [66]:
for i, phn in enumerate(tuple_phones):
    if i > 10000:
        print(f"{'-'.join(phn[1])} [{phn[2]}]")
    if i > 10010:
        break

D-IH-S-P-L-EY [display]
T-AH [to]
T-EY-K [take]
D-AO-G [dog]
AH-N [an]
AH-V [of]
DH-IY [the]
F-ER [for]
DH-AH-T [that]
AH [a]
DH-AH [the]


In [89]:
import editdistance
import statistics
import itertools


def distance(p, q):
    """Compute normalized edit distance between two strings."""

    length = max(len(p), len(q))
    if length <= 0:
        return 0.0

    length = max(len(p), len(q))
    return editdistance.eval(p, q) / length  # Avoid division by zero


def ned(clusters, num_clusters):
    """Compute the normalized edit distance (NED) within each cluster."""
    if not clusters:
        return 0

    clusters = sorted(clusters, key=lambda x: x[0])

    distances = []
    distances_per_cluster = []
    for idx, group in tqdm(
        itertools.groupby(clusters, key=lambda x: x[0]),
        total=num_clusters,
        desc="Clustering",
    ):
        group_list = list(group)

        if len(group_list) < 2:
            continue
        clust_distances = []
        for p, q in itertools.combinations(group_list, 2):
            d = distance(p[1], q[1])
            distances.append(d)
            clust_distances.append(d)
        distances_per_cluster.append(
            (idx, group_list, statistics.mean(clust_distances))
        )

    return statistics.mean(distances), distances_per_cluster

In [67]:
cluster_tuples = []
seen_nodes = set()  # To track nodes we've already added

for node_id, cluster in tqdm(
    zip(best_partition_df["node"], best_partition_df["cluster"]),
    total=len(best_partition_df["node"]),
    desc="Creating Clusters",
):
    for node, phone, word in tuple_phones:
        if node_id == node and node_id not in seen_nodes:
            cluster_tuples.append((cluster, phone, word))
            seen_nodes.add(node_id)  # Mark this node as added
            break  # Exit loop early once node is matched

Creating Clusters: 100%|██████████| 63137/63137 [02:47<00:00, 377.00it/s] 


In [90]:
ned_val, dist_per_cluster = ned(cluster_tuples, num_clusters - diff)
print(f"NED: {ned_val}")

Clustering: 100%|██████████| 13909/13909 [00:05<00:00, 2331.82it/s]


NED: 0.1078547757852461


In [107]:
from collections import Counter

cluster_counters = {}  # Dictionary to store counters per cluster

for cluster_id, group_list, dist in dist_per_cluster:
    words_phones = [("-".join(phn), wrd) for _, phn, wrd in group_list]
    cluster_counters[cluster_id] = Counter(words_phones)  # Count per cluster

# Print results per cluster
for cluster_id, counter in cluster_counters.items():
    print(
        f"{'-' * 50}\nCluster {cluster_id}: {dist_per_cluster[cluster_id][2]}\n{'-' * 50}"
    )
    for (phoneme, word), count in sorted(counter.items(), key=lambda x: -x[1]):
        print(f"{phoneme:8} [{word:5}] -> {count} times")


--------------------------------------------------
Cluster 0: 0.07360234466157677
--------------------------------------------------
DH-AH    [the  ] -> 2358 times
DH-IY    [the  ] -> 138 times
AH       [a    ] -> 35 times
DH-AH-T  [that ] -> 11 times
DH-AH-M  [them ] -> 4 times
DH-IH-S  [this ] -> 2 times
         [  nan] -> 2 times
DH-OW    [though] -> 1 times
EY       [a    ] -> 1 times
DH-EH-R  [there] -> 1 times
DH-EH-R  [their] -> 1 times
DH-AH-N  [than ] -> 1 times
DH-EY    [they ] -> 1 times
--------------------------------------------------
Cluster 1: 0.0
--------------------------------------------------
         [  nan] -> 1554 times
--------------------------------------------------
Cluster 2: 0.037255411134966526
--------------------------------------------------
AH-V     [of   ] -> 1107 times
HH-AE-V  [have ] -> 17 times
AH       [a    ] -> 7 times
IH-F     [if   ] -> 4 times
EY       [a    ] -> 1 times
OW-V-ER  [over ] -> 1 times
         [  nan] -> 1 times
-------------