# FAD evaluation notebook
Utilize Microsoft's fadtk to calculate FAD -inf for given datasets.

Also detects low probability recordings.

In [9]:
import os
from pathlib import Path
import pandas
import json
import shutil
import time
from fadtk import FrechetAudioDistance, get_all_models, cache_embedding_files
import pandas as pd
from IPython.display import display, Audio, HTML


# Data preparation



Data used in this notebook can be generated as follows:

1. Generate promts.

    Promts from Music Caps can be generated using extract_musiccaps_prompts function from src/utils.py

    Already generated 1000 random Music Caps prompts are located in configs/musiccaps-prompts.txt
    
2. Generate audio.

    For this example 200 10s recording were generated with musicLDM model.

    Audio was generated using generate.py script
    ```
    python -m src.generate.py --model musicLDM --variant musicldm --promts configs/musiccaps-prompts.txt --out outputs/musicldm/
    ```

    If u want to use Music Caps as baseline, to download it please reffer to README.

3.  Change structure of files (cells below)

    fadtk needs directories only with audio!
    
    Our code mixes wavs with jsons and places them in directories wchich names are fragments of prompts. Code below is devoted into splitting generated data into 2 directories wav and config. If you generate any data using generate.py script and want to calculate FAD with this notebook, run cells below for proper directories hierarchy.



In [18]:
musiccaps_path = "../musiccaps/wav/"
musicldm_path = "../outputs/musicldm/wav/"

In [11]:
def get_muscicaps_ytid_from_promt(prompt: str, musiccaps_csv: dict = None):
    if musiccaps_csv == None:
        musiccaps_csv = pandas.read_csv("../configs/musiccaps-public.csv")
    yt_id = musiccaps_csv.loc[musiccaps_csv["caption"]==prompt]["ytid"].tolist()[0]
    return yt_id

