In [None]:
from typing import *

from IPython.display import Audio

import numpy as np
import torch
from torch import Tensor
import librosa

from shared import *

In [None]:
PROMPT_PATH = './data/hum.wav'

In [None]:
def playHard(data):
    return Audio(data, rate = ENCODEC_SR)
def play(data, soft = .1):
    t = np.concatenate([data, [1]])
    length = round(soft * ENCODEC_SR)
    t[:length ] = np.multiply(t[:length ], np.linspace(0, 1, length))
    t[-length:] = np.multiply(t[-length:], np.linspace(1, 0, length))
    return playHard(t)


In [None]:
wave_np, sr = librosa.load(PROMPT_PATH, sr=ENCODEC_SR)
assert sr == ENCODEC_SR
wave = torch.Tensor(wave_np).to(DEVICE)
wave.shape, wave.dtype

In [None]:
import math

from audiocraft.utils.notebook import display_audio
from audiocraft.models.musicgen import MusicGen
from audiocraft.models.multibanddiffusion import MultiBandDiffusion

In [None]:
USE_DIFFUSION_DECODER = False
musicGen = MusicGen.get_pretrained('facebook/musicgen-small', device='cuda')
if USE_DIFFUSION_DECODER:
    mbd = MultiBandDiffusion.get_mbd_musicgen()

In [None]:
musicGen.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=30
)

In [None]:
encodec = musicGen.compression_model

In [None]:
with torch.no_grad():
    codes, _ = encodec.encode(wave.unsqueeze(0).unsqueeze(0))
    recon: Tensor = encodec.decode(codes)[0, 0, :]   # type: ignore

In [None]:
play(recon.cpu().numpy())

In [None]:
TEMP_SR = 32000
assert ENCODEC_SR == TEMP_SR
def get_bip_bip(bip_duration=0.125, frequency=440,
                duration=0.5, sample_rate=TEMP_SR, device="cuda"):
    """Generates a series of bip bip at the given frequency."""
    t = torch.arange(
        int(duration * sample_rate), device="cuda", dtype=torch.float) / sample_rate
    wav = torch.cos(2 * math.pi * 440 * t)[None]
    tp = (t % (2 * bip_duration)) / (2 * bip_duration)
    envelope = (tp >= 0.5).float()
    return wav * envelope
bipbip = get_bip_bip().cpu().numpy()[0, :]
play(bipbip)

In [None]:
prompt = wave.unsqueeze(0).expand(1, -1, -1).to(DEVICE)
prompt.shape, prompt.dtype

In [None]:
# res = musicGen.generate_continuation(
#     prompt, 
#     32000, 
#     [
#         None, 
#         # 'Random dude humming jazz', 
#         # 'Heartful EDM with beautiful synths and chords', 
#     ], 
#     progress=True, 
# )
# display_audio(res, 32000)


In [None]:
from audiocraft.modules.conditioners import ClassifierFreeGuidanceDropout
from audiocraft.solvers.musicgen import MusicGenSolver

In [None]:
attributes, prompt_tokens = musicGen._prepare_tokens_and_attributes([None], prompt[:1, :, :])
assert prompt_tokens is not None
print(attributes)
print(prompt_tokens.shape)
assert (prompt_tokens == codes).all().item()
assert prompt_tokens.dtype == codes.dtype

In [None]:
with musicGen.autocast:
    lm = musicGen.lm
    null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(attributes)
    print(attributes)
    print(null_conditions)
    conditions = attributes + null_conditions
    tokenized = lm.condition_provider.tokenize(conditions)
    cfg_conditions = lm.condition_provider(tokenized)
    print(cfg_conditions)
[x.dtype for x in cfg_conditions['description']]

In [None]:
B, K, T = codes.shape
start_offset = T
unknown_token = -1
max_gen_len = 1500
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=DEVICE)
gen_codes[..., :start_offset] = codes
pattern = lm.pattern_provider.get_pattern(max_gen_len)
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, lm.special_token_id)
gen_sequence.shape

In [None]:
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
print(f'{start_offset_sequence = }')
with musicGen.autocast:
    curr_sequence = gen_sequence[..., :start_offset_sequence]
    curr_mask = mask[None, ..., :start_offset_sequence].expand(B, -1, -1)

    # check coherence between mask and sequence
    assert (curr_sequence == torch.where(curr_mask, curr_sequence, lm.special_token_id)).all()
    # should never happen as gen_sequence is filled progressively
    assert not (curr_sequence == unknown_token).any()

    db_sequence = torch.cat([curr_sequence, curr_sequence], dim=0)
    print(db_sequence.shape, db_sequence.dtype)
    out = lm.forward(db_sequence, [], condition_tensors=cfg_conditions)
    print(out.shape)