```
conda init
conda create -n get-speech-embeddings python=3.10 -y
conda activate get-speech-embeddings
pip install numpy==1.26.4
pip install torch==2.6.0+cu118 torchaudio==2.6.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
pip install coqui-tts 
pip install pandas
pip install ipykernel
```

In [1]:
import os
import csv
from pathlib import Path
import torch
import torchaudio
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from tqdm import tqdm
from IPython.display import Audio
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
ROOT = Path("../VoxCeleb2/vox2_test/wav")
CSV_PATH = Path("../VoxCeleb2/full/wav_metadata.csv")
EMB_PATH = Path("../VoxCeleb2/full/wav_embeddings.pt")
GPT_PATH = Path("../VoxCeleb2/full/wav_gpt_latents.pt")
SAVE_CHUNK = 64
TARGET_SR = 22050 

In [3]:
# Device Setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [4]:
# Coqui XTTS v2 | Speech Embedding Model
config = XttsConfig()
config.load_json("xtts/config.json")
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir="xtts/", use_deepspeed=False)
model.cuda().eval()

In [5]:
def preprocess_wav(path, target_sr=TARGET_SR):
    wav, sr = torchaudio.load(path)
    if wav.shape[0] > 1:
        wav = torch.mean(wav, dim=0, keepdim=True)
    if sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)
    return wav.squeeze(0)

In [6]:
def load_resume_index():
    """Reads last CSV index to resume correctly."""
    if not CSV_PATH.exists():
        return 0
    
    rows = list(csv.reader(open(CSV_PATH)))
    
    if len(rows) <= 1:
        return 0
    
    last_index = int(rows[-1][0])
    print(f"[Resume] Last completed index: {last_index}")
    return last_index + 1

In [7]:
# Ensure CSV exists with header
if not CSV_PATH.exists():
    with open(CSV_PATH, "w", newline="") as f:
        csv.writer(f).writerow(["index", "id", "hash", "wav_num", "filepath"])

start_index = load_resume_index()

In [8]:
# Scan filesystem
wav_files = sorted(ROOT.rglob("*.wav"))
print(f"Found {len(wav_files)} WAV files.")

Found 36237 WAV files.


In [9]:
# Load existing embeddings in case resuming
all_embeddings = torch.load(EMB_PATH) if EMB_PATH.exists() else torch.zeros((0, 512))
all_gpt_latents = torch.load(GPT_PATH) if GPT_PATH.exists() else torch.zeros((0, 1024))

In [10]:
cache_embs = []
cache_latents = []
next_index = start_index

for path in tqdm(wav_files[start_index:], desc="Processing WAVs"):
    parts = path.parts
    id_name = parts[-3]
    hash_name = parts[-2]
    wav_num = int(Path(parts[-1]).stem)

    wav_tensor = preprocess_wav(path)
    
    # XTTS embeddings
    gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[str(path)])
    
    speaker_embedding = speaker_embedding.cpu()
    gpt_cond_latent = gpt_cond_latent.cpu()

    speaker_embedding = speaker_embedding.squeeze(-1)
    gpt_cond_latent = gpt_cond_latent.mean(dim=1) 

    cache_embs.append(speaker_embedding)
    cache_latents.append(gpt_cond_latent)

    # CSV
    with open(CSV_PATH, "a", newline="") as f:
        csv.writer(f).writerow([next_index, id_name, hash_name, wav_num, str(path)])
    next_index += 1

    # Periodic save
    if len(cache_embs) >= SAVE_CHUNK:
        all_embeddings = torch.cat([all_embeddings] + cache_embs, dim=0)
        all_gpt_latents = torch.cat([all_gpt_latents] + cache_latents, dim=0)
        torch.save(all_embeddings, EMB_PATH)
        torch.save(all_gpt_latents, GPT_PATH)
        cache_embs, cache_latents = [], []

Processing WAVs: 100%|██████████| 36237/36237 [11:01<00:00, 54.78it/s]


