# Music Generation

In [1]:
import torch
from omegaconf import OmegaConf

from audiocraft.models.loaders import load_lm_model_ckpt, _delete_param, load_compression_model
from audiocraft.models.musicgen import MusicGen
from IPython.display import Audio, display

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


In [2]:
import warnings
warnings.filterwarnings('ignore')

### Get pretrained model

In [3]:
checkpoint_trained = '../XP/checkpoint.th'   #Change only this
checkpoint_def = 'facebook/musicgen-small'

In [4]:
if torch.cuda.device_count():
    device = 'cuda'
else:
    device = 'cpu'
    
cache_dir=None

#### Loading the configuration

In [5]:
lm_model_ckpt = load_lm_model_ckpt(checkpoint_trained, cache_dir=cache_dir)
cfg = OmegaConf.create(lm_model_ckpt['xp.cfg'])

In [6]:
if cfg.device == 'cpu':
    cfg.dtype = 'float32'
else:
    cfg.dtype = 'float16'
cfg.autocast = False

#### Get the language model

In [7]:
from audiocraft.models.builders import get_lm_model

In [8]:
lm_model = get_lm_model(cfg)

In [9]:
lm_model.load_state_dict(lm_model_ckpt['best_state']['model'])
lm_model.eval()
lm_model.cfg = cfg

#### Get the compression model from the default musicgen-small model

In [10]:
compression_model = load_compression_model(checkpoint_def, device=device)
compression_model.eval();

In [11]:
if 'self_wav' in lm_model.condition_provider.conditioners:
    lm_model.condition_provider.conditioners['self_wav'].match_len_on_eval = True
    lm_model.condition_provider.conditioners['self_wav']._use_masking = False

#### Initialize the MusicGen model

In [12]:
musicgen = MusicGen(checkpoint_def, compression_model, lm_model)

#### Set generation to length 15 seconds

In [13]:
musicgen.set_generation_params(duration=15)

### Generation

In [14]:
generation = musicgen.generate(["A piano romantic play"])   #Provide the prompt to generate the music from

### Play the generation

In [15]:
from IPython.display import Audio

In [16]:
Audio(generation.view(-1).cpu(), rate=32000)