In [1]:
import torch
from omegaconf import OmegaConf

from audiocraft.models.loaders import load_lm_model_ckpt, _delete_param, load_compression_model
from audiocraft.models.musicgen import MusicGen

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


In [2]:
checkpoint_def = 'facebook/musicgen-small'
checkpoint_trained = 'C:/tmp/audiocraft_MusicGen/xps/d3d6f8af/checkpoint.th'

In [3]:
if torch.cuda.device_count():
    device = 'cuda'
else:
    device = 'cpu'
    
cache_dir=None
memory_saver=False

In [4]:
lm_model_ckpt = load_lm_model_ckpt(checkpoint_trained, cache_dir=cache_dir)
cfg = OmegaConf.create(lm_model_ckpt['xp.cfg'])

In [5]:
lm_model_ckpt_def = load_lm_model_ckpt(checkpoint_def, cache_dir=cache_dir)
cfg_def = OmegaConf.create(lm_model_ckpt_def['xp.cfg'])

## Load LM Model

In [6]:
if cfg.device == 'cpu':
    cfg.dtype = 'float32'
else:
    cfg.dtype = 'float16'
OmegaConf.update(cfg_def, "memory_saver.enable", memory_saver)
_delete_param(cfg_def, 'conditioners.self_wav.chroma_stem.cache_path')
_delete_param(cfg_def, 'conditioners.args.merge_text_conditions_p')
_delete_param(cfg_def, 'conditioners.args.drop_desc_p')

In [7]:
from audiocraft.models.builders import get_lm_model

In [8]:
lm_model = get_lm_model(cfg_def)

In [9]:
condition_weight = 'condition_provider.conditioners.description.output_proj.weight'
condition_bias = 'condition_provider.conditioners.description.output_proj.bias'

In [10]:
lm_model_ckpt['best_state']['model'][condition_weight] = lm_model_ckpt_def['best_state'][condition_weight]
lm_model_ckpt['best_state']['model'][condition_bias] = lm_model_ckpt_def['best_state'][condition_bias]

In [11]:
lm_model.load_state_dict(lm_model_ckpt['best_state']['model'])
lm_model.eval()
lm_model.cfg = cfg

## Compression Model

In [12]:
compression_model = load_compression_model(checkpoint_def, device=device)



In [13]:
if 'self_wav' in lm_model.condition_provider.conditioners:
    lm_model.condition_provider.conditioners['self_wav'].match_len_on_eval = True
    lm_model.condition_provider.conditioners['self_wav']._use_masking = False

## MusicGen

In [14]:
musicgen = MusicGen(checkpoint_def, compression_model, lm_model)

In [15]:
musicgen.set_generation_params(duration=30)

## Generation

In [None]:
music = musicgen.generate(['A music that has been deeply rooted in the heart and soul of the people of Armenia, featuring traditional instruments like duduk and showcasing various moods such as dramatic, relaxing, and melancholic, creating an emotional connection with its listeners.'])

In [None]:
from IPython.display import Audio
Audio(music[0][0].detach().cpu(), rate=32000)

### From attributes.pt

In [None]:
attributes = torch.load('../Test/attributes.pt')
condition_tensors = attributes['condition_tensors']

In [None]:
for k, v in condition_tensors.items():
    if isinstance(v, torch.Tensor):
        condition_tensors[k] = condition_tensors[k].to('cuda')
    elif isinstance(v, list) or isinstance(v, tuple):
        condition_tensors[k] = tuple(
            [condition_tensors[k][i].to('cuda') for i in range(len(condition_tensors[k]))])
        
generation_params = {
            'use_sampling': cfg.generate.lm.use_sampling,
            'temp': cfg.generate.lm.temp,
            'top_k': cfg.generate.lm.top_k,
            'top_p': cfg.generate.lm.top_p,
        }

condition_tensors['description'] = (condition_tensors['description'][0].unsqueeze(0), condition_tensors['description'][1])

In [None]:
compression_frame_rate = 50

In [None]:
total_gen_len = musicgen.duration * compression_frame_rate
gen_tokens = musicgen.lm.generate(
    None, None, condition_tensors, max_gen_len=total_gen_len,
    num_samples=1, **generation_params)

In [None]:
gen_audio = musicgen.compression_model.decode(gen_tokens, None)
from IPython.display import Audio
Audio(gen_audio[0][0].detach().cpu(), rate=32000)