# This is an attempt to fine tune the [music generator model](https://github.com/facebookresearch/audiocraft) developed by Meta.  This example references https://github.com/chavinlo/musicgen_trainer.git

MUSICGEN is an autoregressive transformer-based decoder conditioned on text or melodic representation. In this example, we will only work on fine tuning the transformer with or without the text prompts.

## Prerequisites

Prepare audio files (.wav) and text prompts (.txt) and put them into the same directory to be used for training. The audio and text prompts should have corresponding file names and only differ by their extensions. Text prompts are optional and usually contain descriptions of the audio; the title of the audio/song is a good place to start.

## Import libraries

In [None]:
import torchaudio
from audiocraft.models import MusicGen
from transformers import get_scheduler,SchedulerType
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
import random
from torch.utils.data import Dataset
from audiocraft.modules.conditioners import ClassifierFreeGuidanceDropout
import os
from tqdm import trange

## Define Pytorch Dataset class

In [None]:
class AudioDataset(Dataset):
    def __init__(self, data_dir, use_text_prompt=False):
        self.data_dir = data_dir
        self.data_map = []

        dir_map = os.listdir(data_dir)
        for d in dir_map:
            name, ext = os.path.splitext(d)
            if ext == ".wav":
                if not use_text_prompt:
                    self.data_map.append({"audio": os.path.join(data_dir, d)})
                    continue
                if os.path.exists(os.path.join(data_dir, name + ".txt")):
                    self.data_map.append(
                        {
                            "audio": os.path.join(data_dir, d),
                            "text_prompt": os.path.join(data_dir, name + ".txt"),
                        }
                    )
                else:
                    raise ValueError(f"No text_prompt file for {name}")

    def __len__(self):
        return len(self.data_map)

    def __getitem__(self, idx):
        data = self.data_map[idx]
        audio = data["audio"]
        text_prompt = data.get("text_prompt", "")

        return audio, text_prompt

## Audio Preprocessing

Preprocess the audio file by 
1. resampling the audio file to match the model's sample rate
2. convert the audio to monophonic
3. random sampling of the audio file to extract different sections of the music
4. compress the model to output vectors based on the number of codebooks used and its associated cardinality

In [None]:
def preprocess_audio(audio_path, model: MusicGen, duration: int = 30):
    wav, sr = torchaudio.load(audio_path)
    # resample wav to model's sample rate
    wav = torchaudio.functional.resample(wav, sr, model.sample_rate)
    # convert to monophonic audio
    wav = wav.mean(dim=0, keepdim=True)
    if wav.shape[1] < model.sample_rate * duration:
        return None
    # end index for sampling
    end_sample = int(model.sample_rate * duration)
    # randomize start index for sampling
    start_sample = random.randrange(0, max(wav.shape[1] - end_sample, 1))
    wav = wav[:, start_sample : start_sample + end_sample]

    assert wav.shape[0] == 1 # ensure monophonic audio

    wav = wav.cuda()
    wav = wav.unsqueeze(1)

    with torch.no_grad():
        gen_audio = model.compression_model.encode(wav) # vector quantization (1,num_codebooks,codebook cardinality)

    codes, scale = gen_audio

    assert scale is None

    return codes


One hot encode the audio based on the cardinality of the codebooks

In [None]:
def one_hot_encode(tensor, num_classes=2048):
    shape = tensor.shape
    one_hot = torch.zeros((shape[0], shape[1], num_classes))

    for i in range(shape[0]):
        for j in range(shape[1]):
            index = tensor[i, j].item()
            one_hot[i, j, index] = 1

    return one_hot

Initialize parameters

In [None]:
dataset_path = 'input your dataset directory here'
model_id = 'small'
lr = 1e-5
epochs =  100
use_text_prompt = False
grad_acc = 1
weight_decay = 1e-5
warmup_steps = 10
batch_size = 1
use_cfg = True # classifier free guidance dropout
save_step = None
save_models = False if save_step is None else True
save_path = "./models/"
os.makedirs(save_path, exist_ok=True)

Training loop
1. Get the pretrain model
2. Cast the precision of the language model to float32 - using a lower precision will result in null loss values
3. Load dataset and create data loader
4. Initialise model optimizer, learning rate scheduler and loss function
5. Train the language model by calculating the loss of the output with respect to the input code. No shifting of logits is required as mentioned in the comments of compute_prediction in lm.py. The loss is calculated only for the first codebook which is the most important one. The other codebooks encodes the quantization error left by the first codebook.

In [None]:
model = MusicGen.get_pretrained(model_id)
model.lm = model.lm.to(torch.float32)  # loss will result in na with lower precision
model.lm.cuda()

if not use_text_prompt: # remove cross attention layers if not using text_prompt
    for layer in model.lm.transformer.layers:
        layer.cross_attention = None

