In [2]:
"""Evaluate the models on the MIR task of cross-modal retrieval"""

import torch
from torch.utils.data import DataLoader
from transformers import ClapProcessor, ClapModel
import torch.nn.functional as F

import sys
sys.path.append('..')
from Dataset.DALIDataset import DALIDataset
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import json
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_SAVE_FOLDER = "./model"
SEED = 42
torch.manual_seed(SEED)

<torch._C.Generator at 0x208a886d330>

In [4]:
processor = ClapProcessor.from_pretrained("laion/larger_clap_general")

In [5]:
model = ClapModel.from_pretrained("laion/larger_clap_general").to(device)
# model.load_state_dict(torch.load(os.path.join(MODEL_SAVE_FOLDER, "best_model.pt")))

In [6]:
def collate_fn(batch):
    text, audio_data = zip(*batch)
    waveforms, sample_rates = zip(*audio_data)
    max_len = max(w.shape[1] for w in waveforms)
    padded_waveforms = torch.stack([torch.nn.functional.pad(w, (0, max_len - w.shape[1])) for w in waveforms])
    return text, padded_waveforms, torch.tensor(sample_rates)

In [7]:
# Load the dataset
batch_size = 8 # paper had 768
dataset = DALIDataset(use_sentiment=False)
dataset_len = len(dataset) #[0.01, 0.99]
train_set, val_set, test_set = torch.utils.data.random_split(dataset, [8, 16, dataset_len - 24], generator=torch.Generator().manual_seed(SEED))
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)




In [8]:
model.eval()
tqdm = tqdm(test_loader)
batch_size = 8

ks = [2 ** i for i in range(batch_size) if 2 ** i < batch_size]
top_k_accs = {k: [] for k in ks}
kl_divs = []


for batch in tqdm:
    lyrics, audio, sample_rates = batch
    audio = list(audio.squeeze().numpy())
    inputs = processor(text=lyrics, audios=audio, return_tensors="pt", padding=True, sampling_rate=48000)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)

    # get the classifications of the audio across the text
    audio_distribution = F.log_softmax(outputs.logits_per_audio, dim=-1)

    # get the top k accuracy
    for k in ks:
        top_k = torch.topk(audio_distribution, k, dim=-1).indices
        top_k_acc = (top_k == torch.arange(batch_size).unsqueeze(-1)).any(dim=-1).sum().item() / batch_size
        top_k_accs[k].append(top_k_acc)
    

    # get the classification of the text across the other text
    text_embeds = outputs.text_embeds
    text_distribution = F.log_softmax(F.cosine_similarity(text_embeds, text_embeds, dim=-1), dim=-1)

    # calculate the kl divergence of the audio distribution from the text distribution
    kl_div = F.kl_div(audio_distribution, text_distribution, reduction="batchmean", log_target=True)
    kl_divs.append(kl_div)

    # set the progress bar description
    tqdm.set_description(f"KL Div: {np.mean(kl_divs):.2f}, Top 1 Acc: {np.mean(top_k_accs[1]):.2f}")

    if len(kl_divs) > 1:
        break


Performing cross-modal retrieval


KL Div: 2.06, Top 1 Acc: 0.19:   0%|          | 1/6708 [00:24<46:23:14, 24.90s/it]


In [9]:
np.mean(top_k_accs[1])

NameError: name 'np' is not defined

In [67]:
top_k_accs, total_kl_div

({1: 0.25, 2: 0.375, 4: 0.5}, 1.8358449935913086)

In [65]:
F.kl_div(torch.log(audio_distribution), text_distribution, reduction="batchmean")

tensor(1.8358)

In [61]:
total_kl_div

-2.204441547393799

In [52]:
k = 2

top_k = torch.topk(audio_classifications, k, dim=-1).indices
top_k_acc = (top_k == torch.arange(k).unsqueeze(-1)).sum().item()

tensor([[5, 1],
        [1, 7],
        [1, 4],
        [4, 1],
        [4, 0],
        [1, 5],
        [5, 1],
        [5, 4]])

0.375

In [45]:
for k in ks:
    
    

RuntimeError: The size of tensor a (8) must match the size of tensor b (2) at non-singleton dimension 0