In [1]:
%load_ext autoreload
%autoreload 2

# MusicGen-RepEng
Welcome to MusicGen-RepEng's demo jupyter notebook. Here you will find a series of self-contained examples of how to use MusicGen with Representation Engineering.

First, we start by initializing MusicGen, you can choose a model from the following selection:
1. `facebook/musicgen-small` - 300M transformer decoder.
2. `facebook/musicgen-medium` - 1.5B transformer decoder.
3. `facebook/musicgen-melody` - 1.5B transformer decoder also supporting melody conditioning.
4. `facebook/musicgen-large` - 3.3B transformer decoder.

We will use the `facebook/musicgen-small` variant for the purpose of this demonstration.

In [79]:
from audiocraft.models import MusicGen
from audiocraft.models import MultiBandDiffusion

USE_DIFFUSION_DECODER = False
# Using small model, better results would be obtained with `medium` or `large`.
model = MusicGen.get_pretrained('facebook/musicgen-small', device="cuda")
if USE_DIFFUSION_DECODER:
    mbd = MultiBandDiffusion.get_mbd_musicgen()



Next, let us configure the generation parameters. Specifically, you can control the following:
* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.
* `top_k` (int, optional): top_k used for sampling. Defaults to 250.
* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.
* `duration` (float, optional): duration of the generated waveform. Defaults to 30.0.
* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.

When left unchanged, MusicGen will revert to its default parameters.

In [3]:
model.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=20
)

Next, we can go ahead and start generating music using one of the following modes:
* Unconditional samples using `model.generate_unconditional`
* Music continuation using `model.generate_continuation`
* Text-conditional samples using `model.generate`
* Melody-conditional samples using `model.generate_with_chroma`

### Get Representations(Hidden States)

In [97]:
import torchaudio

music, sr = torchaudio.load('/home/sake/demucs/separated/mdx_extra/Ditto-2-NewJeans/no_vocals.mp3')

In [98]:
music.shape

torch.Size([2, 8182656])

In [99]:
input_music = music[:, 30*sr:50*sr]

In [101]:
rep = model.get_hidden_states(
    input_music, 
    sr, None, 
    progress=True)

    30 /   1503

  1503 /   1503

In [16]:
len(rep)

1003

In [26]:
import torch

In [102]:
rep_vec = torch.stack(rep, dim=1)[:,500:1000]

In [103]:
rep_vec.shape

torch.Size([1, 500, 24, 1024])

In [105]:
torch.save(rep_vec, "/home/sake/Ditto-2-NewJeans_MusicGenRepEng_Dataset_hidden_states_30-60_mid10_non_avg_smallmodel.pt")

In [None]:
from pathlib import Path
from tqdm import tqdm 

for path in tqdm(Path('/home/sake/MusicGenRepEng_Dataset_separated').rglob('*.mp3')):
    print("Representing: ", path)
    out_path = str(path).replace('MusicGenRepEng_Dataset', 'MusicGenRepEng_Dataset_hidden_states_30-60_mid10')
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    music, sr = torchaudio.load(str(path))
    input_music = music[:, 30*sr:50*sr]
    rep = model.read_representations(
        input_music, 
        sr, None, 
        progress=True)
    rep_vec = torch.stack(rep, dim=1)[:,500:1000].mean(1)
    torch.save(rep_vec, path.with_suffix('.pt'))

### Get Control Vector

In [93]:
import torch
from tqdm import tqdm
from pathlib import Path

In [126]:
reps = []
for path in tqdm(Path('/home/sake/MusicGenRepEng_Dataset_hidden_states_30-60_mid10_non_avg_smallmodel').rglob('*.pt')):
    rep = torch.load(path)
    reps.append(rep.cpu())
reps = torch.stack(reps)
reps = reps.squeeze(1)
# reps = reps.permute(1, 0, 2) # (layers, batch, hidden_states)
reps = reps.permute(2, 0, 1, 3) # (layers, batch, timesteps, hidden_states)
reps = reps[:,:,-1] # Last hidden_state
reps.shape

9it [00:00, 89.32it/s]

110it [00:02, 49.98it/s]


torch.Size([24, 110, 1024])

In [8]:
target = torch.load("/home/sake/MusicGenRepEng_Dataset_hidden_states_30-60_mid10/Ditto-2-NewJeans.pt")
target = target.permute(1, 0, 2).cpu() # (layers, batch, hidden_states)
target.shape

torch.Size([24, 1, 1024])

