## Preparing the notebook

In [None]:
import sys

import torch
from omegaconf import OmegaConf

from audiocraft.models.loaders import load_lm_model_ckpt, _delete_param, load_compression_model
from audiocraft.models.musicgen import MusicGen
from IPython.display import Audio, display


import random
import numpy as np

In [None]:
# Path to trained checkpoint
checkpoint_trained = './additional_tools/checkpoints/best_state.th'

# Path to musicgen small checkpoint
checkpoint_def = 'facebook/musicgen-small'

In [None]:
if torch.cuda.device_count():
    device = 'cuda'
else:
    device = 'cpu'
    
cache_dir=None

# Make our modification to false
memory_saver=False

In [None]:
# Read the config file of the trained checkpoint

lm_model_ckpt = load_lm_model_ckpt(checkpoint_trained, cache_dir=cache_dir)
cfg = OmegaConf.create(lm_model_ckpt['xp.cfg'])

In [None]:
# Read the config file of the musicgen small checkpoint

lm_model_ckpt_def = load_lm_model_ckpt(checkpoint_def, cache_dir=cache_dir)
cfg_def = OmegaConf.create(lm_model_ckpt_def['xp.cfg'])

## Load LM Model

In [None]:
# Deleting some parameters, declaring the device 

if cfg.device == 'cpu':
    cfg.dtype = 'float32'
else:
    cfg.dtype = 'float16'
cfg.autocast = False

# Update the memory saver parameter
OmegaConf.update(cfg_def, "memory_saver.enable", memory_saver)
_delete_param(cfg_def, 'conditioners.self_wav.chroma_stem.cache_path')
_delete_param(cfg_def, 'conditioners.args.merge_text_conditions_p')
_delete_param(cfg_def, 'conditioners.args.drop_desc_p')

In [None]:
from audiocraft.models.builders import get_lm_model

In [None]:
# Get the lm model
lm_model = get_lm_model(cfg_def)

In [None]:
# Get the names of the linear layer weight and bias
condition_weight = 'condition_provider.conditioners.description.output_proj.weight'
condition_bias = 'condition_provider.conditioners.description.output_proj.bias'

In [None]:
# Get the best state of the linear layer (768 -> 1024)
lm_model_ckpt['best_state']['model'][condition_weight] = lm_model_ckpt_def['best_state'][condition_weight]
lm_model_ckpt['best_state']['model'][condition_bias] = lm_model_ckpt_def['best_state'][condition_bias]

In [None]:
# Load the best state of the lm model, switch to eval
lm_model.load_state_dict(lm_model_ckpt['best_state']['model'])
lm_model.eval()
lm_model.cfg = cfg

## Compression model

In [None]:
# Load the EnCodec compression model, switch to eval
compression_model = load_compression_model(checkpoint_def, device=device)
compression_model.eval();

In [None]:
# A default code from MusicGen
if 'self_wav' in lm_model.condition_provider.conditioners:
    lm_model.condition_provider.conditioners['self_wav'].match_len_on_eval = True
    lm_model.condition_provider.conditioners['self_wav']._use_masking = False

## MusicGen

In [None]:
# Instantiate MusicGen class
musicgen = MusicGen(checkpoint_def, compression_model, lm_model)

In [None]:
# Set duration of generation in seconds
musicgen.set_generation_params(duration=15)

In [None]:
def generate(descriptions):
    """
    Given the descriptions as a list, generate music based on the descriptions

    """
    with torch.no_grad():

        # Tokenize the descriptions
        tokenized_descr = lm_model.condition_provider.conditioners['description'].tokenize(descriptions)
        desc_encoded = lm_model.condition_provider.conditioners['description'](tokenized_descr)
    
        # Concatenating the encoded description with itself, a trick done by MusicGen (it did with null conditions, but since we do not provide dropout, this is a better way)
        desc_encoded = tuple([torch.cat([desc_encoded[i], desc_encoded[i]], dim=0).to(device) for i in range(len(desc_encoded))])
       
        desc_encoded = {'description': desc_encoded}

        # Set generation parameters
        generation_params = {
                    'use_sampling': cfg.generate.lm.use_sampling,
                    'temp': cfg.generate.lm.temp,
                    'top_k': cfg.generate.lm.top_k,
                    'top_p': cfg.generate.lm.top_p,
                }

        # Some seeds and compression frame rate (after the encodec)
        compression_frame_rate = 50
        torch.manual_seed(0)
        random.seed(0)
        np.random.seed(0)

        # Generate the music
        with musicgen.autocast:
            total_gen_len = musicgen.duration * compression_frame_rate
            gen_tokens = musicgen.lm.generate(
                None, None, desc_encoded , max_gen_len=total_gen_len,
                num_samples=len(descriptions), **generation_params)

        # Decode using EnCodec
        gen_audio = musicgen.compression_model.decode(gen_tokens, None)
    
        return gen_audio.detach().cpu()

In [None]:
custom_descriptions = ['Romantic piano that can be used as Armenian pop music instrumental',
                       'Duduk for meditation and relaxing',
                       'Violin and piano romantic music for engagement',
                       'Armenian dance music with instrument mix',
                       'Music similar to Eghishi par',
                       'Arno Babajanian style solo piano']

In [None]:
for description in custom_descriptions:
    print(description)
    gen_audio = generate([description])
    display(Audio(gen_audio[0].numpy(), rate=32000))