In [9]:
!nvidia-smi

Fri Dec  1 18:53:34 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P0    54W / 400W |   3395MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [10]:
!pip install -U voicebox-pytorch



In [11]:
import datetime
import os
import urllib

import torch
import torchaudio

from voicebox_pytorch import (
    VoiceBox,
    EncodecVoco,
    ConditionalFlowMatcherWrapper
)
from audiolm_pytorch import HubertWithKmeans

from einops import rearrange

from IPython import display

In [12]:
accelerator = "cuda"

In [13]:
hubert_dir = "hubert"
hubert_ckpt_path = f"{hubert_dir}/hubert_base_ls960.pt"
hubert_quantizer_path = f"{hubert_dir}/hubert_base_ls960_L9_km2000_expresso.bin"

if not os.path.isdir("hubert"):
  os.makedirs("hubert")
if not os.path.isfile(hubert_ckpt_path):
  hubert_ckpt_download = f"https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt"
  print(f"Downloading HuBERT model from: {hubert_ckpt_download}")
  urllib.request.urlretrieve(hubert_ckpt_download, f"./{hubert_ckpt_path}")
if not os.path.isfile(hubert_quantizer_path):
  hubert_quantizer_download = f"https://dl.fbaipublicfiles.com/textless_nlp/expresso/checkpoints/hubert_base_ls960_L9_km2000_expresso.bin"
  print(f"Downloading HuBERT quantizer from: {hubert_quantizer_download}")
  urllib.request.urlretrieve(hubert_quantizer_download, f"./{hubert_quantizer_path}")

wav2vec = HubertWithKmeans(
    checkpoint_path = hubert_ckpt_path,
    kmeans_path = hubert_quantizer_path,
    target_sample_hz = 24_000,
)
wav2vec = wav2vec.to(accelerator)

In [14]:
!wget -qnc "https://huggingface.co/lucasnewman/voicebox-small/resolve/main/voicebox_small.pt?download=true" -O voicebox_small.pt

In [15]:
model = VoiceBox(
    dim = 512,
    dim_cond_emb = 512,
    audio_enc_dec = EncodecVoco(),
    num_cond_tokens = 2001,
    depth = 12,
    dim_head = 64,
    heads = 16,
    ff_mult = 4,
    attn_qk_norm = False,
    num_register_tokens = 0,
    use_gateloop_layers = False,
)

cfm_wrapper = ConditionalFlowMatcherWrapper(
    voicebox = model,
    cond_drop_prob = 0.2
)
cfm_wrapper.load_state_dict(torch.load("voicebox_small.pt", map_location = 'cpu'))
cfm_wrapper = cfm_wrapper.to(accelerator)

print("Parameters: ", sum(p.numel() for p in cfm_wrapper.parameters()))

Parameters:  143517502


In [16]:
!wget -qnc "https://google.github.io/df-conformer/librittsr/data/1841_r960_3.wav"

wave, _ = torchaudio.load('1841_r960_3.wav')

print("Original:")
display.display(display.Audio(wave, rate=24_000))

wave = wave.to(accelerator)
semantic_token_ids = (wav2vec(wave) + 1).to(accelerator) # offset for the padding token

# unconditional generation

start_date = datetime.datetime.now()

output_wave = cfm_wrapper.sample(
    semantic_token_ids = semantic_token_ids,
    steps = 32,
    cond_scale = 1.3
)
elapsed_time = (datetime.datetime.now() - start_date).total_seconds()

output_wave = rearrange(output_wave, "1 1 n -> 1 n")
output_duration = float(output_wave.shape[1]) / 24000
realtime_mult = output_duration / elapsed_time

print(f"\nGenerated sample of duration {output_duration:0.2f}s in {elapsed_time}s ({realtime_mult:0.2f}x realtime)\n\n")

print("Unconditionally generated:")
display.display(display.Audio(output_wave.cpu(), rate=24_000))

# in-filling and style transfer

!wget -qnc "https://google.github.io/df-conformer/librittsr/data/5717_r960_0.wav"

cond_wave, _ = torchaudio.load('5717_r960_0.wav')
cond_wave = cond_wave.to(accelerator)
cond_semantic_token_ids = (wav2vec(cond_wave) + 1).to(accelerator) # offset for the padding token

cond = torch.cat([cond_wave, wave], dim = -1)
cond_mask_copy = torch.zeros_like(cond_semantic_token_ids, dtype = torch.bool)
cond_mask_infill = torch.ones_like(semantic_token_ids, dtype = torch.bool)
cond_mask = torch.cat([cond_mask_copy, cond_mask_infill], dim = -1).to(accelerator)

start_date = datetime.datetime.now()

infilled_wave = cfm_wrapper.sample(
    cond = cond,
    cond_mask = cond_mask,
    semantic_token_ids = torch.cat([cond_semantic_token_ids, semantic_token_ids], dim = -1),
    steps = 32,
    cond_scale = 1.3
)
elapsed_time = (datetime.datetime.now() - start_date).total_seconds()

infilled_wave = rearrange(infilled_wave, "1 1 n -> 1 n")

# crop the conditioning wave from the output

infilled_wave = infilled_wave[:, cond_wave.shape[1]:]
infilled_duration = float(infilled_wave.shape[1]) / 24000
infilled_realtime_mult = infilled_duration / elapsed_time

print(f"\nGenerated sample of duration {output_duration:0.2f}s in {elapsed_time}s ({infilled_realtime_mult:0.2f}x realtime)\n\n")

print("Infilled:")
display.display(display.Audio(infilled_wave.cpu(), rate=24_000))

Original:



Generated sample of duration 5.04s in 2.308327s (2.18x realtime)


Unconditionally generated:



Generated sample of duration 5.04s in 2.338161s (2.15x realtime)


Infilled:
