In [None]:
%load_ext autoreload
%autoreload 2

# MusicGen-RepEng
Welcome to MusicGen-RepEng's demo jupyter notebook. Here you will find a series of self-contained examples of how to use MusicGen with Representation Engineering.

First, we start by initializing MusicGen, you can choose a model from the following selection:
1. `facebook/musicgen-small` - 300M transformer decoder.
2. `facebook/musicgen-medium` - 1.5B transformer decoder.
3. `facebook/musicgen-melody` - 1.5B transformer decoder also supporting melody conditioning.
4. `facebook/musicgen-large` - 3.3B transformer decoder.

We will use the `facebook/musicgen-small` variant for the purpose of this demonstration.

In [None]:
from audiocraft.models import MusicGen
from audiocraft.models import MultiBandDiffusion

USE_DIFFUSION_DECODER = False
# Using small model, better results would be obtained with `medium` or `large`.
model = MusicGen.get_pretrained('facebook/musicgen-medium', device="cuda")
if USE_DIFFUSION_DECODER:
    mbd = MultiBandDiffusion.get_mbd_musicgen()

Next, let us configure the generation parameters. Specifically, you can control the following:
* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.
* `top_k` (int, optional): top_k used for sampling. Defaults to 250.
* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.
* `duration` (float, optional): duration of the generated waveform. Defaults to 30.0.
* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.

When left unchanged, MusicGen will revert to its default parameters.

In [None]:
model.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=0.02
)

Next, we can go ahead and start generating music using one of the following modes:
* Unconditional samples using `model.generate_unconditional`
* Music continuation using `model.generate_continuation`
* Text-conditional samples using `model.generate`
* Melody-conditional samples using `model.generate_with_chroma`

In [None]:
import torchaudio

music, sr = torchaudio.load('/home/sake/MusicGenRepEng_Dataset/Rock/Alternative Rock/Nirvana - Smells Like Teen Spirit.mp3')

In [None]:
import torch

In [None]:
(music[:,:50].repeat(2,1,1) == torch.stack([music[:,:50], music[:,:50]], dim=0)).all()

In [None]:
from audiocraft.utils.notebook import display_audio

model.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=8,
    two_step_cfg=False
)

# Here we use a synthetic signal to prompt both the tonality and the BPM
# of the generated audio.
res = model.generate_continuation(
    music[:,:int(sr*0.02)].repeat(2,1,1),
    sr, ['rock, energetic', 
            'rock, sleepy'], 
    progress=True)
display_audio(res, 32000)

In [None]:
res.shape

### Get Text Condition Representations(Hidden States)

In [None]:
import torchaudio

music, sr = torchaudio.load('/home/sake/MusicGenRepEng_Dataset/Rock/Alternative Rock/Nirvana - Smells Like Teen Spirit.mp3')

In [None]:
model.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=0.02,
    two_step_cfg=True
)

In [None]:
hidden_states = model.get_hidden_states(music[:,:int(sr*0.02)].repeat(2,1,1), sr, ["fast tempo", "slow tempo"])

In [None]:
hidden_states[1][1].shape

In [None]:
hidden_states[1][1].shape

In [None]:
attributes, _ = model._prepare_tokens_and_attributes(["techno, fast beats, happy, hard, joyful, tribal"], None)
attributes, _

In [None]:
embeddings = model.lm.condition_provider(model.lm.condition_provider.tokenize(attributes))
embeddings

In [None]:
embeddings['description']

In [None]:
embeddings['description'][0].shape

In [None]:
attributes = model.get_hidden_states_text_condition(["techno, fast beats"])

In [None]:
attributes[0]['text']

In [None]:
model.lm.condition_provider.conditioners

In [None]:
model.lm.condition_provider

In [None]:
model.lm.condition_provider(model.lm.condition_provider.tokenize(attributes))

In [None]:
model.lm.condition_provider.conditioners.description(attributes[0]['text']['description'])

In [None]:
model.lm.condition_provider(attributes[0]['text'])

### Get Representations(Hidden States)

In [None]:
import torchaudio

music, sr = torchaudio.load('/home/sake/MusicGenRepEng_Dataset/Rock/Alternative Rock/Nirvana - Smells Like Teen Spirit.mp3')

In [None]:
music.shape

