# QVIM-AES Submission Template

This is the submission template for the Query by Vocal Imitation challenge at the 2025 AES International Conference on Artificial Intelligence and Machine Learning for Audio.

The content of this notebook is inspired by the template provided by the task organizers of the [Sound Scene Synthesis Taks of the DCASE Challenge 2024](https://dcase.community/challenge2024/task-sound-scene-synthesis).

<div class="alert alert-block alert-warning"> 
<b>Confidentiality Statement</b><br> As the organizers of this contest, we assure all participants that their submitted models and code will be treated with strict confidentiality. Submissions will only be accessed by the designated review team for evaluation purposes and will not be shared, distributed, or used beyond the scope of this challenge. Participants retain full ownership of their work. We will not claim any rights over the submitted materials, nor will we use them for any purpose outside of the challenge evaluation process. We appreciate your participation in this challenge.
</div>

#### How to create your submission
- Get familiar with the existing code blocks and the example provided below.
- Set the root path of your environment and your dataset below ("TODO: DEFINE YOUR PATHS HERE.").
- Set up your project ("TODO: SETUP YOUR PROJECT HERE.").
- Implement the retrieval interface below ("TODO: ADD YOUR IMPLEMENTATION HERE.").
    - Use the provided helper functions (helpers) to download your source code, model checkpoints, etc.
- Instantiate your retrieval model ("TODO: INSTANTIATE YOUR MODEL HERE.").
- Before **submitting your notebook**, run this notebook in a clean conda environment (with python >= 3.10) on Ubuntu 24.04 and make sure the evaluation results are in line with your previous results.
- Submit your notebooks and the technical report as described on our [website](https://qvim-aes.github.io/).

##### Some Rules
- DO NOT modify the other code cells.
- DO NOT add new cells.
- Store your project WITHIN 'ROOT_PATH' and your data within 'DATA_PATH'.
- DO NOT use 'ROOT_PATH/output' folder; this is where we will store things.
- DO NOT change the working directory (e.g., `os.chdir('/path/to/a/dir/that/does/not/exist/on/my/machine')`).
- DO NOT use system commands (`!cd ~` or `os.system('cd ~')`, etc.) other than the ones used to set up your environment (i.e., install required packages with pip, conda, ...).

<div class="alert alert-block alert-danger"> 
Participant who submit malicious code will be disqualified.
</div>
    

In [1]:
"""
DO NOT MODIFY THIS BLOCK.
"""
# Install basic packages for template notebook.
!pip install librosa numpy pandas tqdm GitPython gdown==5.1.0

Collecting librosa
  Downloading librosa-0.11.0-py3-none-any.whl.metadata (8.7 kB)
Collecting numpy
  Downloading numpy-2.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
Collecting pandas
  Downloading pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
Collecting tqdm
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting GitPython
  Downloading GitPython-3.1.44-py3-none-any.whl.metadata (13 kB)
Collecting gdown==5.1.0
  Downloading gdown-5.1.0-py3-none-any.whl.metadata (5.7 kB)
Collecting filelock (from gdown==5.1.0)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting audioread>=2.1.9 (from librosa)
  Downloading audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)
Collecting numba>=0.51.0 (from librosa)
  Downloading numba-0.61.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.8 kB)
Collecting scipy>=1.6.0 (from librosa)
  Downloading scipy-1.15.3-cp310-c

In [2]:
"""
DO NOT MODIFY THIS BLOCK.
"""
# some imports
import sys
import os

from abc import ABC, abstractmethod
from tqdm import tqdm
import numpy as np
import pandas as pd


## Description of the Retrieval Interface 
`QVIMModel` is the interface specification for all query by vocal imitation systems. Each submitted system is expected to subclass this interface and implement the `compute_similarities` method, which computes the similarities between all pairwise combinations of queries (vocal imitations) and items (reference sounds).

`compute_similarities` takes two dictionaries as input:
- queries is a dictionary mapping ids of items to be retrieved to the corresponding file paths.
- items is a dictionary mapping query ids to the corresponding file paths

Participants are expected to load the sounds themselves, e.g., with `librosa.load`.

In [3]:
"""
DO NOT MODIFY THIS BLOCK.
"""

class QVIMModel(ABC):

    @abstractmethod
    def compute_similarities(
            self, items: dict[str, str], queries: dict[str, str]
    ) -> dict[str, dict[str, float]]:
        """Compute similarity scores between items to be retrieved and a set of queries.

        Each <query, item> pairing should be assigned a single floating point score, where higher
        scores indicate higher similarity.

        Args:
            items (dict[str, str]): A dictionary mapping ids of items to be retrieved to the corresponding file path
            queries (dict[str, str]): A dictionary mapping query ids to the corresponding file path

        Returns:
            scores (dict[str, dict[str, float]]): A dictionary mapping query ids to a dictionary of item
                ids and their corresponding similarity scores. E.g:
                {
                    "query_1": {
                        "item_1": 0.8,
                        "item_2": 0.6,
                        ...
                    },
                    "query_2": {
                        "item_1": 0.4,
                        "item_2": 0.9,
                        ...
                    },
                    ...
                }
        """
        pass

## Some Helper Functions

`helpers.py` contains some helpful functions for downloading code and model checkpoints from Google Drive, Git and public links.

