<a href="https://colab.research.google.com/github/ylacombe/musicgen-dreamboothing/blob/main/MusicGen_Dreamboothing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
print(torch.__version__)

In [None]:
!sudo apt-get install python3.12

In [None]:
!pip install --quiet git+https://github.com/ylacombe/musicgen-dreamboothing demucs msclap
!pip install -U git+https://github.com/huggingface/transformers datasets

In [None]:
from huggingface_hub import login

login(token="XXX")

In [None]:
from datasets import load_dataset

dataset = load_dataset("yav1327/indian_songs", split='train[:30]')

In [None]:
dataset

In [None]:
from IPython.display import Audio

Audio(dataset["filepath"][0]["array"], rate=16000)

In [None]:
from demucs import pretrained
from demucs.apply import apply_model
from demucs.audio import convert_audio
from datasets import Audio
import torch

demucs = pretrained.get_model("htdemucs")
if torch.cuda.device_count() > 0:
    demucs.to("cuda:0")

audio_column_name = "audio_files"

def wrap_audio(audio, sr):
    return {"array": audio.cpu().numpy(), "sampling_rate": sr}

def filter_stems(batch, rank=None):
    device = "cpu" if torch.cuda.device_count() == 0 else "cuda:0"

    wavs = [
        convert_audio(
            torch.tensor(audio["array"][None], device=device).to(
                torch.float32
            ),
            44100,
            demucs.samplerate,
            demucs.audio_channels,
        ).T
        for audio in batch["filepath"]
    ]
    wavs_length = [audio.shape[0] for audio in wavs]

    wavs = torch.nn.utils.rnn.pad_sequence(
        wavs, batch_first=True, padding_value=0.0
    ).transpose(1, 2)
    stems = apply_model(demucs, wavs)

    batch[audio_column_name] = [
        wrap_audio(s[:-1, :, :length].sum(0).mean(0), demucs.samplerate)
        for (s, length) in zip(stems, wavs_length)
    ]

    return batch

num_proc = 1

dataset = dataset.map(
    filter_stems,
    batched=True,
    batch_size=8,
    with_rank=True,
    num_proc=num_proc,
)
dataset = dataset.cast_column(audio_column_name, Audio())

del demucs

In [None]:
from IPython.display import Audio

Audio(dataset[0]["filepath"]["array"], rate=dataset[0]["filepath"]["sampling_rate"])

In [None]:
from utils import instrument_classes, genre_labels, mood_theme_classes
print("Genres", genre_labels)
print("Instruments:", instrument_classes)
print("Moods", mood_theme_classes)

In [None]:
from msclap import CLAP
import librosa
import tempfile
import torchaudio
import random
import numpy as np
import os

clap_model = CLAP(version="2023", use_cuda=True)
instrument_embeddings = clap_model.get_text_embeddings(instrument_classes)
genre_embeddings = clap_model.get_text_embeddings(genre_labels)
mood_embeddings = clap_model.get_text_embeddings(mood_theme_classes)

def enrich_text(batch):
    audio, sampling_rate = (
        batch["filepath"]["array"],
        batch["filepath"]["sampling_rate"],
    )

    tempo, _ = librosa.beat.beat_track(y=audio, sr=sampling_rate)
    tempo = f"{str(np.round(tempo))} bpm"
    chroma = librosa.feature.chroma_stft(y=audio, sr=sampling_rate)
    key = np.argmax(np.sum(chroma, axis=1))
    key = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"][key]

    with tempfile.TemporaryDirectory() as tempdir:
        path = os.path.join(tempdir, "tmp.wav")
        torchaudio.save(path, torch.tensor(audio).unsqueeze(0), sampling_rate)
        audio_embeddings = clap_model.get_audio_embeddings([path])

    instrument = clap_model.compute_similarity(
        audio_embeddings, instrument_embeddings
    ).argmax(dim=1)[0]
    genre = clap_model.compute_similarity(
        audio_embeddings, genre_embeddings
    ).argmax(dim=1)[0]
    mood = clap_model.compute_similarity(
        audio_embeddings, mood_embeddings
    ).argmax(dim=1)[0]

    instrument = instrument_classes[instrument]
    genre = genre_labels[genre]
    mood = mood_theme_classes[mood]

    metadata = [key, tempo, instrument, genre, mood]

    random.shuffle(metadata)
    batch["metadata"] = ", ".join(metadata)
    return batch

