In [188]:
import os
import uuid
from typing import Sequence

import torch
from clickhouse_connect.driver import Client
from pydub import AudioSegment
import numpy as np
import math
import clickhouse_connect
from loguru import logger
from transformers import Wav2Vec2Processor, Wav2Vec2Model

In [149]:
def normalize(arr: np.array) -> np.array:
    """
    Normalize a numpy array.

    This function takes a numpy array and returns a normalized version of the array.
    Normalization scales the array so that its norm (length) is 1, while maintaining the 
    direction of the vector.

    Parameters
    ----------
    arr : np.array
        The input numpy array to be normalized.

    Returns
    -------
    np.array
        The normalized numpy array. If the norm of the input array is 0, the original array 
        is returned unchanged.

    Examples
    --------
    >>> import numpy as np
    >>> arr = np.array([1, 2, 3])
    >>> normalize(arr)
    array([0.26726124, 0.53452248, 0.80178373])
    
    >>> arr = np.array([0, 0, 0])
    >>> normalize(arr)
    array([0, 0, 0])
    """
    norm = np.linalg.norm(arr)
    if norm == 0:
        return arr
    return arr / norm

In [46]:
def split_on_chunk(path2mp3: str, chunk_size: int = 1024, sample_rate: int = 16000) -> AudioSegment:
    """
    Splits an MP3 audio file into chunks of a specified size and returns them as a generator.

    Args:
        path2mp3 (str): The file path to the MP3 audio file.
        chunk_size (int, optional): The size of each chunk in milliseconds. Defaults to 1024 ms.
        sample_rate (int, optional): The sample rate to set for the audio in Hz. Defaults to 16000 Hz.

    Yields:
        AudioSegment: A chunk of the audio as an AudioSegment object.

    Example:
        for chunk in split_on_chunk("example.mp3", chunk_size=5000, sample_rate=22050):
            # Process each chunk
            pass

    Note:
        - The audio is converted to mono and the specified sample rate before splitting.
        - The last chunk may be shorter than the specified chunk size if the total length of the audio
          is not a multiple of chunk_size.
    """
    audio = AudioSegment.from_mp3(path2mp3)
    audio = audio.set_frame_rate(sample_rate).set_channels(1)
    
    total_length = len(audio)
    num_chunks = math.ceil(total_length / chunk_size)
    
    for i in range(num_chunks):
        start_time = i * chunk_size
        end_time = min((i + 1) * chunk_size, total_length)
        chunk = audio[start_time:end_time]
        yield chunk

In [49]:
def get_embedding(chunk: AudioSegment, processor: Wav2Vec2Processor, model: Wav2Vec2Model) -> torch.Tensor:
    """
    Extracts an embedding from an audio chunk using a Wav2Vec2 model.

    Args:
        chunk (AudioSegment): The audio chunk to process, represented as an AudioSegment object.
        processor (Wav2Vec2Processor): The Wav2Vec2 processor to preprocess the audio chunk.
        model (Wav2Vec2Model): The pre-trained Wav2Vec2 model to generate the embedding.

    Returns:
        torch.Tensor: The embedding generated by the Wav2Vec2 model.

    Example:
        processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base-960h')
        model = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
        embedding = get_embedding(audio_chunk, processor, model)

    Note:
        - The audio chunk is converted to a numpy array of float32 samples.
        - The processor converts the audio array to the appropriate input format for the model.
        - The model's last hidden state is returned as the embedding.
    """
    audio_array = np.array(chunk.get_array_of_samples(), dtype=np.float32)
    audio_input = processor(audio_array, sampling_rate=chunk.frame_rate, return_tensors="pt", padding=True).input_values
    with torch.no_grad():    
        embedding = model(audio_input).last_hidden_state
    return embedding

