In [27]:
import chromadb
import numpy as np
import torch
import wespeaker
import pandas as pd
import soundfile as sf
from tqdm import tqdm

In [6]:
df_train = pd.read_csv(r"C:\Users\Nastya\Desktop\interp_dev-main\dataset\train_data.csv")
df_test = pd.read_csv(r"C:\Users\Nastya\Desktop\interp_dev-main\dataset\test_data.csv")

train_audio_files = df_train["filename"].to_list()
test_audio_files = df_test["filename"].to_list()

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_dir = r"C:\Users\Nastya\Desktop\interp_dev-main\voxblink2_samresnet34"

In [28]:
def extract_embeddings(df, audio_files, device, pretrain_dir):
    """
    Extracts embeddings from audio files using the WeSpeaker model
    """
    model = wespeaker.load_model_local(pretrain_dir)
    model.set_device(device)

    embeddings = []

    for audio_file in tqdm(audio_files):
        file_path = "C:/Users/Nastya/Desktop/interp_dev-main/dataset/" + audio_file
        
        data, sample_rate = sf.read(file_path)
        pcm = torch.from_numpy(data).float()
        
        if len(pcm.shape) == 1:
            pcm = pcm.unsqueeze(0)  
        elif len(pcm.shape) == 2:
            pcm = pcm.transpose(0, 1) 

        embedding = model.extract_embedding_from_pcm(pcm, sample_rate)

        embedding = embedding.cpu().numpy()
        embeddings.append({
            "file_path": file_path,
            "embedding": embedding,
            "label": df[df["filename"] == audio_file]["accent"].item()
        })

    return embeddings

In [None]:
train_embeddings = extract_embeddings(df_train, train_audio_files, device, pretrain_dir)

In [None]:
test_embeddings = extract_embeddings(df_test, test_audio_files, device, pretrain_dir)

In [35]:
def save_to_chromadb(embeddings, db_path, split):
    """
    Stores embeddings in ChromaDB
    """
    client = chromadb.PersistentClient(path=db_path)
    collection = client.get_or_create_collection(name="nationality_embeddings")

    collection.add(
        ids=[f"{split}_{i}" for i in range(len(embeddings))],
        embeddings=[item['embedding'] for item in embeddings],
        metadatas=[{
            "file_path": item['file_path'], "label": item['label'],
            "split": split
        }
            for item in embeddings]
    )

In [38]:
save_to_chromadb(train_embeddings, r"C:\Users\Nastya\Desktop\interp_dev-main\Chroma", "train")
save_to_chromadb(test_embeddings, r"C:\Users\Nastya\Desktop\interp_dev-main\Chroma", "test")