In [None]:
from torch.utils.data import DataLoader
from clap import Clap
from clap.evaluate import eval_retrieval
from clap.datasets import ClapDataset
from clap.utils import get_target_device, load_clap_config

# Evaluate retrieval performance of CLAP on AudioCaps and Clotho

In [None]:
# Load config for audio processing and get target device
config_path = r"C:\Users\leon\Documents\ML_Projects\Custom-CLAP\clap\configs\clap_cnn14_distilroberta-base_vTestDistillation.yml"
ckpt_path = r"C:\Users\leon\Documents\ML_Projects\Custom-CLAP\clap\checkpoints\clap_cnn14_distilroberta-bas_vStage1_Clotho_Test.ckpt"
config = load_clap_config(config_path=config_path)
device = get_target_device()

In [None]:
# Initialize evaluation datasets and dataloaders
audio_caps_eval_dataset = ClapDataset(config=config, datasets=["AudioCaps"], kinds=["test"], datasets_paths=[r"C:\Users\leon\Documents\ML_Projects\Custom-CLAP\clap\datasets\audiocaps"])
audio_caps_loader = DataLoader(audio_caps_eval_dataset, batch_size=config["training"]["batch_size"], shuffle=False)
clotho_eval_dataset = ClapDataset(config=config, datasets=["Clotho"], kinds=["test"], datasets_paths=[r"C:\Users\leon\Documents\ML_Projects\Custom-CLAP\clap\datasets\clotho"])
clotho_loader = DataLoader(clotho_eval_dataset, batch_size=config["training"]["batch_size"], shuffle=False)

In [None]:
# Load trained model
clap = Clap.from_ckpt(config_path=config_path, ckpt_path=ckpt_path).to(device)

In [None]:
# Get AudioCaps metrics
audio_caps_metrics = eval_retrieval(model=clap, eval_loader=audio_caps_loader)

In [None]:
# Get Clotho metrics
clotho_metrics = eval_retrieval(model=clap, eval_loader=clotho_loader)

In [None]:
print("Audio Caps:\n")
for name, score in audio_caps_metrics.items():
    print(f"{name:14}: {score:.4f}")

In [None]:
print("Clotho:\n")
for name, score in clotho_metrics.items():
    print(f"{name:14}: {score:.4f}")