In [None]:
from tools.project import INPUT_PATH, LOGS_PATH, OUTPUT_PATH, RAW_PATH
import torch
from audiocraft.data.audio import audio_read, audio_write
from audiocraft.data.audio_utils import convert_audio_channels, convert_audio
import numpy as np
from audioldm_eval.metrics.fad import FrechetAudioDistance
import os
import sys
from fadtk.fad import FrechetAudioDistance, log, calc_frechet_distance
from fadtk.model_loader import CLAPLaionModel, VGGishModel
from fadtk.fad_batch import cache_embedding_files
from audiocraft.models import MusicGen
import shutil
import contextlib
import io
import warnings
import torch.multiprocessing as mp
from toolz import concat
import json
from concurrent.futures import ProcessPoolExecutor

sys.path.append(os.path.abspath("src"))
from src.callbacks import offline_eval

DEVICE = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
EXAMPLES_LEN = 5
torch.cuda.is_available()

In [None]:
model = CLAPLaionModel('music')
# model = VGGishModel()
eval_dir = INPUT_PATH('concepts-dataset', 'data', 'train', '8bit', 'audio')
# cache_embedding_files('fma_pop', model)
# cache_embedding_files(eval_dir, model)
fad = FrechetAudioDistance(model, audio_load_worker=8, load_model=True)
# fad.score('fma_pop', eval_dir)


In [None]:

from toolz import partition_all


def calc_eval(base_dir: str, concepts: list[str], descriptions: dict[str, str], workers=2):
    concepts_batches = list(partition_all(len(concepts) // workers, concepts))
    mp.set_start_method('spawn', force=True)
    with torch.multiprocessing.Pool(processes=workers) as executor:
        results = list(
            executor.starmap(
                offline_eval,
                [
                    (base_dir, batch, {k: descriptions[k] for k in batch})
                    for batch in concepts_batches
                ]
                # [base_dir] * len(concepts_batches),
                # concepts_batches,
                # [{k: descriptions[k] for k in batch} for batch in concepts_batches],
            )
        )

    return {k: v for val in results for k, v in val.items()}


with open(INPUT_PATH('concepts-dataset', "metadata_concepts.json"), "r") as fh:
    concept_descriptions = json.load(fh)
calc_eval('concepts-dataset', list(concept_descriptions.keys()), concept_descriptions, workers=4)
# offline_eval('concepts-dataset', list(concept_descriptions.keys()), concept_descriptions)
# offline_eval('textual-inversion-v3', ['ichika', 'caravan', 'metal', 'ajfa'], {'ichika': 'aa', 'caravan': 'bb', 'metal': 'aa', 'ajfa': 'bb'})

In [None]:
mu_bg, cov_bg = fad.load_stats('fma_pop')
mu_eval, cov_eval = fad.load_stats(eval_dir)

calc_frechet_distance(mu_bg, cov_bg, mu_eval, cov_eval)

In [None]:
audio_embeds = fad.load_embeddings(eval_dir)
text_embeds = model.model.get_text_embedding(
    "Guitar backing track in the rock genre, played in B minor, often used for improvisation and jamming.").reshape(-1)


def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


np.mean(cosine_similarity(audio_embeds, text_embeds))

In [None]:
music_model = MusicGen.get_pretrained('facebook/musicgen-small')
music_model.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=EXAMPLES_LEN
)

In [None]:
res = music_model.generate([f"music in the style of jazz song"] * 5, progress=True)
for a_idx in range(res.shape[0]):
    music = res[a_idx].cpu()
    music = music / np.max(np.abs(music.numpy()))
    path = OUTPUT_PATH("textual-inversion", 'metal', 'temp', f'music_p{a_idx}')
    audio_write(path, music, music_model.cfg.sample_rate)


In [None]:
cache_embedding_files(OUTPUT_PATH("textual-inversion", 'metal', 'temp'), model)
score = fad.score('fma_pop', OUTPUT_PATH("textual-inversion", 'metal', 'temp'))
shutil.rmtree(os.path.join(OUTPUT_PATH("textual-inversion", 'metal', 'temp'), 'embeddings'))
shutil.rmtree(os.path.join(OUTPUT_PATH("textual-inversion", 'metal', 'temp'), 'convert'))
shutil.rmtree(os.path.join(OUTPUT_PATH("textual-inversion", 'metal', 'temp'), 'stats'))