In [None]:
i = 0

In [None]:
input_music = music[:, 20*i*sr:20*(i+1)*sr]

In [None]:
rep = model.get_hidden_states(
    input_music, 
    sr, None, 
    progress=True)

In [None]:
reps = []
for i in range(10):
    input_music = music[:, 20*i*sr:20*(i+1)*sr]
    rep = model.get_hidden_states(
        input_music,
        sr, None,
        progress=True)
    reps.append(rep)

In [None]:
len(reps)

In [None]:
import torch

In [None]:
rep_vec = torch.stack(rep, dim=1)[:,500:1000]

In [None]:
rep_vec.shape

In [None]:
torch.save(reps, "/home/sake/Nirvana - Smells Like Teen Spirit_MusicGenRepEng_Dataset_hidden_states_every20s_10t.pt")

In [None]:
from pathlib import Path
from tqdm import tqdm 

for path in tqdm(Path('/home/sake/MusicGenRepEng_Dataset_separated').rglob('*.mp3')):
    print("Representing: ", path)
    out_path = str(path).replace('MusicGenRepEng_Dataset', 'MusicGenRepEng_Dataset_hidden_states_30-60_mid10')
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    music, sr = torchaudio.load(str(path))
    input_music = music[:, 30*sr:50*sr]
    rep = model.read_representations(
        input_music, 
        sr, None, 
        progress=True)
    rep_vec = torch.stack(rep, dim=1)[:,500:1000].mean(1)
    torch.save(rep_vec, path.with_suffix('.pt'))

### Get Control Vector - Text Pair

In [None]:
import torch
from tqdm import tqdm
from pathlib import Path
import numpy as np
from sklearn.decomposition import PCA

In [None]:
def project_onto_direction(H, direction):
    """Project matrix H (n, d_1) onto direction vector (d_2,)"""
    mag = np.linalg.norm(direction)
    assert not np.isinf(mag)
    return (H @ direction) / mag

In [None]:
representation_pairs = []
for path in tqdm(Path('/home/sake/MusicGenRepEng_Dataset_50ms_energetic_sleepy_mediummodel_rock_norm_nob4layer').rglob('*.pt')):
    loaded = torch.load(str(path))[-1]
    representation_pairs.append(loaded)

In [None]:
representation_pairs = torch.cat(representation_pairs, dim=0)
representation_pairs = representation_pairs.permute(1,0,2)
representation_pairs.shape

In [None]:
relative_layer_hiddens = {}

In [None]:
for layer, pair in enumerate(representation_pairs):
        relative_layer_hiddens[layer] = (
            pair[::2] - pair[1::2]
        )

In [None]:
relative_layer_hiddens[0].shape

In [None]:
for i in range(48):
    print(i, (relative_layer_hiddens[i][0]==0).all())

In [None]:
directions = {}
for layer in range(len(relative_layer_hiddens)):
    # assert representation_pairs[layer].shape[0] == 110 * 2

    # fit layer directions
    train = np.vstack(
        relative_layer_hiddens[layer].to("cpu").numpy()
        - relative_layer_hiddens[layer].to("cpu").numpy().mean(axis=0, keepdims=True)
    )
    pca_model = PCA(n_components=1, whiten=False).fit(train)
    # shape (n_features,)
    directions[layer] = pca_model.components_.astype(np.float32).squeeze(axis=0)

    # calculate sign
    projected_hiddens = project_onto_direction(
        representation_pairs[layer].to("cpu").numpy(), directions[layer]
    )

    # order is [positive, negative, positive, negative, ...]
    positive_smaller_mean = np.mean(
        [
            projected_hiddens[i] < projected_hiddens[i + 1]
            for i in range(0, representation_pairs.shape[1], 2)
        ]
    )
    positive_larger_mean = np.mean(
        [
            projected_hiddens[i] > projected_hiddens[i + 1]
            for i in range(0, representation_pairs.shape[1], 2)
        ]
    )

    if positive_smaller_mean > positive_larger_mean:  # type: ignore
        directions[layer] *= -1

In [None]:
len(directions)

In [None]:
directions[47].shape

### Get Control Vector - A song pair

In [None]:
ditto_reps = torch.load("/home/sake/Ditto-2-NewJeans_MusicGenRepEng_Dataset_hidden_states_every20s_10t.pt")