dataset = dataset.map(
    enrich_text,
    desc="add metadata",
)

del clap_model, instrument_embeddings, genre_embeddings, mood_embeddings


In [None]:
print(dataset[0]["metadata"])

In [None]:
from transformers import (
    AutoProcessor,
    AutoModelForTextToWaveform,
)

processor = AutoProcessor.from_pretrained("facebook/musicgen-melody")
model = AutoModelForTextToWaveform.from_pretrained("facebook/musicgen-melody")

model.freeze_text_encoder()
model.freeze_audio_encoder()

In [None]:
from IPython.display import Audio
import torch

device = torch.device("cuda:0" if torch.cuda.device_count()>0 else "cpu")

model.to(device)

inputs = processor(
    text=["Generate bollywood sad song"],
    padding=True,
    return_tensors="pt",
).to(device)

audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)


Audio(audio_values.cpu().numpy().squeeze(), rate=32000)

In [None]:
from transformers import AutoFeatureExtractor
from datasets import Audio

instance_prompt = "punk"

# take audio_encoder_feature_extractor
audio_encoder_feature_extractor = AutoFeatureExtractor.from_pretrained(
    model.config.audio_encoder._name_or_path,
)

# resample audio if necessary
dataset_sampling_rate = dataset[0]["audio"]["sampling_rate"]

if dataset_sampling_rate != audio_encoder_feature_extractor.sampling_rate:
    dataset = dataset.cast_column(
        "audio",
        Audio(
            sampling_rate=audio_encoder_feature_extractor.sampling_rate
        ),
    )


# Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the targets.
def prepare_audio_features(batch):
    # load audio

    metadata = batch["metadata"]
    metadata = f"{instance_prompt}, {metadata}"
    batch["input_ids"] = processor.tokenizer(metadata)["input_ids"]

    # load audio
    target_sample = batch["audio"]
    labels = audio_encoder_feature_extractor(
        target_sample["array"], sampling_rate=target_sample["sampling_rate"]
    )
    batch["labels"] = labels["input_values"]

    # take length of raw audio waveform
    batch["target_length"] = len(target_sample["array"].squeeze())
    return batch


dataset = dataset.map(
    prepare_audio_features,
    remove_columns=dataset.column_names,
    num_proc=2,
    desc="preprocess datasets",
)

In [None]:

audio_decoder = model.audio_encoder
num_codebooks = model.decoder.config.num_codebooks
audio_encoder_pad_token_id = model.config.decoder.pad_token_id

pad_labels = torch.ones((1, 1, num_codebooks, 1)) * audio_encoder_pad_token_id

if torch.cuda.device_count() == 1:
    audio_decoder.to("cuda")

def apply_audio_decoder(batch):

    with torch.no_grad():
        labels = audio_decoder.encode(
            torch.tensor(batch["labels"]).to(audio_decoder.device)
        )["audio_codes"]

    # add pad token column
    labels = torch.cat(
        [pad_labels.to(labels.device).to(labels.dtype), labels], dim=-1
    )

    labels, delay_pattern_mask = model.decoder.build_delay_pattern_mask(
        labels.squeeze(0),
        audio_encoder_pad_token_id,
        labels.shape[-1] + num_codebooks,
    )

    labels = model.decoder.apply_delay_pattern_mask(labels, delay_pattern_mask)

    # the first timestamp is associated to a row full of BOS, let's get rid of it
    batch["labels"] = labels[:, 1:].cpu()
    return batch

# Encodec doesn't truely support batching
# Pass samples one by one to the GPU
dataset = dataset.map(
    apply_audio_decoder,
    num_proc=1,
    desc="Apply encodec",
)

In [None]:
from peft import LoraConfig, get_peft_model