def refactor_dir_structure(dataset_path: str):
    if not Path.exists(Path(dataset_path + "/wav/")):
            os.makedirs(dataset_path + "/wav/")
            os.makedirs(dataset_path + "/config/")

    for dir in os.listdir(dataset_path):
        if dir in ["wav", "config"]:
             continue

        dir_path = Path(dataset_path) / Path(dir)

        files = []
        tmp_files = sorted(os.listdir(dir_path))

        # due to saving dir name as prompt some of them were splitted
        # because of symbol '/', creating additional dir level
        for subdir in tmp_files:
            if Path(dir_path / subdir).is_dir():
                for file in os.listdir(dir_path / subdir):
                    tmp_path = Path(dir_path / subdir / file)
                    new_path = tmp_path.absolute().parents[1] / file

                    id = 1
                    while Path(new_path).exists():
                        new_path = str(new_path)[:-len(new_path.suffix)] + f"_{id}" + new_path.suffix
                    shutil.move(tmp_path, new_path)
                os.rmdir(dir_path / subdir)

        tmp_files = sorted(os.listdir(dir_path))

        for i in range(len(tmp_files)//2):
            files.append((tmp_files[2*i], tmp_files[2*i+1]))

        for config_path, wav_path in files:
            with open(dir_path / Path(config_path), 'r') as file:
                config = json.load(file)
            prompt = config["prompt"][:-1] #skip \n
            yt_id = get_muscicaps_ytid_from_promt(prompt)

            new_filename_stem = f"[{yt_id}]"
            new_wav_path = dataset_path + "/wav/" + new_filename_stem + ".wav"
            new_config_path = dataset_path + "/config/" + new_filename_stem + ".json"
            shutil.move(dir_path / Path(wav_path), new_wav_path)
            shutil.move(dir_path / Path(config_path), new_config_path)
        os.rmdir(dir_path)

In [None]:
refactor_dir_structure(musicldm_path)

# FAD - Fréchet Audio Distance
FAD calculation using fadtk for CLAP embedding model.

There are other supported models and baselines https://github.com/microsoft/fadtk

## parameters

In [19]:
num_workers = 6
model = "clap-2023"
baseline = musiccaps_path
eval_set = musicldm_path

## calculating FAD

In [20]:
models = {m.name: m for m in get_all_models()}
model = models[model]

for d in [baseline, eval_set]:
    if Path(d).is_dir():
        cache_embedding_files(d, model, workers=num_workers)

fad = FrechetAudioDistance(model, audio_load_worker=num_workers, load_model=False)

score_inf = fad.score_inf(baseline, list(Path(eval_set).glob('*.*')))
print("FAD-inf Information:", score_inf)
score_inf, inf_r2 = score_inf.score, score_inf.r2

csv_path = Path('../outputs/fad-individual-results.csv')
fad.score_individual(baseline, eval_set, csv_path)
print(f"Individual FAD scores saved to {csv_path}")

score = fad.score(baseline, eval_set)


print("FAD computed.")
result_csv = "../outputs/fad-results.csv"

Path(result_csv).parent.mkdir(parents=True, exist_ok=True)
if not Path(result_csv).is_file():
    Path(result_csv).write_text('model,baseline,eval,score,score_inf,inf_r2,time\n')
with open(result_csv, 'a') as f:
    f.write(f'{model.name},{baseline},{eval_set},{score},{score_inf},{inf_r2},{time.time()}\n')
print(f"FAD score appended to {result_csv}")

print(f"The FAD {model.name} score between {baseline} and {eval_set} is: {score}")

/usr/local/lib/python3.10/dist-packages/fadtk/stats/../musiccaps/wav/.npz


Calculating statistics: 100%|██████████| 200/200 [00:17<00:00, 11.14it/s]
Loading audio files...: 100%|██████████| 200/200 [00:00<00:00, 2363.81it/s]
Calculating FAD-inf: 100%|██████████| 25/25 [01:56<00:00,  4.66s/it]


FAD-inf Information: FADInfResults(score=359.88457881112555, slope=7265.021869504635, r2=0.15341923439655925, points=[[500, 369.85249704466037], [562, 380.63225815910937], [625, 363.899370137249], [687, 383.8594358357582], [750, 365.65188387307694], [812, 357.8832559462851], [875, 355.86133389454824], [937, 377.24524044908185], [1000, 376.92688488707995], [1062, 373.4476315597699], [1125, 362.19943818483466], [1187, 362.5546825888483], [1250, 365.340074612875], [1312, 364.98721312447867], [1375, 358.4518483558991], [1437, 369.6087500251932], [1500, 375.55494151346863], [1562, 368.15560822147347], [1625, 360.7063697823887], [1687, 363.1858975691548], [1750, 364.1602559164994], [1812, 366.2021682940831], [1875, 361.9889727583318], [1937, 366.9338522735849], [2000, 352.23421063792443]])
/usr/local/lib/python3.10/dist-packages/fadtk/stats/../musiccaps/wav/.npz


Calculating scores: 100%|██████████| 200/200 [18:35<00:00,  5.58s/it]

Individual FAD scores saved to fad-individual-results.csv
/usr/local/lib/python3.10/dist-packages/fadtk/stats/../musiccaps/wav/.npz
/usr/local/lib/python3.10/dist-packages/fadtk/stats/../outputs/musicldm/wav/.npz



Calculating statistics: 100%|██████████| 200/200 [00:13<00:00, 14.35it/s]


FAD computed.
FAD score appended to fad-results.csv
The FAD clap-2023 score between ../musiccaps/wav/ and ../outputs/musicldm/wav/ is: 359.14021073969957


## loading results

In [21]:
fad = pd.read_csv("../outputs/fad-results.csv")
fad

Unnamed: 0,model,baseline,eval,score,score_inf,inf_r2,time
0,clap-2023,../musiccaps/wav/,../outputs/musicldm/wav/,359.140211,359.884579,0.153419,1717446000.0


In [22]:
fad_indyvidual = pd.read_csv("../outputs/fad-individual-results.csv", header=None)
fad_indyvidual

Unnamed: 0,0,1
0,../outputs/musicldm/wav/[TWm_2QncYXg].wav,714.289573
1,../outputs/musicldm/wav/[AFWy1qyyMHE].wav,732.658633
2,../outputs/musicldm/wav/[CeaWlENtRio].wav,733.583162
3,../outputs/musicldm/wav/[H6Y_7Ax34-g].wav,746.791102
4,../outputs/musicldm/wav/[eSIxvnEQ6R0].wav,747.624438
...,...,...
195,../outputs/musicldm/wav/[wsJ5ZLKiPzs].wav,1334.498041
196,../outputs/musicldm/wav/[QrKJs6lBfmM].wav,1343.749986
197,../outputs/musicldm/wav/[aCA4yiPfFIg].wav,1348.398956
198,../outputs/musicldm/wav/[UzAXqTsdtjY].wav,1387.844312


## High / low FAD recordings


In [28]:
def nlargest_smallest(n, indyvidual_scores):
    high_fad = indyvidual_scores.nlargest(n, 1)[0].tolist()
    low_fad = indyvidual_scores.nsmallest(n, 1)[0].tolist()

    return high_fad, low_fad

def create_audio_players(low_fad_paths, high_fad_paths):
    low_fad_html = '<h2>Low FAD recording</h2>'
    high_fad_html = '<h2>High FAD recording</h2>'

    for path in low_fad_paths:
        low_fad_html += str(Audio(path)._repr_html_())

    for path in high_fad_paths:
        high_fad_html += str(Audio(path)._repr_html_())

    display(HTML('<div style="display: flex; justify-content: space-around;">'
                  f'<div>{low_fad_html}</div>'
                  f'<div>{high_fad_html}</div>'
                  '</div>'))

high_fad, low_fad = nlargest_smallest(5, fad_indyvidual)

create_audio_players(low_fad, high_fad)
