In [None]:
#%pip install fairseq

In [42]:
import bark

In [26]:
from bark.generation import load_codec_model, generate_text_semantic
from encodec.utils import convert_audio

import torchaudio
import torch

model = load_codec_model(use_gpu=True)

In [22]:
# Load and pre-process the audio waveform
audio_filepath = 'christopher_lee.wav' # the audio you want to clone (will get truncated so 5-10 seconds is probably fine, existing samples that I checked are around 7 seconds)
device = 'cuda' # or 'cpu'
wav, sr = torchaudio.load(audio_filepath)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.unsqueeze(0).to(device)

In [23]:
# Extract discrete codes from EnCodec
with torch.no_grad():
    encoded_frames = model.encode(wav)
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze()  # [n_q, T]

In [2]:
import os
import subprocess
import sys

# Unfortunately, fairseq kmeans package resolution is borked on my machine, so manually adding it
# TODO: Fix this
# Get the git repo root directory
git_root = subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).strip().decode("utf-8")

# Append the desired subdirectory
feature_utils_path = os.path.join(git_root, "venv", "lib", "python3.10", "site-packages", "fairseq", "examples", "hubert", "simple_kmeans")

# Add the path to sys.path
sys.path.append(feature_utils_path)
from fairseq.examples.hubert.simple_kmeans.dump_hubert_feature import HubertFeatureReader


In [9]:
# TODO: DELETE THIS
from torch.hub import download_url_to_file

if not os.path.exists("models/hubert_base_ls960.pt"):
    # Yes, hard-coding the URL of the model is jank. Too bad!
    # Update this if this changes! https://github.com/facebookresearch/textlesslib/blob/698e6a039375bac0cd5f1b8683beeec5e8f702c0/textless/checkpoint_manager/__init__.py#L20
    download_url_to_file("https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", "models/hubert_base_ls960.pt")

reader = HubertFeatureReader(
    ckpt_path=os.path.join("models", "hubert_base_ls960.pt"),
    layer=6
)

2023-05-11 20:59:06 | INFO | fairseq.tasks.hubert_pretraining | current directory is /home/ritsuko/projects/ai/audio/bark
2023-05-11 20:59:06 | INFO | fairseq.tasks.hubert_pretraining | HubertPretrainingTask Config {'_name': 'hubert_pretraining', 'data': '/checkpoint/wnhsu/data/librispeech/960h/iter/250K_50hz_km100_mp0_65_v2', 'fine_tuning': False, 'labels': ['layer6.km500'], 'label_dir': None, 'label_rate': 50.0, 'sample_rate': 16000, 'normalize': False, 'enable_padding': False, 'max_keep_size': None, 'max_sample_size': 250000, 'min_sample_size': 32000, 'single_target': False, 'random_crop': True, 'pad_audio': False}
2023-05-11 20:59:06 | INFO | fairseq.models.hubert.hubert | HubertModel Config: {'_name': 'hubert', 'label_rate': 50.0, 'extractor_mode': default, 'encoder_layers': 12, 'encoder_embed_dim': 768, 'encoder_ffn_embed_dim': 3072, 'encoder_attention_heads': 12, 'activation_fn': gelu, 'layer_type': transformer, 'dropout': 0.1, 'attention_dropout': 0.1, 'activation_dropout': 0.0

In [24]:
semantic_hubert_feats = reader.get_feats(audio_filepath)


In [29]:
import torch
import torch.nn as nn

class hubert_to_wte_projection(nn.Module):
    def __init__(self, input_dim=768, output_dim=1024):
        super().__init__()
        self.proj = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return self.proj(x)

proj = hubert_to_wte_projection()
proj.to(device)
proj.load_state_dict(torch.load(os.path.join("models", "hubert_bark_proj_20230511-220924.pt")))

<All keys matched successfully>

In [33]:
with torch.no_grad():
    semantic_emb = proj(semantic_hubert_feats).cpu().numpy()

In [31]:
# move codes to cpu
codes = codes.cpu().numpy()

In [34]:
import numpy as np
voice_name = 'output' # whatever you want the name of the voice to be
output_path = 'bark/assets/prompts/' + voice_name + '.npz'
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_emb)

