Play slices of the waveforms:

In [1]:
import torchaudio
from IPython.display import Audio

waveform, _ = torchaudio.load("data/librispeech_subset/84-121123-0001.wav")
sample_rate=16000

start_time=1.1
end_time=1.46

start_frame = int(start_time * sample_rate)
end_frame = int(end_time * sample_rate)

waveform_slice = waveform[:, start_frame:end_frame]

Audio(waveform_slice.squeeze(), rate=16000)

Get the words & their indices for a .textgrid file from Librispeech

In [None]:
from pathlib import Path
import textgrids

    

input_dir = Path("data/all_textgrid")
files = list(input_dir.rglob("**/*.TextGrid"))
output_path = "data/words_and_indices.txt"

with open(output_path, "w", encoding="utf-8") as f:
    for file in files:
        grid = textgrids.TextGrid(file)

        words_tier = grid["words"]
        file_name = file.stem
        
        f.write(f"{file_name}:\n")
        for idx, interval in enumerate(words_tier, start=1):
            if idx == 1:
                continue
            f.write(f"{idx-2}: {interval.text}\n")
        f.write("\n")


2428-83699-0032
2428-83699-0002
2428-83699-0030
2428-83699-0025
2428-83699-0006
2428-83699-0001
2428-83699-0042
2428-83699-0035
2428-83699-0003
2428-83699-0015
2428-83699-0017
2428-83699-0034
2428-83699-0040
2428-83699-0016
2428-83699-0028
2428-83699-0031
2428-83699-0009
2428-83699-0039
2428-83699-0036
2428-83699-0019
2428-83699-0005
2428-83699-0011
2428-83699-0000
2428-83699-0010
2428-83699-0027
2428-83699-0013
2428-83699-0012
2428-83699-0038
2428-83699-0007
2428-83699-0008
2428-83699-0022
2428-83699-0041
2428-83699-0004
2428-83699-0018
2428-83699-0014
2428-83699-0023
2428-83699-0020
2428-83699-0024
2428-83699-0021
2428-83699-0037
2428-83699-0026
2428-83699-0029
2428-83699-0033
2428-83705-0026
2428-83705-0038
2428-83705-0002
2428-83705-0003
2428-83705-0011
2428-83705-0008
2428-83705-0018
2428-83705-0027
2428-83705-0001
2428-83705-0012
2428-83705-0035
2428-83705-0031
2428-83705-0024
2428-83705-0004
2428-83705-0013
2428-83705-0009
2428-83705-0029
2428-83705-0041
2428-83705-0023
2428-837

Fixing up the query code (basically getting batch processing ready)

In [None]:
import torch
from pathlib import Path
from tqdm import tqdm
from encode import encode
import numpy as np
from dtw import get_frame_num, compute_distance
from cluster import cluster
from eval import parse_text_to_dict
from utils import Cluster

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

indices_dict = parse_text_to_dict("data/words_and_indices.txt")
cluster_files = ['output/dtw/clusters/wavlm_base_8_d0.6.txt','output/dtw/clusters/wavlm_base_8_d0.55.txt','output/dtw/clusters/wavlm_base_8_d0.5.txt','output/dtw/clusters/wavlm_base_8_d0.7.txt']
query_dir = Path("data/query_set")
align_dir = Path("data/all_alignments")
encoding_dir = Path("encodings/librispeech_subset/")
output_dir = Path("output/dtw/")


for k, cluster_file in enumerate(cluster_files):
    cluster_file = Path(cluster_file)
    file_name = cluster_file.stem
    parts = file_name.split("_")

    model_name = parts[0] + "_" + parts[1]
    layer_num = int(parts[2])
    distance = float(f"0.{parts[3].split(".")[1]}")
    print(f"{model_name}, {layer_num}, {distance}")

    encodings_dict = encode(query_dir, None, model_name, layer_num, "wav")

    features = []
    file_names = {}
    index = 0
    for file in encodings_dict.keys():
        alignment_file = [a for a in list(align_dir.rglob("**/*.list")) if a.stem == file]
        if not alignment_file:
            continue
        else:
            alignment_file = alignment_file[0]

        with open(str(alignment_file), "r") as f:
            boundaries = [get_frame_num(float(line.strip()), 16000, 20) for line in f]
        
        new_encodings = torch.from_numpy(encodings_dict[file]).to(device)

        if len(new_encodings.shape) == 1:
            new_encodings = new_encodings.unsqueeze(0)

        for i in range(0,len(boundaries)-1):
            new_feature = new_encodings[boundaries[i]:boundaries[i+1], :]
            features.append(new_feature)
            file_names[index] = f"{file}_{i}"
            index += 1
        
    normalized_features = []
    for feature in features:
        norm_feature = torch.nn.functional.normalize(feature, p=2, dim=1)
        normalized_features.append(norm_feature) 
        
    old_features = []
    encoding_path = encoding_dir / model_name / str(layer_num)
    encoding_files = list(encoding_path.rglob("*.npy"))

    for file in encoding_files:
        alignment_file = [a for a in list(align_dir.rglob("*.list")) if a.stem == file.stem]

        if not alignment_file:
            continue
        else:
            alignment_file = alignment_file[0]
                    
        with open(str(alignment_file), "r") as f:
            boundaries = [get_frame_num(float(line.strip()), 16000, 20) for line in f]
            
        encodings = torch.from_numpy(np.load(file)).to(device)

        if len(encodings.shape) == 1:
            encodings = encodings.unsqueeze(0)

        for i in range(0,len(boundaries)-1):
            new_feature = encodings[boundaries[i]:boundaries[i+1], :]
            old_features.append(new_feature)
            file_names[index] = f"{file.stem}_{i}"
            index += 1

    normalized_old_features = []
    for feature in old_features:
        norm_feature = torch.nn.functional.normalize(feature, p=2, dim=1)
        normalized_old_features.append(norm_feature) 
    
    output_file = output_dir / model_name / str(layer_num) / "norm_distance_matrix.npy"
    norm_dist_mat = np.load(str(output_file))

    extra_num_features = len(normalized_features)
    num_features = len(normalized_old_features)
    
    new_norm_dist_mat = np.pad(norm_dist_mat, pad_width=(0,extra_num_features), constant_values=0)
    all_norm_features = normalized_old_features + normalized_features

    for i in tqdm(range(num_features +1, num_features + extra_num_features), "Calculating Distances"):
        for j in range(0, i):
            distance, norm_distance = compute_distance(i, j, all_norm_features)
            new_norm_dist_mat[j, i] = norm_distance