The functions were taken (with slight modifications) from the submission template provided by the task organizers of [Task 7 of the DCASE Challenge 2024: Sound Scene Synthesis](https://dcase.community/challenge2024/task-sound-scene-synthesis).

In [5]:
import helpers
from helpers import google_drive_download, wget_download, git_clone_checkout, unpack_file

## Step 1: Setup your paths

- Define `ROOT_PATH`; this is where your project lives; for testing, we'll replace it with our custom ROOT_PATH. We recommend using the current working directory ('.').
- Define `DATA_PATH`; this is where your public development data lives; for testing, we'll replace it with our custom DATA_PATH. We recommend using 'data/qvim-dev'.
    

In [6]:
"""
TODO: DEFINE YOUR PATHS HERE.
"""

# replace this with your custom ROOT_PATH; this is where your code/ checkpoints will be downloaded to
ROOT_PATH = "."

# path to the evaluation data; can be in ROOT_PATH
DATA_PATH = os.path.join(ROOT_PATH, "DEVUpdatedDataset")

In [7]:
helpers.ROOT_PATH = ROOT_PATH
os.makedirs(ROOT_PATH, exist_ok=True)
os.makedirs(DATA_PATH, exist_ok=True)
sys.path.append(os.path.join(ROOT_PATH))

# Step 2: Setup your environment, download checkpoints, etc.

Setup your project and install the required packages here.
The easiest way is to:
1) convert your implementation into a package,
2) clone the repository and checkout the specific branch and commit,
3) install your package with pip install -e name_of_your_fancy_package


Hints:
- Make sure your link to the repository and other URLs are publicly available.
- Use **shared public URLs** (e.g. a shared Google Drive, Dropbox, Zenodo link) to download checkpoints into `ROOT_PATH`.
- Use the provided helper functions (`google_drive_download`, `wget_download`, `git_clone_checkout`, and `unpack_file`).

In [8]:
"""
TODO: SETUP YOUR PROJECT HERE.
"""

git_clone_checkout(
    output_dir='qvim_baseline_rp', 
    url='https://github.com/RP335/qvim-baseline', 
    branch='main', 
    commit_sha='8222f2f4651b08a7d3a47026bce6948657c7bf2e' 
)
print("Cloned RP335/qvim-baseline repository.")


!pip install speechbrain hear21passt panns_inference wandb torchaudio==2.6.0 torch==2.6.0 lightning==2.5.1 librosa==0.11.0 torchvision==0.21.0
!pip install cuda-python
print("Required libraries installed.")



google_drive_download(filename="aug_baseline_latest.ckpt", shared_url="https://drive.google.com/file/d/1HglQg8wTQaHzV6eSVND99hfLfUkQnLPl/view?usp=download", relative_dir = "resources")
google_drive_download(filename="mrr_values_baseline_mobilenet.csv", shared_url="https://drive.google.com/file/d/1eFL3uYAbLJmJCNMmjvS4aus1WRu47cVH/view?usp=download", relative_dir = "resources")
google_drive_download(filename="passt_finetuned_1.ckpt", shared_url="https://drive.google.com/file/d/1SjyHPMjyBzSuSj1cWe2NKIOM0m3VUN3k/view?usp=download", relative_dir = "resources")
google_drive_download(filename="Cnn14_mAP=0.431.pth", shared_url="https://drive.google.com/file/d/1zbWcCrF_oopLsk4il5q6oMwl2-lfhwXT/view?usp=download", relative_dir = "resources")
google_drive_download(filename="panns_finetuned_2.ckpt", shared_url="https://drive.google.com/file/d/10QXhEjc0bimeonr6SDBix_MBnss6e4cy/view?usp=download", relative_dir = "resources")
google_drive_download(filename="beats_finetuned_3.ckpt", shared_url="https://drive.google.com/file/d/1ryZLXpdfn6bM9qd49FQpNQys8jW0TS2r/view?usp=download", relative_dir = "resources")
google_drive_download(filename="BEATs_iter3.pt", shared_url="https://drive.google.com/file/d/1NmDz4TFdf66nbxe1NK5-z2I416SIg9FH/view?usp=download", relative_dir = "resources")





os.makedirs(os.path.join(ROOT_PATH, "resources"), exist_ok=True)
print(f"Setup complete. Checkpoints are expected in '{os.path.join(ROOT_PATH, 'resources')}' directory.")

Repository cloned to ./qvim_baseline_rp and checked out to main at commit 8222f2f4651b08a7d3a47026bce6948657c7bf2e.
Cloned RP335/qvim-baseline repository.
Collecting speechbrain
  Downloading speechbrain-1.0.3-py3-none-any.whl.metadata (24 kB)
Collecting hear21passt
  Downloading hear21passt-0.0.26-py3-none-any.whl.metadata (5.8 kB)
Collecting panns_inference
  Downloading panns_inference-0.1.1-py3-none-any.whl.metadata (2.4 kB)
Collecting wandb
  Downloading wandb-0.19.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting torchaudio==2.6.0
  Downloading torchaudio-2.6.0-cp310-cp310-manylinux1_x86_64.whl.metadata (6.6 kB)
Collecting torch==2.6.0
  Downloading torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl.metadata (28 kB)
Collecting lightning==2.5.1
  Downloading lightning-2.5.1-py3-none-any.whl.metadata (39 kB)
