## Infer

In [None]:
import tqdm
import torch
import torch.nn.functional as F
from IPython.display import Audio
from scipy.io.wavfile import write as write_wav

import barkify.bark as bark 
from barkify.utils import Bestckpt
from barkify.bark import create_infer_model
from barkify.datas import PhonemeTokenizer

from omegaconf import OmegaConf
x_dict =  OmegaConf.load("configs/barkify.yaml")

start_path = "../work_env" # your data folder.

In [None]:
TEXT_INPUT_LEN = x_dict.stage1.collate_fn.text_window
TEXT_TOKEN_NUM = x_dict.stage1.collate_fn.text_token_num
SEMANTIC_EOS_TOKEN, SEMANTIC_INFER_TOKEN = TEXT_TOKEN_NUM+1, TEXT_TOKEN_NUM+2

COARSE_BOOK = x_dict.stage2.collate_fn.Q_size
SEMANTIC_TOKEN_NUM = x_dict.stage2.collate_fn.semantic_token_num
SEMANTIC_INPUT_LEN = x_dict.stage2.collate_fn.semantic_window
CODEC_TOKEN_NUM = x_dict.stage2.collate_fn.coarse_num
COARSE_INFER_TOKEN = SEMANTIC_TOKEN_NUM + 1

stage1_model = create_infer_model(x_dict.stage1.model).cuda()
stage2_model = create_infer_model(x_dict.stage2.model).cuda()

tokenizer = PhonemeTokenizer()

In [None]:
ckpt = torch.load(Bestckpt(f"{start_path}/{x_dict.name}/stage_1"))['state_dict']
stage1_model.load_state_dict({".".join(i.split("model.")[1:]):ckpt[i] for i in ckpt})

ckpt = torch.load(Bestckpt(f"{start_path}/{x_dict.name}/stage_2"))['state_dict']
stage2_model.load_state_dict({".".join(i.split("model.")[1:]):ckpt[i] for i in ckpt})

In [None]:
def generate_stage1(
    x, 
    model,
    tempature = 0.60,
    max_steps = 512,
):

    kv_cache = None

    x = F.pad(x, (0, TEXT_INPUT_LEN-x.shape[1]), mode='constant', value=TEXT_TOKEN_NUM)
    x = torch.cat([
        x, 
        torch.tensor([SEMANTIC_INFER_TOKEN], dtype=x.dtype, device=x.device)[None]
    ], dim=1)
    
    text_len = x.shape[1]

    for _ in tqdm.trange(max_steps):
        
        if kv_cache is not None:
            x_input = x[:, [-1]]
        else:
            x_input = x

        logits, kv_cache = model(x_input, use_cache=True, past_kv=kv_cache)

        relevant_logits = torch.hstack(
            (logits[0, 0, TEXT_TOKEN_NUM+3:], logits[0, 0, [SEMANTIC_EOS_TOKEN]])
        )

        probs = F.softmax(relevant_logits / tempature, dim=-1)
        item_next = torch.multinomial(probs, num_samples=1)

        if item_next == len(relevant_logits) - 1:
            break

        x = torch.cat((x, item_next[None]+TEXT_TOKEN_NUM+3), dim=1)
    
    return x[:, text_len:] - TEXT_TOKEN_NUM - 3

In [None]:
def generate_stage2(
    x, 
    model,
    tempature = 0.6,
    max_steps = 768
):

    kv_cache = None
    
    x = F.pad(x, (0, SEMANTIC_INPUT_LEN-x.shape[1]), mode='constant', value=SEMANTIC_TOKEN_NUM)
    x = torch.cat([
        x, 
        torch.tensor([COARSE_INFER_TOKEN], dtype=x.dtype, device=x.device)[None]
    ], dim=1)
    
    semantic_len = x.shape[1]

    for i in tqdm.trange(max_steps):

        Q = i % COARSE_BOOK
        if kv_cache is not None:
            x_input = x[:, [-1]]
        else:
            x_input = x

        logits, kv_cache = model(x_input, use_cache=True, past_kv=kv_cache)
        start = SEMANTIC_TOKEN_NUM + 2 + Q * CODEC_TOKEN_NUM
        relevant_logits = logits[0, 0, start : start + CODEC_TOKEN_NUM]
        
        probs = F.softmax(relevant_logits / tempature, dim=-1)
        item_next = torch.multinomial(probs, num_samples=1)
        x = torch.cat((x, item_next[None]+start), dim=1)
    
    output = x[:, semantic_len:]
    for Q in range(COARSE_BOOK):
        output[:, Q::COARSE_BOOK] -= (SEMANTIC_TOKEN_NUM + 2 + Q * CODEC_TOKEN_NUM)
    
    return output.reshape(-1, COARSE_BOOK).T

In [None]:
tgt_text = "At a given signal, they reenacted the event. Baker's movements were timed with a stopwatch."

tokens = tokenizer(tgt_text)
dummy_tokenized = torch.tensor([tokens]).cuda()
dummy_semantic = generate_stage1(dummy_tokenized, model=stage1_model)
dummy_coarse = generate_stage2(dummy_semantic, model=stage2_model)

In [None]:
dummy_fine = bark.generate_fine(dummy_coarse.detach().cpu().numpy(), history_prompt=None)
audio_array = bark.codec_decode(dummy_fine)

# play text in notebook
Audio(audio_array, rate=24000)

# write_wav("bark_generation.wav", 24000, audio_array)