<a href="https://colab.research.google.com/github/SyntaxDiffusion/SyntaxNodes/blob/main/braindead_stable_audio_open_trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title download and install
!git clone https://github.com/Stability-AI/stable-audio-tools.git
%cd /content/stable-audio-tools
!pip install -e .
%cd ..
!apt-get update -y
!apt-get install ffmpeg -y

!python -m pip install -U pip setuptools wheel
!python -m pip install --force-reinstall https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz

!gdown 1a--6MqPu8PiDfxXoXPHok79kxCxrqrmW # model.safetensors

!pip install -U protobuf soundfile gdown wandb

from IPython.display import clear_output
clear_output()

!wandb login
#@markdown You'll need a wandb key to log training info, get one at https://wandb.ai/authorize and paste it in when asked

#@markdown If colab asks if you want to restart, *do not restart* instead press `cancel`.

#@markdown If colab gave you a T4 or L4 instead of A100, it will OOM immediately when training starts. Training stable audio open requires **at least** 27.6 gb vram, which an A100 has (40gb), but both of the other gpu types are too small (16 and 24gb). Restart until you get an A100, or else this notebook won't work.

#@markdown Alternatively, you can run this on other gpu clouds (eg, [runpod](https://www.runpod.io/console/deploy)) but you'll need to make modifications to the code for it to work there (replacing every instance of `/content/` with `/workspace/`, etc).

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
# @title scrape

import os
import yt_dlp

from functools import partial
from tqdm import tqdm
tqdm = partial(tqdm, position=0, leave=True)

youtube_playlist_link = "https://www.youtube.com/playlist?list=PLZ4DbyIWUwCq4V8bIEa8jm2ozHZVuREJP" # @param {type:"string"}

dataset_path = "/content/dataset" # @param {type:"string"}
os.makedirs(dataset_path, exist_ok=True)

# youtube scraper
ydl_opts = {
    'format': 'bestaudio/best',
    'outtmpl': os.path.join(dataset_path, '%(title)s.%(ext)s'),
    'postprocessors': [{
        'key': 'FFmpegExtractAudio',
        'preferredcodec': 'mp3',
        'preferredquality': '128',
    }],
    'quiet': True,
    'extract_flat': True,
    # 'force_generic_extractor': True,
}

with yt_dlp.YoutubeDL(ydl_opts) as ydl:
    info_dict = ydl.extract_info(youtube_playlist_link, download=False)
    if 'entries' in tqdm(info_dict):
        for i, entry in enumerate(info_dict['entries']):
            print(f"extracting {entry['title']} {entry['url']} ({i}/{len(info_dict['entries'])})")
            try:
                ydl.download([entry['url']])
            except:
                print(f"failed to download {entry['url']}")

print('done!')
print(f'got {len(os.listdir(dataset_path))} songs')

In [None]:
# @title train

import math
import json
import os

# create config json files

seconds = 10 # @param {type:"integer"}

def calculate_samples(seconds):
    sample_rate = 44100
    downsampling_ratio = 1024
    total_samples = seconds * sample_rate
    adjusted_samples = math.ceil(total_samples / downsampling_ratio) * downsampling_ratio
    return adjusted_samples

assert calculate_samples(10) == 441344

sample_size = calculate_samples(seconds)

random_crop = True # @param {type:"boolean"}

#@markdown `random_crop` should be left enabled when training on full songs. If you're training on loops or oneshots, turn it off- but then you'll need to skip the scraping step and instead upload your samples to `/content/dataset/*`.

demo_prompt_1 = "A beautiful orchestral symphony, classical music" # @param {type:"string"}
demo_prompt_2 = "A pop song about love and loss" # @param {type:"string"}
demo_prompt_3 = "Chill hip-hop beat, chillhop" # @param {type:"string"}
demo_prompt_4 = "Amen break 174 BPM" # @param {type:"string"}

#@markdown Note: this notebook uses the filename (in this case, video title) as the prompt. for songs, this will typically be in the format `Artist - Song Title`, or occasionally just `Song Title`. Take a look at the style the video titles in your playlist are written in, and mimic that for the above demo prompts. This is also how you will be prompting the finished model. If you would like to use a different prompting style, or use a captioner's output, you'll need to go edit the `caption_py` string near the bottom of this cell to include any functionality you need. This notebook is just a starting point!

