In [None]:
import sys
sys.path.insert(0, '../')

import torch
from omegaconf import OmegaConf

from audiocraft.models.loaders import load_lm_model_ckpt, _delete_param, load_compression_model
from audiocraft.models.musicgen import MusicGen

In [None]:
checkpoint_def = 'facebook/musicgen-small'
checkpoint_trained = '/home/karlos/Documents/workspace/proj/music/trained_models/v0_22_apr26/dea3706f/checkpoint.th'

In [None]:
if torch.cuda.device_count():
    device = 'cuda'
else:
    device = 'cpu'
    
cache_dir=None
memory_saver=False

In [None]:
lm_model_ckpt = load_lm_model_ckpt(checkpoint_trained, cache_dir=cache_dir)
cfg = OmegaConf.create(lm_model_ckpt['xp.cfg'])

In [None]:
lm_model_ckpt_def = load_lm_model_ckpt(checkpoint_def, cache_dir=cache_dir)
cfg_def = OmegaConf.create(lm_model_ckpt_def['xp.cfg'])

## Load LM Model

In [None]:
if cfg.device == 'cpu':
    cfg.dtype = 'float32'
else:
    cfg.dtype = 'float16'
OmegaConf.update(cfg_def, "memory_saver.enable", memory_saver)
_delete_param(cfg_def, 'conditioners.self_wav.chroma_stem.cache_path')
_delete_param(cfg_def, 'conditioners.args.merge_text_conditions_p')
_delete_param(cfg_def, 'conditioners.args.drop_desc_p')

In [None]:
from audiocraft.models.builders import get_lm_model

In [None]:
lm_model = get_lm_model(cfg_def)

In [None]:
condition_weight = 'condition_provider.conditioners.description.output_proj.weight'
condition_bias = 'condition_provider.conditioners.description.output_proj.bias'

In [None]:
lm_model_ckpt['best_state']['model'][condition_weight] = lm_model_ckpt_def['best_state'][condition_weight]
lm_model_ckpt['best_state']['model'][condition_bias] = lm_model_ckpt_def['best_state'][condition_bias]

In [None]:
lm_model.load_state_dict(lm_model_ckpt['best_state']['model'])
lm_model.eval()
lm_model.cfg = cfg

## Compression Model

In [None]:
compression_model = load_compression_model(checkpoint_def, device=device)

In [None]:
if 'self_wav' in lm_model.condition_provider.conditioners:
    lm_model.condition_provider.conditioners['self_wav'].match_len_on_eval = True
    lm_model.condition_provider.conditioners['self_wav']._use_masking = False

## MusicGen

In [None]:
musicgen = MusicGen(checkpoint_def, compression_model, lm_model)

In [None]:
musicgen.set_generation_params(duration=30)

## Generation

In [None]:
music = musicgen.generate(['Generate duduk with relaxing attributes that shows Armenian herritage',
                          'Duduk',
                          'duduk with violing and piano in the background',
                          'Armenian',
                          'Armenian folk music'])

In [None]:
from IPython.display import Audio, display
for i in range(music.shape[0]):
    display(Audio(music[i][0].detach().cpu(), rate=32000))

In [None]:
music = musicgen.generate(['Piano in Armenian style',
                          'Relaxing piano',
                          'Piano',
                          'A music that has been described as classical, contemporary, impressionist, and modern, with piano playing as an essential instrument, creates a relaxing atmosphere where emotions are evoked through various moods.',
                          'Classical, contemporary, modern piano',
                          'combination of piano, violin, and cello instruments for film soundtrack'])

In [None]:
from IPython.display import Audio, display
for i in range(music.shape[0]):
    display(Audio(music[i][0].detach().cpu(), rate=32000))

In [None]:
music = musicgen.generate(['Duduk to be used as a film soundtrack',
                          'Relaxing duduk',
                          'Active duduk for dances',
                          ])

In [None]:
from IPython.display import Audio, display
for i in range(music.shape[0]):
    display(Audio(music[i][0].detach().cpu(), rate=32000))

In [None]:
music = musicgen.generate(['A music that has elements of classical, contemporary, and traditional Armenian folk music, featuring the duduk as an instrument, accompanied by a piano, evokes a sense of relaxation, meditation, and emotional connection with the audience.',
                          'duduk piano for meditation',
                          'Armenian folk music with instruments duduk and piano',
                          ])

In [None]:
from IPython.display import Audio, display
for i in range(music.shape[0]):
    display(Audio(music[i][0].detach().cpu(), rate=32000))

### From attributes.pt

In [None]:
attributes = torch.load('../Test/attributes.pt')
condition_tensors = attributes['condition_tensors']

In [None]:
for k, v in condition_tensors.items():
    if isinstance(v, torch.Tensor):
        condition_tensors[k] = condition_tensors[k].to('cuda')
    elif isinstance(v, list) or isinstance(v, tuple):
        condition_tensors[k] = tuple(
            [condition_tensors[k][i].to('cuda') for i in range(len(condition_tensors[k]))])
        
generation_params = {
            'use_sampling': cfg.generate.lm.use_sampling,
            'temp': cfg.generate.lm.temp,
            'top_k': cfg.generate.lm.top_k,
            'top_p': cfg.generate.lm.top_p,
        }

condition_tensors['description'] = (condition_tensors['description'][0].unsqueeze(0), condition_tensors['description'][1])

In [None]:
compression_frame_rate = 50

In [None]:
total_gen_len = musicgen.duration * compression_frame_rate
gen_tokens = musicgen.lm.generate(
    None, None, condition_tensors, max_gen_len=total_gen_len,
    num_samples=1, **generation_params)

In [None]:
gen_audio = musicgen.compression_model.decode(gen_tokens, None)
from IPython.display import Audio
Audio(gen_audio[0][0].detach().cpu(), rate=32000)