In [11]:
# Save remaining embeddings
if len(cache_embs) > 0:
    all_embeddings = torch.cat([all_embeddings] + cache_embs, dim=0)
    all_gpt_latents = torch.cat([all_gpt_latents] + cache_latents, dim=0)
    torch.save(all_embeddings, EMB_PATH)
    torch.save(all_gpt_latents, GPT_PATH)

In [12]:
# Check CSV File
df = pd.read_csv(CSV_PATH)
print(len(df))
df.head()

36237


Unnamed: 0,index,id,hash,wav_num,filepath
0,0,id00017,01dfn2spqyE,1,..\VoxCeleb2\vox2_test\wav\id00017\01dfn2spqyE...
1,1,id00017,5MkXgwdrmJw,2,..\VoxCeleb2\vox2_test\wav\id00017\5MkXgwdrmJw...
2,2,id00017,7t6lfzvVaTM,3,..\VoxCeleb2\vox2_test\wav\id00017\7t6lfzvVaTM...
3,3,id00017,7t6lfzvVaTM,4,..\VoxCeleb2\vox2_test\wav\id00017\7t6lfzvVaTM...
4,4,id00017,7t6lfzvVaTM,5,..\VoxCeleb2\vox2_test\wav\id00017\7t6lfzvVaTM...


In [13]:
# Check Speaker Embeddings File
embeddings = torch.load(EMB_PATH)
print("Speaker embeddings shape:", embeddings.shape)
print("Has zeros:", (embeddings.abs().sum(dim=1) == 0).any())
print(embeddings[0][:6])

Speaker embeddings shape: torch.Size([36237, 512])
Has zeros: tensor(False)
tensor([ 0.0008, -0.0557, -0.0161,  0.0582, -0.0567,  0.0020])


In [14]:
# Check Speaker Embeddings
index = 67

row = df.iloc[index]
print(row)

vec = embeddings[index]

print("Norm:", torch.norm(vec).item())
print("Min value:", vec.min().item())
print("Max value:", vec.max().item())
print("Embedding vector:", vec[:6])

index                                                      67
id                                                    id00017
hash                                              kN2iEvnTHc4
wav_num                                                   135
filepath    ..\VoxCeleb2\vox2_test\wav\id00017\kN2iEvnTHc4...
Name: 67, dtype: object
Norm: 1.0
Min value: -0.2600252628326416
Max value: 0.15831495821475983
Embedding vector: tensor([ 0.0268, -0.0380, -0.0178,  0.0490, -0.0299,  0.0232])


In [15]:
# Check GPT Latents File
gpt_latents = torch.load(GPT_PATH)
print("GPT latents shape:", gpt_latents.shape)
print("Has zeros:", (gpt_latents.abs().sum(dim=1) == 0).any())
print(gpt_latents[0][:6])

GPT latents shape: torch.Size([36237, 1024])
Has zeros: tensor(False)
tensor([0.1205, 0.9428, 0.1296, 0.7079, 0.2082, 0.2103])


In [16]:
# Check Speaker Embeddings
index = 67

row = df.iloc[index]
print(row)

vec = gpt_latents[index]

print("Norm:", torch.norm(vec).item())
print("Min value:", vec.min().item())
print("Max value:", vec.max().item())
print("Embedding vector:", vec[:6])

index                                                      67
id                                                    id00017
hash                                              kN2iEvnTHc4
wav_num                                                   135
filepath    ..\VoxCeleb2\vox2_test\wav\id00017\kN2iEvnTHc4...
Name: 67, dtype: object
Norm: 28.044885635375977
Min value: -10.159658432006836
Max value: 8.519124031066895
Embedding vector: tensor([ 0.3297, -0.3601, -0.4744,  0.4089, -0.3257, -0.5071])


In [17]:
# Show wav file
wav_path = row["filepath"]
wav, sr = torchaudio.load(wav_path)
if wav.shape[0] > 1:
    wav = wav.mean(dim=0, keepdim=True)
Audio(wav.squeeze(0).numpy(), rate=sr)