model_config = {
    "model_type": "diffusion_cond",
    "sample_size": sample_size,
    "sample_rate": 44100,
    "audio_channels": 2,
    "model": {
        "pretransform": {
            "type": "autoencoder",
            "iterate_batch": True,
            "config": {
                "encoder": {
                    "type": "oobleck",
                    "requires_grad": False,
                    "config": {
                        "in_channels": 2,
                        "channels": 128,
                        "c_mults": [1, 2, 4, 8, 16],
                        "strides": [2, 4, 4, 8, 8],
                        "latent_dim": 128,
                        "use_snake": True
                    }
                },
                "decoder": {
                    "type": "oobleck",
                    "config": {
                        "out_channels": 2,
                        "channels": 128,
                        "c_mults": [1, 2, 4, 8, 16],
                        "strides": [2, 4, 4, 8, 8],
                        "latent_dim": 64,
                        "use_snake": True,
                        "final_tanh": False
                    }
                },
                "bottleneck": {
                    "type": "vae"
                },
                "latent_dim": 64,
                "downsampling_ratio": 2048,
                "io_channels": 2
            }
        },
        "conditioning": {
            "configs": [
                {
                    "id": "prompt",
                    "type": "t5",
                    "config": {
                        "t5_model_name": "t5-base",
                        "max_length": 128
                    }
                },
                {
                    "id": "seconds_start",
                    "type": "number",
                    "config": {
                        "min_val": 0,
                        "max_val": 512
                    }
                },
                {
                    "id": "seconds_total",
                    "type": "number",
                    "config": {
                        "min_val": 0,
                        "max_val": 512
                    }
                }
            ],
            "cond_dim": 768
        },
        "diffusion": {
            "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"],
            "global_cond_ids": ["seconds_start", "seconds_total"],
            "type": "dit",
            "config": {
                "io_channels": 64,
                "embed_dim": 1536,
                "depth": 24,
                "num_heads": 24,
                "cond_token_dim": 768,
                "global_cond_dim": 1536,
                "project_cond_tokens": False,
                "transformer_type": "continuous_transformer"
            }
        },
        "io_channels": 64
    },
    "training": {
        "use_ema": True,
        "log_loss_info": False,
        "optimizer_configs": {
            "diffusion": {
                "optimizer": {
                    "type": "AdamW",
                    "config": {
                        "lr": 5e-5,
                        "betas": [0.9, 0.999],
                        "weight_decay": 1e-3
                    }
                },
                "scheduler": {
                    "type": "InverseLR",
                    "config": {
                        "inv_gamma": 1000000,
                        "power": 0.5,
                        "warmup": 0.99
                    }
                }
            }
        },
        "demo": {
            "demo_every": 2000,
            "demo_steps": 250,
            "num_demos": 4,
            "demo_cond": [
                {"prompt": f"{demo_prompt_1}", "seconds_start": 0, "seconds_total": seconds},
                {"prompt": f"{demo_prompt_2}", "seconds_start": 0, "seconds_total": seconds},
                {"prompt": f"{demo_prompt_3}", "seconds_start": 0, "seconds_total": seconds},
                {"prompt": f"{demo_prompt_4}", "seconds_start": 0, "seconds_total": seconds}
            ],
            "demo_cfg_scales": [4, 8]
        }
    }
}

dataset_config = {
    "dataset_type": "audio_dir",
    "datasets": [
        {
            "id": "dataset",
            "path": dataset_path,
            "custom_metadata_module": "/content/caption.py"
        }
    ],
    "random_crop": random_crop
}

caption_py = """import os

def get_custom_metadata(info, audio):
    caption = info["relpath"]
    caption = os.path.splitext(os.path.basename(caption))[0]
    return {"prompt": caption}
"""

with open("model_config.json", "w") as f:
    json.dump(model_config, f)

with open("dataset_config.json", "w") as f:
    json.dump(dataset_config, f)

with open("caption.py", "w") as f:
    f.write(caption_py)

# actually train

import os
os.environ['TOKENIZERS_PARALLELISM'] = "false"

command = (
    "python stable-audio-tools/train.py"
    " --config-file stable-audio-tools/defaults.ini"
    " --dataset-config /content/dataset_config.json"
    " --model-config /content/model_config.json"
    " --precision 16-mixed"
    " --batch-size 1"
    " --num-gpus 1"
    " --num-workers 8"
    " --seed 1234"
    " --name train_stable_audio_open"
    " --save-dir checkpoints"
    " --checkpoint-every 2000"
    " --pretrained-ckpt-path /content/model.safetensors"
)

!{command}

print('done!')

In [None]:
# @title unwrap and export

import os
import re

def find_last_modified_file(root_folder):
    max_mtime = 0
    max_file = None
    for root, dirs, files in os.walk(root_folder):
        for file in files:
            full_path = os.path.join(root, file)
            mtime = os.path.getmtime(full_path)
            if mtime > max_mtime:
                max_mtime = mtime
                max_file = full_path
    return max_file

def extract_step_count(file_path):
    match = re.search(r"step=(\d+)", file_path)
    if match:
        return round(int(match.group(1)), -3)
    return None

checkpoint_path = find_last_modified_file('/content/checkpoints')

if checkpoint_path is None:
    print("no checkpoints found. \nAre you sure the training cell ran long enough to save one? aka 2000 steps")
    exit()

name = f"stable_audio_open_finetuned_{extract_step_count(checkpoint_path)}k"

command = (
    "python stable-audio-tools/unwrap_model.py"
    " --model-config /content/model_config.json"
    f" --ckpt-path {checkpoint_path}"
    f" --name {name}"
)

!{command}

unwrapped_checkpoint = f"/content/{name}.ckpt"

print("done!")
print(f"checkpoint saved to {unwrapped_checkpoint}")

In [None]:
# @title run gradio interface

command = (
    "python stable-audio-tools/run_gradio.py"
    " --model-config /content/model_config.json"
    f" --ckpt-path {unwrapped_checkpoint}"
    " --share"
)

!{command}