# TODO(YL): add modularity here
target_modules = (
    [
        "enc_to_dec_proj",
        "audio_enc_to_dec_proj",
        "k_proj",
        "v_proj",
        "q_proj",
        "out_proj",
        "fc1",
        "fc2",
        "lm_heads.0",
    ]
    + [f"lm_heads.{str(i)}" for i in range(len(model.decoder.lm_heads))]
    + [f"embed_tokens.{str(i)}" for i in range(len(model.decoder.lm_heads))]
)

config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=target_modules,
    lora_dropout=0.05,
    bias="none",
)
model.enable_input_require_grads()
model = get_peft_model(model, config)
model.print_trainable_parameters()

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from dataclasses import dataclass

class MusicgenTrainer(Seq2SeqTrainer):
    def _pad_tensors_to_max_len(self, tensor, max_length):
        if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
            # If PAD token is not defined at least EOS token has to be defined
            pad_token_id = (
                self.tokenizer.pad_token_id
                if self.tokenizer.pad_token_id is not None
                else self.tokenizer.eos_token_id
            )
        else:
            if self.model.config.pad_token_id is not None:
                pad_token_id = self.model.config.pad_token_id
            else:
                raise ValueError(
                    "Pad_token_id must be set in the configuration of the model, in order to pad tensors"
                )

        padded_tensor = pad_token_id * torch.ones(
            (tensor.shape[0], max_length, tensor.shape[2]),
            dtype=tensor.dtype,
            device=tensor.device,
        )
        length = min(max_length, tensor.shape[1])
        padded_tensor[:, :length] = tensor[:, :length]
        return padded_tensor


@dataclass
class DataCollatorMusicGenWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.AutoProcessor`)
            The processor used for proccessing the data.
    """

    processor: AutoProcessor

    def __call__(
        self, features
    ):
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        labels = [
            torch.tensor(feature["labels"]).transpose(0, 1) for feature in features
        ]
        # (bsz, seq_len, num_codebooks)
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=-100
        )

        input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
        input_ids = self.processor.tokenizer.pad(input_ids, return_tensors="pt")

        batch = {"labels": labels, **input_ids}

        return batch

# Instantiate custom data collator
data_collator = DataCollatorMusicGenWithPadding(
    processor=processor,
)



In [None]:

training_args = Seq2SeqTrainingArguments(
      output_dir="./output/",
      num_train_epochs=4,
      gradient_accumulation_steps=8,
      gradient_checkpointing=True,
      per_device_train_batch_size= 2,
      learning_rate=2e-4,
      weight_decay=0.1,
      adam_beta2=0.99,
      fp16=True,
      dataloader_num_workers=2,
      logging_steps=2,
      report_to="none",
      push_to_hub=True,
      push_to_hub_model_id="music-gen-indian-songs",
)


# Initialize MusicgenTrainer
trainer = MusicgenTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
    tokenizer=processor,
)

train_result = trainer.train()

trainer.save_model()
trainer.save_state()


kwargs = {
    "finetuned_from": "facebook/musicgen-melody",
    "tasks": "text-to-audio",
    "tags": ["text-to-audio", "tiny-punk"],
    "dataset": "yav1327/indian_songs",
}

trainer.push_to_hub(**kwargs)


In [None]:
from peft import PeftConfig, PeftModel
from transformers import AutoModelForTextToWaveform, AutoProcessor
import torch

device = torch.device("cuda:0" if torch.cuda.device_count()>0 else "cpu")

repo_id = "yav1327/musicgen-melody-lora-punk-colab"

config = PeftConfig.from_pretrained(repo_id)
model = AutoModelForTextToWaveform.from_pretrained(config.base_model_name_or_path, torch_dtype=torch.float16)
model = PeftModel.from_pretrained(model, repo_id).to(device)

processor = AutoProcessor.from_pretrained(repo_id)

inputs = processor(
    text=["Generate bollywood sad song"],
    padding=True,
    return_tensors="pt",
).to(device)
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)

In [None]:
from IPython.display import Audio

Audio(audio_values.cpu().numpy().squeeze(), rate=44100)