In [177]:
def save_embedding(embedding: np.array, title: str, client: Client) -> None:
    """
    Saves an embedding to a database.

    Args:
        embedding (np.array): The embedding to save, represented as a NumPy array.
        title (str): The title or label associated with the embedding.
        client (Client): The database client used to execute the insert command.

    Returns:
        None

    Example:
        client = get_database_client()
        save_embedding(embedding, "sample_title", client)

    Note:
        - A unique identifier (UUID) is generated for each embedding entry.
        - The embedding is converted to a list before being saved to the database.
        - The client must have a method `command` that executes SQL commands.

    Raises:
        Any exceptions raised by the `client.command` method will propagate to the caller.
    """
    id = uuid.uuid4()
    client.command('INSERT INTO `embedding` (`id`, `vector`, `title`) VALUES (%s, %s, %s)', (id, embedding.tolist(), title))
    

In [189]:
def search(embedding: np.array, top_k: int, client: Client) -> Sequence[Sequence]:
    """
    Searches for the top-k most similar embeddings in the database based on cosine similarity.

    Args:
        embedding (np.array): The query embedding to search for, represented as a NumPy array.
        top_k (int): The number of top similar embeddings to return.
        client (Client): The database client used to execute the search query.

    Returns:
        Sequence[Sequence]: A sequence of results, where each result is a sequence containing the ID, title, 
                            and cosine similarity of the matching embeddings.

    Example:
        client = get_database_client()
        results = search(query_embedding, top_k=5, client=client)
        for result in results:
            print(result)

    Note:
        - The query is constructed using the embedding converted to a list.
        - The cosine similarity is calculated to determine the similarity between embeddings.
        - The client must have a method `query` that executes the search query.

    Raises:
        Any exceptions raised by the `client.query` method will propagate to the caller.
    """
    query = f'''
    WITH {embedding.tolist()} as query_vector
    SELECT id, title, 
    arraySum(x -> x * x, vector) * arraySum(x -> x * x, query_vector) != 0
    ? arraySum((x, y) -> x * y, vector, query_vector) / sqrt(arraySum(x -> x * x, vector) * arraySum(x -> x * x, query_vector))
    : 0 AS cosine_distance
    FROM embedding
    WHERE length(query_vector) == length(vector)
    ORDER BY cosine_distance DESC 
    LIMIT {top_k}
    '''
    
    result = client.query(query, settings={"max_query_size": "10000000000000"})
    return result.result_rows

## Initialize Wav2Vec2 Processor and Model

Initialize the Wav2Vec2Processor and Wav2Vec2Model using pre-trained models from Facebook.

In [13]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2Model: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Connect to ClickHouse

Connect to a ClickHouse database client running on localhost at port 8123.

In [15]:
client = clickhouse_connect.get_client(host='localhost', port=8123)

## List Files in Dataset Directory

List all files in the "dataset" directory.

In [146]:
files = os.listdir("dataset")
files

['posledina_lubov.mp3', 'den-h.mp3', 'test.mp3', 'cvetok.mp3']

## Save Embeddings for Each File

Loop through each file in the dataset directory, split it into chunks, compute embeddings, normalize them, and save the embeddings to ClickHouse.

In [180]:
for file in files:
    logger.debug(f"starts saving the file {file}")
    path = os.path.join("dataset", file)
    for index, chunk in enumerate(split_on_chunk(path)):
        embedding = get_embedding(chunk, processor, model)
        embedding = embedding.squeeze().flatten().numpy()
        embedding = normalize(embedding)
        save_embedding(embedding, file, client)

[32m2024-06-09 20:04:48.469[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [34m[1mstarts saving the file posledina_lubov.mp3[0m
[32m2024-06-09 20:05:05.563[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [34m[1mstarts saving the file den-h.mp3[0m
[32m2024-06-09 20:05:23.414[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [34m[1mstarts saving the file test.mp3[0m
[32m2024-06-09 20:05:38.877[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [34m[1mstarts saving the file cvetok.mp3[0m


## Search Embeddings

For a given test file, split it into chunks, compute embeddings, and search for the top 2 nearest embeddings in the database.

In [None]:
for chunk in split_on_chunk("dataset/test.mp3", chunk_size=2024):
    embedding = get_embedding(chunk, processor, model)
    embedding = embedding.squeeze().flatten().numpy()
    searched = search(embedding , 2, client)
    print(searched[0])
