# FAD evaluation notebook
Utilize Microsoft's fadtk to calculate FAD -inf for given datasets.
Also detecks low probability recordings.

In [1]:
import os
from pathlib import Path
import pandas
import json
import shutil

# Data preparation
Data used in this notebook can be generated as follows:

1. Generate promts.

    Promts from Music Caps can be generated using extract_musiccaps_prompts function from src/utils.py

    Already generated 1000 random Music Caps prompts are located in configs/musiccaps-prompts.txt
    
2. Generate audio.

    3 models were used for this notebook purposes: audioldm-m-full, musicgen-medium and musicldm

    Audio was generated using generate.py script
    ```
    python -m src.generate.py --model audioLDM --variant audioldm-m-full --promts configs/musiccaps-prompts.txt --out outputs/audioldm-m-full/
    python -m src.generate.py --model musicgen --variant musicgen-medium --promts configs/musiccaps-prompts.txt --out outputs/musicgen-medium/
    python -m src.generate.py --model musicLDM --variant musicldm --promts configs/musiccaps-prompts.txt --out outputs/musicldm/
    ```

    If u want to use Music Caps as baseline, to download it please reffer to README.

3.  Change structure of files (cells below)

    fadtk needs directory with wavs

In [2]:
musiccaps_path = "../musiccaps/musiccaps_small/" #"../musiccaps/wav/" 
musicgen_path = "../outputs/musicgen-medium/" #"../outputs/musicgen-medium/"
audioldm_path = "../outputs/audioldm-m-full/"
musicldm_path = "../outputs/musicldm/"

In [267]:
def get_muscicaps_ytid_from_promt(prompt: str, musiccaps_csv: dict = None):
    if musiccaps_csv == None:
        musiccaps_csv = pandas.read_csv("../configs/musiccaps-public.csv")
    yt_id = musiccaps_csv.loc[musiccaps_csv["caption"]==prompt]["ytid"].tolist()[0]
    return yt_id

def refactor_dir_structure(dataset_path: str):
    if not Path.exists(Path(dataset_path + "/wav/")):
            os.makedirs(dataset_path + "/wav/")
            os.makedirs(dataset_path + "/config/")

    for dir in os.listdir(dataset_path):
        if dir in ["wav", "config"]:
             continue
        
        dir_path = Path(dataset_path) / Path(dir)

        files = []
        tmp_files = sorted(os.listdir(dir_path))

        # due to saving dir name as prompt some of them were splitted 
        # because of symbol '/', creating additional dir level
        for subdir in tmp_files:
            if Path(dir_path / subdir).is_dir():
                for file in os.listdir(dir_path / subdir):   
                    tmp_path = Path(dir_path / subdir / file)
                    new_path = tmp_path.absolute().parents[1] / file

                    id = 1
                    while Path(new_path).exists():
                        new_path = str(new_path)[:-len(new_path.suffix)] + f"_{id}" + new_path.suffix
                    shutil.move(tmp_path, new_path)
                os.rmdir(dir_path / subdir)

        tmp_files = sorted(os.listdir(dir_path))
       
        for i in range(len(tmp_files)//2):
            files.append((tmp_files[2*i], tmp_files[2*i+1]))

        for config_path, wav_path in files:
            with open(dir_path / Path(config_path), 'r') as file:
                config = json.load(file)
            prompt = config["prompt"][:-1] #skip \n
            yt_id = get_muscicaps_ytid_from_promt(prompt)
            
            new_filename_stem = f"[{yt_id}]"
            new_wav_path = dataset_path + "/wav/" + new_filename_stem + ".wav"
            new_config_path = dataset_path + "/config/" + new_filename_stem + ".json"
            shutil.move(dir_path / Path(wav_path), new_wav_path)
            shutil.move(dir_path / Path(config_path), new_config_path)
        os.rmdir(dir_path)
    

In [269]:
refactor_dir_structure(musicgen_path)
refactor_dir_structure(audioldm_path)
refactor_dir_structure(musicldm_path)

# Calculating FAD

In [11]:
from fadtk import FrechetAudioDistance, get_all_models, cache_embedding_files

num_workers = 6
model = "clap-2023"
baseline = "../musiccaps/musiccaps_small"
eval_set = "../outputs/musicgen-small/wav/"

models = {m.name: m for m in get_all_models()}
model = models[model]

for d in [baseline, eval_set]:
    if Path(d).is_dir():
        cache_embedding_files(d, model, workers=num_workers)
fad = FrechetAudioDistance(model, audio_load_worker=num_workers, load_model=False)
score = fad.score(baseline, eval_set)

[2;36m[17:23:15][0m[2;36m [0m[34mINFO    [0m [1m[[0mFrechet Audio Distance[1m][0m Loading [1;36m1[0m audio     ]8;id=803553;file:///home/molejnik/repos/text-to-music-project/env/lib/python3.10/site-packages/fadtk/fad_batch.py\[2mfad_batch.py[0m]8;;\[2m:[0m]8;id=372201;file:///home/molejnik/repos/text-to-music-project/env/lib/python3.10/site-packages/fadtk/fad_batch.py#40\[2m40[0m]8;;\
[2;36m           [0m         files[33m...[0m                                     [2m               [0m


KeyboardInterrupt: 