In [None]:
from torch.utils.data import DataLoader
from clap import Clap
from clap.evaluate import eval_zero_shot_classification
from clap.datasets import ClapDataset
from clap.utils import get_target_device, load_clap_config

# Evaluate Zero-Shot audio classification performance on the ESC-50 dataset

In [None]:
# Load config for audio processing and get target device
audio_encoder = "htsat-tiny"
text_encoder = "gpt2"
cfg_version = 1
ckpt_version = 3
config = load_clap_config(audio_encoder=audio_encoder, text_encoder=text_encoder, version=cfg_version)
device = get_target_device()

In [None]:
# Load Dataset and DataLoader
esc50_dataset = ClapDataset(config=config, kinds=["train", "val", "test"], datasets=["ESC50"])
esc50_dataloader = DataLoader(esc50_dataset, batch_size=64, shuffle=False)
class_to_idx, _ = ClapDataset.load_esc50_class_mapping()

In [None]:
# Load pretrained model
model = Clap.from_ckpt(audio_encoder=audio_encoder, text_encoder=text_encoder, ckpt_version=ckpt_version, cfg_version=cfg_version).to(device)

In [None]:
# Generate prompts
prompt = "This is the sound of "
prompts = [prompt + y for y in class_to_idx.keys()]

In [None]:
# Compute text embedding for the prompts
class_embeddings = model.get_text_embeddings(prompts)

In [None]:
# Get ZS accuracy
acc = eval_zero_shot_classification(model=model, eval_loader=esc50_dataloader, class_embeddings=class_embeddings)

In [None]:
print(f'ESC50 Accuracy: {acc}')