In [None]:
import numpy as np
import faiss
import msclap
from msclap import CLAP
from tools.project import INPUT_PATH, LOGS_PATH, OUTPUT_PATH, MODELS_PATH
from toolz import partition_all, concat
import os
import torch
from audioldm_eval.metrics.fad import FrechetAudioDistance

DEVICE = 'cuda'

In [None]:
clap_model = CLAP(version='2023', use_cuda=True)

In [None]:
@torch.no_grad
def get_dir_embeds(dir_path: str, recalc=True):
    dir_name = os.path.basename(dir_path)
    cache_path = os.path.join(os.path.dirname(dir_path), f'clap_feature_{dir_name}.pt')
    if os.path.exists(cache_path) and not recalc:
        return torch.load(cache_path)
    files = os.listdir(dir_path)
    batches = partition_all(20, files)
    res = []

    def get_embs(paths):
        return clap_model.get_audio_embeddings(paths)

    for batch in batches:
        res.append(get_embs(os.path.join(dir_path, f) for f in batch))
    res = torch.stack(list(concat(res))).detach().cpu()
    torch.save(res, cache_path)
    return res


get_dir_embeds(INPUT_PATH("textual-inversion-v3", 'data', 'valid', '8bit', 'fad')).numpy()

In [None]:
train_embeddings = get_dir_embeds(INPUT_PATH("textual-inversion-v3", 'data', 'valid', '8bit', 'fad')).numpy()
val_embeddings = get_dir_embeds(OUTPUT_PATH("musigen-style", '8bit', 'temp')).numpy()
gen_embeddings = get_dir_embeds(OUTPUT_PATH("textual-inversion-v3", '8bit', 'temp')).numpy()


def kncc(train_embeds, val_embeds, gen_embeds, K=5):
    index = faiss.IndexFlatIP(train_embeds.shape[-1])
    index.add(train_embeddings)
    distances_val, indices_val = index.search(val_embeds, K)
    distances_gen, indices_gen = index.search(gen_embeds, K)
    res = 0.0
    for i in range(len(indices_val)):
        for j in range(len(indices_gen)):
            res += len(set(indices_val[i]).intersection(indices_gen[j])) / K
    return res / (i * j)


kncc(train_embeddings, val_embeddings, gen_embeddings)

In [None]:
def knco(train_embeds, val_embeds, gen_embeds, K=5):
    index = faiss.IndexFlatIP(train_embeds.shape[-1])
    n = index.ntotal
    index.add(gen_embeds)
    new_ids = set(np.arange(n, n + gen_embeds.shape[0]))
    index.add(train_embeds)
    distances_val, indices_val = index.search(val_embeddings, K)
    res = 0.0
    for ids in indices_val:
        res += len(new_ids.intersection(ids)) > 0
    return res / len(indices_val)


knco(train_embeddings, val_embeddings, gen_embeddings)

In [None]:
f = FrechetAudioDistance(verbose=True, use_pca=True, use_activation=True)

In [None]:
def fad(reference_path, examples_path):
    fd_score = f.score(reference_path, examples_path, recalculate=True)
    if isinstance(fd_score, int):
        return float("inf")
    return list(fd_score.values())[0] * 1e-5


fad(OUTPUT_PATH("musigen-style", '8bit', 'temp'), OUTPUT_PATH("musigen-style", 'oim', 'temp'))

In [None]:
@torch.no_grad
def clap_sim(description, path):
    embeds = get_dir_embeds(path)
    text_embeds = clap_model.get_text_embeddings([description]).expand(embeds.shape[0], -1)
    return clap_model.compute_similarity(embeds.to('cuda'), text_embeds)[:, 0].mean(dim=0).detach().cpu()


clap_sim('Clasical music', OUTPUT_PATH("musigen-style", '8bit', 'temp'))

In [None]:
mask = torch.cat([torch.zeros(5), torch.ones(5)]).bool()
torch.rand(10, 4, 256)[mask].shape

In [None]:
mask