In [None]:
# That's it! Now you can head over to the generate.ipynb and use your voice_name for the 'history_prompt'

In [None]:
# Heres the generation stuff copy-pasted for convenience

In [1]:
from bark.api import generate_audio
from transformers import BertTokenizer
from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic

# Enter your prompt and speaker here
text_prompt = "Hello, my name is Serpy. And, uh — and I like pizza. [laughs]"
voice_name = "output" # use your custom voice name here if you have one

# load the tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# download and load all models
preload_models(
    text_use_gpu=True,
    text_use_small=False,
    coarse_use_gpu=True,
    coarse_use_small=False,
    fine_use_gpu=True,
    fine_use_small=False,
    codec_use_gpu=True,
    force_reload=False,
    path="models"
)

In [3]:
# simple generation
audio_array = generate_audio(text_prompt, history_prompt=voice_name, text_temp=0.7, waveform_temp=0.7)

  2%|▏         | 2/100 [00:00<00:05, 17.09it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 0])
X: torch.Size([1, 257, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 1])
X: torch.Size([1, 258, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 2])
X: torch.Size([1, 259, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 3])
X: torch.Size([1, 260, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 4])
X: torch.Size([1, 261, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 5])
X: torch.Size([1, 262, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 6])
X: torch.Size([1, 263, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 7])
X: torch.Size([1, 264, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 8])
X: torch.Size([1, 265, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 9])
X: torch.Size([1, 266, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 10])
X: torch.Size([1, 267, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 11])
X: torch.Size([1, 268, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 12])
X: torch.Size

  7%|▋         | 7/100 [00:00<00:04, 19.40it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 30])
X: torch.Size([1, 287, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 31])
X: torch.Size([1, 288, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 32])
X: torch.Size([1, 289, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 33])
X: torch.Size([1, 290, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 34])
X: torch.Size([1, 291, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 35])
X: torch.Size([1, 292, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 36])
X: torch.Size([1, 293, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 37])
X: torch.Size([1, 294, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 38])
X: torch.Size([1, 295, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 39])
X: torch.Size([1, 296, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 40])
X: torch.Size([1, 297, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 41])
X: torch.Size([1, 298, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 42])
X: 

 10%|█         | 10/100 [00:00<00:04, 20.14it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 61])
X: torch.Size([1, 318, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 62])
X: torch.Size([1, 319, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 63])
X: torch.Size([1, 320, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 64])
X: torch.Size([1, 321, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 65])
X: torch.Size([1, 322, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 66])
X: torch.Size([1, 323, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 67])
X: torch.Size([1, 324, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 68])
X: torch.Size([1, 325, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 69])
X: torch.Size([1, 326, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 70])
X: torch.Size([1, 327, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 71])
X: torch.Size([1, 328, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 72])
X: torch.Size([1, 329, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 73])
X: 

 16%|█▌        | 16/100 [00:00<00:04, 20.08it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 94])