In [None]:
dexter_reps = torch.load("/home/sake/Ricardo Villalobos - Dexter [SED008]_MusicGenRepEng_Dataset_hidden_states_every20s_10t.pt")

In [None]:
ditto_reps[0][-1].shape

In [None]:
target = torch.stack([rep[-1] for rep in ditto_reps], dim=0)
target = target.squeeze(1).permute(1, 0, 2).cpu()
target.shape

In [None]:
reps = torch.stack([rep[-1] for rep in dexter_reps], dim=0)
reps = reps.squeeze(1).permute(1, 0, 2).cpu()
reps.shape

In [None]:
reps = torch.stack([rep[-1] for rep in reps], dim=0)
reps = reps.squeeze(1).permute(1, 0, 2).cpu()
reps.shape

In [None]:
import torch
from tqdm import tqdm
from pathlib import Path
import numpy as np
from sklearn.decomposition import PCA

In [None]:
def project_onto_direction(H, direction):
    """Project matrix H (n, d_1) onto direction vector (d_2,)"""
    mag = np.linalg.norm(direction)
    assert not np.isinf(mag)
    return (H @ direction) / mag

In [None]:
# Get Difference

diffs = target.cpu() - reps.cpu() # target - others (pos - neg)
diffs.shape

In [None]:
# Avg or Last Hidden State

directions = {}
for layer in tqdm(range(diffs.shape[0])):
    # assert diff[layer].shape[0] == len(inputs) * 2

    # fit layer directions
    train = np.vstack(
        diffs[layer]
        # - diffs[layer].mean(axis=0, keepdims=True)
    )
    pca_model = PCA(n_components=1, whiten=False).fit(train)
    # shape (n_features,)
    directions[layer] = pca_model.components_.astype(np.float32).squeeze(axis=0)
    # print(directions[layer].shape)
    # calculate sign
    # projected_hiddens = project_onto_direction(
    #     reps[layer], directions[layer]
    # )
    # # print(projected_hiddens[0])
    # target_projected_hiddens = project_onto_direction(
    #     target[layer], directions[layer]
    # )
    # # print(target_projected_hiddens[0])

    # # order is [positive, negative, positive, negative, ...]
    # positive_smaller_mean = np.mean(
    #     [
    #         target_projected_hiddens[0] < projected_hiddens[i] # target is smaller than others
    #         for i in range(0, reps.shape[1])
    #     ]
    # )
    # positive_larger_mean = np.mean(
    #     [
    #         target_projected_hiddens[0] > projected_hiddens[i] # target is larger than others
    #         for i in range(0, reps.shape[1])
    #     ]
    # )

    # if positive_smaller_mean > positive_larger_mean:  # type: ignore
    #     directions[layer] *= -1

In [None]:
ditto_smells_directions = directions

### Get Control Vector

In [None]:
import torch
from tqdm import tqdm
from pathlib import Path
from sklearn.decomposition import PCA

In [None]:
reps = []
for path in tqdm(Path('/home/sake/MusicGenRepEng_Dataset_hidden_states_30-60_mid10_non_avg_smallmodel').rglob('*.pt')):
    rep = torch.load(path)
    reps.append(rep.cpu())
reps = torch.stack(reps)
reps = reps.squeeze(1)
# reps = reps.permute(1, 0, 2) # (layers, batch, hidden_states)
reps = reps.permute(2, 0, 1, 3) # (layers, batch, timesteps, hidden_states)
# reps = reps[:,:,-1] # Last hidden_state
reps.shape

In [None]:
target = torch.load("/home/sake/MusicGenRepEng_Dataset_hidden_states_30-60_mid10/Ditto-2-NewJeans.pt")
target = target.permute(1, 0, 2).cpu() # (layers, batch, hidden_states)
target.shape

In [None]:
target = rep_vec.cpu()
target = target.permute(2, 0, 1, 3) # (layers, batch, timesteps, hidden_states)
# target = target[:,:,-1] # Last hidden_state
target.shape

In [None]:
# Get Difference

diffs = target.cpu() - reps.cpu() # target - others (pos - neg)
diffs.shape

In [None]:
import numpy as np

In [None]:
from sklearn.decomposition import PCA