wavlm_base, 8, 0.6


Encoding Audio Features: 100%|██████████| 1/1 [00:00<00:00,  2.67it/s]
Calculating Distances: 100%|██████████| 51/51 [08:40<00:00, 10.21s/it]


In [23]:
distance = 10
norm_dist_file = f"output/dtw/{model_name}/{layer_num}/norm_dist_mat_d{distance}.npy"
np.save(norm_dist_file, new_norm_dist_mat)


[[   0.         1388.03462564 2014.05046154 ... 2121.15911111
  1350.16543256 1552.08610909]
 [1388.03462564    0.         1939.898368   ... 2042.18683077
  1558.60114286 1373.96509767]
 [2014.05046154 1939.898368      0.         ... 1454.59987692
  2239.05015172 2069.42208   ]
 ...
 [2121.15911111 2042.18683077 1454.59987692 ...    0.
  2150.44778667 2119.59411613]
 [1350.16543256 1558.60114286 2239.05015172 ... 2150.44778667
     0.         1455.93627234]
 [1552.08610909 1373.96509767 2069.42208    ... 2119.59411613
  1455.93627234    0.        ]]


In [50]:
new_norm_dist_mat = np.load(f"output/dtw/{model_name}/{layer_num}/norm_dist_mat.npy")
print(new_norm_dist_mat)

[[0.         0.67775128 0.98342308 ... 1.03572222 0.65926047 0.75785455]
 [0.67775128 0.         0.947216   ... 0.99716154 0.76103571 0.6708814 ]
 [0.98342308 0.947216   0.         ... 0.71025385 1.09328621 1.01046   ]
 ...
 [1.03572222 0.99716154 0.71025385 ... 0.         1.05002333 1.03495806]
 [0.65926047 0.76103571 1.09328621 ... 1.05002333 0.         0.71090638]
 [0.75785455 0.6708814  1.01046    ... 1.03495806 0.71090638 0.        ]]


In [51]:
distance = 0.77
clusters = cluster(new_norm_dist_mat, file_names, model_name, layer_num, distance)

print(len(clusters))
appended_clusters = []
for i, clust in enumerate(clusters):
    new_cluster = Cluster(i)
    for j in range(len(clust)):
        filename = file_names[clust[j]]
        wordunit_id = j
        file_parts = filename.split("_")
        file_name = file_parts[0]
        index = int(file_parts[1])  

        new_cluster.add_word_unit(wordunit_id, index, file_name)
    
    for word_unit in new_cluster.word_dict:
        word = indices_dict[word_unit.file][word_unit.index]
        new_cluster.add_true_word(word)

        # if word_unit.file in encodings_dict.keys():
        # print(f"new word '{word}' added to cluster {new_cluster.id}")

    appended_clusters.append(new_cluster.true_word_dict)
    new_cluster.cluster_purity()

    if len(new_cluster.word_dict) > 1:
        print(f"purity : {new_cluster.purity*100}%")
        print(new_cluster.true_word_dict)

48
purity : 12.76595744680851%
['it', 'groaned', 'seems', 'to', 'me', 'more', 'and', 'more', 'as', 'i', 'live', 'longer', ' ', 'that', 'most', 'poetry', 'and', ' ', 'most', 'literature', ' ', 'and', 'particularly', 'the', 'literature', 'of', 'the', 'past', ' ', 'is', 'discordant', 'with', 'the', 'vastness', 'and', 'variety', ' ', 'the', 'reserves', 'and', 'resources', 'and', '<unk>', 'of', 'life', ' ', 'as', 'we', 'live', 'it', 'to', 'day', ' ', 'go', ' ', 'do', 'you', 'hear', ' ', 'worse', 'and', 'worse', ' ', 'he', 'is', ' ', 'even', 'presumed', 'to', 'be', 'the', "captive's", 'sweetheart', ' ', 'who', 'wheedles', 'the', 'flower', 'the', 'ring', 'and', 'the', 'prison', 'key', 'out', 'of', 'the', 'strict', 'virgins', 'for', 'his', 'own', 'purposes', ' ', 'and', 'flies', 'with', 'her', 'at', 'last', 'in', 'his', 'shallop', 'across', 'the', 'sea', ' ', 'to', 'live', 'with', 'her', 'happily', 'ever', 'after', ' ', 'at', 'this', 'moment', ' ', 'the', 'whole', 'soul', 'of', 'the', 'old', '