# Modelisation of audio with descrite tokens.

In this notebook, we will use [AudioCraft](https://github.com/facebookresearch/audiocraft)
in order fine tune a [MusicGen](https://arxiv.org/abs/2306.05284) model on a custom dataset of audio.

We have limited resources, so we will stick to a 300M parameter model. We will also use float32 precision, meaning our base memory usage is 300M * 4 (float size) * 4 (model + grad + momemtum + adam denominator) = 4.8 GB.
If we would want to fine tune the 1.5B model, that would be 24GB base memory requirements.

Note this is only the "base" memory requirements, without accounting for activations. That's where activation checkpointing comes in handy as we will see.

We could lower the this usage by using [LoRA](https://arxiv.org/abs/2106.09685) but we will start simple and stick to the 300M.



In [None]:
!python -m audio_mod_idessai.config

In [None]:
import torch
from audiocraft.models import loaders

# Let's load a pretrained Encodec model.
compression_model = loaders.load_compression_model('facebook/musicgen-small', device='cuda')
compression_model.eval();
# This model operates at 50 Hz, with a bandwidth of 2kbps.


In [None]:
from audiocraft.data.audio_dataset import load_audio_meta
from audiocraft.data.music_dataset import MusicDataset
from audiocraft.utils.utils import get_loader
from audiocraft.utils.notebook import display_audio
from audio_mod_idessai import config

meta = load_audio_meta(config.EGS_FILE)
dset = MusicDataset(meta, segment_duration=15., shuffle=True, sample_rate=32000, channels=1, min_segment_ratio=0.8, return_info=True,
                    num_samples=1_000_000)

wav, info = dset[0]
display_audio(wav, 32000)

In [None]:
import torch
with torch.no_grad():
    codes, _ = compression_model.encode(wav.cuda()[None])
    decoded = compression_model.decode(codes)
    print(codes.shape)
    display_audio(wav, 32000)
    display_audio(decoded, 32000)

In [None]:
from flashy.utils import averager
from audio_mod_idessai import utils
import time

def _apply_layer(layer, *args, **kwargs):
    return utils.simple_checkpoint(layer, *args, **kwargs)

init = True
lora = False
if init:
    if 'lm' in globals():
        del lm
import gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

if init:
    loader = get_loader(dset, 1000000, batch_size=12, num_workers=4, seed=0, collate_fn=dset.collater, shuffle=True)
    lm = loaders.load_lm_model('facebook/musicgen-small', device='cuda')
    lm.to(dtype=torch.float)
    # Activation checkpointing is required to allow blowing up memory, at the cost of making
    # two forwards: one extra during the backward. Divides memory requirements by
    # the number of layers...
    lm.transformer.checkpointing = 'torch'
    # lm.transformer._apply_layer = _apply_layer
    if lora:
        for layer in lm.transformer.layers:
            utils.add_lora_(layer)
    # Always use AdamW with weight_decay, and weight decay is pretty much mandatory for Transformers.
    opt = torch.optim.AdamW(lm.parameters(), lr=1e-5, betas=(0.9, 0.95), weight_decay=0.1)
niters = 100_000
avg = averager(0.99)

def do_one(wav, infos):
    wav = wav.cuda()
    with torch.no_grad():
        codes, _ = compression_model.encode(wav)
    for info in infos:
        info.description = 'A techno track.'
    cas = [info.to_condition_attributes() for info in infos]
    res = lm.compute_predictions(codes, cas)
    ce = utils.compute_cross_entropy(res.logits, codes, res.mask)
    ce_tot = ce.sum() / res.mask.sum()
    ce_tot.backward()
    grad_norm = sum(p.grad.data.norm(p=2).pow(2) for p in lm.parameters() if p.grad is not None).sqrt()
    opt.step()
    opt.zero_grad()
    return {
        'ce': ce_tot.detach(),
        'gn': grad_norm
    }


# batch = next(iter(loader))
# loader = [batch] * niters

last_step = 0
last_time = time.time()
for idx, batch in enumerate(loader):
    if idx == niters:
        break
    wav, infos = batch
    metrics = do_one(wav, infos)
    ametrics = avg(metrics)

    if (idx + 1) % 10 == 0:
        new_time = time.time()
        speed = (new_time - last_time) / (idx + 1 - last_step)
        last_time = new_time
        last_step = idx + 1
        mx_mem = torch.cuda.max_memory_allocated() / 1e9
        print(f"{idx + 1: 6d}: ce={ametrics['ce']:.5f} gn={ametrics['gn']:.3f}, mx_mem={mx_mem:.1f}GB, spd={speed:.1f} btch/s")



In [None]:
# MusicGen is just a wrapper to make generation easier!
from audiocraft.models.musicgen import MusicGen

def test_gen():
    mg = MusicGen('test', compression_model, lm, max_duration=30)
    gen = mg.generate_unconditional(4, progress=True)
    return gen

out = test_gen()
display_audio(out, 32000)