In [None]:
def project_onto_direction(H, direction):
    """Project matrix H (n, d_1) onto direction vector (d_2,)"""
    mag = np.linalg.norm(direction)
    assert not np.isinf(mag)
    return (H @ direction) / mag

In [None]:
reps[0].shape

In [None]:
# Avg or Last Hidden State

directions = {}
for layer in tqdm(range(diffs.shape[0])):
    # assert diff[layer].shape[0] == len(inputs) * 2

    # fit layer directions
    train = np.vstack(
        diffs[layer]
        - diffs[layer].mean(axis=0, keepdims=True)
    )
    pca_model = PCA(n_components=1, whiten=False).fit(train)
    # shape (n_features,)
    directions[layer] = pca_model.components_.astype(np.float32).squeeze(axis=0)
    # print(directions[layer].shape)
    # calculate sign
    projected_hiddens = project_onto_direction(
        reps[layer], directions[layer]
    )
    # print(projected_hiddens[0])
    target_projected_hiddens = project_onto_direction(
        target[layer], directions[layer]
    )
    # print(target_projected_hiddens[0])

    # order is [positive, negative, positive, negative, ...]
    positive_smaller_mean = np.mean(
        [
            target_projected_hiddens[0] < projected_hiddens[i] # target is smaller than others
            for i in range(0, reps.shape[1])
        ]
    )
    positive_larger_mean = np.mean(
        [
            target_projected_hiddens[0] > projected_hiddens[i] # target is larger than others
            for i in range(0, reps.shape[1])
        ]
    )

    if positive_smaller_mean > positive_larger_mean:  # type: ignore
        directions[layer] *= -1

In [None]:
f_diffs = diffs.flatten(1,2).cpu()
f_target = target.flatten(1,2).cpu()
f_reps = reps.flatten(1,2).cpu()
f_diffs.shape, f_target.shape, f_reps.shape

In [None]:
# Multiple Hidden States


directions = {}
for layer in tqdm(range(f_diffs.shape[0])):
    # assert diff[layer].shape[0] == len(inputs) * 2

    # fit layer directions
    train = np.vstack(
        f_diffs[layer]
        - f_diffs[layer].mean(axis=0, keepdims=True)
    )
    pca_model = PCA(n_components=1, whiten=False).fit(train)
    # shape (n_features,)
    directions[layer] = pca_model.components_.astype(np.float32).squeeze(axis=0)
    # calculate sign
    projected_hiddens = project_onto_direction(
        f_reps[layer], directions[layer]
    )
    target_projected_hiddens = project_onto_direction(
        f_target[layer], directions[layer]
    )

    # order is [positive, negative, positive, negative, ...]
    positive_smaller_mean = np.mean(
        [
            target_projected_hiddens[i%f_target.shape[1]] < projected_hiddens[i] # target is smaller than others
            for i in range(0, f_reps.shape[1])
        ]
    )
    positive_larger_mean = np.mean(
        [
            target_projected_hiddens[i%f_target.shape[1]] > projected_hiddens[i] # target is larger than others
            for i in range(0, f_reps.shape[1])
        ]
    )

    if positive_smaller_mean > positive_larger_mean:  # type: ignore
        directions[layer] *= -1

In [None]:
directions[0].shape

In [None]:
directions

In [None]:
directions[0].shape

In [None]:
torch.save(directions, "/home/sake/Ditto-2-NewJeans_MusicGenRepEng_Dataset_hidden_states_30-60_non_avg_smallmodel_directions.pth")

### Inference with Control Vector

In [None]:
from audiocraft.utils.notebook import display_audio

In [None]:
import torchaudio

music, sr = torchaudio.load('/home/sake/MusicGenRepEng_Dataset/Rock/Alternative Rock/Nirvana - Smells Like Teen Spirit.mp3')

In [None]:
import os
import random
import torch
import numpy as np
# From https://gist.github.com/gatheluck/c57e2a40e3122028ceaecc3cb0d152ac
def set_all_seeds(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
set_all_seeds(42)

In [None]:
model.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=10,
    two_step_cfg=False
)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[0.15], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[0.1], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[0.07], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[0.06], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[0.05], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[0.04], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[0.03], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[0.02], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[0.01], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[0.00], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[-0.01], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[-0.02], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[-0.03], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[-0.04], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[-0.05], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[-0.06], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[-0.07], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[-0.1], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
n = 4
set_all_seeds(42)
res = model.generate_with_control_vectors(descriptions=["rock"]*n, control_vectors=[directions], coefficients=[-0.15], sustains=[100], ramps=[250],
                                                       before_layer=False, progress=True)
