In [None]:
import chromadb
import numpy as np
import torch
import wespeaker
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
df_train = pd.read_csv('dataset/train.csv')
df_test = pd.read_csv('dataset/test.csv')

train_audio_files = df_train['audio_path'].to_list()
test_audio_files = df_test['audio_path'].to_list()

In [85]:
def extract_embeddings(df, audio_files, device, pretrain_dir):
    """
    Extracts embeddings from audio files using the WeSpeaker model
    """
    model = wespeaker.load_model_local(pretrain_dir)
    model.set_device(device)

    embeddings = []

    for file_path in audio_files:
        data, sample_rate = sf.read(file_path)

        pcm = torch.from_numpy(data).float()

        if len(pcm.shape) == 1:
            pcm = pcm.unsqueeze(0)  # Добавляем измерение для канала
        elif len(pcm.shape) == 2:
            pcm = pcm.transpose(0, 1)  # Меняем местами каналы и сэмплы

        embedding = model.extract_embedding_from_pcm(pcm, sample_rate)

        embedding = embedding.cpu().numpy()
        embeddings.append({
            'file_path': str(file_path),
            'embedding': embedding,
            'label': df[df['audio_path'] == file_path]['class'].values[0]
        })

    return embeddings


def save_to_chromadb(embeddings, db_path, split):
    """
    Stores embeddings in ChromaDB
    """
    client = chromadb.PersistentClient(path=db_path)
    collection = client.get_or_create_collection(name="gender_embeddings")

    collection.add(
        ids=[f"{split}_{i}" for i in range(len(embeddings))],
        embeddings=[item['embedding'] for item in embeddings],
        metadatas=[{
            "file_path": item['file_path'], "label": item['label'],
            "split": split
        }
            for item in embeddings]
    )

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pretrain_dir = 'voxblink2_samresnet34'

test_embeddings = extract_embeddings(df_test, test_audio_files, device, pretrain_dir)

In [87]:
len(test_embeddings)

300

In [None]:
train_embeddings = extract_embeddings(df_train, train_audio_files, device, pretrain_dir)

In [89]:
len(train_embeddings)

1800

In [90]:
save_to_chromadb(train_embeddings, 'chroma_db', 'train')
save_to_chromadb(test_embeddings, 'chroma_db', 'test')