In [131]:
target = rep_vec
target = target.permute(2, 0, 1, 3) # (layers, batch, timesteps, hidden_states)
target = target[:,:,-1].cpu() # Last hidden_state
target.shape

torch.Size([24, 1, 1024])

In [128]:
diffs = target.cpu() - reps.cpu() # target - others (pos - neg)
diffs.shape

torch.Size([24, 110, 1024])

In [10]:
import numpy as np

In [11]:
from sklearn.decomposition import PCA

In [12]:
def project_onto_direction(H, direction):
    """Project matrix H (n, d_1) onto direction vector (d_2,)"""
    mag = np.linalg.norm(direction)
    assert not np.isinf(mag)
    return (H @ direction) / mag

In [129]:
reps[0].shape

torch.Size([110, 1024])

In [115]:
f_diffs = diffs.flatten(1,2).cpu()
f_target = target.flatten(1,2).cpu()
f_reps = reps.flatten(1,2).cpu()
f_diffs.shape, f_target.shape, f_reps.shape

(torch.Size([24, 55000, 1024]),
 torch.Size([24, 500, 1024]),
 torch.Size([24, 55000, 1024]))

In [146]:
directions = {}
for layer in tqdm(range(diffs.shape[0])):
    # assert diff[layer].shape[0] == len(inputs) * 2

    # fit layer directions
    train = np.vstack(
        diffs[layer]
        - diffs[layer].mean(axis=0, keepdims=True)
    )
    pca_model = PCA(n_components=1, whiten=False).fit(train)
    # shape (n_features,)
    directions[layer] = pca_model.components_.astype(np.float32).squeeze(axis=0)
    # print(directions[layer].shape)
    # calculate sign
    projected_hiddens = project_onto_direction(
        reps[layer], directions[layer]
    )
    # print(projected_hiddens[0])
    target_projected_hiddens = project_onto_direction(
        target[layer], directions[layer]
    )
    # print(target_projected_hiddens[0])

    # order is [positive, negative, positive, negative, ...]
    positive_smaller_mean = np.mean(
        [
            target_projected_hiddens[0] < projected_hiddens[i] # target is smaller than others
            for i in range(0, reps.shape[1])
        ]
    )
    positive_larger_mean = np.mean(
        [
            target_projected_hiddens[0] > projected_hiddens[i] # target is larger than others
            for i in range(0, reps.shape[1])
        ]
    )

    if positive_smaller_mean > positive_larger_mean:  # type: ignore
        directions[layer] *= -1

  8%|▊         | 2/24 [00:00<00:01, 16.42it/s]

100%|██████████| 24/24 [00:01<00:00, 15.16it/s]


In [116]:
directions = {}
for layer in tqdm(range(f_diffs.shape[0])):
    # assert diff[layer].shape[0] == len(inputs) * 2

    # fit layer directions
    train = np.vstack(
        f_diffs[layer]
        - f_diffs[layer].mean(axis=0, keepdims=True)
    )
    pca_model = PCA(n_components=1, whiten=False).fit(train)
    # shape (n_features,)
    directions[layer] = pca_model.components_.astype(np.float32).squeeze(axis=0)
    # calculate sign
    projected_hiddens = project_onto_direction(
        f_reps[layer], directions[layer]
    )
    target_projected_hiddens = project_onto_direction(
        f_target[layer], directions[layer]
    )

    # order is [positive, negative, positive, negative, ...]
    positive_smaller_mean = np.mean(
        [
            target_projected_hiddens[i%f_target.shape[1]] < projected_hiddens[i] # target is smaller than others
            for i in range(0, f_reps.shape[1])
        ]
    )
    positive_larger_mean = np.mean(
        [
            target_projected_hiddens[i%f_target.shape[1]] > projected_hiddens[i] # target is larger than others
            for i in range(0, f_reps.shape[1])
        ]
    )

    if positive_smaller_mean > positive_larger_mean:  # type: ignore
        directions[layer] *= -1

  0%|          | 0/24 [00:00<?, ?it/s]

100%|██████████| 24/24 [01:17<00:00,  3.21s/it]


In [133]:
directions[0].shape

(1024,)

In [134]:
directions

