In [None]:
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
from datasets import load_dataset, Audio
import numpy as np
import torch
import chromadb
import json

In [None]:
with open("../config.json", mode = "r") as f:
    data = json.load(f)
    SAMPLING_RATE = data["sampling_rate"]
    SEGMENT_LEN = data["segment_length"]
    OVERLAP_LEN = data["overlap_length"]
    DB_PATH = data["database_path"]
    COLLECTION_NAME = data["collection_name"]

In [None]:
fineTunedExtractor = AutoFeatureExtractor.from_pretrained("checkpoints-15-5/checkpoint-32094")
fineTunedModel = AutoModelForAudioClassification.from_pretrained("checkpoints-15-5/checkpoint-32094", device_map = "cuda")

In [None]:
dataset = load_dataset("Saads/xecanto_birds")

#### Chunking

In [None]:
# def chunk_audio_fine_tuned(audio_array, chunk_length = 15, overlap = 5):
#     chunk_length = chunk_length * SAMPLING_RATE
#     overlap = overlap * SAMPLING_RATE
    
#     chunks = []
#     start = 0
#     while start + chunk_length <= len(audio_array):
#         chunks.append(audio_array[start : start + chunk_length])
#         start += (chunk_length - overlap)
    
#     # if start < len(audio_array):
#     #     last_chunk = audio_array[start:]
#     #     padded_last_chunk = np.pad(last_chunk, (0, chunk_length - len(last_chunk)))
#     #     chunks.append(padded_last_chunk)
    
#     return chunks

In [None]:
# def preprocess_fine_tuned(row):
#     chunks = chunk_audio_fine_tuned(row["audio"]["array"])
#     row["input_values"] = []
#     if(chunks):
#         inputs = fineTunedExtractor(chunks, sampling_rate = SAMPLING_RATE, return_tensors = "pt")
#         row["input_values"] = inputs["input_values"]
#     return row

In [None]:
# dataset = dataset.cast_column("audio", Audio(sampling_rate = SAMPLING_RATE))
# dataset = dataset.map(
#     preprocess_fine_tuned,
#     remove_columns = "audio",
#     batched = False,
#     num_proc = 16,
#     writer_batch_size = 200
# )

#### Whole Audio

In [None]:
def preprocess_fine_tuned(batched_data):    
    audio_array = [x["array"] for x in batched_data["audio"]]
    inputs = fineTunedExtractor(audio_array, sampling_rate = SAMPLING_RATE)
    return inputs

In [None]:
dataset = dataset.cast_column("audio", Audio(sampling_rate = SAMPLING_RATE))
dataset = dataset.map(
    preprocess_fine_tuned,
    remove_columns = "audio",
    batched = True,
    batch_size = 32,
    num_proc = 16,
    writer_batch_size = 150
)

In [None]:
# chroma_client = chromadb.Client()
chroma_client = chromadb.PersistentClient(path = DB_PATH)
collection = chroma_client.create_collection(name = COLLECTION_NAME)

In [None]:
# chroma_client.delete_collection(COLLECTION_NAME)

#### Chunked

In [None]:
# def add_embedding_chromaDB(row, index):
#     for subIdx, chunk in enumerate(row["input_values"]):
#         metadataDict = {"name": row["common_name"], "url": row["url"]}
#         inputs = torch.tensor(chunk)
#         with torch.no_grad():
#             outputs = fineTunedModel(inputs.unsqueeze(0), output_hidden_states = True)
#             logits = outputs.logits
#             hidden_states = outputs.hidden_states

#         probabilities = torch.nn.functional.softmax(logits, dim = -1)
#         values, idxs = torch.topk(probabilities, k = 5)
#         values = values.numpy()[0]
#         idxs = idxs.numpy()[0]
#         for i in range(len(idxs)):
#             metadataDict[f"pred_name_{i + 1}"] = fineTunedModel.config.id2label[idxs[i]]
#             metadataDict[f"pred_prob_{i + 1}"] = values[i].item() * 100
        
#         embeddings = hidden_states[-1]
#         embeddings = embeddings.mean(dim = 1)[0].numpy()
        
#         collection.add(
#             embeddings = [embeddings],
#             metadatas = [metadataDict],
#             ids = [f"{index}_{subIdx}"]
#         )

In [None]:
# datasetEmbeddings.map(
#     add_embedding_chromaDB,
#     batched = False,
#     num_proc = 1,
#     writer_batch_size = 1000,
#     with_indices = True
# )

#### Whole Audio

In [None]:
def add_embedding_chromaDB(row, index):
    metadataDict = {"name": row["common_name"], "url": row["url"]}
    inputs = torch.tensor(row["input_values"]).unsqueeze(0).to(torch.device("cuda"))
    with torch.no_grad():
        outputs = fineTunedModel(inputs, output_hidden_states = True)
        logits = outputs.logits
        hidden_states = outputs.hidden_states

    probabilities = torch.nn.functional.softmax(logits, dim = -1)
    values, idxs = torch.topk(probabilities, k = 5)
    values = values.cpu().numpy()[0]
    idxs = idxs.cpu().numpy()[0]
    for i in range(len(idxs)):
        metadataDict[f"pred_name_{i + 1}"] = fineTunedModel.config.id2label[idxs[i]]
        metadataDict[f"pred_prob_{i + 1}"] = values[i].item() * 100
    
    embeddings = hidden_states[-1]
    embeddings = embeddings.mean(dim = 1)[0].cpu().numpy()
    
    collection.add(
        embeddings = [embeddings],
        metadatas = [metadataDict],
        ids = [f"id_{index}"]
    )

    del inputs
    torch.cuda.empty_cache()

In [None]:
datasetEmbeddings.map(
    add_embedding_chromaDB,
    batched = False,
    num_proc = 1,
    writer_batch_size = 1000,
    with_indices = True
)