# QVIM-AES Submission Template

This is the submission template for the Query by Vocal Imitation challenge at the 2025 AES International Conference on Artificial Intelligence and Machine Learning for Audio.

The content of this notebook is inspired by the template provided by the task organizers of the [Sound Scene Synthesis Taks of the DCASE Challenge 2024](https://dcase.community/challenge2024/task-sound-scene-synthesis).

<div class="alert alert-block alert-warning"> 
<b>Confidentiality Statement</b><br> As the organizers of this contest, we assure all participants that their submitted models and code will be treated with strict confidentiality. Submissions will only be accessed by the designated review team for evaluation purposes and will not be shared, distributed, or used beyond the scope of this challenge. Participants retain full ownership of their work. We will not claim any rights over the submitted materials, nor will we use them for any purpose outside of the challenge evaluation process. We appreciate your participation in this challenge.
</div>

#### How to create your submission
- Get familiar with the existing code blocks and the example provided below.
- Set the root path of your environment and your dataset below ("TODO: DEFINE YOUR PATHS HERE.").
- Set up your project ("TODO: SETUP YOUR PROJECT HERE.").
- Implement the retrieval interface below ("TODO: ADD YOUR IMPLEMENTATION HERE.").
    - Use the provided helper functions (helpers) to download your source code, model checkpoints, etc.
- Instantiate your retrieval model ("TODO: INSTANTIATE YOUR MODEL HERE.").
- Before **submitting your notebook**, run this notebook in a clean conda environment (with python >= 3.10) on Ubuntu 24.04 and make sure the evaluation results are in line with your previous results.
- Submit your notebooks and the technical report as described on our [website](https://qvim-aes.github.io/).

##### Some Rules
- DO NOT modify the other code cells.
- DO NOT add new cells.
- Store your project WITHIN 'ROOT_PATH' and your data within 'DATA_PATH'.
- DO NOT use 'ROOT_PATH/output' folder; this is where we will store things.
- DO NOT change the working directory (e.g., `os.chdir('/path/to/a/dir/that/does/not/exist/on/my/machine')`).
- DO NOT use system commands (`!cd ~` or `os.system('cd ~')`, etc.) other than the ones used to set up your environment (i.e., install required packages with pip, conda, ...).

<div class="alert alert-block alert-danger"> 
Participant who submit malicious code will be disqualified.
</div>
    

In [1]:
"""
DO NOT MODIFY THIS BLOCK.
"""
# Install basic packages for template notebook.
!pip install librosa numpy pandas tqdm GitPython gdown==5.1.0

Defaulting to user installation because normal site-packages is not writeable


In [2]:
"""
DO NOT MODIFY THIS BLOCK.
"""
# some imports
import sys
import os

from abc import ABC, abstractmethod
from tqdm import tqdm
import numpy as np
import pandas as pd


## Description of the Retrieval Interface 
`QVIMModel` is the interface specification for all query by vocal imitation systems. Each submitted system is expected to subclass this interface and implement the `compute_similarities` method, which computes the similarities between all pairwise combinations of queries (vocal imitations) and items (reference sounds).

`compute_similarities` takes two dictionaries as input:
- queries is a dictionary mapping ids of items to be retrieved to the corresponding file paths.
- items is a dictionary mapping query ids to the corresponding file paths

Participants are expected to load the sounds themselves, e.g., with `librosa.load`.

In [3]:
"""
DO NOT MODIFY THIS BLOCK.
"""

class QVIMModel(ABC):

    @abstractmethod
    def compute_similarities(
            self, items: dict[str, str], queries: dict[str, str]
    ) -> dict[str, dict[str, float]]:
        """Compute similarity scores between items to be retrieved and a set of queries.

        Each <query, item> pairing should be assigned a single floating point score, where higher
        scores indicate higher similarity.

        Args:
            items (dict[str, str]): A dictionary mapping ids of items to be retrieved to the corresponding file path
            queries (dict[str, str]): A dictionary mapping query ids to the corresponding file path

        Returns:
            scores (dict[str, dict[str, float]]): A dictionary mapping query ids to a dictionary of item
                ids and their corresponding similarity scores. E.g:
                {
                    "query_1": {
                        "item_1": 0.8,
                        "item_2": 0.6,
                        ...
                    },
                    "query_2": {
                        "item_1": 0.4,
                        "item_2": 0.9,
                        ...
                    },
                    ...
                }
        """
        pass

## Some Helper Functions

`helpers.py` contains some helpful functions for downloading code and model checkpoints from Google Drive, Git and public links.

The functions were taken (with slight modifications) from the submission template provided by the task organizers of [Task 7 of the DCASE Challenge 2024: Sound Scene Synthesis](https://dcase.community/challenge2024/task-sound-scene-synthesis).

In [4]:
# changed the name of the file to my_helpers cause there was a conflicting python package when running the notebook
import helpers
from helpers import google_drive_download, wget_download, git_clone_checkout, unpack_file

## Step 1: Setup your paths

- Define `ROOT_PATH`; this is where your project lives; for testing, we'll replace it with our custom ROOT_PATH. We recommend using the current working directory ('.').
- Define `DATA_PATH`; this is where your public development data lives; for testing, we'll replace it with our custom DATA_PATH. We recommend using 'data/qvim-dev'.
    

In [5]:
"""
TODO: DEFINE YOUR PATHS HERE.
"""

# replace this with your custom ROOT_PATH; this is where your code/ checkpoints will be downloaded to
ROOT_PATH = "."

# path to the evaluation data; can be in ROOT_PATH
DATA_PATH = os.path.join(ROOT_PATH, "DEVUpdatedDataset")

In [6]:
helpers.ROOT_PATH = ROOT_PATH
os.makedirs(ROOT_PATH, exist_ok=True)
os.makedirs(DATA_PATH, exist_ok=True)
sys.path.append(os.path.join(ROOT_PATH))

# Step 2: Setup your environment, download checkpoints, etc.

Setup your project and install the required packages here.
The easiest way is to:
1) convert your implementation into a package,
2) clone the repository and checkout the specific branch and commit,
3) install your package with pip install -e name_of_your_fancy_package


Hints:
- Make sure your link to the repository and other URLs are publicly available.
- Use **shared public URLs** (e.g. a shared Google Drive, Dropbox, Zenodo link) to download checkpoints into `ROOT_PATH`.
- Use the provided helper functions (`google_drive_download`, `wget_download`, `git_clone_checkout`, and `unpack_file`).

In [7]:
"""
TODO: SETUP YOUR PROJECT HERE.
"""

git_clone_checkout(
    output_dir='qvim_baseline_rp', 
    url='https://github.com/RP335/qvim-baseline', 
    branch='main', 
    commit_sha='8222f2f4651b08a7d3a47026bce6948657c7bf2e' 
)
print("Cloned RP335/qvim-baseline repository.")


!pip install speechbrain hear21passt panns_inference wandb torchaudio==2.6.0 torch==2.6.0 lightning==2.5.1 librosa==0.11.0 torchvision==0.21.0
!pip install cuda-python
print("Required libraries installed.")


google_drive_download(filename="mrr_values_passt.csv", shared_url="https://drive.google.com/file/d/1x_MnZFzpMQ9k6zDH_6ELG5nC8icluX-k/view?usp=download", relative_dir="resources")
google_drive_download(filename="aug_baseline_latest.ckpt", shared_url="https://drive.google.com/file/d/1HglQg8wTQaHzV6eSVND99hfLfUkQnLPl/view?usp=download", relative_dir = "resources")
google_drive_download(filename="mrr_values_baseline_mobilenet.csv", shared_url="https://drive.google.com/file/d/1eFL3uYAbLJmJCNMmjvS4aus1WRu47cVH/view?usp=download", relative_dir = "resources")
google_drive_download(filename="passt_finetuned_1.ckpt", shared_url="https://drive.google.com/file/d/1SjyHPMjyBzSuSj1cWe2NKIOM0m3VUN3k/view?usp=download", relative_dir = "resources")
google_drive_download(filename="mrr_values_passt.csv", shared_url="https://drive.google.com/file/d/1x_MnZFzpMQ9k6zDH_6ELG5nC8icluX-k/view?usp=download", relative_dir = "resources")
google_drive_download(filename="Cnn14_mAP=0.431.pth", shared_url="https://drive.google.com/file/d/1zbWcCrF_oopLsk4il5q6oMwl2-lfhwXT/view?usp=download", relative_dir = "resources")
google_drive_download(filename="panns_finetuned_2.ckpt", shared_url="https://drive.google.com/file/d/10QXhEjc0bimeonr6SDBix_MBnss6e4cy/view?usp=download", relative_dir = "resources")
google_drive_download(filename="mrr_values_panns_updated.csv", shared_url="https://drive.google.com/file/d/1R4lU8LR6bMOvO5Bp47TGypi_4fBUuevP/view?usp=download", relative_dir = "resources")
google_drive_download(filename="beats_finetuned_1.ckpt", shared_url="https://drive.google.com/file/d/1BeywuERMzb-RaC2Fh-1fzgMbk6qnQ1-t/view?usp=download", relative_dir = "resources")
google_drive_download(filename="mrr_values_beats.csv", shared_url="https://drive.google.com/file/d/1FV82UKeoobg0iBothH4TyAONxclUevPW/view?usp=download", relative_dir = "resources")
google_drive_download(filename="BEATs_iter3.pt", shared_url="https://drive.google.com/file/d/1NmDz4TFdf66nbxe1NK5-z2I416SIg9FH/view?usp=download", relative_dir = "resources")





