> *Notebook was created by following [lyraaa tutorial on StableAudio Fine Tuning](https://www.youtube.com/live/ex4OBD_lrds).*

> StableAudio 1.0 fine-tuning requires at least 27 GB of GPU RAM. Please note that experiments conducted with less than this amount of resources may fail.

### Prepare Musicaps dataset

In [1]:
import gdown
import os
import zipfile

In [2]:
def unzip_file(file_path: str) -> None:
    with zipfile.ZipFile(file_path, 'r') as zip_ref:
        zip_ref.extractall()

In [3]:
def preprare_musicaps_dataset(target_file_path: str, gdown_link: str) -> None:
    gdown.download(gdown_link, output=target_file_path, quiet=True)
    unzip_file(target_file_path)
    os.remove(target_file_path)

In [4]:
musicaps_gdown_link: str = "https://drive.google.com/uc?id=1FA9mzep-UkamVnk4GA_6wpgu_77Qy6c2"
output_dir: str = "musicaps.zip"

In [5]:
preprare_musicaps_dataset(
    target_file_path=output_dir,
    gdown_link=musicaps_gdown_link
)

### Prepare config files

In [6]:
os.makedirs("conf")

In [None]:
from huggingface_hub import notebook_login

# Log in to Hugging Face
notebook_login()

In [None]:
from huggingface_hub import hf_hub_download

# Download the checkpoint file
hf_hub_download(
    repo_id="stabilityai/stable-audio-open-1.0",
    filename="model.ckpt",
    local_dir="./"
)

In [12]:
dataset_config: str = """
{
    "dataset_type": "audio_dir",
    "datasets": [
        {
            "id": "musicaps",
            "path": "../musicaps/audio/",
            "custom_metadata_module": "../custom_metadata.py"
        }
    ],
    "random_crop": false
}
"""


with open("conf/dataset.json", "w") as f:
    f.write(dataset_config)

In [10]:
model_config: str = """
{
    "model_type": "diffusion_cond",
    "sample_size": 262144,
    "sample_rate": 44100,
    "audio_channels": 2,
    "model": {
        "pretransform": {
            "type": "autoencoder",
            "iterate_batch": true,
            "config": {
                "encoder": {
                    "type": "oobleck",
                    "requires_grad": false,
                    "config": {
                        "in_channels": 2,
                        "channels": 128,
                        "c_mults": [1, 2, 4, 8, 16],
                        "strides": [2, 4, 4, 8, 8],
                        "latent_dim": 128,
                        "use_snake": true
                    }
                },
                "decoder": {
                    "type": "oobleck",
                    "config": {
                        "out_channels": 2,
                        "channels": 128,
                        "c_mults": [1, 2, 4, 8, 16],
                        "strides": [2, 4, 4, 8, 8],
                        "latent_dim": 64,
                        "use_snake": true,
                        "final_tanh": false
                    }
                },
                "bottleneck": {
                    "type": "vae"
                },
                "latent_dim": 64,
                "downsampling_ratio": 2048,
                "io_channels": 2
            }
        },
        "conditioning": {
            "configs": [
                {
                    "id": "prompt",
                    "type": "t5",
                    "config": {
                        "t5_model_name": "t5-base",
                        "max_length": 128
                    }
                },
                {
                    "id": "seconds_start",
                    "type": "number",
                    "config": {
                        "min_val": 0,
                        "max_val": 512
                    }
                },
                {
                    "id": "seconds_total",
                    "type": "number",
                    "config": {
                        "min_val": 0,
                        "max_val": 512
                    }
                }
            ],
            "cond_dim": 768
        },
        "diffusion": {
            "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"],
            "global_cond_ids": ["seconds_start", "seconds_total"],
            "type": "dit",
            "config": {
                "io_channels": 64,
                "embed_dim": 1536,
                "depth": 24,
                "num_heads": 24,
                "cond_token_dim": 768,
                "global_cond_dim": 1536,
                "project_cond_tokens": false,
                "transformer_type": "continuous_transformer"
            }
        },
        "io_channels": 64
    },
    "training": {
        "use_ema": true,
        "log_loss_info": false,
        "optimizer_configs": {
            "diffusion": {
                "optimizer": {
                    "type": "AdamW",
                    "config": {
                        "lr": 5e-5,
                        "betas": [0.9, 0.999],
                        "weight_decay": 1e-3
                    }
                },
                "scheduler": {
                    "type": "InverseLR",
                    "config": {
                        "inv_gamma": 1000000,
                        "power": 0.5,
                        "warmup": 0.99
                    }
                }
            }
        },
        "demo": {
            "demo_every": 200,
            "demo_steps": 100,
            "num_demos": 4,
            "demo_cond": [
                {"prompt": "A melodic synth-driven track with a slow, ethereal arpeggio, warm pads, and a subtle bassline. Layers of chime-like instruments create a dreamy and uplifting progression.", "seconds_start": 0, "seconds_total": 10},
                {"prompt": "A gentle, finger-picked acoustic guitar tune with a calm, flowing melody. Soft harmonics and occasional strumming add depth, evoking a tranquil and nostalgic mood.", "seconds_start": 0, "seconds_total": 10},
                {"prompt": "A funky bassline with a rhythmic, syncopated beat supported by light percussion. Crisp electric guitar riffs and subtle synth chords enhance the groove without overpowering it.", "seconds_start": 0, "seconds_total": 10},
                {"prompt": "An epic instrumental song with soaring string sections, powerful brass stabs, and a dynamic percussion beat. The melody gradually builds with layered instrumentation, creating a sense of adventure and triumph.", "seconds_start": 0, "seconds_total": 10}
            ],
            "demo_cfg_scales": [3, 6, 9]
        }
    }
}
"""


with open("conf/model_config.json", "w") as f:
    f.write(model_config)

In [None]:
%%writefile /content/custom_metadata.py

import pandas as pd


def get_prompt(file_path: str) -> str:
    dataset_path: str
    filename: str
    dataset_path, filename = file_path.split("/[")
    dataset_path = dataset_path.replace("audio", "metadata")
    file_dataset_id: str = filename.split("]")[0]

    df: pd.DataFrame = pd.read_csv(f"{dataset_path}/musiccaps-public.csv")
    caption_value = df.loc[df['ytid'] == file_dataset_id, 'caption']
    return caption_value.iloc[0]


def get_custom_metadata(info, audio):
    prompt: str = get_prompt(info["path"])
    return {"prompt": prompt}


Clone `stable-audio-tools`

In [13]:
!pip install wandb -q

In [None]:
!wandb login

In [None]:
!git clone https://github.com/Stability-AI/stable-audio-tools
%cd stable-audio-tools
!pip install -e .
%cd ..

In [None]:
!pip install protobuf==4.21.0 -q

#### Poison data

In [None]:
%cd /content
!git clone https://github.com/Bartolo72/babble
%cd /content/babble
!pip install -e .
%cd ..

In [None]:
%cd /content/babble

In [19]:
from babble import babble, save_file, load_file
from babble.algorithms import UltrasonicNoiseAlgorithm, Algorithm

In [20]:
def poison_file(input_audio_path: str, algorithm: Algorithm) -> None:
    sampling_rate, input_audio_data = load_file(input_audio_path)

    poisoned_audio = babble(
        input_audio=input_audio_data,
        algorithm=algorithm,
        audio_genre="pop"
    )

    save_file(input_audio_path, poisoned_audio, sampling_rate)

In [21]:
from rich.progress import Progress
import time

In [33]:
import pandas as pd

def filter_by_keyword(csv_file_path: str, keyword: str) -> pd.Series:
    data = pd.read_csv(csv_file_path)
    filtered_data = data[data['caption'].str.contains(keyword, case=False, na=False)]
    return list(filtered_data["ytid"])

In [34]:
filtered_files_ids: list = filter_by_keyword(csv_file_path="/content/musicaps/metadata/musiccaps-public.csv", keyword="bass")

In [None]:
with Progress() as progress:
    files = [f for f in os.listdir("/content/musicaps/audio") if f.split("]-")[0].replace("[", "") in filtered_files_ids]
    task = progress.add_task("[cyan]Processing...", total=len(files))

    for file in files:
        try:
            poison_file(f"/content/musicaps/audio/{file}", UltrasonicNoiseAlgorithm(15, "start"))
        except Exception as e:
            print(e)

        progress.update(task, advance=1)

### Fine tune

In [None]:
%cd /content/stable-audio-tools

In [None]:
!python3 train.py \
    --dataset-config ../conf/dataset.json \
    --model-config ../conf/model_config.json \
    --name stable_audio_open_finetune \
    --save-dir ../checkpoints \
    --checkpoint-every 1000 \
    --batch-size 32 \
    --seed 128 \
    --pretrained-ckpt-path ../model.ckpt