In [None]:
import torch
from torch import nn
import torchaudio
from audiocraft.models import MusicGen

%env CUDA_LAUNCH_BLOCKING=1

In [None]:
model = MusicGen.get_pretrained('facebook/musicgen-melody')

In [3]:
data_path = "../data/toy"

In [4]:
def preprocess_waveform(filename, device='cuda'):
    if type(filename) == str:
        waveform, sample_rate = torchaudio.load(filename)
    else:
        print("DuBose make this function work with a whole list")
        raise NotImplementedError

    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    return waveform.unsqueeze(0).to(device)


# Largely ripped from LMModel.generate() in lm.py
def get_patterns(prompt, max_gen_len=10000, device="cuda"):
    B, K, T = prompt.shape
    start_offset = T
    assert start_offset < max_gen_len

    pattern = model.lm.pattern_provider.get_pattern(max_gen_len)

    # this token is used as default value for codes that are not generated yet
    unknown_token = -1

    # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
    gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
    # filling the gen_codes with the prompt if needed
    gen_codes[..., :start_offset] = prompt
    # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
    gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, 1024)
    # retrieve the start_offset in the sequence:
    # it is the first sequence step that contains the `start_offset` timestep
    start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
    assert start_offset_sequence is not None

    print(gen_sequence.shape)

    return gen_sequence


# Largely ripped from LMModel.forward() and __init__() in lm.py
def prep_input(sequence, vocab_size=2048, embed_dim=1536):
    device = sequence.device
    B, K, S = sequence.shape
    emb = nn.ModuleList([nn.Embedding(vocab_size, embed_dim) for _ in range(K)]).to(device)

    # Apply each embedding layer to its corresponding codebook and sum the results
    input_ = sum(emb[k](sequence[:, k]) for k in range(K))
    
    return input_


In [None]:
import os
from audiocraft.modules.conditioners import ConditioningAttributes
import numpy as np

blank_cond_attr = ConditioningAttributes()

for file in os.listdir(data_path):
    file = os.path.join(data_path, file)
    waveform = preprocess_waveform(file)
    codes, scale = model.compression_model.encode(waveform)
    del waveform

    # transform the codes so that they match the embedding dimension (1536)
    gen_sequence = get_patterns(codes)

    print(np.unique(gen_sequence.cpu(), axis=None))

    x = prep_input(gen_sequence)
    
    # TODO There is some kind of indexing issue, try moving everything to the
    # CPU for easier debugging 
    
    # for layer in model.lm.transformer.layers:
    #     x = layer(x)
    #     print(x.shape)

    # test = model.lm(codes, conditions=None)
    # print(type(test))
    
    break

