In [6]:
from pathlib import Path
from typing import List

from more_itertools import chunked
import pandas as pd
import torch

from src.config.project_paths import get_data_file_path, get_model_save_dir, get_project_root_path
from speechbrain.inference.speaker import EncoderClassifier
from src.embedding.create_embedding import batch_create_speechbrain_embedding
from src.embedding.embedded_audio import EmbeddedAudio
from tqdm.auto import tqdm
import pickle

In [7]:
# device used for embedding creation
DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# pretrained model to download and use for embedded creation. must be a speechbrain model compatible
# with EncoderClassifier class
MODEL_NAME: str = "speechbrain/spkrec-ecapa-voxceleb"
# Path to a json file, which contains relative paths to audio files
ANNOTATION_PATH: Path = get_data_file_path("annotations_with_metadata.json")
# How many audio files to pass through the model at the same time. Must be >= 1
BATCH_SIZE: int = 1
# once all the embeddings are calculated, they are pickled to this path.
EMBEDDING_PICKLE_PATH: Path = get_data_file_path(f"raw_audio_embeddings_{MODEL_NAME.replace('/', '-')}.pkl")

In [8]:
model = EncoderClassifier.from_hparams(source=MODEL_NAME, savedir=get_model_save_dir(MODEL_NAME),
                                       run_opts={"device": str(DEVICE)})

In [9]:
annotation_df = pd.read_json(ANNOTATION_PATH, orient="records")
rel_audio_paths: List[str] = annotation_df["wav_path"].to_list()
rel_audio_path_batches = chunked(rel_audio_paths, BATCH_SIZE)

In [None]:
number_of_batches = len(rel_audio_paths) // BATCH_SIZE + (1 if len(rel_audio_paths) % BATCH_SIZE > 0 else 0)
print(f"Number of batches: {number_of_batches}")

In [None]:
embeddings = []
for rel_audio_path_batch in tqdm(rel_audio_path_batches, f"Creating embeddings", total=number_of_batches):
    abs_audio_path_batch = list(map(lambda rel_path: get_project_root_path() / rel_path, rel_audio_path_batch))
    audio_embeddings = batch_create_speechbrain_embedding(model, abs_audio_path_batch)
    embedded_audio_list = [EmbeddedAudio(audio_rel_path=audio_rel_path, embedding=embedding) for
                           audio_rel_path, embedding in zip(rel_audio_path_batch, audio_embeddings)]

    embeddings.extend(embedded_audio_list)

In [14]:
with open(EMBEDDING_PICKLE_PATH, "wb+") as output_file:
    pickle.dump(embeddings, output_file)