In [None]:
import torch
from torch import nn
import torchaudio
from audiocraft.models import MusicGen

In [None]:
model = MusicGen.get_pretrained('facebook/musicgen-melody')

In [None]:
lm = model.lm

print(lm.card + 1)
print(lm.dim)
print(lm.n_q)

In [4]:
data_path = "../data/toy"

In [5]:
def preprocess_waveform(filename, device='cuda'):
    if type(filename) == str:
        waveform, sample_rate = torchaudio.load(filename)
    else:
        print("DuBose make this function work with a whole list")
        raise NotImplementedError

    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    return waveform.unsqueeze(0).to(device)


# Largely ripped from LMModel.generate() in lm.py
def get_patterns(prompt, max_gen_len=10000, device="cuda"):
    B, K, T = prompt.shape
    start_offset = T
    assert start_offset < max_gen_len

    pattern = model.lm.pattern_provider.get_pattern(max_gen_len)

    # this token is used as default value for codes that are not generated yet
    unknown_token = -1

    # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
    gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
    # filling the gen_codes with the prompt if needed
    gen_codes[..., :start_offset] = prompt
    # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
    gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, 2048)
    # retrieve the start_offset in the sequence:
    # it is the first sequence step that contains the `start_offset` timestep
    start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
    assert start_offset_sequence is not None

    return gen_sequence

class ScaledEmbedding(nn.Embedding):
    def __init__(self, *args, lr=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.lr = lr

    def make_optim_group(self):
        group = {"params": list(self.parameters())}
        if self.lr is not None:
            group["lr"] = self.lr
        return group


def prep_input(sequence, pad_token=-1, embed_dim=1536, emb_lr=1.0):
    device = sequence.device
    B, K, S = sequence.shape
    
    # Adjust vocab_size to account for padding token and maximum value
    vocab_size = sequence.max().item() + 1
    emb = nn.ModuleList([ScaledEmbedding(vocab_size, embed_dim, padding_idx=pad_token, lr=emb_lr) for _ in range(K)]).to(device)

    # Apply each embedding layer to its corresponding codebook and sum the results
    embedded = []
    for k in range(K):
        # Replace -1 with the last index in vocab (which will be mapped to zero vector due to padding_idx)
        seq_k = torch.where(sequence[:, k] == pad_token, torch.tensor(vocab_size - 1, device=device), sequence[:, k])
        emb_k = emb[k](seq_k)
        embedded.append(emb_k)
    
    input_ = sum(embedded)
    return input_

In [None]:
import os
from audiocraft.modules.conditioners import ConditioningAttributes
import numpy as np

blank_cond_attr = ConditioningAttributes()

for file in os.listdir(data_path):
    file = os.path.join(data_path, file)
    waveform = preprocess_waveform(file)
    codes, scale = model.compression_model.encode(waveform)
    del waveform


    # transform the codes so that they match the embedding dimension (1536)
    gen_sequence = get_patterns(codes)
    x = prep_input(gen_sequence)
    
    for i, layer in enumerate(model.lm.transformer.layers):
        #layer.self_attn.attention_as_float32 = False
        print(i)
        x = x.half()
        x = layer(x)
  
    break