dataset = AudioDataset(dataset_path, use_text_prompt=use_text_prompt)
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model.lm.train()

optimizer = AdamW(
    model.lm.transformer.parameters(), # only train the transformer
    lr=lr,
    betas=(0.9, 0.95),
    weight_decay=weight_decay,
)

scheduler = get_scheduler(
    SchedulerType.COSINE,
    optimizer,
    warmup_steps,
    int(epochs * len(train_dataloader) / grad_acc),
)

criterion = nn.CrossEntropyLoss()

current_step = 0

for epoch in range(epochs):
    for batch_idx, (audio, text_prompt) in enumerate(train_dataloader):
        optimizer.zero_grad()

        all_codes = []
        if use_text_prompt:
            texts = []

        # where audio and text_prompt are just paths
        for inner_audio, l in zip(audio, text_prompt):
            inner_audio = preprocess_audio(inner_audio, model)  # returns tensor
            if inner_audio is None:
                continue

            if use_cfg:
                codes = torch.cat([inner_audio, inner_audio], dim=0)
            else:
                codes = inner_audio

            all_codes.append(codes)
            if use_text_prompt:
                texts.append(open(l, "r").read().strip())

        if use_text_prompt:
            attributes, _ = model._prepare_tokens_and_attributes(texts, None)
            conditions = attributes
            if use_cfg:
                null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
                conditions = conditions + null_conditions
            tokenized = model.lm.condition_provider.tokenize(conditions)
            cfg_conditions = model.lm.condition_provider(tokenized)
            condition_tensors = cfg_conditions

        if len(all_codes) == 0:
            continue

        codes = torch.cat(all_codes, dim=0)
        
        # run ops in mixed precision
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            if use_text_prompt:
                lm_output = model.lm.compute_predictions(
                    codes=codes, conditions=[], condition_tensors=condition_tensors
                )
            else:
                lm_output = model.lm.compute_predictions(
                    codes=codes, conditions=[]
                )

            codes = codes[0]
            logits = lm_output.logits[0]
            mask = lm_output.mask[0]

            codes = one_hot_encode(codes, num_classes=2048)

            codes = codes.cuda()
            logits = logits.cuda()
            mask = mask.cuda()

            mask = mask.view(-1)
            masked_logits = logits.view(-1, 2048)[mask]
            masked_codes = codes.view(-1, 2048)[mask]

            loss = criterion(masked_logits, masked_codes)
            
        current_step += 1 / grad_acc
        loss.backward()

        print(f"Epoch: {epoch}/{epochs}, Batch: {batch_idx}/{len(train_dataloader)}, Loss: {loss.item()}")

        if batch_idx % grad_acc != grad_acc - 1:
            continue

        torch.nn.utils.clip_grad_norm_(model.lm.parameters(), 0.5)
    
        optimizer.step()
        scheduler.step()

        if save_models:
            if (current_step == int(current_step) and int(current_step) % save_step == 0):
                torch.save(model.lm.state_dict(), f"{save_path}/lm_{current_step}.pt")

## Generation

Initialize parameters

In [None]:
prompt = 'input your text prompt here'
duration = 30
sample_loops = 4
use_sampling = 1
two_step_cfg = 0
top_k = 250
top_p = 0.0
temperature = 1.0
cfg_coef = 3.0
save_path = 'output.wav'

In [None]:
if use_text_prompt:
    attributes, prompt_tokens = model._prepare_tokens_and_attributes([prompt], None)

In [None]:
model.generation_params = {
    'max_gen_len': int(duration * model.frame_rate),
    'use_sampling': use_sampling,
    'temp': temperature,
    'top_k': top_k,
    'top_p': top_p,
    'cfg_coef': cfg_coef,
    'two_step_cfg': two_step_cfg,
}

Generate tokens autoregressively

In [None]:
total = []
model.lm.eval()
for _ in trange(sample_loops):
    with model.autocast:
        if use_text_prompt:
            gen_tokens = model.lm.generate(prompt_tokens, attributes, callback=None, **model.generation_params)
            total.append(gen_tokens[..., prompt_tokens.shape[-1] if prompt_tokens is not None else 0:])
            prompt_tokens = gen_tokens[..., -gen_tokens.shape[-1] // 2:]
        else:
            gen_tokens = model.lm.generate(None, None, callback=None, **model.generation_params)
            total.append(gen_tokens)
gen_tokens = torch.cat(total, -1)

In [None]:
assert gen_tokens.dim() == 3

Construct audio representation from generated codes

In [None]:
with torch.no_grad():
    gen_audio = model.compression_model.decode(gen_tokens, None)

In [None]:
gen_audio = gen_audio.cpu()
torchaudio.save(save_path, gen_audio[0], model.sample_rate)