display_audio(res, 32000)

In [None]:
# +0.15
n = 4
res = model.generate_continuation_with_control_vectors(music[:,:int(sr*0.02)].repeat(n,1,1), sr, 
                                                       control_vectors=[directions], coefficients=[0.15], sustains=[100], ramps=[250],
                                                       before_layer=False, descriptions=["rock"]*n, progress=True)
display_audio(res, 32000)

In [None]:
# +0.15
n = 4
res = model.generate_continuation_with_control_vectors(music[:,:int(sr*0.02)].repeat(n,1,1), sr, 
                                                       control_vectors=[directions, directions], coefficients=[0.15, 0.1], sustains=[100, 10], ramps=[250, 50],
                                                       before_layer=False, descriptions=["rock"]*n, progress=True)
display_audio(res, 32000)

### Music Continuation

In [None]:
import math
import torchaudio
import torch
from audiocraft.utils.notebook import display_audio

def get_bip_bip(bip_duration=0.125, frequency=440,
                duration=0.5, sample_rate=32000, device="cuda"):
    """Generates a series of bip bip at the given frequency."""
    t = torch.arange(
        int(duration * sample_rate), device="cuda", dtype=torch.float) / sample_rate
    wav = torch.cos(2 * math.pi * 440 * t)[None]
    tp = (t % (2 * bip_duration)) / (2 * bip_duration)
    envelope = (tp >= 0.5).float()
    return wav * envelope

In [None]:
# Here we use a synthetic signal to prompt both the tonality and the BPM
# of the generated audio.
res = model.generate_continuation(
    get_bip_bip(0.125).expand(2, -1, -1), 
    32000, ['Jazz jazz and only jazz', 
            'Heartful EDM with beautiful synths and chords'], 
    progress=True)
display_audio(res, 32000)

In [None]:
# You can also use any audio from a file. Make sure to trim the file if it is too long!
prompt_waveform, prompt_sr = torchaudio.load("../assets/bach.mp3")
prompt_duration = 2
prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]
output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True, return_tokens=True)
display_audio(output[0], sample_rate=32000)
if USE_DIFFUSION_DECODER:
    out_diffusion = mbd.tokens_to_wav(output[1])
    display_audio(out_diffusion, sample_rate=32000)

### Text-conditional Generation

In [None]:
from audiocraft.utils.notebook import display_audio

output = model.generate(
    descriptions=[
        #'80s pop track with bassy drums and synth',
        #'90s rock song with loud guitars and heavy drums',
        #'Progressive rock drum and bass solo',
        #'Punk Rock song with loud drum and power guitar',
        #'Bluesy guitar instrumental with soulful licks and a driving rhythm section',
        #'Jazz Funk song with slap bass and powerful saxophone',
        'drum and bass beat with intense percussions'
    ],
    progress=True, return_tokens=True
)
display_audio(output[0], sample_rate=32000)
if USE_DIFFUSION_DECODER:
    out_diffusion = mbd.tokens_to_wav(output[1])
    display_audio(out_diffusion, sample_rate=32000)

### Melody-conditional Generation

In [None]:
import torchaudio
from audiocraft.utils.notebook import display_audio

model = MusicGen.get_pretrained('facebook/musicgen-melody')
model.set_generation_params(duration=8)

melody_waveform, sr = torchaudio.load("../assets/bach.mp3")
melody_waveform = melody_waveform.unsqueeze(0).repeat(2, 1, 1)
output = model.generate_with_chroma(
    descriptions=[
        '80s pop track with bassy drums and synth',
        '90s rock song with loud guitars and heavy drums',
    ],
    melody_wavs=melody_waveform,
    melody_sample_rate=sr,
    progress=True, return_tokens=True
)
display_audio(output[0], sample_rate=32000)
if USE_DIFFUSION_DECODER:
    out_diffusion = mbd.tokens_to_wav(output[1])
    display_audio(out_diffusion, sample_rate=32000)