# Loading libraries

In [1]:
import torch
from omegaconf import OmegaConf

from audiocraft.models.loaders import load_lm_model_ckpt, _delete_param, load_compression_model
from audiocraft.models.musicgen import MusicGen

from audiocraft.modules.conditioners import ConditioningAttributes, ClassifierFreeGuidanceDropout

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


## Defining the configs

In [2]:
checkpoint_def = 'facebook/musicgen-small'
checkpoint_trained = 'C:/tmp/audiocraft_MusicGen/xps/8e9a546c/checkpoint.th'

In [3]:
if torch.cuda.device_count():
    device = 'cuda'
else:
    device = 'cpu'

In [4]:
cache_dir=None
memory_saver=False

In [5]:
lm_model_ckpt = load_lm_model_ckpt(checkpoint_trained, cache_dir=cache_dir)
cfg = OmegaConf.create(lm_model_ckpt['xp.cfg'])

In [6]:
lm_model_ckpt_def = load_lm_model_ckpt(checkpoint_def, cache_dir=cache_dir)
cfg_def = OmegaConf.create(lm_model_ckpt_def['xp.cfg'])

## Loading the LM model

In [7]:
from audiocraft.models.builders import get_lm_model

In [8]:
if cfg.device == 'cpu':
    cfg.dtype = 'float32'
else:
    cfg.dtype = 'float16'

In [9]:
OmegaConf.update(cfg_def, "memory_saver.enable", memory_saver)
_delete_param(cfg_def, 'conditioners.self_wav.chroma_stem.cache_path')
_delete_param(cfg_def, 'conditioners.args.merge_text_conditions_p')
_delete_param(cfg_def, 'conditioners.args.drop_desc_p')

In [10]:
lm_model = get_lm_model(cfg_def)

In [11]:
condition_weight = 'condition_provider.conditioners.description.output_proj.weight'
condition_bias = 'condition_provider.conditioners.description.output_proj.bias'

In [12]:
lm_model_ckpt['best_state']['model'][condition_weight] = lm_model_ckpt_def['best_state'][condition_weight]
lm_model_ckpt['best_state']['model'][condition_bias] = lm_model_ckpt_def['best_state'][condition_bias]

In [13]:
lm_model.load_state_dict(lm_model_ckpt['best_state']['model'])
lm_model.eval()
lm_model.cfg = cfg

## Loading the compression model

In [14]:
compression_model = load_compression_model(checkpoint_def, device=device)



In [15]:
if 'self_wav' in lm_model.condition_provider.conditioners:
    lm_model.condition_provider.conditioners['self_wav'].match_len_on_eval = True
    lm_model.condition_provider.conditioners['self_wav']._use_masking = False

## Loading the MusicGen

In [16]:
musicgen = MusicGen(checkpoint_def, compression_model, lm_model)

In [17]:
musicgen.set_generation_params(duration=15)

## Generation from description (default)

In [18]:
descriptions = ['A music that has been deeply rooted in the heart and soul of the people of Armenia, featuring traditional instruments like duduk and showcasing various moods such as dramatic, relaxing, and melancholic, creating an emotional connection with its listeners.', 'A music that has been deeply rooted in the heart and soul of the people of Armenia, featuring traditional instruments like duduk and showcasing various moods such as dramatic, relaxing, and melancholic, creating an emotional connection with its listeners.']

In [19]:
attributes = [ConditioningAttributes(text={'description': description}) for description in descriptions]
prompt_tokens = None

In [20]:
total_gen_len = int(musicgen.duration * musicgen.frame_rate)
condition_tensors=None

In [21]:
if musicgen.duration <= musicgen.max_duration:
    with musicgen.autocast:
        gen_tokens = musicgen.lm.generate(
            prompt_tokens, attributes, condition_tensors,
            callback=None, max_gen_len=total_gen_len, **musicgen.generation_params)

In [22]:
gen_audio = musicgen.compression_model.decode(gen_tokens, None)

In [23]:
from IPython.display import Audio
Audio(gen_audio[0][0].detach().cpu(), rate=32000)

## Generation from description (Our method)

In [24]:
descriptions = ['A music that has been deeply rooted in the heart and soul of the people of Armenia, featuring traditional instruments like duduk and showcasing various moods such as dramatic, relaxing, and melancholic, creating an emotional connection with its listeners.', 'A music that has been deeply rooted in the heart and soul of the people of Armenia, featuring traditional instruments like duduk and showcasing various moods such as dramatic, relaxing, and melancholic, creating an emotional connection with its listeners.']

In [25]:
attributes = [ConditioningAttributes(text={'description': description}) for description in descriptions]
prompt_tokens = None

In [26]:
tokenized = musicgen.lm.condition_provider.tokenize(attributes)
condition_tensors = musicgen.lm.condition_provider(tokenized)

In [27]:
total_gen_len = int(musicgen.duration * musicgen.frame_rate)

In [28]:
if musicgen.duration <= musicgen.max_duration:
    with musicgen.autocast:
        gen_tokens = musicgen.lm.generate(
            None, None, condition_tensors,
            callback=None, max_gen_len=total_gen_len, **musicgen.generation_params)

In [29]:
gen_audio = musicgen.compression_model.decode(gen_tokens, None)

In [30]:
from IPython.display import Audio
Audio(gen_audio[0][0].detach().cpu(), rate=32000)

## Generation from train data (saved condition tensor)

In [31]:
attributes = torch.load('../Test/attributes.pt')
condition_tensors = attributes['condition_tensors']

In [32]:
for k, v in condition_tensors.items():
    if isinstance(v, torch.Tensor):
        condition_tensors[k] = condition_tensors[k].to('cuda')
    elif isinstance(v, list) or isinstance(v, tuple):
        condition_tensors[k] = tuple(
            [condition_tensors[k][i].to('cuda') for i in range(len(condition_tensors[k]))])

In [35]:
condition_tensors['description'] = (condition_tensors['description'][0].unsqueeze(0), condition_tensors['description'][1])

In [40]:
total_gen_len = int(musicgen.duration * musicgen.frame_rate)
gen_tokens = musicgen.lm.generate(
    None, None, condition_tensors, max_gen_len=total_gen_len,
    num_samples=1, **musicgen.generation_params)

In [41]:
gen_audio = musicgen.compression_model.decode(gen_tokens, None)

In [42]:
from IPython.display import Audio
Audio(gen_audio[0][0].detach().cpu(), rate=32000)

## Read generated audio (during training)

In [56]:
gen_tokens = torch.load('./read_this/generated4.pt')

In [57]:
gen_audio = musicgen.compression_model.decode(gen_tokens.to(device).unsqueeze(0), None)

In [58]:
from IPython.display import Audio
Audio(gen_audio[0][0].detach().cpu(), rate=32000)