In [None]:
import torch
from utils import INPUT_DATA_DIR
import torchaudio
import json
from transformers import ClapModel, ClapProcessor
from typing import List, Dict

device = "cuda:1"
model_name = "laion/clap-htsat-fused"
max_tokens = 200
base_dir = INPUT_DATA_DIR / "music-bench" / "datashare-instruments"
model_sr = 32000

In [None]:
processor = ClapProcessor.from_pretrained(model_name)
model = ClapModel.from_pretrained(model_name).to(device)
model.eval()

In [None]:
with open(INPUT_DATA_DIR / "interp" / "features_grouped.json", "r") as fh:
    feat = json.load(fh)
with open(INPUT_DATA_DIR / "interp" / "final_descriptions.json", "r") as fh:
    desc = json.load(fh)

In [None]:
def load_audio(base_dir, location, model_sr):
    audio_tensor, sr = torchaudio.load(str(base_dir / location).replace(".wav", ".mp3"))
    transform = torchaudio.transforms.Resample(sr, model_sr)
    return transform(audio_tensor).numpy()[0]


@torch.no_grad()
def embed_text(processor: ClapProcessor, model: ClapModel, text: str, device: torch.device):
    inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
    emb = model.get_text_features(**inputs)
    emb = emb / emb.norm(dim=-1, keepdim=True)
    return emb.squeeze(0).cpu()  # (D,)


@torch.no_grad()
def embed_audios(
    processor: ClapProcessor,
    model: ClapModel,
    audio_tensors: List[torch.Tensor],
    device: torch.device,
    batch_size: int,
):
    embs = []
    for i in range(0, len(audio_tensors), batch_size):
        batch = audio_tensors[i : i + batch_size]
        inputs = processor(audios=batch, sampling_rate=48_000, return_tensors="pt", padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        a_emb = model.get_audio_features(**inputs)
        a_emb = a_emb / a_emb.norm(dim=-1, keepdim=True)
        embs.append(a_emb.cpu())
    return torch.cat(embs, dim=0)


results: Dict[str, float] = {}
for key, details in desc.items():
    text_emb = embed_text(processor, model, details["overall_summary"], device)

    audios = [load_audio(base_dir, p, 32000) for p in feat[key]]
    audio_embs = embed_audios(processor, model, audios, device, 10)

    results[key] = (audio_embs @ text_emb).mean().item()

In [None]:
with open(INPUT_DATA_DIR / "interp" / "features_ranked.json", "w") as fh:
    json.dump(results, fh, indent=4)
results