X: torch.Size([1, 351, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 95])
X: torch.Size([1, 352, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 96])
X: torch.Size([1, 353, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 97])
X: torch.Size([1, 354, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 98])
X: torch.Size([1, 355, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 99])
X: torch.Size([1, 356, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 100])
X: torch.Size([1, 357, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 101])
X: torch.Size([1, 358, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 102])
X: torch.Size([1, 359, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 103])
X: torch.Size([1, 360, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 104])
X: torch.Size([1, 361, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 105])
X: torch.Size([1, 362, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 10

 19%|█▉        | 19/100 [00:00<00:04, 19.74it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 126])
X: torch.Size([1, 383, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 127])
X: torch.Size([1, 384, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 128])
X: torch.Size([1, 385, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 129])
X: torch.Size([1, 386, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 130])
X: torch.Size([1, 387, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 131])
X: torch.Size([1, 388, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 132])
X: torch.Size([1, 389, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 133])
X: torch.Size([1, 390, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 134])
X: torch.Size([1, 391, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 135])
X: torch.Size([1, 392, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 136])
X: torch.Size([1, 393, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 137])
X: torch.Size([1, 394, 1024])
Next:torch.Size([1, 1])
Out: torch.Size(

 24%|██▍       | 24/100 [00:01<00:03, 19.40it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 156])
X: torch.Size([1, 413, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 157])
X: torch.Size([1, 414, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 158])
X: torch.Size([1, 415, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 159])
X: torch.Size([1, 416, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 160])
X: torch.Size([1, 417, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 161])
X: torch.Size([1, 418, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 162])
X: torch.Size([1, 419, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 163])
X: torch.Size([1, 420, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 164])
X: torch.Size([1, 421, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 165])
X: torch.Size([1, 422, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 166])
X: torch.Size([1, 423, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 167])
X: torch.Size([1, 424, 1024])
Next:torch.Size([1, 1])
Out: torch.Size(

 27%|██▋       | 27/100 [00:01<00:03, 20.23it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 186])
X: torch.Size([1, 443, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 187])
X: torch.Size([1, 444, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 188])
X: torch.Size([1, 445, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 189])
X: torch.Size([1, 446, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 190])
X: torch.Size([1, 447, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 191])
X: torch.Size([1, 448, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 192])
X: torch.Size([1, 449, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 193])
X: torch.Size([1, 450, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 194])
X: torch.Size([1, 451, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 195])
X: torch.Size([1, 452, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 196])
X: torch.Size([1, 453, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 197])
X: torch.Size([1, 454, 1024])
Next:torch.Size([1, 1])
Out: torch.Size(

 33%|███▎      | 33/100 [00:01<00:03, 21.16it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 220])
X: torch.Size([1, 477, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 221])
X: torch.Size([1, 478, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 222])
X: torch.Size([1, 479, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 223])
X: torch.Size([1, 480, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 224])
X: torch.Size([1, 481, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 225])
X: torch.Size([1, 482, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 226])
X: torch.Size([1, 483, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 227])
X: torch.Size([1, 484, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 228])
X: torch.Size([1, 485, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 229])
X: torch.Size([1, 486, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 230])
X: torch.Size([1, 487, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 231])
X: torch.Size([1, 488, 1024])
Next:torch.Size([1, 1])
Out: torch.Size(

 36%|███▌      | 36/100 [00:01<00:03, 21.29it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 255])
X: torch.Size([1, 512, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 256])
X: torch.Size([1, 513, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 257])
X: torch.Size([1, 514, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 258])
X: torch.Size([1, 515, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 259])
X: torch.Size([1, 516, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 260])
X: torch.Size([1, 517, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 261])
X: torch.Size([1, 518, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 262])
X: torch.Size([1, 519, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 263])
X: torch.Size([1, 520, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 264])
X: torch.Size([1, 521, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 265])
X: torch.Size([1, 522, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 266])
X: torch.Size([1, 523, 1024])
Next:torch.Size([1, 1])
Out: torch.Size(

 39%|███▉      | 39/100 [00:01<00:02, 20.83it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 287])
X: torch.Size([1, 544, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 288])
X: torch.Size([1, 545, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 289])
X: torch.Size([1, 546, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 290])
X: torch.Size([1, 547, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 291])
X: torch.Size([1, 548, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 292])
X: torch.Size([1, 549, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 293])
X: torch.Size([1, 550, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 294])
X: torch.Size([1, 551, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 295])
X: torch.Size([1, 552, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 296])
X: torch.Size([1, 553, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 297])
X: torch.Size([1, 554, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 298])
X: torch.Size([1, 555, 1024])
Next:torch.Size([1, 1])
Out: torch.Size(

 45%|████▌     | 45/100 [00:02<00:02, 21.00it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 319])
X: torch.Size([1, 576, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 320])
X: torch.Size([1, 577, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 321])
X: torch.Size([1, 578, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 322])
X: torch.Size([1, 579, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 323])
X: torch.Size([1, 580, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 324])
X: torch.Size([1, 581, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 325])
X: torch.Size([1, 582, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 326])
X: torch.Size([1, 583, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 327])
X: torch.Size([1, 584, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 328])
X: torch.Size([1, 585, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 329])
X: torch.Size([1, 586, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 330])
X: torch.Size([1, 587, 1024])
Next:torch.Size([1, 1])
Out: torch.Size(

 48%|████▊     | 48/100 [00:02<00:02, 21.23it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 353])
X: torch.Size([1, 610, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 354])
X: torch.Size([1, 611, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 355])
X: torch.Size([1, 612, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 356])
X: torch.Size([1, 613, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 357])
X: torch.Size([1, 614, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 358])
X: torch.Size([1, 615, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 359])
X: torch.Size([1, 616, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 360])
X: torch.Size([1, 617, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 361])
X: torch.Size([1, 618, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 362])
X: torch.Size([1, 619, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 363])
X: torch.Size([1, 620, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 364])
X: torch.Size([1, 621, 1024])
Next:torch.Size([1, 1])
Out: torch.Size(

 54%|█████▍    | 54/100 [00:02<00:02, 21.26it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 386])
X: torch.Size([1, 643, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 387])
X: torch.Size([1, 644, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 388])
X: torch.Size([1, 645, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 389])
X: torch.Size([1, 646, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 390])
X: torch.Size([1, 647, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 391])
X: torch.Size([1, 648, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 392])
X: torch.Size([1, 649, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 393])
X: torch.Size([1, 650, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 394])
X: torch.Size([1, 651, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 395])
X: torch.Size([1, 652, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 396])
X: torch.Size([1, 653, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 397])
X: torch.Size([1, 654, 1024])
Next:torch.Size([1, 1])
Out: torch.Size(

 57%|█████▋    | 57/100 [00:02<00:02, 21.06it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 419])
X: torch.Size([1, 676, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 420])
X: torch.Size([1, 677, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 421])
X: torch.Size([1, 678, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 422])
X: torch.Size([1, 679, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 423])
X: torch.Size([1, 680, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 424])
X: torch.Size([1, 681, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 425])
X: torch.Size([1, 682, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 426])
X: torch.Size([1, 683, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 427])
X: torch.Size([1, 684, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 428])
X: torch.Size([1, 685, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 429])
X: torch.Size([1, 686, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 430])
X: torch.Size([1, 687, 1024])
Next:torch.Size([1, 1])
Out: torch.Size(

 60%|██████    | 60/100 [00:02<00:01, 20.36it/s]

Next:torch.Size([1, 1])
Out: torch.Size([1, 449])
X: torch.Size([1, 706, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 450])
X: torch.Size([1, 707, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 451])
X: torch.Size([1, 708, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 452])
X: torch.Size([1, 709, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 453])
X: torch.Size([1, 710, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 454])
X: torch.Size([1, 711, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 455])
X: torch.Size([1, 712, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 456])
X: torch.Size([1, 713, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 457])
X: torch.Size([1, 714, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 458])
X: torch.Size([1, 715, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 459])
X: torch.Size([1, 716, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 460])
X: torch.Size([1, 717, 1024])
Next:torch.Size([1, 1])
Out: torch.Size(

100%|██████████| 100/100 [00:03<00:00, 32.06it/s]


Next:torch.Size([1, 1])
Out: torch.Size([1, 480])
X: torch.Size([1, 737, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 481])
X: torch.Size([1, 738, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 482])
X: torch.Size([1, 739, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 483])
X: torch.Size([1, 740, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 484])
X: torch.Size([1, 741, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 485])
X: torch.Size([1, 742, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 486])
X: torch.Size([1, 743, 1024])
Next:torch.Size([1, 1])
Out: torch.Size([1, 487])
X: torch.Size([1, 744, 1024])


AssertionError: 

In [None]:
# generation with more control
x_semantic = generate_text_semantic(
    text_prompt,
    history_prompt=voice_name,
    temp=0.7,
    top_k=50,
    top_p=0.95,
)

x_coarse_gen = generate_coarse(
    x_semantic,
    history_prompt=voice_name,
    temp=0.7,
    top_k=50,
    top_p=0.95,
)
x_fine_gen = generate_fine(
    x_coarse_gen,
    history_prompt=voice_name,
    temp=0.5,
)
audio_array = codec_decode(x_fine_gen)

In [None]:
from IPython.display import Audio
# play audio
Audio(audio_array, rate=SAMPLE_RATE)

In [None]:
from scipy.io.wavfile import write as write_wav
# save audio
filepath = "/output/audio.wav" # change this to your desired output path
write_wav(filepath, SAMPLE_RATE, audio_array)