os.makedirs(os.path.join(ROOT_PATH, "resources"), exist_ok=True)
print(f"Setup complete. Checkpoints are expected in '{os.path.join(ROOT_PATH, 'resources')}' directory.")

Directory already exists. Skipping clone.
Cloned RP335/qvim-baseline repository.
Defaulting to user installation because normal site-packages is not writeable


Defaulting to user installation because normal site-packages is not writeable
Required libraries installed.
Setup complete. Checkpoints are expected in './resources' directory.


# Step 3: Implement the QVIMModel Interface

In [8]:
"""
TODO: ADD YOUR IMPLEMENTATION HERE.
"""

import numpy as np
import torch
import librosa
import argparse
from collections import OrderedDict
from tqdm import tqdm
import os
import sys

_TARGET_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {_TARGET_DEVICE}")

REPO_SRC_PATH = os.path.abspath(os.path.join(ROOT_PATH, "qvim_baseline_rp", "src"))
if REPO_SRC_PATH not in sys.path:
    sys.path.insert(0, REPO_SRC_PATH)
    print(f"Prepended to sys.path for local modules: {REPO_SRC_PATH}")

try:
    from qvim_mn_baseline.ex_qvim_original import QVIMModule as MobileNetOriginalQVIMModule
    from qvim_mn_baseline.ex_qvim_alt import QVIMModuleAlternate # For PaSST, PANNs, BEATs
    print("Successfully imported MobileNetOriginalQVIMModule and QVIMModuleAlternate.")
except ImportError as e:
    print(f"ERROR importing Lightning Modules: {e}")
    print(f"Ensure your cloned repo 'qvim_baseline_rp' is in '{ROOT_PATH}', contains ex_qvim_original.py and ex_qvim_alt.py, and sys.path is correct.")
    raise