{0: array([-3.0706025e-05,  2.8330598e-03,  2.3901025e-03, ...,
         1.3627790e-02, -1.1222321e-02, -4.5703858e-02], dtype=float32),
 1: array([ 0.01482962,  0.00067971, -0.00088191, ...,  0.00183522,
        -0.01989052, -0.07816414], dtype=float32),
 2: array([-0.01118825, -0.02277423,  0.02056845, ..., -0.04227544,
         0.03643353, -0.02939701], dtype=float32),
 3: array([ 0.05020184,  0.01184384, -0.00752417, ..., -0.03280373,
         0.02858786,  0.01039631], dtype=float32),
 4: array([ 0.03886262, -0.00087491,  0.01859564, ..., -0.02553524,
         0.03741219, -0.00121752], dtype=float32),
 5: array([-0.05889817, -0.03270496, -0.02083655, ...,  0.02700403,
        -0.02045577,  0.01640975], dtype=float32),
 6: array([-0.00050306,  0.02076481,  0.08304001, ..., -0.02417303,
         0.01543669, -0.03410292], dtype=float32),
 7: array([-0.00135256,  0.02330022, -0.02909057, ..., -0.0064423 ,
        -0.01486332,  0.02940181], dtype=float32),
 8: array([ 0.12973118, -0.072

In [147]:
torch.save(directions, "/home/sake/Ditto-2-NewJeans_MusicGenRepEng_Dataset_hidden_states_30-60_20last_smallmodel_directions.pth")

### Inference with Control Vector

In [None]:
from audiocraft.utils.notebook import display_audio

In [119]:
model.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=30
)

In [148]:
res = model.generate_with_control_vectors([None], directions, progress=True)

     9 /   1503

   838 /   1503

In [None]:
display_audio(res, 32000)

### Music Continuation

In [None]:
import math
import torchaudio
import torch
from audiocraft.utils.notebook import display_audio

def get_bip_bip(bip_duration=0.125, frequency=440,
                duration=0.5, sample_rate=32000, device="cuda"):
    """Generates a series of bip bip at the given frequency."""
    t = torch.arange(
        int(duration * sample_rate), device="cuda", dtype=torch.float) / sample_rate
    wav = torch.cos(2 * math.pi * 440 * t)[None]
    tp = (t % (2 * bip_duration)) / (2 * bip_duration)
    envelope = (tp >= 0.5).float()
    return wav * envelope

In [None]:
# Here we use a synthetic signal to prompt both the tonality and the BPM
# of the generated audio.
res = model.generate_continuation(
    get_bip_bip(0.125).expand(2, -1, -1), 
    32000, ['Jazz jazz and only jazz', 
            'Heartful EDM with beautiful synths and chords'], 
    progress=True)
display_audio(res, 32000)

In [None]:
# You can also use any audio from a file. Make sure to trim the file if it is too long!
prompt_waveform, prompt_sr = torchaudio.load("../assets/bach.mp3")
prompt_duration = 2
prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]
output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True, return_tokens=True)
display_audio(output[0], sample_rate=32000)
if USE_DIFFUSION_DECODER:
    out_diffusion = mbd.tokens_to_wav(output[1])
    display_audio(out_diffusion, sample_rate=32000)

### Text-conditional Generation

In [None]:
from audiocraft.utils.notebook import display_audio

output = model.generate(
    descriptions=[
        #'80s pop track with bassy drums and synth',
        #'90s rock song with loud guitars and heavy drums',
        #'Progressive rock drum and bass solo',
        #'Punk Rock song with loud drum and power guitar',
        #'Bluesy guitar instrumental with soulful licks and a driving rhythm section',
        #'Jazz Funk song with slap bass and powerful saxophone',
        'drum and bass beat with intense percussions'
    ],
    progress=True, return_tokens=True
)
display_audio(output[0], sample_rate=32000)
if USE_DIFFUSION_DECODER:
    out_diffusion = mbd.tokens_to_wav(output[1])
    display_audio(out_diffusion, sample_rate=32000)

### Melody-conditional Generation

In [None]:
import torchaudio
from audiocraft.utils.notebook import display_audio

model = MusicGen.get_pretrained('facebook/musicgen-melody')
model.set_generation_params(duration=8)

melody_waveform, sr = torchaudio.load("../assets/bach.mp3")
melody_waveform = melody_waveform.unsqueeze(0).repeat(2, 1, 1)
output = model.generate_with_chroma(
    descriptions=[
        '80s pop track with bassy drums and synth',
        '90s rock song with loud guitars and heavy drums',
    ],
    melody_wavs=melody_waveform,
    melody_sample_rate=sr,
    progress=True, return_tokens=True
)
display_audio(output[0], sample_rate=32000)
if USE_DIFFUSION_DECODER:
    out_diffusion = mbd.tokens_to_wav(output[1])
    display_audio(out_diffusion, sample_rate=32000)