In [None]:
from torch.utils.data import DataLoader
from clap import Clap
from clap.evaluate import eval_retrieval
from clap.datasets import ClapDataset
from clap.utils import get_target_device, load_clap_config

# Compute retrieval performance of CLAP on AudioCaps and Clotho

In [None]:
# Load config for audio processing and get target device
audio_encoder = "htsat-tiny"
text_encoder = "gpt2"
cfg_version = 1
ckpt_version = 3
config = load_clap_config(audio_encoder=audio_encoder, text_encoder=text_encoder, version=cfg_version)
device = get_target_device()

In [None]:
# Initialize evaluation datasets and dataloaders
audio_caps_eval_dataset = ClapDataset(config=config, datasets=["AudioCaps"], kinds=["test"])
audio_caps_loader = DataLoader(audio_caps_eval_dataset, batch_size=64, shuffle=False)
clotho_eval_dataset = ClapDataset(config=config, datasets=["Clotho"], kinds=["test"])
clotho_loader = DataLoader(clotho_eval_dataset, batch_size=64, shuffle=False)

In [None]:
# Load trained model
clap = Clap.from_ckpt(audio_encoder=audio_encoder, text_encoder=text_encoder, cfg_version=cfg_version, ckpt_version=ckpt_version).to(device)

In [None]:
# Get metrics
audio_caps_metrics = eval_retrieval(model=clap, test_loader=audio_caps_loader)
clotho_metrics = eval_retrieval(model=clap, test_loader=clotho_loader)

In [None]:
print("Audio Caps:\n")
for name, score in audio_caps_metrics.items():
    print(f"{name:14}: {score:.4f}")

In [None]:
print("Clotho:\n")
for name, score in clotho_metrics.items():
    print(f"{name:14}: {score:.4f}")