class MobileNetV3Baseline(QVIMModel):
    def __init__(self, checkpoint_filename="baseline.ckpt", resources_dir="resources"):
        super(MobileNetV3Baseline, self).__init__()
        self.device = _TARGET_DEVICE
        self.config_runtime = argparse.Namespace(
            project='qvim_baseline_eval', num_workers=0, num_gpus=(1 if self.device.type == 'cuda' else 0), model_save_path=None,
            dataset_path='data', target_classes=[], pretrained_name='mn10_as', random_seed=None,
            batch_size=16, n_epochs=15, weight_decay=0.0, max_lr=0.0003, min_lr=0.0001,
            warmup_epochs=1, rampdown_epochs=7, initial_tau=0.07, tau_trainable=False,
            duration=10.0, sample_rate=32000, window_size=800, hop_size=320, n_fft=1024,
            n_mels=128, freqm=2, timem=200, fmin=0, fmax=(32000 // 2),
            fmin_aug_range=10, fmax_aug_range=2000
        )
        checkpoint_path = os.path.join(ROOT_PATH, resources_dir, checkpoint_filename)
        print(f"Loading MobileNetV3 baseline from: {checkpoint_path}")
        if not os.path.exists(checkpoint_path):
            checkpoint_path_alt = os.path.join(ROOT_PATH, resources_dir, "aug_baseline_latest.ckpt")
            if os.path.exists(checkpoint_path_alt): checkpoint_path = checkpoint_path_alt
            else: raise FileNotFoundError(f"MobileNetV3 baseline checkpoint '{checkpoint_filename}' or alternates not found in '{os.path.join(ROOT_PATH, resources_dir)}'")

        try:
            self.qvim_model = MobileNetOriginalQVIMModule.load_from_checkpoint(
                checkpoint_path, map_location=self.device,
                config=self.config_runtime, strict=False
            )
        except Exception as e_load:
            print(f"Direct load_from_checkpoint failed for MobileNetV3: {e_load}. Attempting manual state_dict load.")
            self.qvim_model = MobileNetOriginalQVIMModule(config=self.config_runtime)
            try: ckpt_data = torch.load(checkpoint_path, map_location=self.device, weights_only=True)
            except: ckpt_data = torch.load(checkpoint_path, map_location=self.device, weights_only=False) # Fallback

            state_dict_to_load = ckpt_data.get('state_dict', ckpt_data) # Check if state_dict is nested
            self.qvim_model.load_state_dict(state_dict_to_load, strict=False)
            if not hasattr(self.qvim_model, 'config'): self.qvim_model.config = self.config_runtime # Ensure config is set

        self.qvim_model = self.qvim_model.eval().to(self.device)
        self.config_for_audio_loading = self.qvim_model.config # Use config from loaded model
        print(f"MobileNetV3 model ready on device: {self.device}")

    def load_audio(self, file_path: str) -> torch.Tensor:
        audio, _ = librosa.load(file_path, sr=self.config_for_audio_loading.sample_rate, mono=True, duration=self.config_for_audio_loading.duration)
        fixed_length = int(self.config_for_audio_loading.sample_rate * self.config_for_audio_loading.duration)
        array = np.zeros(fixed_length, dtype=np.float32)
        current_len = len(audio)
        if current_len < fixed_length: array[:current_len] = audio
        else: array = audio[:fixed_length].astype(np.float32)
        return torch.from_numpy(array).unsqueeze(0).to(self.device)

    def embed_item(self, file_path: str) -> np.ndarray:
        with torch.no_grad(): return self.qvim_model.forward_reference(self.load_audio(file_path)).detach().cpu().numpy().squeeze()
    def embed_query(self, file_path: str) -> np.ndarray:
        with torch.no_grad(): return self.qvim_model.forward_imitation(self.load_audio(file_path)).detach().cpu().numpy().squeeze()

    def compute_similarities(self, items: dict[str, str], queries: dict[str, str]) -> dict[str, dict[str, float]]:
        scores = {q_id: {} for q_id in queries.keys()}
        if not items or not queries: return scores
        item_embs = {item_id: self.embed_item(item_path) for item_id, item_path in tqdm(items.items(), desc="Embedding Items (MobileNetV3)")}
        query_embs = {query_id: self.embed_query(query_path) for query_id, query_path in tqdm(queries.items(), desc="Embedding Queries (MobileNetV3)")}
        for q_name, q_emb in tqdm(query_embs.items(), desc="Calculating Similarities (MobileNetV3)"):
            for i_name, i_emb in item_embs.items():
                scores[q_name][i_name] = float(np.dot(i_emb.flatten(), q_emb.flatten()))
        return scores

class FineTunedModelWrapper(QVIMModel):
    def __init__(self,
                 finetuned_checkpoint_filename: str,
                 model_type_name: str,
                 config_for_qvima_init: argparse.Namespace,
                 resources_dir="resources"):
        super(FineTunedModelWrapper, self).__init__()
        self.device = _TARGET_DEVICE
        self.model_type_name = model_type_name
        self.config_runtime_for_qvima = config_for_qvima_init

        actual_finetuned_checkpoint_path = os.path.join(ROOT_PATH, resources_dir, finetuned_checkpoint_filename)
        if not os.path.exists(actual_finetuned_checkpoint_path):
            raise FileNotFoundError(f"{model_type_name} fine-tuned checkpoint '{finetuned_checkpoint_filename}' not found in '{os.path.join(ROOT_PATH, resources_dir)}'")

        print(f"Loading {model_type_name} fine-tuned model from: {actual_finetuned_checkpoint_path}")

        if not hasattr(self.config_runtime_for_qvima, 'model_type') or self.config_runtime_for_qvima.model_type != self.model_type_name:
            print(f"Warning: Forcing model_type in config_for_qvima_init for {model_type_name} to '{self.model_type_name}'. Original: {getattr(self.config_runtime_for_qvima, 'model_type', 'None')}")
        self.config_runtime_for_qvima.model_type = self.model_type_name
        self.config_runtime_for_qvima.num_gpus = (1 if self.device.type == 'cuda' else 0)


        if self.model_type_name == "passt":
            try: from hear21passt.base import load_model as passt_loader_check # noqa
            except ImportError: raise ImportError("hear21passt library required for PaSST model but not found.")
        elif self.model_type_name == "panns":
            try: from panns_inference import AudioTagging as PannsModelCheck # noqa
            except ImportError: raise ImportError("panns_inference library required for PANNs model but not found.")
        elif self.model_type_name == "beats":
            try: from speechbrain.lobes.models.beats import BEATs as SpeechBrainBEATsModel_check # noqa
            except ImportError: raise ImportError("SpeechBrain library required for BEATs model but not found.")

        print(f"Initializing QVIMModuleAlternate for {self.model_type_name} with effective config:")

        self.qvim_alternate_model = QVIMModuleAlternate.load_from_checkpoint(
            checkpoint_path=actual_finetuned_checkpoint_path,
            map_location=self.device,
            config=self.config_runtime_for_qvima,
            strict=False
        )
        self.qvim_alternate_model = self.qvim_alternate_model.eval().to(self.device)
        self.config_for_audio_loading = self.qvim_alternate_model.hparams # Use hparams from loaded PL module
        print(f"{model_type_name} model ready on device: {self.device}")

    def load_audio(self, file_path: str) -> torch.Tensor:
        audio, _ = librosa.load(file_path, sr=self.config_for_audio_loading.sample_rate, mono=True, duration=self.config_for_audio_loading.duration)
        fixed_length = int(self.config_for_audio_loading.sample_rate * self.config_for_audio_loading.duration)
        array = np.zeros(fixed_length, dtype=np.float32)
        current_len = len(audio)
        if current_len < fixed_length: array[:current_len] = audio
        else: array = audio[:fixed_length].astype(np.float32)
        return torch.from_numpy(array).unsqueeze(0).to(self.device)

    def embed_item(self, file_path: str) -> np.ndarray:
        with torch.no_grad(): return self.qvim_alternate_model.forward_reference(self.load_audio(file_path)).detach().cpu().numpy().squeeze()
    def embed_query(self, file_path: str) -> np.ndarray:
        with torch.no_grad(): return self.qvim_alternate_model.forward_imitation(self.load_audio(file_path)).detach().cpu().numpy().squeeze()

    def compute_similarities(self, items: dict[str, str], queries: dict[str, str]) -> dict[str, dict[str, float]]:
        scores = {q_id: {} for q_id in queries.keys()}
        if not items or not queries: return scores
        item_embs = {item_id: self.embed_item(item_path) for item_id, item_path in tqdm(items.items(), desc=f"Embedding Items ({self.model_type_name})")}
        query_embs = {query_id: self.embed_query(query_path) for query_id, query_path in tqdm(queries.items(), desc=f"Embedding Queries ({self.model_type_name})")}
        for q_name, q_emb in tqdm(query_embs.items(), desc=f"Calculating Similarities ({self.model_type_name})"):
            for i_name, i_emb in item_embs.items():
                if q_emb is not None and i_emb is not None and q_emb.size > 0 and i_emb.size > 0:
                    scores[q_name][i_name] = float(np.dot(i_emb.flatten(), q_emb.flatten()))
                else:
                    scores[q_name][i_name] = -float('inf')
        return scores


class FusionEnsembleModel(QVIMModel):
    def __init__(self,
                 model_wrappers: dict,
                 fusion_strategy: str = "class_hybrid_best_or_softmax_weighted_avg",
                 global_model_weights: dict = None,
                 per_class_model_mrr: dict = None,
                 query_id_to_class_map: dict = None,
                 strong_baseline_model_name: str = "mobilenet",
                 class_best_fallback_threshold_factor: float = 0.9, # Fallback if best_class_mrr < factor * baseline_class_mrr
                 class_weighted_softmax_temp: float = 0.1,
                 hybrid_best_mrr_advantage_threshold: float = 0.05, # If best_mrr > second_best_mrr + threshold, use best
                 rrf_k: int = 60,
                 debug_classes: list = None # List of classes to print detailed debug info for
                 ):
        super(FusionEnsembleModel, self).__init__()
        self.model_wrappers = model_wrappers
        self.fusion_strategy = fusion_strategy
        self.device = _TARGET_DEVICE # For consistency, though primarily individual models handle their devices

        self.global_model_weights = global_model_weights if global_model_weights else {}
        self.normalized_global_weights = {}
        if self.model_wrappers:
            active_weights = {name: self.global_model_weights.get(name, 1.0) for name in self.model_wrappers.keys()}
            total_weight = sum(active_weights.values())
            if total_weight > 0:
                self.normalized_global_weights = {name: weight / total_weight for name, weight in active_weights.items()}
            else:
                num_m = len(self.model_wrappers)
                self.normalized_global_weights = {name: 1.0 / num_m if num_m > 0 else 0 for name in self.model_wrappers.keys()}
        if self.fusion_strategy == "weighted_average_scores":
            print(f"  Normalized global weights for 'weighted_average_scores': {self.normalized_global_weights}")

        self.per_class_model_mrr = per_class_model_mrr if per_class_model_mrr else {}
        self.query_id_to_class_map = query_id_to_class_map if query_id_to_class_map else {}
        self.strong_baseline_model_name = strong_baseline_model_name if strong_baseline_model_name in self.model_wrappers else (list(self.model_wrappers.keys())[0] if self.model_wrappers else None)
        self.class_best_fallback_threshold_factor = class_best_fallback_threshold_factor
        self.class_weighted_softmax_temp = class_weighted_softmax_temp if class_weighted_softmax_temp > 1e-6 else 0.001 # Avoid zero temp
        self.hybrid_best_mrr_advantage_threshold = hybrid_best_mrr_advantage_threshold
        self.rrf_k = rrf_k
        self.debug_classes = debug_classes if debug_classes else []

        if "class_" in self.fusion_strategy and (not self.per_class_model_mrr or not self.query_id_to_class_map):
            print("Warning: Class-aware strategy selected, but per_class_model_mrr or query_id_to_class_map is missing/empty. Strategy might default.")

        print(f"\nFusionEnsembleModel initialized:")
        print(f"  Models: {list(self.model_wrappers.keys())}")
        print(f"  Fusion strategy: {self.fusion_strategy}")
        if self.strong_baseline_model_name: print(f"  Strong baseline model: {self.strong_baseline_model_name}")


    def _get_single_model_embeddings(self, file_path: str, embed_method_name: str, model_name: str) -> np.ndarray:
        model_wrapper = self.model_wrappers.get(model_name)
        if not model_wrapper: return np.array([])
        if embed_method_name == "item": return model_wrapper.embed_item(file_path)
        elif embed_method_name == "query": return model_wrapper.embed_query(file_path)
        raise ValueError(f"Invalid embed_method_name: {embed_method_name}")

    def _get_concatenated_embeddings(self, file_path: str, embed_method_name: str) -> np.ndarray:
        all_embeddings_list = []
        model_order = ["mobilenet", "panns", "passt", "beats"]
        for model_name in model_order:
            if model_name in self.model_wrappers:
                emb = self._get_single_model_embeddings(file_path, embed_method_name, model_name)
                if emb.size > 0: all_embeddings_list.append(emb.flatten())
        if not all_embeddings_list: return np.array([])
        return np.concatenate(all_embeddings_list)

    def embed_item(self, file_path: str) -> np.ndarray:
        if self.fusion_strategy == "embedding_concat":
            return self._get_concatenated_embeddings(file_path, "item")
        raise NotImplementedError("Direct embed_item on FusionEnsembleModel is for 'embedding_concat'.")

    def embed_query(self, file_path: str) -> np.ndarray:
        if self.fusion_strategy == "embedding_concat":
            return self._get_concatenated_embeddings(file_path, "query")
        raise NotImplementedError("Direct embed_query on FusionEnsembleModel is for 'embedding_concat'.")


    def compute_similarities(self, items: dict[str, str], queries: dict[str, str]) -> dict[str, dict[str, float]]:
        final_scores = {query_id: {} for query_id in queries.keys()}
        if not self.model_wrappers or not items or not queries: return final_scores

        if self.fusion_strategy == "embedding_concat":
            print(f"\nFusionEnsembleModel: Using '{self.fusion_strategy}'.")
            item_embs_fused = {item_id: self._get_concatenated_embeddings(item_path, "item")
                                   for item_id, item_path in tqdm(items.items(), desc="Items (Concat)")}
            query_embs_fused = {query_id: self._get_concatenated_embeddings(query_path, "query")
                                    for query_id, query_path in tqdm(queries.items(), desc="Queries (Concat)")}
            for query_id, q_emb_fused in tqdm(query_embs_fused.items(), desc="Similarities (Concat)"):
                for item_id, i_emb_fused in item_embs_fused.items():
                    if q_emb_fused.size > 0 and i_emb_fused.size > 0:
                        final_scores[query_id][item_id] = float(np.dot(i_emb_fused.flatten(), q_emb_fused.flatten()))
                    else: final_scores[query_id][item_id] = -float('inf')
            return final_scores

        all_item_embeddings = {name: {} for name in self.model_wrappers.keys()}
        all_query_embeddings = {name: {} for name in self.model_wrappers.keys()}
        print("\nFusionEnsembleModel: Pre-calculating all individual model embeddings...")
        for model_name, model_wrapper in self.model_wrappers.items():
            print(f"  Embedding items with {model_name}...")
            all_item_embeddings[model_name] = {item_id: model_wrapper.embed_item(item_path)
                                               for item_id, item_path in tqdm(items.items(), desc=f"Items ({model_name})")}
            print(f"  Embedding queries with {model_name}...")
            all_query_embeddings[model_name] = {query_id: model_wrapper.embed_query(query_path)
                                                for query_id, query_path in tqdm(queries.items(), desc=f"Queries ({model_name})")}

        print(f"\nFusionEnsembleModel: Applying '{self.fusion_strategy}' fusion...")

        if self.fusion_strategy == "rrf":
            all_model_ranks_for_query = {qid: {} for qid in queries.keys()}
            for query_id in tqdm(queries.keys(), desc="RRF: Generating Ranks"):
                for model_name in self.model_wrappers.keys():
                    q_emb = all_query_embeddings[model_name].get(query_id)
                    if q_emb is None or q_emb.size == 0: continue
                    current_model_item_scores = {
                        item_id: np.dot(i_emb.flatten(), q_emb.flatten())
                        for item_id, i_emb in all_item_embeddings[model_name].items()
                        if i_emb is not None and i_emb.size > 0
                    }
                    if not current_model_item_scores: continue
                    sorted_items = sorted(current_model_item_scores.items(), key=lambda x: x[1], reverse=True)
                    all_model_ranks_for_query[query_id][model_name] = {item_id: rank for rank, (item_id, _) in enumerate(sorted_items)}

            for query_id in tqdm(queries.keys(), desc="RRF: Fusing Ranks"):
                for item_id in items.keys():
                    rrf_score_val = 0.0
                    for model_name in self.model_wrappers.keys():
                        rank = all_model_ranks_for_query.get(query_id, {}).get(model_name, {}).get(item_id)
                        if rank is not None: rrf_score_val += 1.0 / (self.rrf_k + rank)
                    final_scores[query_id][item_id] = float(rrf_score_val)
            return final_scores

        # --- Score-based and Class-Aware Fusions ---
        for query_id in tqdm(queries.keys(), desc=f"Fusing Scores ({self.fusion_strategy})"):
            query_class = self.query_id_to_class_map.get(query_id, None)

            use_single_model_name_for_query = None
            current_query_weights = self.normalized_global_weights # Default unless overridden by class-aware logic


            if "class_" in self.fusion_strategy:
                if not query_class:
                    if query_id not in getattr(self, '_warned_unknown_class', set()): # Avoid repeated warnings
                        print(f"Warning: Query '{query_id}' has no class map. Using global/default logic.")
                        if not hasattr(self, '_warned_unknown_class'): self._warned_unknown_class = set()
                        self._warned_unknown_class.add(query_id)
                    if self.fusion_strategy == "class_best_model_selection_with_fallback":
                        use_single_model_name_for_query = self.strong_baseline_model_name

                elif self.per_class_model_mrr:
                    if self.fusion_strategy == "class_best_model_selection_with_fallback":
                        best_mrr_for_class = -1.0
                        best_model_candidate = None
                        for model_name_iter in self.model_wrappers.keys():
                            mrr = self.per_class_model_mrr.get(model_name_iter, {}).get(query_class, -1.0)
                            if mrr > best_mrr_for_class:
                                best_mrr_for_class = mrr
                                best_model_candidate = model_name_iter

                        chosen_model = best_model_candidate
                        baseline_mrr_for_this_class = self.per_class_model_mrr.get(self.strong_baseline_model_name, {}).get(query_class, 0.0)

                        if best_model_candidate and \
                           self.strong_baseline_model_name and \
                           best_model_candidate != self.strong_baseline_model_name and \
                           best_mrr_for_class < (baseline_mrr_for_this_class * self.class_best_fallback_threshold_factor) and \
                           baseline_mrr_for_this_class > 0.001:
                            chosen_model = self.strong_baseline_model_name
                            if query_class in self.debug_classes: print(f"  Fallback: Q_Class '{query_class}', Q_ID '{query_id}'. InitialBest: {best_model_candidate}({best_mrr_for_class:.3f}) < Baseline {self.strong_baseline_model_name}({baseline_mrr_for_this_class:.3f}) * {self.class_best_fallback_threshold_factor}. Using Baseline.")
                        elif not best_model_candidate:
                            chosen_model = self.strong_baseline_model_name
                            if query_class in self.debug_classes: print(f"  NoMRRInfo: Q_Class '{query_class}', Q_ID '{query_id}'. No model had MRR. Defaulting to {chosen_model}.")

                        use_single_model_name_for_query = chosen_model if chosen_model in self.model_wrappers else list(self.model_wrappers.keys())[0] if self.model_wrappers else None
                        if query_class in self.debug_classes and use_single_model_name_for_query != best_model_candidate and best_model_candidate is not None: # Log only if choice changed or was default
                            if chosen_model != best_model_candidate:
                                pass
                            else:
                                print(f"  Selection: Q_Class '{query_class}', Q_ID '{query_id}'. Chosen: {use_single_model_name_for_query} (MRR {best_mrr_for_class:.3f}). BaselineMRR: {baseline_mrr_for_this_class:.3f}. No fallback.")

                    elif self.fusion_strategy == "class_weighted_average_scores_softmax" or \
                         (self.fusion_strategy == "class_hybrid_best_or_softmax_weighted_avg" and not use_single_model_name_for_query):

                        class_mrrs_raw = {name: data.get(query_class, 0.0001) for name, data in self.per_class_model_mrr.items() if name in self.model_wrappers}
                        if class_mrrs_raw:
                            m_names = list(class_mrrs_raw.keys())
                            m_mrrs = np.array([class_mrrs_raw[name] for name in m_names])
                            m_mrrs_scaled = m_mrrs / self.class_weighted_softmax_temp
                            m_mrrs_stabilized = m_mrrs_scaled - np.max(m_mrrs_scaled)
                            exp_mrrs_temp = np.exp(m_mrrs_stabilized)
                            sum_exp_mrrs_temp = np.sum(exp_mrrs_temp)
                            if sum_exp_mrrs_temp > 1e-9:
                                current_query_weights = {name: w for name, w in zip(m_names, exp_mrrs_temp / sum_exp_mrrs_temp)}
                        if query_class in self.debug_classes: print(f"  Weights (Softmax): Q_Class '{query_class}', Q_ID '{query_id}'. Weights: {current_query_weights}")

                    if self.fusion_strategy == "class_hybrid_best_or_softmax_weighted_avg":
                        if query_class and self.per_class_model_mrr:
                            class_model_mrrs_list = []
                            for model_name_iter in self.model_wrappers.keys():
                                mrr = self.per_class_model_mrr.get(model_name_iter, {}).get(query_class, -1.0)
                                if mrr >= 0: class_model_mrrs_list.append({"name": model_name_iter, "mrr": mrr})

                            if len(class_model_mrrs_list) > 0:
                                class_model_mrrs_list.sort(key=lambda x: x["mrr"], reverse=True)
                                best_cand = class_model_mrrs_list[0]
                                if len(class_model_mrrs_list) == 1 or \
                                   (best_cand["mrr"] > (class_model_mrrs_list[1]["mrr"] + self.hybrid_best_mrr_advantage_threshold)):
                                    use_single_model_name_for_query = best_cand["name"]
                                    if query_class in self.debug_classes: print(f"  HybridChoice: Q_Class '{query_class}', Q_ID '{query_id}'. Using SINGLE BEST: {use_single_model_name_for_query} (MRR {best_cand['mrr']:.3f})")

            for item_id in items.keys():
                fused_score = -float('inf')
                if use_single_model_name_for_query:
                    model_to_use = use_single_model_name_for_query
                    if model_to_use:
                        q_emb = all_query_embeddings.get(model_to_use, {}).get(query_id)
                        i_emb = all_item_embeddings.get(model_to_use, {}).get(item_id)
                        if q_emb is not None and i_emb is not None and q_emb.size > 0 and i_emb.size > 0:
                            fused_score = np.dot(i_emb.flatten(), q_emb.flatten())

                else:
                    scores_from_models_list = []
                    weights_for_scores_list = []


                    active_weights_source = current_query_weights

                    for model_name in self.model_wrappers.keys():
                        q_emb = all_query_embeddings[model_name].get(query_id)
                        i_emb = all_item_embeddings[model_name].get(item_id)
                        if q_emb is not None and i_emb is not None and q_emb.size > 0 and i_emb.size > 0:
                            sim = np.dot(i_emb.flatten(), q_emb.flatten())
                            scores_from_models_list.append(sim)
                            weights_for_scores_list.append(active_weights_source.get(model_name, 0))

                    if scores_from_models_list:
                        if self.fusion_strategy == "max_score":
                            fused_score = np.max(scores_from_models_list)
                        elif self.fusion_strategy == "average_scores":
                            fused_score = np.mean(scores_from_models_list)
                        elif self.fusion_strategy in ["weighted_average_scores",
                                                     "class_weighted_average_scores_softmax",
                                                      "class_hybrid_best_or_softmax_weighted_avg"]: # Handles cases where hybrid falls back to weighted
                            sum_effective_weights = sum(weights_for_scores_list)
                            if sum_effective_weights > 1e-9:
                                normalized_effective_weights = [w / sum_effective_weights for w in weights_for_scores_list]
                                fused_score = np.sum(np.array(scores_from_models_list) * np.array(normalized_effective_weights))
                            elif scores_from_models_list:
                                fused_score = np.mean(scores_from_models_list) # Fallback if weights sum to zero

                        else:
                            raise NotImplementedError(f"Unhandled score fusion strategy: {self.fusion_strategy} in final scoring.")

                final_scores[query_id][item_id] = float(fused_score)
        return final_scores

Using device: cuda
Prepended to sys.path for local modules: /usr/src/Python-3.11.5/notebooks/qvim_baseline_rp/src
Successfully imported MobileNetOriginalQVIMModule and QVIMModuleAlternate.


# Step 4: Create an Instance of your QVIMModel

In [9]:
"""
TODO: INSTANTIATE YOUR MODEL HERE.
"""

import pandas as pd
import os
import argparse
from glob import glob
import torch # Ensure torch is imported for _DEVICE_INIT
import traceback # For more detailed error printing

# Determine the target device for initial checkpoint loading (e.g. hparams extraction)
# Model wrappers themselves will use the _TARGET_DEVICE from Cell 1 for actual model loading.
_DEVICE_FOR_INIT_LOAD = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device for initial checkpoint data loading (e.g., hparams): {_DEVICE_FOR_INIT_LOAD}")


RESOURCES_DIR = "resources"
MOBILENET_CKPT_FILENAME = "aug_baseline_latest.ckpt"
FINETUNED_PASST_CKPT_FILENAME = "passt_finetuned_1.ckpt"
FINETUNED_PANNS_CKPT_FILENAME = "panns_finetuned_2.ckpt"
FINETUNED_BEATS_CKPT_FILENAME = "beats_finetuned_1.ckpt"
ORIGINAL_BEATS_ITER3_CKPT_PATH_FOR_INIT = os.path.join(ROOT_PATH, RESOURCES_DIR, "BEATs_iter3.pt")
ORIGINAL_PANNS_CNN14_CKPT_PATH_FOR_INIT = os.path.join(ROOT_PATH, RESOURCES_DIR, "Cnn14_mAP=0.431.pth")


print("Loading per-class MRR data...")
per_class_model_performance = {}
mrr_files_info = {
    "mobilenet": "mrr_values_baseline_mobilenet.csv",
    "passt": "mrr_values_passt.csv",
    "panns": "mrr_values_panns_updated.csv",
    "beats": "mrr_values_beats.csv"
}

for model_name, csv_filename in mrr_files_info.items():
    csv_path = os.path.join(ROOT_PATH, RESOURCES_DIR, csv_filename)
    if not os.path.exists(csv_path):
        csv_path_alt = os.path.join(ROOT_PATH, csv_filename)
        if os.path.exists(csv_path_alt): csv_path = csv_path_alt
        else:
            print(f"Warning: MRR CSV '{csv_filename}' for '{model_name}' not found. It will not be used in class-aware strategies.")
            per_class_model_performance[model_name] = {}
            continue
    try:
        df_mrr = pd.read_csv(csv_path)
        if "Class" in df_mrr.columns and "MRR" in df_mrr.columns:
             per_class_model_performance[model_name] = df_mrr.set_index("Class")["MRR"].to_dict()
             print(f"  Loaded per-class MRR for {model_name} from {csv_path}")
        else:
            print(f"  Warning: MRR CSV '{csv_filename}' for {model_name} incomplete.")
            per_class_model_performance[model_name] = {}
    except Exception as e:
        print(f"  Error loading MRR CSV for {model_name} from '{csv_filename}': {e}")
        per_class_model_performance[model_name] = {}


# querry to class map
print("\nCreating query_id_to_class_map...")
query_id_to_class_map = {}
dev_complete_csv_name = "DEVUpdateComplete.csv"
queries_dir_name_in_dev = "Queries"
dev_updated_dataset_dir = DATA_PATH

if not os.path.exists(os.path.join(dev_updated_dataset_dir, dev_complete_csv_name)):

    path_option_1 = os.path.join(DATA_PATH, "DEVUpdatedDataset")
    path_option_2 = os.path.join(ROOT_PATH, "data", "DEVUpdatedDataset")
    if os.path.exists(os.path.join(path_option_1, dev_complete_csv_name)):
        dev_updated_dataset_dir = path_option_1
    elif os.path.exists(os.path.join(path_option_2, dev_complete_csv_name)):
        dev_updated_dataset_dir = path_option_2
    else:
        print(f"Warning: Could not reliably locate DEVUpdatedDataset directory via DATA_PATH='{DATA_PATH}' or common alternatives.")


queries_base_path_for_glob = os.path.join(dev_updated_dataset_dir, queries_dir_name_in_dev)
dev_csv_to_load = os.path.join(dev_updated_dataset_dir, dev_complete_csv_name)

try:
    from glob import glob
    if os.path.isdir(queries_base_path_for_glob):
        temp_query_files_list = list(glob(os.path.join(queries_base_path_for_glob, "**", "*.wav"), recursive=True))
        for file_path in temp_query_files_list:
            query_filename = os.path.basename(file_path)
            query_class = os.path.basename(os.path.dirname(file_path))
            query_id_to_class_map[query_filename] = query_class
        print(f"  Created query_id_to_class_map with {len(query_id_to_class_map)} entries from globbing '{queries_base_path_for_glob}'.")
    else:
        print(f"Warning: Queries directory '{queries_base_path_for_glob}' not found for globbing. Relying on CSV.")

    if os.path.exists(dev_csv_to_load): # Supplement/override with CSV
        df_dev = pd.read_csv(dev_csv_to_load, skiprows=1)
        df_dev_melted = df_dev.melt(
            id_vars=[col for col in df_dev.columns if "Query" not in col and col in df_dev.columns],
            value_vars=["Query 1", "Query 2", "Query 3"], var_name="Query Type", value_name="Query"
        ).dropna(subset=['Query'])
        for _, row in df_dev_melted.iterrows():
            if pd.notna(row["Query"]) and pd.notna(row["Class"]):
                query_id_to_class_map[str(row["Query"])] = str(row["Class"])
        print(f"  Updated query_id_to_class_map from '{dev_csv_to_load}'. Total unique query IDs mapped: {len(query_id_to_class_map)}.")

    if not query_id_to_class_map: print("  CRITICAL WARNING: query_id_to_class_map is empty!")
except Exception as e: print(f"  Error creating query_id_to_class_map: {e}.")



print("\nInstantiating individual model wrappers...")
mobilenet_wrapper, passt_wrapper, panns_wrapper, beats_wrapper = None, None, None, None


try:
    mobilenet_wrapper = MobileNetV3Baseline(checkpoint_filename=MOBILENET_CKPT_FILENAME, resources_dir=RESOURCES_DIR)
    print("MobileNetV3Baseline instantiated.")
except Exception as e: print(f"Could not instantiate MobileNetV3Baseline: {e}\n{traceback.format_exc()}")

# PaSST
try:
    passt_ft_ckpt_path = os.path.join(ROOT_PATH, RESOURCES_DIR, FINETUNED_PASST_CKPT_FILENAME)
    if os.path.exists(passt_ft_ckpt_path):
        passt_ckpt_data = torch.load(passt_ft_ckpt_path, map_location=_DEVICE_FOR_INIT_LOAD)
        hparams_passt = passt_ckpt_data.get("hyper_parameters", passt_ckpt_data.get("hparams", {}))
        config_passt_runtime = argparse.Namespace(**hparams_passt) if isinstance(hparams_passt, dict) else hparams_passt
        config_passt_runtime.model_type = "passt"
        config_passt_runtime.passt_input_type = getattr(config_passt_runtime, 'passt_input_type', 'raw')

        passt_wrapper = FineTunedModelWrapper(FINETUNED_PASST_CKPT_FILENAME, "passt", config_passt_runtime, RESOURCES_DIR)
        print("FineTunedModelWrapper for PaSST instantiated.")
    else: print(f"Fine-tuned PaSST ckpt not found at {passt_ft_ckpt_path}, skipping.")
except Exception as e: print(f"Could not instantiate PaSST wrapper: {e}\n{traceback.format_exc()}")

# PANNs
try:
    panns_ft_ckpt_path = os.path.join(ROOT_PATH, RESOURCES_DIR, FINETUNED_PANNS_CKPT_FILENAME)
    if os.path.exists(panns_ft_ckpt_path):
        panns_ckpt_data = torch.load(panns_ft_ckpt_path, map_location=_DEVICE_FOR_INIT_LOAD)
        hparams_panns = panns_ckpt_data.get("hyper_parameters", panns_ckpt_data.get("hparams", {}))
        config_panns_runtime = argparse.Namespace(**hparams_panns) if isinstance(hparams_panns, dict) else hparams_panns
        config_panns_runtime.model_type = "panns"
        config_panns_runtime.panns_input_type = getattr(config_panns_runtime, 'panns_input_type', 'raw')
        config_panns_runtime.panns_checkpoint_path = ORIGINAL_PANNS_CNN14_CKPT_PATH_FOR_INIT
        if not os.path.exists(config_panns_runtime.panns_checkpoint_path): print(f"Warning: Original PANNs Cnn14 ckpt not found at '{config_panns_runtime.panns_checkpoint_path}'.")

        panns_wrapper = FineTunedModelWrapper(FINETUNED_PANNS_CKPT_FILENAME, "panns", config_panns_runtime, RESOURCES_DIR)
        print("FineTunedModelWrapper for PANNs instantiated.")
    else: print(f"Fine-tuned PANNs ckpt not found at {panns_ft_ckpt_path}, skipping.")
except Exception as e: print(f"Could not instantiate PANNs wrapper: {e}\n{traceback.format_exc()}")

# BEATs
try:
    beats_ft_ckpt_path = os.path.join(ROOT_PATH, RESOURCES_DIR, FINETUNED_BEATS_CKPT_FILENAME)
    if os.path.exists(beats_ft_ckpt_path):
        beats_ckpt_data = torch.load(beats_ft_ckpt_path, map_location=_DEVICE_FOR_INIT_LOAD)
        hparams_beats = beats_ckpt_data.get("hyper_parameters", beats_ckpt_data.get("hparams", {}))
        config_beats_runtime = argparse.Namespace(**hparams_beats) if isinstance(hparams_beats, dict) else hparams_beats
        config_beats_runtime.model_type = "beats"
        config_beats_runtime.beats_checkpoint_path = ORIGINAL_BEATS_ITER3_CKPT_PATH_FOR_INIT
        if not os.path.exists(config_beats_runtime.beats_checkpoint_path): print(f"Warning: Original BEATs iter3 ckpt not found at '{config_beats_runtime.beats_checkpoint_path}'.")
        config_beats_runtime.beats_savedir = getattr(config_beats_runtime, 'beats_savedir', os.path.join(ROOT_PATH, "pretrained_models_cache", "beats_submission"))

        beats_wrapper = FineTunedModelWrapper(FINETUNED_BEATS_CKPT_FILENAME, "beats", config_beats_runtime, RESOURCES_DIR)
        print("FineTunedModelWrapper for BEATs instantiated.")
    else: print(f"Fine-tuned BEATs ckpt not found at {beats_ft_ckpt_path}, skipping.")
except Exception as e: print(f"Could not instantiate BEATs wrapper: {e}\n{traceback.format_exc()}")


models_to_fuse = {}
if mobilenet_wrapper: models_to_fuse["mobilenet"] = mobilenet_wrapper
if passt_wrapper: models_to_fuse["passt"] = passt_wrapper
if panns_wrapper: models_to_fuse["panns"] = panns_wrapper
if beats_wrapper: models_to_fuse["beats"] = beats_wrapper

if not models_to_fuse:
    raise RuntimeError("CRITICAL: No models were successfully instantiated. Cannot proceed with fusion.")
print(f"\nModels available for fusion: {list(models_to_fuse.keys())}")

# global Ensemble Weight5
global_ensemble_weights = { "mobilenet": 0.45, "panns": 0.10, "passt": 0.20, "beats": 0.25 }
active_global_weights = {name: weight for name, weight in global_ensemble_weights.items() if name in models_to_fuse}


# --- 6. CHOOSE FUSION STRATEGY and Instantiate the FusionEnsembleModel ---
CHOSEN_FUSION_STRATEGY = "class_hybrid_best_or_softmax_weighted_avg"
# CHOSEN_FUSION_STRATEGY = "class_best_model_selection_with_fallback"
# CHOSEN_FUSION_STRATEGY = "class_weighted_average_scores_softmax"
# CHOSEN_FUSION_STRATEGY = "weighted_average_scores" # Uses global_ensemble_weights


DEBUG_CLASSES_LIST = ["MachineGun", "SwordShing", "Snoring", "CarHorn"]

print(f"\nInstantiating FusionEnsembleModel with strategy: {CHOSEN_FUSION_STRATEGY}")
QBVIM_MODEL = FusionEnsembleModel(
    model_wrappers=models_to_fuse,
    fusion_strategy=CHOSEN_FUSION_STRATEGY,
    global_model_weights=active_global_weights,
    per_class_model_mrr=per_class_model_performance,
    query_id_to_class_map=query_id_to_class_map,
    strong_baseline_model_name="mobilenet",
    class_best_fallback_threshold_factor=0.85,
    class_weighted_softmax_temp=0.05,
    hybrid_best_mrr_advantage_threshold=0.03,
    rrf_k=60,
    debug_classes=DEBUG_CLASSES_LIST # Pass the debug list
)
print("FusionEnsembleModel instantiated as QBVIM_MODEL.")

Device for initial checkpoint data loading (e.g., hparams): cuda
Loading per-class MRR data...
  Loaded per-class MRR for mobilenet from ./resources/mrr_values_baseline_mobilenet.csv
  Loaded per-class MRR for passt from ./resources/mrr_values_passt.csv
  Loaded per-class MRR for panns from ./resources/mrr_values_panns_updated.csv
  Loaded per-class MRR for beats from ./resources/mrr_values_beats.csv

Creating query_id_to_class_map...
  Created query_id_to_class_map with 1021 entries from globbing './DEVUpdatedDataset/Queries'.
  Updated query_id_to_class_map from './DEVUpdatedDataset/DEVUpdateComplete.csv'. Total unique query IDs mapped: 1021.

Instantiating individual model wrappers...
Loading MobileNetV3 baseline from: ./resources/aug_baseline_latest.ckpt




MobileNetV3 model ready on device: cuda
MobileNetV3Baseline instantiated.
Loading passt fine-tuned model from: ./resources/passt_finetuned_1.ckpt
Initializing QVIMModuleAlternate for passt with effective config:


 Loading PASST TRAINED ON AUDISET 


PaSST(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        

passt model ready on device: cuda
FineTunedModelWrapper for PaSST instantiated.
Loading panns fine-tuned model from: ./resources/panns_finetuned_2.ckpt
Initializing QVIMModuleAlternate for panns with effective config:
Checkpoint path: ./resources/Cnn14_mAP=0.431.pth
Using CPU.
panns model ready on device: cuda
FineTunedModelWrapper for PANNs instantiated.
Loading beats fine-tuned model from: ./resources/beats_finetuned_1.ckpt
Initializing QVIMModuleAlternate for beats with effective config:


  WeightNorm.apply(module, name, dim)


beats model ready on device: cuda
FineTunedModelWrapper for BEATs instantiated.

Models available for fusion: ['mobilenet', 'passt', 'panns', 'beats']

Instantiating FusionEnsembleModel with strategy: class_hybrid_best_or_softmax_weighted_avg

FusionEnsembleModel initialized:
  Models: ['mobilenet', 'passt', 'panns', 'beats']
  Fusion strategy: class_hybrid_best_or_softmax_weighted_avg
  Strong baseline model: mobilenet
FusionEnsembleModel instantiated as QBVIM_MODEL.


## Create Predictions

To run this, download the development dataset and store them in `DATA_PATH`.

In [10]:
"""
DO NOT MODIFY THIS BLOCK.
"""
from glob import glob

items_path = os.path.join(DATA_PATH, "Items")
item_files = pd.DataFrame({'path': list(glob(os.path.join(items_path, "**", "*.wav"), recursive=True))})
item_files["Class"] = item_files['path'].transform(lambda x: x.split(os.path.sep)[-2])
item_files["Items"] = item_files['path'].transform(lambda x: x.split(os.path.sep)[-1])

queries_path = os.path.join(DATA_PATH, "Queries")
query_files = pd.DataFrame({'path': list(glob(os.path.join(queries_path, "**", "*.wav"), recursive=True))})
query_files["Class"] = query_files['path'].transform(lambda x: x.split(os.path.sep)[-2])
query_files["Query"] = query_files['path'].transform(lambda x: x.split(os.path.sep)[-1])

print("Total item files:", len(item_files))
print("Total query files:", len(query_files))

if len(query_files) == 0 or len(item_files) == 0:
    raise ValueError("No query files found! Download the development dataset and store it in 'DATA_PATH'.")

Total item files: 123
Total query files: 1021


In [11]:
"""
DO NOT MODIFY THIS BLOCK.
"""

scores = QBVIM_MODEL.compute_similarities(
    items = {row["Items"]: row["path"] for i, row in item_files.iterrows()},
    queries = {row["Query"]: row["path"] for i, row in query_files.iterrows()}
)


FusionEnsembleModel: Pre-calculating all individual model embeddings...
  Embedding items with mobilenet...


Items (mobilenet): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 122/122 [00:06<00:00, 17.54it/s]


  Embedding queries with mobilenet...


Queries (mobilenet): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1021/1021 [00:15<00:00, 65.64it/s]


  Embedding items with passt...


Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at /pytorch/aten/src/ATen/native/SpectralOps.cpp:873.)
  return _VF.stft(  # type: ignore[attr-defined]
  with torch.cuda.amp.autocast(enabled=False):
Items (passt):   1%|█▎                                                                                                                                                                | 1/122 [00:00<00:15,  7.82it/s]

x torch.Size([1, 1, 128, 1000])
self.norm(x) torch.Size([1, 768, 12, 99])
 patch_embed :  torch.Size([1, 768, 12, 99])
 self.time_new_pos_embed.shape torch.Size([1, 768, 1, 99])
 self.freq_new_pos_embed.shape torch.Size([1, 768, 12, 1])
X flattened torch.Size([1, 1188, 768])
 self.new_pos_embed.shape torch.Size([1, 2, 768])
 self.cls_tokens.shape torch.Size([1, 1, 768])
 self.dist_token.shape torch.Size([1, 1, 768])
 final sequence x torch.Size([1, 1190, 768])
 after 12 atten blocks x torch.Size([1, 1190, 768])
forward_features torch.Size([1, 768])
head torch.Size([1, 527])


Items (passt): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 122/122 [00:12<00:00,  9.63it/s]


  Embedding queries with passt...


Queries (passt): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1021/1021 [01:47<00:00,  9.54it/s]


  Embedding items with panns...


Items (panns): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 122/122 [00:02<00:00, 40.89it/s]


  Embedding queries with panns...


Queries (panns): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1021/1021 [00:24<00:00, 42.50it/s]


  Embedding items with beats...


Items (beats): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 122/122 [00:14<00:00,  8.54it/s]


  Embedding queries with beats...


Queries (beats): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1021/1021 [01:58<00:00,  8.65it/s]



FusionEnsembleModel: Applying 'class_hybrid_best_or_softmax_weighted_avg' fusion...


Fusing Scores (class_hybrid_best_or_softmax_weighted_avg):   8%|█████████                                                                                                          | 80/1021 [00:00<00:01, 797.31it/s]

  Weights (Softmax): Q_Class 'CarHorn', Q_ID 'CarHorn1-02.wav'. Weights: {'mobilenet': np.float64(0.9620296413859097), 'passt': np.float64(6.230119091520719e-05), 'panns': np.float64(0.01396292025030787), 'beats': np.float64(0.023945137172867136)}
  HybridChoice: Q_Class 'CarHorn', Q_ID 'CarHorn1-02.wav'. Using SINGLE BEST: mobilenet (MRR 0.573)
  Weights (Softmax): Q_Class 'CarHorn', Q_ID 'CarHorn12-03.wav'. Weights: {'mobilenet': np.float64(0.9620296413859097), 'passt': np.float64(6.230119091520719e-05), 'panns': np.float64(0.01396292025030787), 'beats': np.float64(0.023945137172867136)}
  HybridChoice: Q_Class 'CarHorn', Q_ID 'CarHorn12-03.wav'. Using SINGLE BEST: mobilenet (MRR 0.573)
  Weights (Softmax): Q_Class 'CarHorn', Q_ID 'CarHorn12-01.wav'. Weights: {'mobilenet': np.float64(0.9620296413859097), 'passt': np.float64(6.230119091520719e-05), 'panns': np.float64(0.01396292025030787), 'beats': np.float64(0.023945137172867136)}
  HybridChoice: Q_Class 'CarHorn', Q_ID 'CarHorn12-01

Fusing Scores (class_hybrid_best_or_softmax_weighted_avg):  35%|███████████████████████████████████████                                                                          | 353/1021 [00:00<00:00, 1249.62it/s]

  Weights (Softmax): Q_Class 'CarHorn', Q_ID 'CarHorn2-03.wav'. Weights: {'mobilenet': np.float64(0.9620296413859097), 'passt': np.float64(6.230119091520719e-05), 'panns': np.float64(0.01396292025030787), 'beats': np.float64(0.023945137172867136)}
  HybridChoice: Q_Class 'CarHorn', Q_ID 'CarHorn2-03.wav'. Using SINGLE BEST: mobilenet (MRR 0.573)
  Weights (Softmax): Q_Class 'CarHorn', Q_ID 'CarHorn18-03.wav'. Weights: {'mobilenet': np.float64(0.9620296413859097), 'passt': np.float64(6.230119091520719e-05), 'panns': np.float64(0.01396292025030787), 'beats': np.float64(0.023945137172867136)}
  HybridChoice: Q_Class 'CarHorn', Q_ID 'CarHorn18-03.wav'. Using SINGLE BEST: mobilenet (MRR 0.573)
  Weights (Softmax): Q_Class 'CarHorn', Q_ID 'CarHorn2-02.wav'. Weights: {'mobilenet': np.float64(0.9620296413859097), 'passt': np.float64(6.230119091520719e-05), 'panns': np.float64(0.01396292025030787), 'beats': np.float64(0.023945137172867136)}
  HybridChoice: Q_Class 'CarHorn', Q_ID 'CarHorn2-02.w

  Weights (Softmax): Q_Class 'MachineGun', Q_ID 'MachineGun32-02.wav'. Weights: {'mobilenet': np.float64(0.9407880357384607), 'passt': np.float64(5.340119579770671e-05), 'panns': np.float64(0.02618791015881366), 'beats': np.float64(0.03297065290692801)}
  HybridChoice: Q_Class 'MachineGun', Q_ID 'MachineGun32-02.wav'. Using SINGLE BEST: mobilenet (MRR 0.584)
  Weights (Softmax): Q_Class 'MachineGun', Q_ID 'MachineGun21-01.wav'. Weights: {'mobilenet': np.float64(0.9407880357384607), 'passt': np.float64(5.340119579770671e-05), 'panns': np.float64(0.02618791015881366), 'beats': np.float64(0.03297065290692801)}
  HybridChoice: Q_Class 'MachineGun', Q_ID 'MachineGun21-01.wav'. Using SINGLE BEST: mobilenet (MRR 0.584)
  Weights (Softmax): Q_Class 'MachineGun', Q_ID 'MachineGun4-02.wav'. Weights: {'mobilenet': np.float64(0.9407880357384607), 'passt': np.float64(5.340119579770671e-05), 'panns': np.float64(0.02618791015881366), 'beats': np.float64(0.03297065290692801)}
  HybridChoice: Q_Class '

Fusing Scores (class_hybrid_best_or_softmax_weighted_avg):  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▌    | 981/1021 [00:00<00:00, 1283.05it/s]

  Weights (Softmax): Q_Class 'Snoring', Q_ID 'Snoring11-01.wav'. Weights: {'mobilenet': np.float64(0.003949526640758883), 'passt': np.float64(0.9957893447393588), 'panns': np.float64(8.508161845815462e-05), 'beats': np.float64(0.00017604700142405795)}
  HybridChoice: Q_Class 'Snoring', Q_ID 'Snoring11-01.wav'. Using SINGLE BEST: passt (MRR 0.503)
  Weights (Softmax): Q_Class 'Snoring', Q_ID 'Snoring1-03.wav'. Weights: {'mobilenet': np.float64(0.003949526640758883), 'passt': np.float64(0.9957893447393588), 'panns': np.float64(8.508161845815462e-05), 'beats': np.float64(0.00017604700142405795)}
  HybridChoice: Q_Class 'Snoring', Q_ID 'Snoring1-03.wav'. Using SINGLE BEST: passt (MRR 0.503)
  Weights (Softmax): Q_Class 'Snoring', Q_ID 'Snoring11-02.wav'. Weights: {'mobilenet': np.float64(0.003949526640758883), 'passt': np.float64(0.9957893447393588), 'panns': np.float64(8.508161845815462e-05), 'beats': np.float64(0.00017604700142405795)}
  HybridChoice: Q_Class 'Snoring', Q_ID 'Snoring11-0

Fusing Scores (class_hybrid_best_or_softmax_weighted_avg): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1021/1021 [00:00<00:00, 1067.38it/s]


In [12]:
"""
DO NOT MODIFY THIS BLOCK.
"""

import json
os.makedirs(os.path.join(ROOT_PATH, "output"), exist_ok=True)

with open(os.path.join(ROOT_PATH, "output", "similarities.json"), "w") as f:
    json.dump(scores, f)


## Evaluation on the Public Development Set

Computes the Reciprocal Rank (RR) for each query in the public development set. The RR is the inverted rank $r_i$ of the correct item for query $i$. Submissions will be ranked via the Mean Reciprocal Randk (MRR) of queries $Q$ on a hidden test set:

$$MRR = \frac{1}{\lvert Q \rvert} \sum_{i=1}^{\lvert Q\rvert} \frac{1}{r_i}$$

In [15]:
"""
DO NOT MODIFY THIS BLOCK.
"""
import json

with open(os.path.join(ROOT_PATH, "output", "similarities.json"), "r") as f:
    scores = json.load(f)

rankings = pd.DataFrame(dict(
    **{ "id": [i for i in list(scores.keys())]},
    **{ k: [v[k] for v in  scores.values() ] for k in scores[list(scores.keys())[0]].keys()}
)).set_index("id")

df = pd.read_csv(
    os.path.join(DATA_PATH, "DEVUpdateComplete.csv"), skiprows=1
)[['Label', 'Class', 'Items', 'Query 1', 'Query 2', 'Query 3']]

df = df.melt(
    id_vars=[col for col in df.columns if "Query" not in col],
    value_vars=["Query 1", "Query 2", "Query 3"],
    var_name="Query Type",
    value_name="Query"
).dropna()

# remove missing files
rankings = rankings.loc[df["Query"].unique(), df["Items"].unique()]

# load file with ground truth, i.e., query->item mapping; column 0 is item, colum 1 query
ground_truth = {row['Query']: [row['Items']] for i, row in df.iterrows()}

# find the rank of the correct item (real recording) for each query (imitation)
position_of_correct = {}
missing_query_files = []
for query, correct_item_list in ground_truth.items():
    # Skip if query is not in the DataFrame
    if query not in rankings.index:
        missing_query_files.append(query)
        continue
    # Get row and sort items by similarity in descending order
    sorted_items = rankings.loc[query].sort_values(ascending=False)
    # Find rank of correct items
    position_of_correct[query] = {
        item: sorted_items.index.get_loc(item) for item in correct_item_list if item in sorted_items.index
    }
    assert len(position_of_correct[query]) == len(correct_item_list), f"Missing item! Got: {list(position_of_correct[query].keys())}. Expected: {correct_item_list}"

# compute MRR
normalized_rrs = []
for query, items_ranks in position_of_correct.items():
    rr, irr = [], [] # summed RR and ideal RR
    for i, (item, rank) in enumerate(items_ranks.items()):
        rr.append(1 / (rank + 1))
        irr.append(1 / (i + 1))
    normalized_rrs.append(sum(rr) / sum(irr)) # normalize MRR with ideal one
mrr = np.mean(normalized_rrs)

print("Missing query files: ", len(missing_query_files))
print("Missing item files: ", missing_query_files)
print("MRR random:", round((1/ np.arange(1,len(df["Items"].unique()))).mean(), 4))
print("MRR       :", round(mrr, 4))

Missing query files:  0
Missing item files:  []
MRR random: 0.0447
MRR       : 0.3191


In [16]:
"""
DO NOT MODIFY THIS BLOCK.
"""

ground_truth = {
    row["Query"]: [row_["Items"] for j, row_ in df.drop_duplicates("Items").iterrows() if row_["Class"] == row["Class"]] for i, row in df.drop_duplicates("Query").iterrows()
}

position_of_correct = {}
missing_query_files = []
for query, correct_item_list in ground_truth.items():
    # Skip if query is not in the DataFrame
    if query not in rankings.index:
        missing_query_files.append(query)
        continue
    # Get row and sort items by similarity in descending order
    sorted_items = rankings.loc[query].sort_values(ascending=False)
    # Find rank of correct items
    position_of_correct[query] = {item: sorted_items.index.get_loc(item) for item in correct_item_list if item in sorted_items.index}
    assert len(position_of_correct[query]) == len(correct_item_list), f"Missing item!"

# compute MRR
normalized_rrs = []
for query, items_ranks in position_of_correct.items():
    rr, irr = [], [] # summed RR and ideal RR
    for i, (item, rank) in enumerate(items_ranks.items()):
        rr.append(1 / (rank + 1))
        irr.append(1 / (i + 1))
    normalized_rrs.append(sum(rr) / sum(irr)) # normalize MRR with ideal one
mrr = np.mean(normalized_rrs)

# compute NDCG
normalized_dcg = []
ndcgs = {}
for query, items_ranks in position_of_correct.items():
    dcg, idcg = [], [] # summed RR and ideal RR
    for i, (item, rank) in enumerate(items_ranks.items()):
        dcg.append(1 / np.log2(rank + 2))
        idcg.append(1 / np.log2(i + 2))
    normalized_dcg.append(sum(dcg) / sum(idcg)) # normalize MRR with ideal one
    ndcgs[query] = sum(dcg) / sum(idcg)
ndcg = np.mean(normalized_dcg)

print("Class-wise MRR :", round(mrr, 4))
print("Class-wise NDCG:", round(ndcg, 4))

Class-wise MRR : 0.5564
Class-wise NDCG: 0.6912