Collecting torchvision==0.21.0
  Downloading torchvision-0.21.0-cp310-cp310-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting networkx (from to

Defaulting to user installation because normal site-packages is not writeable
Required libraries installed.
Setup complete. Checkpoints are expected in './resources' directory.


# Step 3: Implement the QVIMModel Interface

In [9]:
"""
TODO: ADD YOUR IMPLEMENTATION HERE.
"""

import numpy as np
import torch
import librosa
import argparse
from collections import OrderedDict
from tqdm import tqdm
import os
import sys

# Determine the target device
_TARGET_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {_TARGET_DEVICE}")

PATH_TO_YOUR_REPO_SRC = os.path.abspath(os.path.join(ROOT_PATH, "qvim_baseline_rp", "src"))
if PATH_TO_YOUR_REPO_SRC not in sys.path:
    sys.path.insert(0, PATH_TO_YOUR_REPO_SRC)
    print(f"Prepended to sys.path for local modules: {PATH_TO_YOUR_REPO_SRC}")

try:
    # Assuming ex_qvim_original.py contains the baseline QVIMModule
    from qvim_mn_baseline.ex_qvim_original import QVIMModule as MobileNetOriginalQVIMModule
    from qvim_mn_baseline.ex_qvim_alt import QVIMModuleAlternate # For PaSST, PANNs, BEATs
    print("Successfully imported MobileNetOriginalQVIMModule and QVIMModuleAlternate.")
except ImportError as e:
    print(f"ERROR importing Lightning Modules: {e}")
    print(f"Ensure your cloned repo 'qvim_baseline_rp' is in '{ROOT_PATH}', contains ex_qvim_original.py and ex_qvim_alt.py, and sys.path is correct.")
    raise


# QVIMModel class definition needs to be provided if it's a custom base class.
# Assuming it's defined elsewhere or is a placeholder for a known library's class.
# For this example, I'll add a minimal placeholder if it's not imported.
try:
    QVIMModel # Check if it exists
except NameError:
    class QVIMModel: # Minimal placeholder
        def __init__(self): pass
        def embed_item(self, file_path: str) -> np.ndarray: raise NotImplementedError
        def embed_query(self, file_path: str) -> np.ndarray: raise NotImplementedError
        def compute_similarities(self, items: dict[str, str], queries: dict[str, str]) -> dict[str, dict[str, float]]: raise NotImplementedError


class MobileNetV3Baseline(QVIMModel):
    def __init__(self, checkpoint_filename="baseline.ckpt", resources_dir="resources"):
        super(MobileNetV3Baseline, self).__init__()
        self.device = _TARGET_DEVICE
        self.config_runtime = argparse.Namespace(
            project='qvim_baseline_eval', num_workers=0, num_gpus=(1 if self.device.type == 'cuda' else 0), model_save_path=None,
            dataset_path='data', target_classes=[], pretrained_name='mn10_as', random_seed=None,
            batch_size=16, n_epochs=15, weight_decay=0.0, max_lr=0.0003, min_lr=0.0001,
            warmup_epochs=1, rampdown_epochs=7, initial_tau=0.07, tau_trainable=False,
            duration=10.0, sample_rate=32000, window_size=800, hop_size=320, n_fft=1024,
            n_mels=128, freqm=2, timem=200, fmin=0, fmax=(32000 // 2),
            fmin_aug_range=10, fmax_aug_range=2000
        )
        checkpoint_path = os.path.join(ROOT_PATH, resources_dir, checkpoint_filename)
        print(f"Loading MobileNetV3 baseline from: {checkpoint_path}")
        if not os.path.exists(checkpoint_path):
            checkpoint_path_alt = os.path.join(ROOT_PATH, resources_dir, "aug_baseline_latest.ckpt")
            if os.path.exists(checkpoint_path_alt): checkpoint_path = checkpoint_path_alt
            else: raise FileNotFoundError(f"MobileNetV3 baseline checkpoint '{checkpoint_filename}' or alternates not found in '{os.path.join(ROOT_PATH, resources_dir)}'")

        try:
            self.qvim_model = MobileNetOriginalQVIMModule.load_from_checkpoint(
                checkpoint_path, map_location=self.device,
                config=self.config_runtime, strict=False
            )
        except Exception as e_load:
            print(f"Direct load_from_checkpoint failed for MobileNetV3: {e_load}. Attempting manual state_dict load.")
            self.qvim_model = MobileNetOriginalQVIMModule(config=self.config_runtime)
            try: ckpt_data = torch.load(checkpoint_path, map_location=self.device, weights_only=True)
            except: ckpt_data = torch.load(checkpoint_path, map_location=self.device, weights_only=False)

            state_dict_to_load = ckpt_data.get('state_dict', ckpt_data)
            self.qvim_model.load_state_dict(state_dict_to_load, strict=False)
            if not hasattr(self.qvim_model, 'config'): self.qvim_model.config = self.config_runtime

        self.qvim_model = self.qvim_model.eval().to(self.device)
        self.config_for_audio_loading = self.qvim_model.config
        print(f"MobileNetV3 model ready on device: {self.device}")

    def load_audio(self, file_path: str) -> torch.Tensor:
        audio, _ = librosa.load(file_path, sr=self.config_for_audio_loading.sample_rate, mono=True, duration=self.config_for_audio_loading.duration)
        fixed_length = int(self.config_for_audio_loading.sample_rate * self.config_for_audio_loading.duration)
        array = np.zeros(fixed_length, dtype=np.float32)
        current_len = len(audio)
        if current_len < fixed_length: array[:current_len] = audio
        else: array = audio[:fixed_length].astype(np.float32)
        return torch.from_numpy(array).unsqueeze(0).to(self.device)

    def embed_item(self, file_path: str) -> np.ndarray:
        with torch.no_grad(): return self.qvim_model.forward_reference(self.load_audio(file_path)).detach().cpu().numpy().squeeze()
    def embed_query(self, file_path: str) -> np.ndarray:
        with torch.no_grad(): return self.qvim_model.forward_imitation(self.load_audio(file_path)).detach().cpu().numpy().squeeze()

    def compute_similarities(self, items: dict[str, str], queries: dict[str, str]) -> dict[str, dict[str, float]]:
        scores = {q_id: {} for q_id in queries.keys()}
        if not items or not queries: return scores
        item_embs = {item_id: self.embed_item(item_path) for item_id, item_path in tqdm(items.items(), desc="Embedding Items (MobileNetV3)")}
        query_embs = {query_id: self.embed_query(query_path) for query_id, query_path in tqdm(queries.items(), desc="Embedding Queries (MobileNetV3)")}
        for q_name, q_emb in tqdm(query_embs.items(), desc="Calculating Similarities (MobileNetV3)"):
            for i_name, i_emb in item_embs.items():
                scores[q_name][i_name] = float(np.dot(i_emb.flatten(), q_emb.flatten()))
        return scores

class FineTunedModelWrapper(QVIMModel):
    def __init__(self,
                 finetuned_checkpoint_filename: str,
                 model_type_name: str,
                 config_for_qvima_init: argparse.Namespace,
                 resources_dir="resources"):
        super(FineTunedModelWrapper, self).__init__()
        self.device = _TARGET_DEVICE
        self.model_type_name = model_type_name
        self.config_runtime_for_qvima = config_for_qvima_init

        actual_finetuned_checkpoint_path = os.path.join(ROOT_PATH, resources_dir, finetuned_checkpoint_filename)
        if not os.path.exists(actual_finetuned_checkpoint_path):
            raise FileNotFoundError(f"{model_type_name} fine-tuned checkpoint '{finetuned_checkpoint_filename}' not found in '{os.path.join(ROOT_PATH, resources_dir)}'")

        print(f"Loading {model_type_name} fine-tuned model from: {actual_finetuned_checkpoint_path}")

        if not hasattr(self.config_runtime_for_qvima, 'model_type') or self.config_runtime_for_qvima.model_type != self.model_type_name:
            print(f"Warning: Forcing model_type in config_for_qvima_init for {model_type_name} to '{self.model_type_name}'. Original: {getattr(self.config_runtime_for_qvima, 'model_type', 'None')}")
        self.config_runtime_for_qvima.model_type = self.model_type_name
        self.config_runtime_for_qvima.num_gpus = (1 if self.device.type == 'cuda' else 0)

        if self.model_type_name == "passt":
            try: from hear21passt.base import load_model as passt_loader_check # noqa
            except ImportError: raise ImportError("hear21passt library required for PaSST model but not found.")
        elif self.model_type_name == "panns":
            try: from panns_inference import AudioTagging as PannsModelCheck # noqa
            except ImportError: raise ImportError("panns_inference library required for PANNs model but not found.")
        elif self.model_type_name == "beats":
            try: from speechbrain.lobes.models.beats import BEATs as SpeechBrainBEATsModel_check # noqa
            except ImportError: raise ImportError("SpeechBrain library required for BEATs model but not found.")

        print(f"Initializing QVIMModuleAlternate for {self.model_type_name} with effective config:")

        self.qvim_alternate_model = QVIMModuleAlternate.load_from_checkpoint(
            checkpoint_path=actual_finetuned_checkpoint_path,
            map_location=self.device,
            config=self.config_runtime_for_qvima,
            strict=False
        )
        self.qvim_alternate_model = self.qvim_alternate_model.eval().to(self.device)
        self.config_for_audio_loading = self.qvim_alternate_model.hparams
        print(f"{model_type_name} model ready on device: {self.device}")

    def load_audio(self, file_path: str) -> torch.Tensor:
        audio, _ = librosa.load(file_path, sr=self.config_for_audio_loading.sample_rate, mono=True, duration=self.config_for_audio_loading.duration)
        fixed_length = int(self.config_for_audio_loading.sample_rate * self.config_for_audio_loading.duration)
        array = np.zeros(fixed_length, dtype=np.float32)
        current_len = len(audio)
        if current_len < fixed_length: array[:current_len] = audio
        else: array = audio[:fixed_length].astype(np.float32)
        return torch.from_numpy(array).unsqueeze(0).to(self.device)

    def embed_item(self, file_path: str) -> np.ndarray:
        with torch.no_grad(): return self.qvim_alternate_model.forward_reference(self.load_audio(file_path)).detach().cpu().numpy().squeeze()
    def embed_query(self, file_path: str) -> np.ndarray:
        with torch.no_grad(): return self.qvim_alternate_model.forward_imitation(self.load_audio(file_path)).detach().cpu().numpy().squeeze()

    def compute_similarities(self, items: dict[str, str], queries: dict[str, str]) -> dict[str, dict[str, float]]:
        scores = {q_id: {} for q_id in queries.keys()}
        if not items or not queries: return scores
        item_embs = {item_id: self.embed_item(item_path) for item_id, item_path in tqdm(items.items(), desc=f"Embedding Items ({self.model_type_name})")}
        query_embs = {query_id: self.embed_query(query_path) for query_id, query_path in tqdm(queries.items(), desc=f"Embedding Queries ({self.model_type_name})")}
        for q_name, q_emb in tqdm(query_embs.items(), desc=f"Calculating Similarities ({self.model_type_name})"):
            for i_name, i_emb in item_embs.items():
                if q_emb is not None and i_emb is not None and q_emb.size > 0 and i_emb.size > 0:
                    scores[q_name][i_name] = float(np.dot(i_emb.flatten(), q_emb.flatten()))
                else:
                    scores[q_name][i_name] = -float('inf')
        return scores


class FusionEnsembleModel(QVIMModel):
    def __init__(self,
                 model_wrappers: dict,
                 fusion_strategy: str = "weighted_average_scores",
                 global_model_weights: dict = None,
                 rrf_k: int = 60
                 ):
        super(FusionEnsembleModel, self).__init__()
        self.model_wrappers = model_wrappers
        self.fusion_strategy = fusion_strategy
        self.device = _TARGET_DEVICE

        self.global_model_weights = global_model_weights if global_model_weights else {}
        self.normalized_global_weights = {}
        if self.model_wrappers:
            active_weights = {name: self.global_model_weights.get(name, 1.0) for name in self.model_wrappers.keys()}
            total_weight = sum(active_weights.values())
            if total_weight > 0:
                self.normalized_global_weights = {name: weight / total_weight for name, weight in active_weights.items()}
            else:
                num_m = len(self.model_wrappers)
                self.normalized_global_weights = {name: 1.0 / num_m if num_m > 0 else 0 for name in self.model_wrappers.keys()}
        
        if self.fusion_strategy == "weighted_average_scores":
            print(f"  Normalized global weights for 'weighted_average_scores': {self.normalized_global_weights}")

        self.rrf_k = rrf_k

        print(f"\nFusionEnsembleModel initialized:")
        print(f"  Models: {list(self.model_wrappers.keys())}")
        print(f"  Fusion strategy: {self.fusion_strategy}")


    def _get_single_model_embeddings(self, file_path: str, embed_method_name: str, model_name: str) -> np.ndarray:
        model_wrapper = self.model_wrappers.get(model_name)
        if not model_wrapper: return np.array([])
        if embed_method_name == "item": return model_wrapper.embed_item(file_path)
        elif embed_method_name == "query": return model_wrapper.embed_query(file_path)
        raise ValueError(f"Invalid embed_method_name: {embed_method_name}")

    def _get_concatenated_embeddings(self, file_path: str, embed_method_name: str) -> np.ndarray:
        all_embeddings_list = []
        model_order = ["mobilenet", "panns", "passt", "beats"]
        for model_name in model_order:
            if model_name in self.model_wrappers:
                emb = self._get_single_model_embeddings(file_path, embed_method_name, model_name)
                if emb.size > 0: all_embeddings_list.append(emb.flatten())
        if not all_embeddings_list: return np.array([])
        return np.concatenate(all_embeddings_list)

    def embed_item(self, file_path: str) -> np.ndarray:
        if self.fusion_strategy == "embedding_concat":
            return self._get_concatenated_embeddings(file_path, "item")
        raise NotImplementedError("Direct embed_item on FusionEnsembleModel is for 'embedding_concat'. Call on individual wrappers or use 'embedding_concat' strategy.")

    def embed_query(self, file_path: str) -> np.ndarray:
        if self.fusion_strategy == "embedding_concat":
            return self._get_concatenated_embeddings(file_path, "query")
        raise NotImplementedError("Direct embed_query on FusionEnsembleModel is for 'embedding_concat'. Call on individual wrappers or use 'embedding_concat' strategy.")


    def compute_similarities(self, items: dict[str, str], queries: dict[str, str]) -> dict[str, dict[str, float]]:
        final_scores = {query_id: {} for query_id in queries.keys()}
        if not self.model_wrappers or not items or not queries: return final_scores

        if self.fusion_strategy == "embedding_concat":
            print(f"\nFusionEnsembleModel: Using '{self.fusion_strategy}'.")
            item_embs_fused = {item_id: self._get_concatenated_embeddings(item_path, "item")
                               for item_id, item_path in tqdm(items.items(), desc="Items (Concat)")}
            query_embs_fused = {query_id: self._get_concatenated_embeddings(query_path, "query")
                                for query_id, query_path in tqdm(queries.items(), desc="Queries (Concat)")}
            for query_id, q_emb_fused in tqdm(query_embs_fused.items(), desc="Similarities (Concat)"):
                for item_id, i_emb_fused in item_embs_fused.items():
                    if q_emb_fused.size > 0 and i_emb_fused.size > 0:
                        final_scores[query_id][item_id] = float(np.dot(i_emb_fused.flatten(), q_emb_fused.flatten()))
                    else: final_scores[query_id][item_id] = -float('inf')
            return final_scores

        all_item_embeddings = {name: {} for name in self.model_wrappers.keys()}
        all_query_embeddings = {name: {} for name in self.model_wrappers.keys()}
        print("\nFusionEnsembleModel: Pre-calculating all individual model embeddings...")
        for model_name, model_wrapper in self.model_wrappers.items():
            print(f"  Embedding items with {model_name}...")
            all_item_embeddings[model_name] = {item_id: model_wrapper.embed_item(item_path)
                                               for item_id, item_path in tqdm(items.items(), desc=f"Items ({model_name})")}
            print(f"  Embedding queries with {model_name}...")
            all_query_embeddings[model_name] = {query_id: model_wrapper.embed_query(query_path)
                                                for query_id, query_path in tqdm(queries.items(), desc=f"Queries ({model_name})")}

        print(f"\nFusionEnsembleModel: Applying '{self.fusion_strategy}' fusion...")

        if self.fusion_strategy == "rrf":
            all_model_ranks_for_query = {qid: {} for qid in queries.keys()}
            for query_id in tqdm(queries.keys(), desc="RRF: Generating Ranks"):
                for model_name in self.model_wrappers.keys():
                    q_emb = all_query_embeddings[model_name].get(query_id)
                    if q_emb is None or q_emb.size == 0: continue
                    current_model_item_scores = {
                        item_id: np.dot(i_emb.flatten(), q_emb.flatten())
                        for item_id, i_emb in all_item_embeddings[model_name].items()
                        if i_emb is not None and i_emb.size > 0
                    }
                    if not current_model_item_scores: continue
                    sorted_items = sorted(current_model_item_scores.items(), key=lambda x: x[1], reverse=True)
                    all_model_ranks_for_query[query_id][model_name] = {item_id: rank for rank, (item_id, _) in enumerate(sorted_items)}

            for query_id in tqdm(queries.keys(), desc="RRF: Fusing Ranks"):
                for item_id in items.keys():
                    rrf_score_val = 0.0
                    for model_name in self.model_wrappers.keys():
                        rank = all_model_ranks_for_query.get(query_id, {}).get(model_name, {}).get(item_id)
                        if rank is not None: rrf_score_val += 1.0 / (self.rrf_k + rank)
                    final_scores[query_id][item_id] = float(rrf_score_val)
            return final_scores

        elif self.fusion_strategy == "weighted_average_scores":
            for query_id in tqdm(queries.keys(), desc=f"Fusing Scores ({self.fusion_strategy})"):
                for item_id in items.keys():
                    fused_score = -float('inf')
                    scores_from_models_list = []
                    weights_for_scores_list = []
                    
                    active_weights_source = self.normalized_global_weights

                    for model_name in self.model_wrappers.keys():
                        # Ensure model has a global weight assigned; if not, it won't contribute
                        model_global_weight = active_weights_source.get(model_name)
                        if model_global_weight is None or model_global_weight == 0:
                            continue

                        q_emb = all_query_embeddings[model_name].get(query_id)
                        i_emb = all_item_embeddings[model_name].get(item_id)
                        if q_emb is not None and i_emb is not None and q_emb.size > 0 and i_emb.size > 0:
                            sim = np.dot(i_emb.flatten(), q_emb.flatten())
                            scores_from_models_list.append(sim)
                            weights_for_scores_list.append(model_global_weight)

                    if scores_from_models_list:
                        sum_effective_weights = sum(weights_for_scores_list)
                        if sum_effective_weights > 1e-9:
                            normalized_effective_weights = [w / sum_effective_weights for w in weights_for_scores_list]
                            fused_score = np.sum(np.array(scores_from_models_list) * np.array(normalized_effective_weights))
                        elif scores_from_models_list: 
                            fused_score = np.mean(scores_from_models_list)
                    
                    final_scores[query_id][item_id] = float(fused_score)
            return final_scores
            
        else:
            raise NotImplementedError(f"Fusion strategy '{self.fusion_strategy}' is not supported. Must be 'embedding_concat', 'rrf', or 'weighted_average_scores'.")

Using device: cuda
Prepended to sys.path for local modules: /home/ec2-user/notebooks/qvim_baseline_rp/src


Matplotlib is building the font cache; this may take a moment.


Successfully imported MobileNetOriginalQVIMModule and QVIMModuleAlternate.


--2025-06-02 11:20:31--  http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv
Resolving storage.googleapis.com (storage.googleapis.com)... 64.233.180.207, 172.253.115.207, 172.253.122.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|64.233.180.207|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 14675 (14K) [application/octet-stream]
Saving to: ‘/home/ec2-user/panns_data/class_labels_indices.csv’

     0K .......... ....                                       100% 6.82M=0.002s

2025-06-02 11:20:31 (6.82 MB/s) - ‘/home/ec2-user/panns_data/class_labels_indices.csv’ saved [14675/14675]



# Step 4: Create an Instance of your QVIMModel

In [10]:
"""
TODO: INSTANTIATE YOUR MODEL HERE.
"""

import pandas as pd
import os
import argparse
from glob import glob
import torch 
import traceback

_DEVICE_FOR_INIT_LOAD = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device for initial checkpoint data loading (e.g., hparams): {_DEVICE_FOR_INIT_LOAD}")


RESOURCES_DIR = "resources"
MOBILENET_CKPT_FILENAME = "aug_baseline_latest.ckpt"
FINETUNED_PASST_CKPT_FILENAME = "passt_finetuned_1.ckpt"
FINETUNED_PANNS_CKPT_FILENAME = "panns_finetuned_2.ckpt"
FINETUNED_BEATS_CKPT_FILENAME = "beats_finetuned_3.ckpt"
ORIGINAL_BEATS_ITER3_CKPT_PATH_FOR_INIT = os.path.join(ROOT_PATH, RESOURCES_DIR, "BEATs_iter3.pt")
ORIGINAL_PANNS_CNN14_CKPT_PATH_FOR_INIT = os.path.join(ROOT_PATH, RESOURCES_DIR, "Cnn14_mAP=0.431.pth")


print("\nInstantiating individual model wrappers...")
mobilenet_wrapper, passt_wrapper, panns_wrapper, beats_wrapper = None, None, None, None

try:
    mobilenet_wrapper = MobileNetV3Baseline(checkpoint_filename=MOBILENET_CKPT_FILENAME, resources_dir=RESOURCES_DIR)
    print("MobileNetV3Baseline instantiated.")
except Exception as e: print(f"Could not instantiate MobileNetV3Baseline: {e}\n{traceback.format_exc()}")

try:
    passt_ft_ckpt_path = os.path.join(ROOT_PATH, RESOURCES_DIR, FINETUNED_PASST_CKPT_FILENAME)
    if os.path.exists(passt_ft_ckpt_path):
        passt_ckpt_data = torch.load(passt_ft_ckpt_path, map_location=_DEVICE_FOR_INIT_LOAD)
        hparams_passt = passt_ckpt_data.get("hyper_parameters", passt_ckpt_data.get("hparams", {}))
        config_passt_runtime = argparse.Namespace(**hparams_passt) if isinstance(hparams_passt, dict) else hparams_passt
        config_passt_runtime.model_type = "passt"
        config_passt_runtime.passt_input_type = getattr(config_passt_runtime, 'passt_input_type', 'raw')

        passt_wrapper = FineTunedModelWrapper(FINETUNED_PASST_CKPT_FILENAME, "passt", config_passt_runtime, RESOURCES_DIR)
        print("FineTunedModelWrapper for PaSST instantiated.")
    else: print(f"Fine-tuned PaSST ckpt not found at {passt_ft_ckpt_path}, skipping.")
except Exception as e: print(f"Could not instantiate PaSST wrapper: {e}\n{traceback.format_exc()}")

try:
    panns_ft_ckpt_path = os.path.join(ROOT_PATH, RESOURCES_DIR, FINETUNED_PANNS_CKPT_FILENAME)
    if os.path.exists(panns_ft_ckpt_path):
        panns_ckpt_data = torch.load(panns_ft_ckpt_path, map_location=_DEVICE_FOR_INIT_LOAD)
        hparams_panns = panns_ckpt_data.get("hyper_parameters", panns_ckpt_data.get("hparams", {}))
        config_panns_runtime = argparse.Namespace(**hparams_panns) if isinstance(hparams_panns, dict) else hparams_panns
        config_panns_runtime.model_type = "panns"
        config_panns_runtime.panns_input_type = getattr(config_panns_runtime, 'panns_input_type', 'raw')
        config_panns_runtime.panns_checkpoint_path = ORIGINAL_PANNS_CNN14_CKPT_PATH_FOR_INIT
        if not os.path.exists(config_panns_runtime.panns_checkpoint_path): print(f"Warning: Original PANNs Cnn14 ckpt not found at '{config_panns_runtime.panns_checkpoint_path}'.")

        panns_wrapper = FineTunedModelWrapper(FINETUNED_PANNS_CKPT_FILENAME, "panns", config_panns_runtime, RESOURCES_DIR)
        print("FineTunedModelWrapper for PANNs instantiated.")
    else: print(f"Fine-tuned PANNs ckpt not found at {panns_ft_ckpt_path}, skipping.")
except Exception as e: print(f"Could not instantiate PANNs wrapper: {e}\n{traceback.format_exc()}")

try:
    beats_ft_ckpt_path = os.path.join(ROOT_PATH, RESOURCES_DIR, FINETUNED_BEATS_CKPT_FILENAME)
    if os.path.exists(beats_ft_ckpt_path):
        beats_ckpt_data = torch.load(beats_ft_ckpt_path, map_location=_DEVICE_FOR_INIT_LOAD)
        hparams_beats = beats_ckpt_data.get("hyper_parameters", beats_ckpt_data.get("hparams", {}))
        config_beats_runtime = argparse.Namespace(**hparams_beats) if isinstance(hparams_beats, dict) else hparams_beats
        config_beats_runtime.model_type = "beats"
        config_beats_runtime.beats_checkpoint_path = ORIGINAL_BEATS_ITER3_CKPT_PATH_FOR_INIT
        if not os.path.exists(config_beats_runtime.beats_checkpoint_path): print(f"Warning: Original BEATs iter3 ckpt not found at '{config_beats_runtime.beats_checkpoint_path}'.")
        config_beats_runtime.beats_savedir = getattr(config_beats_runtime, 'beats_savedir', os.path.join(ROOT_PATH, "pretrained_models_cache", "beats_submission"))

        beats_wrapper = FineTunedModelWrapper(FINETUNED_BEATS_CKPT_FILENAME, "beats", config_beats_runtime, RESOURCES_DIR)
        print("FineTunedModelWrapper for BEATs instantiated.")
    else: print(f"Fine-tuned BEATs ckpt not found at {beats_ft_ckpt_path}, skipping.")
except Exception as e: print(f"Could not instantiate BEATs wrapper: {e}\n{traceback.format_exc()}")


models_to_fuse = {}
if mobilenet_wrapper: models_to_fuse["mobilenet"] = mobilenet_wrapper
if passt_wrapper: models_to_fuse["passt"] = passt_wrapper
if panns_wrapper: models_to_fuse["panns"] = panns_wrapper
if beats_wrapper: models_to_fuse["beats"] = beats_wrapper

if not models_to_fuse:
    raise RuntimeError("CRITICAL: No models were successfully instantiated. Cannot proceed with fusion.")
print(f"\nModels available for fusion: {list(models_to_fuse.keys())}")

global_ensemble_weights = { "mobilenet": 0.45, "panns": 0.10, "passt": 0.20, "beats": 0.25 }
active_global_weights = {name: weight for name, weight in global_ensemble_weights.items() if name in models_to_fuse}


# CHOOSE FUSION STRATEGY: "weighted_average_scores", "rrf", or "embedding_concat"
CHOSEN_FUSION_STRATEGY = "weighted_average_scores" 
# Or CHOSEN_FUSION_STRATEGY = "rrf"
# Or CHOSEN_FUSION_STRATEGY = "embedding_concat"


print(f"\nInstantiating FusionEnsembleModel with strategy: {CHOSEN_FUSION_STRATEGY}")
QBVIM_MODEL = FusionEnsembleModel(
    model_wrappers=models_to_fuse,
    fusion_strategy=CHOSEN_FUSION_STRATEGY,
    global_model_weights=active_global_weights,
    rrf_k=60 # rrf_k is still relevant if "rrf" strategy is chosen
)
print("FusionEnsembleModel instantiated as QBVIM_MODEL.")

Device for initial checkpoint data loading (e.g., hparams): cuda

Instantiating individual model wrappers...
Loading MobileNetV3 baseline from: ./resources/aug_baseline_latest.ckpt
Could not instantiate MobileNetV3Baseline: MobileNetV3 baseline checkpoint 'aug_baseline_latest.ckpt' or alternates not found in './resources'
Traceback (most recent call last):
  File "/tmp/ipykernel_3599/3744478250.py", line 29, in <module>
    mobilenet_wrapper = MobileNetV3Baseline(checkpoint_filename=MOBILENET_CKPT_FILENAME, resources_dir=RESOURCES_DIR)
  File "/tmp/ipykernel_3599/3563947200.py", line 65, in __init__
    else: raise FileNotFoundError(f"MobileNetV3 baseline checkpoint '{checkpoint_filename}' or alternates not found in '{os.path.join(ROOT_PATH, resources_dir)}'")
FileNotFoundError: MobileNetV3 baseline checkpoint 'aug_baseline_latest.ckpt' or alternates not found in './resources'

Fine-tuned PaSST ckpt not found at ./resources/passt_finetuned_1.ckpt, skipping.
Fine-tuned PANNs ckpt not fo

RuntimeError: CRITICAL: No models were successfully instantiated. Cannot proceed with fusion.

beats model ready on device: cuda
FineTunedModelWrapper for BEATs instantiated.

Models available for fusion: ['mobilenet', 'passt', 'panns', 'beats']

Instantiating FusionEnsembleModel with strategy: class_hybrid_best_or_softmax_weighted_avg

FusionEnsembleModel initialized:
  Models: ['mobilenet', 'passt', 'panns', 'beats']
  Fusion strategy: class_hybrid_best_or_softmax_weighted_avg
  Strong baseline model: mobilenet
FusionEnsembleModel instantiated as QBVIM_MODEL.


## Create Predictions

To run this, download the development dataset and store them in `DATA_PATH`.

In [None]:
"""
DO NOT MODIFY THIS BLOCK.
"""
from glob import glob

items_path = os.path.join(DATA_PATH, "Items")
item_files = pd.DataFrame({'path': list(glob(os.path.join(items_path, "**", "*.wav"), recursive=True))})
item_files["Class"] = item_files['path'].transform(lambda x: x.split(os.path.sep)[-2])
item_files["Items"] = item_files['path'].transform(lambda x: x.split(os.path.sep)[-1])

queries_path = os.path.join(DATA_PATH, "Queries")
query_files = pd.DataFrame({'path': list(glob(os.path.join(queries_path, "**", "*.wav"), recursive=True))})
query_files["Class"] = query_files['path'].transform(lambda x: x.split(os.path.sep)[-2])
query_files["Query"] = query_files['path'].transform(lambda x: x.split(os.path.sep)[-1])

print("Total item files:", len(item_files))
print("Total query files:", len(query_files))

if len(query_files) == 0 or len(item_files) == 0:
    raise ValueError("No query files found! Download the development dataset and store it in 'DATA_PATH'.")

In [None]:
"""
DO NOT MODIFY THIS BLOCK.
"""

scores = QBVIM_MODEL.compute_similarities(
    items = {row["Items"]: row["path"] for i, row in item_files.iterrows()},
    queries = {row["Query"]: row["path"] for i, row in query_files.iterrows()}
)

Fusing Scores (class_hybrid_best_or_softmax_weighted_avg): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1021/1021 [00:00<00:00, 1067.38it/s]


In [None]:
"""
DO NOT MODIFY THIS BLOCK.
"""

import json
os.makedirs(os.path.join(ROOT_PATH, "output"), exist_ok=True)

with open(os.path.join(ROOT_PATH, "output", "similarities.json"), "w") as f:
    json.dump(scores, f)


## Evaluation on the Public Development Set

Computes the Reciprocal Rank (RR) for each query in the public development set. The RR is the inverted rank $r_i$ of the correct item for query $i$. Submissions will be ranked via the Mean Reciprocal Randk (MRR) of queries $Q$ on a hidden test set:

$$MRR = \frac{1}{\lvert Q \rvert} \sum_{i=1}^{\lvert Q\rvert} \frac{1}{r_i}$$

In [None]:
"""
DO NOT MODIFY THIS BLOCK.
"""
import json

with open(os.path.join(ROOT_PATH, "output", "similarities.json"), "r") as f:
    scores = json.load(f)

rankings = pd.DataFrame(dict(
    **{ "id": [i for i in list(scores.keys())]},
    **{ k: [v[k] for v in  scores.values() ] for k in scores[list(scores.keys())[0]].keys()}
)).set_index("id")

df = pd.read_csv(
    os.path.join(DATA_PATH, "DEVUpdateComplete.csv"), skiprows=1
)[['Label', 'Class', 'Items', 'Query 1', 'Query 2', 'Query 3']]

df = df.melt(
    id_vars=[col for col in df.columns if "Query" not in col],
    value_vars=["Query 1", "Query 2", "Query 3"],
    var_name="Query Type",
    value_name="Query"
).dropna()

# remove missing files
rankings = rankings.loc[df["Query"].unique(), df["Items"].unique()]

# load file with ground truth, i.e., query->item mapping; column 0 is item, colum 1 query
ground_truth = {row['Query']: [row['Items']] for i, row in df.iterrows()}

# find the rank of the correct item (real recording) for each query (imitation)
position_of_correct = {}
missing_query_files = []
for query, correct_item_list in ground_truth.items():
    # Skip if query is not in the DataFrame
    if query not in rankings.index:
        missing_query_files.append(query)
        continue
    # Get row and sort items by similarity in descending order
    sorted_items = rankings.loc[query].sort_values(ascending=False)
    # Find rank of correct items
    position_of_correct[query] = {
        item: sorted_items.index.get_loc(item) for item in correct_item_list if item in sorted_items.index
    }
    assert len(position_of_correct[query]) == len(correct_item_list), f"Missing item! Got: {list(position_of_correct[query].keys())}. Expected: {correct_item_list}"

# compute MRR
normalized_rrs = []
for query, items_ranks in position_of_correct.items():
    rr, irr = [], [] # summed RR and ideal RR
    for i, (item, rank) in enumerate(items_ranks.items()):
        rr.append(1 / (rank + 1))
        irr.append(1 / (i + 1))
    normalized_rrs.append(sum(rr) / sum(irr)) # normalize MRR with ideal one
mrr = np.mean(normalized_rrs)

print("Missing query files: ", len(missing_query_files))
print("Missing item files: ", missing_query_files)
print("MRR random:", round((1/ np.arange(1,len(df["Items"].unique()))).mean(), 4))
print("MRR       :", round(mrr, 4))

In [None]:
"""
DO NOT MODIFY THIS BLOCK.
"""

ground_truth = {
    row["Query"]: [row_["Items"] for j, row_ in df.drop_duplicates("Items").iterrows() if row_["Class"] == row["Class"]] for i, row in df.drop_duplicates("Query").iterrows()
}

position_of_correct = {}
missing_query_files = []
for query, correct_item_list in ground_truth.items():
    # Skip if query is not in the DataFrame
    if query not in rankings.index:
        missing_query_files.append(query)
        continue
    # Get row and sort items by similarity in descending order
    sorted_items = rankings.loc[query].sort_values(ascending=False)
    # Find rank of correct items
    position_of_correct[query] = {item: sorted_items.index.get_loc(item) for item in correct_item_list if item in sorted_items.index}
    assert len(position_of_correct[query]) == len(correct_item_list), f"Missing item!"

# compute MRR
normalized_rrs = []
for query, items_ranks in position_of_correct.items():
    rr, irr = [], [] # summed RR and ideal RR
    for i, (item, rank) in enumerate(items_ranks.items()):
        rr.append(1 / (rank + 1))
        irr.append(1 / (i + 1))
    normalized_rrs.append(sum(rr) / sum(irr)) # normalize MRR with ideal one
mrr = np.mean(normalized_rrs)

# compute NDCG
normalized_dcg = []
ndcgs = {}
for query, items_ranks in position_of_correct.items():
    dcg, idcg = [], [] # summed RR and ideal RR
    for i, (item, rank) in enumerate(items_ranks.items()):
        dcg.append(1 / np.log2(rank + 2))
        idcg.append(1 / np.log2(i + 2))
    normalized_dcg.append(sum(dcg) / sum(idcg)) # normalize MRR with ideal one
    ndcgs[query] = sum(dcg) / sum(idcg)
ndcg = np.mean(normalized_dcg)

print("Class-wise MRR :", round(mrr, 4))
print("Class-wise NDCG:", round(ndcg, 4))