<a href="https://colab.research.google.com/github/Joab-S/kokoro-voice/blob/main/Kokoro_TTS_and_Voice_Blending.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [30]:
!git lfs install
!git clone https://huggingface.co/hexgrad/Kokoro-82M
%cd Kokoro-82M
!apt-get -qq -y install espeak-ng > /dev/null 2>&1
!pip install -q phonemizer torch transformers scipy munch
!pip install -U kokoro-onnx

!pip install loguru
# from kokoro import en, espeak, ja, zh
!pip install -q "misaki[en, ja, zh]"

Updated git hooks.
Git LFS initialized.
Cloning into 'Kokoro-82M'...
remote: Enumerating objects: 421, done.[K
remote: Counting objects: 100% (421/421), done.[K
remote: Compressing objects: 100% (199/199), done.[K
remote: Total 421 (delta 240), reused 391 (delta 221), pack-reused 0 (from 0)[K
Receiving objects: 100% (421/421), 1.83 MiB | 2.88 MiB/s, done.
Resolving deltas: 100% (240/240), done.
Filtering content: 100% (61/61), 344.32 MiB | 59.35 MiB/s, done.
/content/Kokoro-82M/Kokoro-82M/Kokoro-82M/Kokoro-82M
Collecting loguru
  Downloading loguru-0.7.3-py3-none-any.whl.metadata (22 kB)
Downloading loguru-0.7.3-py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: loguru
Successfully installed loguru-0.7.3


In [31]:
import numpy as np
from scipy.io.wavfile import write
from IPython.display import display, Audio

import torch

In [45]:
# ADAPTED from https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
# from kokoro.custom_stft import CustomSTFT
from torch.nn.utils import weight_norm
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
def init_weights(m, mean=0.0, std=0.01):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        m.weight.data.normal_(mean, std)

def get_padding(kernel_size, dilation=1):
    return int((kernel_size*dilation - dilation)/2)


class AdaIN1d(nn.Module):
    def __init__(self, style_dim, num_features):
        super().__init__()
        # affine should be False, however there's a bug in the old torch.onnx.export (not newer dynamo) that causes the channel dimension to be lost if affine=False. When affine is true, there's additional learnably parameters. This shouldn't really matter setting it to True, since we're in inference mode
        self.norm = nn.InstanceNorm1d(num_features, affine=True)
        self.fc = nn.Linear(style_dim, num_features*2)

    def forward(self, x, s):
        h = self.fc(s)
        h = h.view(h.size(0), h.size(1), 1)
        gamma, beta = torch.chunk(h, chunks=2, dim=1)
        return (1 + gamma) * self.norm(x) + beta


class AdaINResBlock1(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
        super(AdaINResBlock1, self).__init__()
        self.convs1 = nn.ModuleList([
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
                                  padding=get_padding(kernel_size, dilation[0]))),
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
                                  padding=get_padding(kernel_size, dilation[1]))),
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
                                  padding=get_padding(kernel_size, dilation[2])))
        ])
        self.convs1.apply(init_weights)
        self.convs2 = nn.ModuleList([
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
                                  padding=get_padding(kernel_size, 1))),
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
                                  padding=get_padding(kernel_size, 1))),
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
                                  padding=get_padding(kernel_size, 1)))
        ])
        self.convs2.apply(init_weights)
        self.adain1 = nn.ModuleList([
            AdaIN1d(style_dim, channels),
            AdaIN1d(style_dim, channels),
            AdaIN1d(style_dim, channels),
        ])
        self.adain2 = nn.ModuleList([
            AdaIN1d(style_dim, channels),
            AdaIN1d(style_dim, channels),
            AdaIN1d(style_dim, channels),
        ])
        self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
        self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])

    def forward(self, x, s):
        for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
            xt = n1(x, s)
            xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2)  # Snake1D
            xt = c1(xt)
            xt = n2(xt, s)
            xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2)  # Snake1D
            xt = c2(xt)
            x = xt + x
        return x


class TorchSTFT(nn.Module):
    def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
        super().__init__()
        self.filter_length = filter_length
        self.hop_length = hop_length
        self.win_length = win_length
        assert window == 'hann', window
        self.window = torch.hann_window(win_length, periodic=True, dtype=torch.float32)

    def transform(self, input_data):
        forward_transform = torch.stft(
            input_data,
            self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
            return_complex=True)
        return torch.abs(forward_transform), torch.angle(forward_transform)

    def inverse(self, magnitude, phase):
        inverse_transform = torch.istft(
            magnitude * torch.exp(phase * 1j),
            self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
        return inverse_transform.unsqueeze(-2)  # unsqueeze to stay consistent with conv_transpose1d implementation

    def forward(self, input_data):
        self.magnitude, self.phase = self.transform(input_data)
        reconstruction = self.inverse(self.magnitude, self.phase)
        return reconstruction


class SineGen(nn.Module):
    """ Definition of sine generator
    SineGen(samp_rate, harmonic_num = 0,
            sine_amp = 0.1, noise_std = 0.003,
            voiced_threshold = 0,
            flag_for_pulse=False)
    samp_rate: sampling rate in Hz
    harmonic_num: number of harmonic overtones (default 0)
    sine_amp: amplitude of sine-wavefrom (default 0.1)
    noise_std: std of Gaussian noise (default 0.003)
    voiced_thoreshold: F0 threshold for U/V classification (default 0)
    flag_for_pulse: this SinGen is used inside PulseGen (default False)
    Note: when flag_for_pulse is True, the first time step of a voiced
        segment is always sin(torch.pi) or cos(0)
    """
    def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
                 sine_amp=0.1, noise_std=0.003,
                 voiced_threshold=0,
                 flag_for_pulse=False):
        super(SineGen, self).__init__()
        self.sine_amp = sine_amp
        self.noise_std = noise_std
        self.harmonic_num = harmonic_num
        self.dim = self.harmonic_num + 1
        self.sampling_rate = samp_rate
        self.voiced_threshold = voiced_threshold
        self.flag_for_pulse = flag_for_pulse
        self.upsample_scale = upsample_scale

    def _f02uv(self, f0):
        # generate uv signal
        uv = (f0 > self.voiced_threshold).type(torch.float32)
        return uv

    def _f02sine(self, f0_values):
        """ f0_values: (batchsize, length, dim)
            where dim indicates fundamental tone and overtones
        """
        # convert to F0 in rad. The interger part n can be ignored
        # because 2 * torch.pi * n doesn't affect phase
        rad_values = (f0_values / self.sampling_rate) % 1
        # initial phase noise (no noise for fundamental component)
        rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
        rand_ini[:, 0] = 0
        rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
        # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
        if not self.flag_for_pulse:
            rad_values = F.interpolate(rad_values.transpose(1, 2), scale_factor=1/self.upsample_scale, mode="linear").transpose(1, 2)
            phase = torch.cumsum(rad_values, dim=1) * 2 * torch.pi
            phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale, scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
            sines = torch.sin(phase)
        else:
            # If necessary, make sure that the first time step of every
            # voiced segments is sin(pi) or cos(0)
            # This is used for pulse-train generation
            # identify the last time step in unvoiced segments
            uv = self._f02uv(f0_values)
            uv_1 = torch.roll(uv, shifts=-1, dims=1)
            uv_1[:, -1, :] = 1
            u_loc = (uv < 1) * (uv_1 > 0)
            # get the instantanouse phase
            tmp_cumsum = torch.cumsum(rad_values, dim=1)
            # different batch needs to be processed differently
            for idx in range(f0_values.shape[0]):
                temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
                temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
                # stores the accumulation of i.phase within
                # each voiced segments
                tmp_cumsum[idx, :, :] = 0
                tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
            # rad_values - tmp_cumsum: remove the accumulation of i.phase
            # within the previous voiced segment.
            i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
            # get the sines
            sines = torch.cos(i_phase * 2 * torch.pi)
        return sines

    def forward(self, f0):
        """ sine_tensor, uv = forward(f0)
        input F0: tensor(batchsize=1, length, dim=1)
                  f0 for unvoiced steps should be 0
        output sine_tensor: tensor(batchsize=1, length, dim)
        output uv: tensor(batchsize=1, length, 1)
        """
        f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
        # fundamental component
        fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
        # generate sine waveforms
        sine_waves = self._f02sine(fn) * self.sine_amp
        # generate uv signal
        # uv = torch.ones(f0.shape)
        # uv = uv * (f0 > self.voiced_threshold)
        uv = self._f02uv(f0)
        # noise: for unvoiced should be similar to sine_amp
        #        std = self.sine_amp/3 -> max value ~ self.sine_amp
        #        for voiced regions is self.noise_std
        noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
        noise = noise_amp * torch.randn_like(sine_waves)
        # first: set the unvoiced part to 0 by uv
        # then: additive noise
        sine_waves = sine_waves * uv + noise
        return sine_waves, uv, noise


class SourceModuleHnNSF(nn.Module):
    """ SourceModule for hn-nsf
    SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
                 add_noise_std=0.003, voiced_threshod=0)
    sampling_rate: sampling_rate in Hz
    harmonic_num: number of harmonic above F0 (default: 0)
    sine_amp: amplitude of sine source signal (default: 0.1)
    add_noise_std: std of additive Gaussian noise (default: 0.003)
        note that amplitude of noise in unvoiced is decided
        by sine_amp
    voiced_threshold: threhold to set U/V given F0 (default: 0)
    Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
    F0_sampled (batchsize, length, 1)
    Sine_source (batchsize, length, 1)
    noise_source (batchsize, length 1)
    uv (batchsize, length, 1)
    """
    def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
                 add_noise_std=0.003, voiced_threshod=0):
        super(SourceModuleHnNSF, self).__init__()
        self.sine_amp = sine_amp
        self.noise_std = add_noise_std
        # to produce sine waveforms
        self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
                                 sine_amp, add_noise_std, voiced_threshod)
        # to merge source harmonics into a single excitation
        self.l_linear = nn.Linear(harmonic_num + 1, 1)
        self.l_tanh = nn.Tanh()

    def forward(self, x):
        """
        Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
        F0_sampled (batchsize, length, 1)
        Sine_source (batchsize, length, 1)
        noise_source (batchsize, length 1)
        """
        # source for harmonic branch
        with torch.no_grad():
            sine_wavs, uv, _ = self.l_sin_gen(x)
        sine_merge = self.l_tanh(self.l_linear(sine_wavs))
        # source for noise branch, in the same shape as uv
        noise = torch.randn_like(uv) * self.sine_amp / 3
        return sine_merge, noise, uv


class Generator(nn.Module):
    def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=False):
        super(Generator, self).__init__()
        self.num_kernels = len(resblock_kernel_sizes)
        self.num_upsamples = len(upsample_rates)
        self.m_source = SourceModuleHnNSF(
                    sampling_rate=24000,
                    upsample_scale=math.prod(upsample_rates) * gen_istft_hop_size,
                    harmonic_num=8, voiced_threshod=10)
        self.f0_upsamp = nn.Upsample(scale_factor=math.prod(upsample_rates) * gen_istft_hop_size)
        self.noise_convs = nn.ModuleList()
        self.noise_res = nn.ModuleList()
        self.ups = nn.ModuleList()
        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
            self.ups.append(weight_norm(
                nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
                                   k, u, padding=(k-u)//2)))
        self.resblocks = nn.ModuleList()
        for i in range(len(self.ups)):
            ch = upsample_initial_channel//(2**(i+1))
            for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
                self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim))
            c_cur = upsample_initial_channel // (2 ** (i + 1))
            if i + 1 < len(upsample_rates):
                stride_f0 = math.prod(upsample_rates[i + 1:])
                self.noise_convs.append(nn.Conv1d(
                    gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
                self.noise_res.append(AdaINResBlock1(c_cur, 7, [1,3,5], style_dim))
            else:
                self.noise_convs.append(nn.Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
                self.noise_res.append(AdaINResBlock1(c_cur, 11, [1,3,5], style_dim))
        self.post_n_fft = gen_istft_n_fft
        self.conv_post = weight_norm(nn.Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
        self.ups.apply(init_weights)
        self.conv_post.apply(init_weights)
        self.reflection_pad = nn.ReflectionPad1d((1, 0))
        self.stft = (
            CustomSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
            if disable_complex
            else TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
        )

    def forward(self, x, s, f0):
        with torch.no_grad():
            f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
            har_source, noi_source, uv = self.m_source(f0)
            har_source = har_source.transpose(1, 2).squeeze(1)
            har_spec, har_phase = self.stft.transform(har_source)
            har = torch.cat([har_spec, har_phase], dim=1)
        for i in range(self.num_upsamples):
            x = F.leaky_relu(x, negative_slope=0.1)
            x_source = self.noise_convs[i](har)
            x_source = self.noise_res[i](x_source, s)
            x = self.ups[i](x)
            if i == self.num_upsamples - 1:
                x = self.reflection_pad(x)
            x = x + x_source
            xs = None
            for j in range(self.num_kernels):
                if xs is None:
                    xs = self.resblocks[i*self.num_kernels+j](x, s)
                else:
                    xs += self.resblocks[i*self.num_kernels+j](x, s)
            x = xs / self.num_kernels
        x = F.leaky_relu(x)
        x = self.conv_post(x)
        spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
        phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
        return self.stft.inverse(spec, phase)


class UpSample1d(nn.Module):
    def __init__(self, layer_type):
        super().__init__()
        self.layer_type = layer_type

    def forward(self, x):
        if self.layer_type == 'none':
            return x
        else:
            return F.interpolate(x, scale_factor=2, mode='nearest')


class AdainResBlk1d(nn.Module):
    def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), upsample='none', dropout_p=0.0):
        super().__init__()
        self.actv = actv
        self.upsample_type = upsample
        self.upsample = UpSample1d(upsample)
        self.learned_sc = dim_in != dim_out
        self._build_weights(dim_in, dim_out, style_dim)
        self.dropout = nn.Dropout(dropout_p)
        if upsample == 'none':
            self.pool = nn.Identity()
        else:
            self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))

    def _build_weights(self, dim_in, dim_out, style_dim):
        self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
        self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
        self.norm1 = AdaIN1d(style_dim, dim_in)
        self.norm2 = AdaIN1d(style_dim, dim_out)
        if self.learned_sc:
            self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))

    def _shortcut(self, x):
        x = self.upsample(x)
        if self.learned_sc:
            x = self.conv1x1(x)
        return x

    def _residual(self, x, s):
        x = self.norm1(x, s)
        x = self.actv(x)
        x = self.pool(x)
        x = self.conv1(self.dropout(x))
        x = self.norm2(x, s)
        x = self.actv(x)
        x = self.conv2(self.dropout(x))
        return x

    def forward(self, x, s):
        out = self._residual(x, s)
        out = (out + self._shortcut(x)) * torch.rsqrt(torch.tensor(2))
        return out


class Decoder(nn.Module):
    def __init__(self, dim_in, style_dim, dim_out,
                 resblock_kernel_sizes,
                 upsample_rates,
                 upsample_initial_channel,
                 resblock_dilation_sizes,
                 upsample_kernel_sizes,
                 gen_istft_n_fft, gen_istft_hop_size,
                 disable_complex=False):
        super().__init__()
        self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
        self.decode = nn.ModuleList()
        self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
        self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
        self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
        self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
        self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
        self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
        self.asr_res = nn.Sequential(weight_norm(nn.Conv1d(512, 64, kernel_size=1)))
        self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
                                   upsample_initial_channel, resblock_dilation_sizes,
                                   upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=disable_complex)

    def forward(self, asr, F0_curve, N, s):
        F0 = self.F0_conv(F0_curve.unsqueeze(1))
        N = self.N_conv(N.unsqueeze(1))
        x = torch.cat([asr, F0, N], axis=1)
        x = self.encode(x, s)
        asr_res = self.asr_res(asr)
        res = True
        for block in self.decode:
            if res:
                x = torch.cat([x, asr_res, F0, N], axis=1)
            x = block(x, s)
            if block.upsample_type != "none":
                res = False
        x = self.generator(x, s, F0_curve)
        return x

In [40]:
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
from torch.nn.utils import weight_norm
from transformers import AlbertModel
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class LinearNorm(nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
        super(LinearNorm, self).__init__()
        self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
        nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain))

    def forward(self, x):
        return self.linear_layer(x)


class LayerNorm(nn.Module):
    def __init__(self, channels, eps=1e-5):
        super().__init__()
        self.channels = channels
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(channels))
        self.beta = nn.Parameter(torch.zeros(channels))

    def forward(self, x):
        x = x.transpose(1, -1)
        x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
        return x.transpose(1, -1)


class TextEncoder(nn.Module):
    def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
        super().__init__()
        self.embedding = nn.Embedding(n_symbols, channels)
        padding = (kernel_size - 1) // 2
        self.cnn = nn.ModuleList()
        for _ in range(depth):
            self.cnn.append(nn.Sequential(
                weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
                LayerNorm(channels),
                actv,
                nn.Dropout(0.2),
            ))
        self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)

    def forward(self, x, input_lengths, m):
        x = self.embedding(x)  # [B, T, emb]
        x = x.transpose(1, 2)  # [B, emb, T]
        m = m.unsqueeze(1)
        x.masked_fill_(m, 0.0)
        for c in self.cnn:
            x = c(x)
            x.masked_fill_(m, 0.0)
        x = x.transpose(1, 2)  # [B, T, chn]
        lengths = input_lengths if input_lengths.device == torch.device('cpu') else input_lengths.to('cpu')
        x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        self.lstm.flatten_parameters()
        x, _ = self.lstm(x)
        x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
        x = x.transpose(-1, -2)
        x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
        x_pad[:, :, :x.shape[-1]] = x
        x = x_pad
        x.masked_fill_(m, 0.0)
        return x


class AdaLayerNorm(nn.Module):
    def __init__(self, style_dim, channels, eps=1e-5):
        super().__init__()
        self.channels = channels
        self.eps = eps
        self.fc = nn.Linear(style_dim, channels*2)

    def forward(self, x, s):
        x = x.transpose(-1, -2)
        x = x.transpose(1, -1)
        h = self.fc(s)
        h = h.view(h.size(0), h.size(1), 1)
        gamma, beta = torch.chunk(h, chunks=2, dim=1)
        gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
        x = F.layer_norm(x, (self.channels,), eps=self.eps)
        x = (1 + gamma) * x + beta
        return x.transpose(1, -1).transpose(-1, -2)


class ProsodyPredictor(nn.Module):
    def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
        super().__init__()
        self.text_encoder = DurationEncoder(sty_dim=style_dim, d_model=d_hid,nlayers=nlayers, dropout=dropout)
        self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
        self.duration_proj = LinearNorm(d_hid, max_dur)
        self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
        self.F0 = nn.ModuleList()
        self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
        self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
        self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
        self.N = nn.ModuleList()
        self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
        self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
        self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
        self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
        self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)

    def forward(self, texts, style, text_lengths, alignment, m):
        d = self.text_encoder(texts, style, text_lengths, m)
        m = m.unsqueeze(1)
        lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
        x = nn.utils.rnn.pack_padded_sequence(d, lengths, batch_first=True, enforce_sorted=False)
        self.lstm.flatten_parameters()
        x, _ = self.lstm(x)
        x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
        x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]], device=x.device)
        x_pad[:, :x.shape[1], :] = x
        x = x_pad
        duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False))
        en = (d.transpose(-1, -2) @ alignment)
        return duration.squeeze(-1), en

    def F0Ntrain(self, x, s):
        x, _ = self.shared(x.transpose(-1, -2))
        F0 = x.transpose(-1, -2)
        for block in self.F0:
            F0 = block(F0, s)
        F0 = self.F0_proj(F0)
        N = x.transpose(-1, -2)
        for block in self.N:
            N = block(N, s)
        N = self.N_proj(N)
        return F0.squeeze(1), N.squeeze(1)


class DurationEncoder(nn.Module):
    def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
        super().__init__()
        self.lstms = nn.ModuleList()
        for _ in range(nlayers):
            self.lstms.append(nn.LSTM(d_model + sty_dim, d_model // 2, num_layers=1, batch_first=True, bidirectional=True, dropout=dropout))
            self.lstms.append(AdaLayerNorm(sty_dim, d_model))
        self.dropout = dropout
        self.d_model = d_model
        self.sty_dim = sty_dim

    def forward(self, x, style, text_lengths, m):
        masks = m
        x = x.permute(2, 0, 1)
        s = style.expand(x.shape[0], x.shape[1], -1)
        x = torch.cat([x, s], axis=-1)
        x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
        x = x.transpose(0, 1)
        x = x.transpose(-1, -2)
        for block in self.lstms:
            if isinstance(block, AdaLayerNorm):
                x = block(x.transpose(-1, -2), style).transpose(-1, -2)
                x = torch.cat([x, s.permute(1, 2, 0)], axis=1)
                x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
            else:
                lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
                x = x.transpose(-1, -2)
                x = nn.utils.rnn.pack_padded_sequence(
                    x, lengths, batch_first=True, enforce_sorted=False)
                block.flatten_parameters()
                x, _ = block(x)
                x, _ = nn.utils.rnn.pad_packed_sequence(
                    x, batch_first=True)
                x = F.dropout(x, p=self.dropout, training=False)
                x = x.transpose(-1, -2)
                x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
                x_pad[:, :, :x.shape[-1]] = x
                x = x_pad

        return x.transpose(-1, -2)


# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
class CustomAlbert(AlbertModel):
    def forward(self, *args, **kwargs):
        outputs = super().forward(*args, **kwargs)
        return outputs.last_hidden_state

In [41]:
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from loguru import logger
from transformers import AlbertConfig
from typing import Dict, Optional, Union
import json
import torch

class KModel(torch.nn.Module):
    '''
    KModel is a torch.nn.Module with 2 main responsibilities:
    1. Init weights, downloading config.json + model.pth from HF if needed
    2. forward(phonemes: str, ref_s: FloatTensor) -> (audio: FloatTensor)

    You likely only need one KModel instance, and it can be reused across
    multiple KPipelines to avoid redundant memory allocation.

    Unlike KPipeline, KModel is language-blind.

    KModel stores self.vocab and thus knows how to map phonemes -> input_ids,
    so there is no need to repeatedly download config.json outside of KModel.
    '''

    MODEL_NAMES = {
        'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth',
        'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth',
    }

    def __init__(
        self,
        repo_id: Optional[str] = None,
        config: Union[Dict, str, None] = None,
        model: Optional[str] = None,
        disable_complex: bool = False
    ):
        super().__init__()
        if repo_id is None:
            repo_id = 'hexgrad/Kokoro-82M'
            print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
        self.repo_id = repo_id
        if not isinstance(config, dict):
            if not config:
                logger.debug("No config provided, downloading from HF")
                config = hf_hub_download(repo_id=repo_id, filename='config.json')
            with open(config, 'r', encoding='utf-8') as r:
                config = json.load(r)
                logger.debug(f"Loaded config: {config}")
        self.vocab = config['vocab']
        self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
        self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
        self.context_length = self.bert.config.max_position_embeddings
        self.predictor = ProsodyPredictor(
            style_dim=config['style_dim'], d_hid=config['hidden_dim'],
            nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout']
        )
        self.text_encoder = TextEncoder(
            channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'],
            depth=config['n_layer'], n_symbols=config['n_token']
        )
        self.decoder = Decoder(
            dim_in=config['hidden_dim'], style_dim=config['style_dim'],
            dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet']
        )
        if not model:
            model = hf_hub_download(repo_id=repo_id, filename=KModel.MODEL_NAMES[repo_id])
        for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
            assert hasattr(self, key), key
            try:
                getattr(self, key).load_state_dict(state_dict)
            except:
                logger.debug(f"Did not load {key} from state_dict")
                state_dict = {k[7:]: v for k, v in state_dict.items()}
                getattr(self, key).load_state_dict(state_dict, strict=False)

    @property
    def device(self):
        return self.bert.device

    @dataclass
    class Output:
        audio: torch.FloatTensor
        pred_dur: Optional[torch.LongTensor] = None

    @torch.no_grad()
    def forward_with_tokens(
        self,
        input_ids: torch.LongTensor,
        ref_s: torch.FloatTensor,
        speed: float = 1
    ) -> tuple[torch.FloatTensor, torch.LongTensor]:
        input_lengths = torch.full(
            (input_ids.shape[0],),
            input_ids.shape[-1],
            device=input_ids.device,
            dtype=torch.long
        )

        text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths)
        text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device)
        bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
        d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
        s = ref_s[:, 128:]
        d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
        x, _ = self.predictor.lstm(d)
        duration = self.predictor.duration_proj(x)
        duration = torch.sigmoid(duration).sum(axis=-1) / speed
        pred_dur = torch.round(duration).clamp(min=1).long().squeeze()
        indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur)
        pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device)
        pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1
        pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
        en = d.transpose(-1, -2) @ pred_aln_trg
        F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
        t_en = self.text_encoder(input_ids, input_lengths, text_mask)
        asr = t_en @ pred_aln_trg
        audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze()
        return audio, pred_dur

    def forward(
        self,
        phonemes: str,
        ref_s: torch.FloatTensor,
        speed: float = 1,
        return_output: bool = False
    ) -> Union['KModel.Output', torch.FloatTensor]:
        input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
        logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
        assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
        input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
        ref_s = ref_s.to(self.device)
        audio, pred_dur = self.forward_with_tokens(input_ids, ref_s, speed)
        audio = audio.squeeze().cpu()
        pred_dur = pred_dur.cpu() if pred_dur is not None else None
        logger.debug(f"pred_dur: {pred_dur}")
        return self.Output(audio=audio, pred_dur=pred_dur) if return_output else audio

class KModelForONNX(torch.nn.Module):
    def __init__(self, kmodel: KModel):
        super().__init__()
        self.kmodel = kmodel

    def forward(
        self,
        input_ids: torch.LongTensor,
        ref_s: torch.FloatTensor,
        speed: float = 1
    ) -> tuple[torch.FloatTensor, torch.LongTensor]:
        waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
        return waveform, duration

In [42]:
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from loguru import logger
from misaki import en, espeak
from typing import Callable, Generator, List, Optional, Tuple, Union
import re
import torch
import os

ALIASES = {
    'en-us': 'a',
    'en-gb': 'b',
    'es': 'e',
    'fr-fr': 'f',
    'hi': 'h',
    'it': 'i',
    'pt-br': 'p',
    'ja': 'j',
    'zh': 'z',
}

LANG_CODES = dict(
    # pip install misaki[en]
    a='American English',
    b='British English',

    # espeak-ng
    e='es',
    f='fr-fr',
    h='hi',
    i='it',
    p='pt-br',

    # pip install misaki[ja]
    j='Japanese',

    # pip install misaki[zh]
    z='Mandarin Chinese',
)

class KPipeline:
    '''
    KPipeline is a language-aware support class with 2 main responsibilities:
    1. Perform language-specific G2P, mapping (and chunking) text -> phonemes
    2. Manage and store voices, lazily downloaded from HF if needed

    You are expected to have one KPipeline per language. If you have multiple
    KPipelines, you should reuse one KModel instance across all of them.

    KPipeline is designed to work with a KModel, but this is not required.
    There are 2 ways to pass an existing model into a pipeline:
    1. On init: us_pipeline = KPipeline(lang_code='a', model=model)
    2. On call: us_pipeline(text, voice, model=model)

    By default, KPipeline will automatically initialize its own KModel. To
    suppress this, construct a "quiet" KPipeline with model=False.

    A "quiet" KPipeline yields (graphemes, phonemes, None) without generating
    any audio. You can use this to phonemize and chunk your text in advance.

    A "loud" KPipeline _with_ a model yields (graphemes, phonemes, audio).
    '''
    def __init__(
        self,
        lang_code: str,
        repo_id: Optional[str] = None,
        model: Union[KModel, bool] = True,
        trf: bool = False,
        en_callable: Optional[Callable[[str], str]] = None,
        device: Optional[str] = None
    ):
        """Initialize a KPipeline.

        Args:
            lang_code: Language code for G2P processing
            model: KModel instance, True to create new model, False for no model
            trf: Whether to use transformer-based G2P
            device: Override default device selection ('cuda' or 'cpu', or None for auto)
                   If None, will auto-select cuda if available
                   If 'cuda' and not available, will explicitly raise an error
        """
        if repo_id is None:
            repo_id = 'hexgrad/Kokoro-82M'
            print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
        self.repo_id = repo_id
        lang_code = lang_code.lower()
        lang_code = ALIASES.get(lang_code, lang_code)
        assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
        self.lang_code = lang_code
        self.model = None
        if isinstance(model, KModel):
            self.model = model
        elif model:
            if device == 'cuda' and not torch.cuda.is_available():
                raise RuntimeError("CUDA requested but not available")
            if device == 'mps' and not torch.backends.mps.is_available():
                raise RuntimeError("MPS requested but not available")
            if device == 'mps' and os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') != '1':
                raise RuntimeError("MPS requested but fallback not enabled")
            if device is None:
                if torch.cuda.is_available():
                    device = 'cuda'
                elif os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') == '1' and torch.backends.mps.is_available():
                    device = 'mps'
                else:
                    device = 'cpu'
            try:
                self.model = KModel(repo_id=repo_id).to(device).eval()
            except RuntimeError as e:
                if device == 'cuda':
                    raise RuntimeError(f"""Failed to initialize model on CUDA: {e}.
                                       Try setting device='cpu' or check CUDA installation.""")
                raise
        self.voices = {}
        if lang_code in 'ab':
            try:
                fallback = espeak.EspeakFallback(british=lang_code=='b')
            except Exception as e:
                logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
                logger.warning({str(e)})
                fallback = None
            self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='')
        elif lang_code == 'j':
            try:
                from misaki import ja
                self.g2p = ja.JAG2P()
            except ImportError:
                logger.error("You need to `pip install misaki[ja]` to use lang_code='j'")
                raise
        elif lang_code == 'z':
            try:
                from misaki import zh
                self.g2p = zh.ZHG2P(
                    version=None if repo_id.endswith('/Kokoro-82M') else '1.1',
                    en_callable=en_callable
                )
            except ImportError:
                logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
                raise
        else:
            language = LANG_CODES[lang_code]
            logger.warning(f"Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
            self.g2p = espeak.EspeakG2P(language=language)

    def load_single_voice(self, voice: str):
        if voice in self.voices:
            return self.voices[voice]
        if voice.endswith('.pt'):
            f = voice
        else:
            f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt')
            if not voice.startswith(self.lang_code):
                v = LANG_CODES.get(voice, voice)
                p = LANG_CODES.get(self.lang_code, self.lang_code)
                logger.warning(f'Language mismatch, loading {v} voice into {p} pipeline.')
        pack = torch.load(f, weights_only=True)
        self.voices[voice] = pack
        return pack

    """
    load_voice is a helper function that lazily downloads and loads a voice:
    Single voice can be requested (e.g. 'af_bella') or multiple voices (e.g. 'af_bella,af_jessica').
    If multiple voices are requested, they are averaged.
    Delimiter is optional and defaults to ','.
    """
    def load_voice(self, voice: Union[str, torch.FloatTensor], delimiter: str = ",") -> torch.FloatTensor:
        if isinstance(voice, torch.FloatTensor):
            return voice
        if voice in self.voices:
            return self.voices[voice]
        logger.debug(f"Loading voice: {voice}")
        packs = [self.load_single_voice(v) for v in voice.split(delimiter)]
        if len(packs) == 1:
            return packs[0]
        self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
        return self.voices[voice]

    @staticmethod
    def tokens_to_ps(tokens: List[en.MToken]) -> str:
        return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()

    @staticmethod
    def waterfall_last(
        tokens: List[en.MToken],
        next_count: int,
        waterfall: List[str] = ['!.?…', ':;', ',—'],
        bumps: List[str] = [')', '”']
    ) -> int:
        for w in waterfall:
            z = next((i for i, t in reversed(list(enumerate(tokens))) if t.phonemes in set(w)), None)
            if z is None:
                continue
            z += 1
            if z < len(tokens) and tokens[z].phonemes in bumps:
                z += 1
            if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510:
                return z
        return len(tokens)

    @staticmethod
    def tokens_to_text(tokens: List[en.MToken]) -> str:
        return ''.join(t.text + t.whitespace for t in tokens).strip()

    def en_tokenize(
        self,
        tokens: List[en.MToken]
    ) -> Generator[Tuple[str, str, List[en.MToken]], None, None]:
        tks = []
        pcount = 0
        for t in tokens:
            # American English: ɾ => T
            t.phonemes = '' if t.phonemes is None else t.phonemes#.replace('ɾ', 'T')
            next_ps = t.phonemes + (' ' if t.whitespace else '')
            next_pcount = pcount + len(next_ps.rstrip())
            if next_pcount > 510:
                z = KPipeline.waterfall_last(tks, next_pcount)
                text = KPipeline.tokens_to_text(tks[:z])
                logger.debug(f"Chunking text at {z}: '{text[:30]}{'...' if len(text) > 30 else ''}'")
                ps = KPipeline.tokens_to_ps(tks[:z])
                yield text, ps, tks[:z]
                tks = tks[z:]
                pcount = len(KPipeline.tokens_to_ps(tks))
                if not tks:
                    next_ps = next_ps.lstrip()
            tks.append(t)
            pcount += len(next_ps)
        if tks:
            text = KPipeline.tokens_to_text(tks)
            ps = KPipeline.tokens_to_ps(tks)
            yield ''.join(text).strip(), ''.join(ps).strip(), tks

    @staticmethod
    def infer(
        model: KModel,
        ps: str,
        pack: torch.FloatTensor,
        speed: Union[float, Callable[[int], float]] = 1
    ) -> KModel.Output:
        if callable(speed):
            speed = speed(len(ps))
        return model(ps, pack[len(ps)-1], speed, return_output=True)

    def generate_from_tokens(
        self,
        tokens: Union[str, List[en.MToken]],
        voice: str,
        speed: float = 1,
        model: Optional[KModel] = None
    ) -> Generator['KPipeline.Result', None, None]:
        """Generate audio from either raw phonemes or pre-processed tokens.

        Args:
            tokens: Either a phoneme string or list of pre-processed MTokens
            voice: The voice to use for synthesis
            speed: Speech speed modifier (default: 1)
            model: Optional KModel instance (uses pipeline's model if not provided)

        Yields:
            KPipeline.Result containing the input tokens and generated audio

        Raises:
            ValueError: If no voice is provided or token sequence exceeds model limits
        """
        model = model or self.model
        if model and voice is None:
            raise ValueError('Specify a voice: pipeline.generate_from_tokens(..., voice="af_heart")')

        pack = self.load_voice(voice).to(model.device) if model else None

        # Handle raw phoneme string
        if isinstance(tokens, str):
            logger.debug("Processing phonemes from raw string")
            if len(tokens) > 510:
                raise ValueError(f'Phoneme string too long: {len(tokens)} > 510')
            output = KPipeline.infer(model, tokens, pack, speed) if model else None
            yield self.Result(graphemes='', phonemes=tokens, output=output)
            return

        logger.debug("Processing MTokens")
        # Handle pre-processed tokens
        for gs, ps, tks in self.en_tokenize(tokens):
            if not ps:
                continue
            elif len(ps) > 510:
                logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
                logger.warning("Truncating to 510 characters")
                ps = ps[:510]
            output = KPipeline.infer(model, ps, pack, speed) if model else None
            if output is not None and output.pred_dur is not None:
                KPipeline.join_timestamps(tks, output.pred_dur)
            yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)

    @staticmethod
    def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor):
        # Multiply by 600 to go from pred_dur frames to sample_rate 24000
        # Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
        # We will count nice round half-frames, so the divisor is 80
        MAGIC_DIVISOR = 80
        if not tokens or len(pred_dur) < 3:
            # We expect at least 3: <bos>, token, <eos>
            return
        # We track 2 counts, measured in half-frames: (left, right)
        # This way we can cut space characters in half
        # TODO: Is -3 an appropriate offset?
        left = right = 2 * max(0, pred_dur[0].item() - 3)
        # Updates:
        # left = right + (2 * token_dur) + space_dur
        # right = left + space_dur
        i = 1
        for t in tokens:
            if i >= len(pred_dur)-1:
                break
            if not t.phonemes:
                if t.whitespace:
                    i += 1
                    left = right + pred_dur[i].item()
                    right = left + pred_dur[i].item()
                    i += 1
                continue
            j = i + len(t.phonemes)
            if j >= len(pred_dur):
                break
            t.start_ts = left / MAGIC_DIVISOR
            token_dur = pred_dur[i: j].sum().item()
            space_dur = pred_dur[j].item() if t.whitespace else 0
            left = right + (2 * token_dur) + space_dur
            t.end_ts = left / MAGIC_DIVISOR
            right = left + space_dur
            i = j + (1 if t.whitespace else 0)

    @dataclass
    class Result:
        graphemes: str
        phonemes: str
        tokens: Optional[List[en.MToken]] = None
        output: Optional[KModel.Output] = None
        text_index: Optional[int] = None

        @property
        def audio(self) -> Optional[torch.FloatTensor]:
            return None if self.output is None else self.output.audio

        @property
        def pred_dur(self) -> Optional[torch.LongTensor]:
            return None if self.output is None else self.output.pred_dur

        ### MARK: BEGIN BACKWARD COMPAT ###
        def __iter__(self):
            yield self.graphemes
            yield self.phonemes
            yield self.audio

        def __getitem__(self, index):
            return [self.graphemes, self.phonemes, self.audio][index]

        def __len__(self):
            return 3
        #### MARK: END BACKWARD COMPAT ####

    def __call__(
        self,
        text: Union[str, List[str]],
        voice: Optional[str] = None,
        speed: Union[float, Callable[[int], float]] = 1,
        split_pattern: Optional[str] = r'\n+',
        model: Optional[KModel] = None
    ) -> Generator['KPipeline.Result', None, None]:
        model = model or self.model
        if model and voice is None:
            raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")')
        pack = self.load_voice(voice).to(model.device) if model else None

        # Convert input to list of segments
        if isinstance(text, str):
            text = re.split(split_pattern, text.strip()) if split_pattern else [text]

        # Process each segment
        for graphemes_index, graphemes in enumerate(text):
            if not graphemes.strip():  # Skip empty segments
                continue

            # English processing (unchanged)
            if self.lang_code in 'ab':
                logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
                _, tokens = self.g2p(graphemes)
                for gs, ps, tks in self.en_tokenize(tokens):
                    if not ps:
                        continue
                    elif len(ps) > 510:
                        logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
                        ps = ps[:510]
                    output = KPipeline.infer(model, ps, pack, speed) if model else None
                    if output is not None and output.pred_dur is not None:
                        KPipeline.join_timestamps(tks, output.pred_dur)
                    yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output, text_index=graphemes_index)

            # Non-English processing with chunking
            else:
                # Split long text into smaller chunks (roughly 400 characters each)
                # Using sentence boundaries when possible
                chunk_size = 400
                chunks = []

                # Try to split on sentence boundaries first
                sentences = re.split(r'([.!?]+)', graphemes)
                current_chunk = ""

                for i in range(0, len(sentences), 2):
                    sentence = sentences[i]
                    # Add the punctuation back if it exists
                    if i + 1 < len(sentences):
                        sentence += sentences[i + 1]

                    if len(current_chunk) + len(sentence) <= chunk_size:
                        current_chunk += sentence
                    else:
                        if current_chunk:
                            chunks.append(current_chunk.strip())
                        current_chunk = sentence

                if current_chunk:
                    chunks.append(current_chunk.strip())

                # If no chunks were created (no sentence boundaries), fall back to character-based chunking
                if not chunks:
                    chunks = [graphemes[i:i+chunk_size] for i in range(0, len(graphemes), chunk_size)]

                # Process each chunk
                for chunk in chunks:
                    if not chunk.strip():
                        continue

                    ps, _ = self.g2p(chunk)
                    if not ps:
                        continue
                    elif len(ps) > 510:
                        logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
                        ps = ps[:510]

                    output = KPipeline.infer(model, ps, pack, speed) if model else None
                    yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index)

In [43]:
"""Models module for Kokoro TTS Local"""
from typing import Optional, Tuple, List
import torch
import os
import json
import codecs
from pathlib import Path
import numpy as np
import shutil
import threading

# Set environment variables for proper encoding
os.environ["PYTHONIOENCODING"] = "utf-8"
# Disable symlinks warning
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

# Setup for safer monkey-patching
import atexit
import signal
import sys

# Track whether patches have been applied
_patches_applied = {
    'json_load': False,
    'load_voice': False
}

def _cleanup_monkey_patches():
    """Restore original functions that were monkey-patched"""
    try:
        if _patches_applied['json_load'] and _original_json_load is not None:
            restore_json_load()
            _patches_applied['json_load'] = False
            print("Restored original json.load function")
    except Exception as e:
        print(f"Warning: Error restoring json.load: {e}")

    try:
        if _patches_applied['load_voice']:
            restore_original_load_voice()
            _patches_applied['load_voice'] = False
            print("Restored original KPipeline.load_voice function")
    except Exception as e:
        print(f"Warning: Error restoring KPipeline.load_voice: {e}")

# Register cleanup for normal exit
atexit.register(_cleanup_monkey_patches)

# Register cleanup for signals
for sig in [signal.SIGINT, signal.SIGTERM]:
    try:
        signal.signal(sig, lambda signum, frame: (
            print(f"\nReceived signal {signum}, cleaning up..."),
            _cleanup_monkey_patches(),
            sys.exit(1)
        ))
    except (ValueError, AttributeError):
        # Some signals might not be available on all platforms
        pass

# List of available voice files (54 voices across 8 languages)
VOICE_FILES = [
    # American English Female voices (11 voices)
    "af_heart.pt", "af_alloy.pt", "af_aoede.pt", "af_bella.pt", "af_jessica.pt",
    "af_kore.pt", "af_nicole.pt", "af_nova.pt", "af_river.pt", "af_sarah.pt", "af_sky.pt",

    # American English Male voices (9 voices)
    "am_adam.pt", "am_echo.pt", "am_eric.pt", "am_fenrir.pt", "am_liam.pt",
    "am_michael.pt", "am_onyx.pt", "am_puck.pt", "am_santa.pt",

    # British English Female voices (4 voices)
    "bf_alice.pt", "bf_emma.pt", "bf_isabella.pt", "bf_lily.pt",

    # British English Male voices (4 voices)
    "bm_daniel.pt", "bm_fable.pt", "bm_george.pt", "bm_lewis.pt",

    # Japanese voices (5 voices)
    "jf_alpha.pt", "jf_gongitsune.pt", "jf_nezumi.pt", "jf_tebukuro.pt", "jm_kumo.pt",

    # Mandarin Chinese voices (8 voices)
    "zf_xiaobei.pt", "zf_xiaoni.pt", "zf_xiaoxiao.pt", "zf_xiaoyi.pt",
    "zm_yunjian.pt", "zm_yunxi.pt", "zm_yunxia.pt", "zm_yunyang.pt",

    # Spanish voices (3 voices)
    "ef_dora.pt", "em_alex.pt", "em_santa.pt",

    # French voices (1 voice)
    "ff_siwis.pt",

    # Hindi voices (4 voices)
    "hf_alpha.pt", "hf_beta.pt", "hm_omega.pt", "hm_psi.pt",

    # Italian voices (2 voices)
    "if_sara.pt", "im_nicola.pt",

    # Brazilian Portuguese voices (3 voices)
    "pf_dora.pt", "pm_alex.pt", "pm_santa.pt"
]

# Language code mapping for different languages
LANGUAGE_CODES = {
    'a': 'American English',
    'b': 'British English',
    'j': 'Japanese',
    'z': 'Mandarin Chinese',
    'e': 'Spanish',
    'f': 'French',
    'h': 'Hindi',
    'i': 'Italian',
    'p': 'Brazilian Portuguese'
}

# Patch KPipeline's load_voice method to use weights_only=False
original_load_voice = KPipeline.load_voice

def patched_load_voice(self, voice_path):
    """Load voice model with weights_only=False for compatibility"""
    if not os.path.exists(voice_path):
        raise FileNotFoundError(f"Voice file not found: {voice_path}")
    voice_name = Path(voice_path).stem
    try:
        voice_model = torch.load(voice_path, weights_only=False)
        if voice_model is None:
            raise ValueError(f"Failed to load voice model from {voice_path}")
        # Ensure device is set
        if not hasattr(self, 'device'):
            self.device = 'cpu'
        # Move model to device and store in voices dictionary
        self.voices[voice_name] = voice_model.to(self.device)
        return self.voices[voice_name]
    except Exception as e:
        print(f"Error loading voice {voice_name}: {e}")
        raise

# Apply the patch
KPipeline.load_voice = patched_load_voice
_patches_applied['load_voice'] = True

# Store original function for restoration if needed
def restore_original_load_voice():
    global _patches_applied
    if _patches_applied['load_voice']:
        KPipeline.load_voice = original_load_voice
        _patches_applied['load_voice'] = False

def patch_json_load():
    """Patch json.load to handle UTF-8 encoded files with special characters"""
    global _patches_applied, _original_json_load
    original_load = json.load
    _original_json_load = original_load  # Store for restoration

    def custom_load(fp, *args, **kwargs):
        try:
            # Try reading with UTF-8 encoding
            if hasattr(fp, 'buffer'):
                content = fp.buffer.read().decode('utf-8')
            else:
                content = fp.read()
            try:
                return json.loads(content)
            except json.JSONDecodeError as e:
                print(f"JSON parsing error: {e}")
                raise
        except UnicodeDecodeError:
            # If UTF-8 fails, try with utf-8-sig for files with BOM
            fp.seek(0)
            content = fp.read()
            if isinstance(content, bytes):
                content = content.decode('utf-8-sig', errors='replace')
            try:
                return json.loads(content)
            except json.JSONDecodeError as e:
                print(f"JSON parsing error: {e}")
                raise

    json.load = custom_load
    _patches_applied['json_load'] = True
    return original_load  # Return original for restoration

# Store the original load function for potential restoration
_original_json_load = None

def restore_json_load():
    """Restore the original json.load function"""
    global _original_json_load, _patches_applied
    if _original_json_load is not None and _patches_applied['json_load']:
        json.load = _original_json_load
        _original_json_load = None
        _patches_applied['json_load'] = False

def load_config(config_path: str) -> dict:
    """Load configuration file with proper encoding handling"""
    try:
        with codecs.open(config_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except UnicodeDecodeError:
        # Fallback to utf-8-sig if regular utf-8 fails
        with codecs.open(config_path, 'r', encoding='utf-8-sig') as f:
            return json.load(f)

# Initialize espeak-ng
phonemizer_available = False  # Global flag to track if phonemizer is working
try:
    from phonemizer.backend.espeak.wrapper import EspeakWrapper
    from phonemizer import phonemize
    import espeakng_loader

    # Make library available first
    library_path = espeakng_loader.get_library_path()
    data_path = espeakng_loader.get_data_path()
    espeakng_loader.make_library_available()

    # Set up espeak-ng paths
    EspeakWrapper.library_path = library_path
    EspeakWrapper.data_path = data_path

    # Verify espeak-ng is working
    try:
        test_phonemes = phonemize('test', language='en-us')
        if test_phonemes:
            phonemizer_available = True
            print("Phonemizer successfully initialized")
        else:
            print("Note: Phonemization returned empty result")
            print("TTS will work, but phoneme visualization will be disabled")
    except Exception as e:
        # Continue without espeak functionality
        print(f"Note: Phonemizer not available: {e}")
        print("TTS will work, but phoneme visualization will be disabled")

except ImportError as e:
    print(f"Note: Phonemizer packages not installed: {e}")
    print("TTS will work, but phoneme visualization will be disabled")
    # Rather than automatically installing packages, inform the user
    print("If you want phoneme visualization, manually install required packages:")
    print("pip install espeakng-loader phonemizer-fork")

# Initialize pipeline globally with thread safety
_pipeline = None
_pipeline_lock = threading.RLock()  # Reentrant lock for thread safety

def download_voice_files(voice_files=None, repo_version="main", required_count=1):
    """Download voice files from Hugging Face.

    Args:
        voice_files: Optional list of voice files to download. If None, download all VOICE_FILES.
        repo_version: Version/tag of the repository to use (default: "main")
        required_count: Minimum number of voices required (default: 1)

    Returns:
        List of successfully downloaded voice files

    Raises:
        ValueError: If fewer than required_count voices could be downloaded
    """
    # Use absolute path for voices directory
    voices_dir = Path(os.path.abspath("voices"))
    voices_dir.mkdir(exist_ok=True)

    # Import here to avoid startup dependency
    from huggingface_hub import hf_hub_download
    downloaded_voices = []
    failed_voices = []

    # If specific voice files are requested, use those. Otherwise use all.
    files_to_download = voice_files if voice_files is not None else VOICE_FILES
    total_files = len(files_to_download)

    print(f"\nDownloading voice files... ({total_files} total files)")

    # Check for existing voice files first
    existing_files = []
    for voice_file in files_to_download:
        voice_path = voices_dir / voice_file
        if voice_path.exists():
            print(f"Voice file {voice_file} already exists")
            downloaded_voices.append(voice_file)
            existing_files.append(voice_file)

    # Remove existing files from the download list
    files_to_download = [f for f in files_to_download if f not in existing_files]
    if not files_to_download and downloaded_voices:
        print(f"All required voice files already exist ({len(downloaded_voices)} files)")
        return downloaded_voices

    # Proceed with downloading missing files
    retry_count = 3
    try:
        import tempfile
        with tempfile.TemporaryDirectory() as temp_dir:
            for voice_file in files_to_download:
                # Full path where the voice file should be
                voice_path = voices_dir / voice_file

                # Try with retries
                for attempt in range(retry_count):
                    try:
                        print(f"Downloading {voice_file}... (attempt {attempt+1}/{retry_count})")
                        # Download to a temporary location first
                        temp_path = hf_hub_download(
                            repo_id="hexgrad/Kokoro-82M",
                            filename=f"voices/{voice_file}",
                            local_dir=temp_dir,
                            force_download=True,
                            revision=repo_version
                        )

                        # Move the file to the correct location
                        os.makedirs(os.path.dirname(str(voice_path)), exist_ok=True)
                        shutil.copy2(temp_path, str(voice_path))  # Use copy2 instead of move

                        # Verify file integrity
                        if os.path.getsize(str(voice_path)) > 0:
                            downloaded_voices.append(voice_file)
                            print(f"Successfully downloaded {voice_file}")
                            break  # Success, exit retry loop
                        else:
                            print(f"Warning: Downloaded file {voice_file} has zero size, retrying...")
                            os.remove(str(voice_path))  # Remove invalid file
                            if attempt == retry_count - 1:
                                failed_voices.append(voice_file)
                    except (IOError, OSError, ValueError, FileNotFoundError, ConnectionError) as e:
                        print(f"Warning: Failed to download {voice_file} (attempt {attempt+1}): {e}")
                        if attempt == retry_count - 1:
                            failed_voices.append(voice_file)
                            print(f"Error: Failed all {retry_count} attempts to download {voice_file}")
    except Exception as e:
        print(f"Error during voice download process: {e}")
        import traceback
        traceback.print_exc()

    # Report results
    if failed_voices:
        print(f"Warning: Failed to download {len(failed_voices)} voice files: {', '.join(failed_voices)}")

    if not downloaded_voices:
        error_msg = "No voice files could be downloaded. Please check your internet connection."
        print(f"Error: {error_msg}")
        raise ValueError(error_msg)
    elif len(downloaded_voices) < required_count:
        error_msg = f"Only {len(downloaded_voices)} voice files could be downloaded, but {required_count} were required."
        print(f"Error: {error_msg}")
        raise ValueError(error_msg)
    else:
        print(f"Successfully processed {len(downloaded_voices)} voice files")

    return downloaded_voices

def build_model(model_path: str, device: str, repo_version: str = "main") -> KPipeline:
    """Build and return the Kokoro pipeline with proper encoding configuration

    Args:
        model_path: Path to the model file or None to use default
        device: Device to use ('cuda' or 'cpu')
        repo_version: Version/tag of the repository to use (default: "main")

    Returns:
        Initialized KPipeline instance
    """
    global _pipeline, _pipeline_lock

    # Use a lock for thread safety
    with _pipeline_lock:
        # Double-check pattern to avoid race conditions
        if _pipeline is not None:
            return _pipeline

        try:
            # Patch json loading before initializing pipeline
            patch_json_load()

            # Download model if it doesn't exist
            if model_path is None:
                model_path = 'kokoro-v1_0.pth'

            model_path = os.path.abspath(model_path)
            if not os.path.exists(model_path):
                print(f"Downloading model file {model_path}...")
                try:
                    from huggingface_hub import hf_hub_download
                    model_path = hf_hub_download(
                        repo_id="hexgrad/Kokoro-82M",
                        filename="kokoro-v1_0.pth",
                        local_dir=".",
                        force_download=True,
                        revision=repo_version
                    )
                    print(f"Model downloaded to {model_path}")
                except Exception as e:
                    print(f"Error downloading model: {e}")
                    raise ValueError(f"Could not download model: {e}") from e

            # Download config if it doesn't exist
            config_path = os.path.abspath("config.json")
            if not os.path.exists(config_path):
                print("Downloading config file...")
                try:
                    config_path = hf_hub_download(
                        repo_id="hexgrad/Kokoro-82M",
                        filename="config.json",
                        local_dir=".",
                        force_download=True,
                        revision=repo_version
                    )
                    print(f"Config downloaded to {config_path}")
                except Exception as e:
                    print(f"Error downloading config: {e}")
                    raise ValueError(f"Could not download config: {e}") from e

            # Download voice files - require at least one voice
            try:
                downloaded_voices = download_voice_files(repo_version=repo_version, required_count=1)
            except ValueError as e:
                print(f"Error: Voice files download failed: {e}")
                raise ValueError("Voice files download failed") from e

            # Validate language code
            lang_code = 'a'  # Default to 'a' for American English
            supported_codes = list(LANGUAGE_CODES.keys())
            if lang_code not in supported_codes:
                print(f"Warning: Unsupported language code '{lang_code}'. Using 'a' (American English).")
                print(f"Supported language codes: {', '.join(supported_codes)}")
                lang_code = 'a'

            # Initialize pipeline with validated language code
            pipeline_instance = KPipeline(lang_code=lang_code)
            if pipeline_instance is None:
                raise ValueError("Failed to initialize KPipeline - pipeline is None")

            # Store device parameter for reference in other operations
            pipeline_instance.device = device

            # Initialize voices dictionary if it doesn't exist
            if not hasattr(pipeline_instance, 'voices'):
                pipeline_instance.voices = {}

            # Try to load the first available voice with improved error handling
            voice_loaded = False
            for voice_file in downloaded_voices:
                voice_path = os.path.abspath(os.path.join("voices", voice_file))
                if os.path.exists(voice_path):
                    try:
                        pipeline_instance.load_voice(voice_path)
                        print(f"Successfully loaded voice: {voice_file}")
                        voice_loaded = True
                        break  # Successfully loaded a voice
                    except Exception as e:
                        print(f"Warning: Failed to load voice {voice_file}: {e}")
                        continue

            if not voice_loaded:
                print("Warning: Could not load any voice models")

            # Set the global _pipeline only after successful initialization
            _pipeline = pipeline_instance

        except Exception as e:
            print(f"Error initializing pipeline: {e}")
            # Restore original json.load on error
            restore_json_load()
            raise

        return _pipeline

def list_available_voices() -> List[str]:
    """List all available voice models"""
    # Always use absolute path for consistency
    voices_dir = Path(os.path.abspath("voices"))

    # Create voices directory if it doesn't exist
    if not voices_dir.exists():
        print(f"Creating voices directory at {voices_dir}")
        voices_dir.mkdir(exist_ok=True)
        return []

    # Get all .pt files in the voices directory
    voice_files = list(voices_dir.glob("*.pt"))

    # If we found voice files, return them
    if voice_files:
        return [f.stem for f in sorted(voice_files, key=lambda f: f.stem.lower())]

    # If no voice files in standard location, check if we need to do a one-time migration
    # This is legacy support for older installations
    alt_voices_path = Path(".") / "voices"
    if alt_voices_path.exists() and alt_voices_path.is_dir() and alt_voices_path != voices_dir:
        print(f"Checking alternative voice location: {alt_voices_path.absolute()}")
        alt_voice_files = list(alt_voices_path.glob("*.pt"))

        if alt_voice_files:
            print(f"Found {len(alt_voice_files)} voice files in alternate location")
            print("Moving files to the standard voices directory...")

            # Process files in a batch for efficiency
            files_moved = 0
            for voice_file in alt_voice_files:
                target_path = voices_dir / voice_file.name
                if not target_path.exists():
                    try:
                        # Use copy2 to preserve metadata, then remove original if successful
                        shutil.copy2(str(voice_file), str(target_path))
                        files_moved += 1
                    except (OSError, IOError) as e:
                        print(f"Error copying {voice_file.name}: {e}")

            if files_moved > 0:
                print(f"Successfully moved {files_moved} voice files")
                return [f.stem for f in sorted(voices_dir.glob("*.pt"), key=lambda f: f.stem.lower())]

    print("No voice files found. Please run the application again to download voices.")
    return []

def get_language_code_from_voice(voice_name: str) -> str:
    """Get the appropriate language code from a voice name

    Args:
        voice_name: Name of the voice (e.g., 'af_bella', 'jf_alpha')

    Returns:
        Language code for the voice
    """
    # Extract prefix from voice name
    prefix = voice_name[:2] if len(voice_name) >= 2 else 'af'

    # Map voice prefixes to language codes
    prefix_to_lang = {
        'af': 'a', 'am': 'a',  # American English
        'bf': 'b', 'bm': 'b',  # British English
        'jf': 'j', 'jm': 'j',  # Japanese
        'zf': 'z', 'zm': 'z',  # Mandarin Chinese
        'ef': 'e', 'em': 'e',  # Spanish
        'ff': 'f', 'fm': 'f',  # French
        'hf': 'h', 'hm': 'h',  # Hindi
        'if': 'i', 'im': 'i',  # Italian
        'pf': 'p', 'pm': 'p',  # Brazilian Portuguese
    }

    return prefix_to_lang.get(prefix, 'a')  # Default to American English

def load_voice(voice_name: str, device: str) -> torch.Tensor:
    """Load a voice model in a thread-safe manner

    Args:
        voice_name: Name of the voice to load (with or without .pt extension)
        device: Device to use ('cuda' or 'cpu')

    Returns:
        Loaded voice model tensor

    Raises:
        ValueError: If voice file not found or loading fails
    """
    pipeline = build_model(None, device)

    # Format voice path correctly - strip .pt if it was included
    voice_name = voice_name.replace('.pt', '')
    voice_path = os.path.abspath(os.path.join("voices", f"{voice_name}.pt"))

    if not os.path.exists(voice_path):
        raise ValueError(f"Voice file not found: {voice_path}")

    # Use a lock to ensure thread safety when loading voices
    with _pipeline_lock:
        # Check if voice is already loaded
        if hasattr(pipeline, 'voices') and voice_name in pipeline.voices:
            return pipeline.voices[voice_name]

        # Load voice if not already loaded
        return pipeline.load_voice(voice_path)

def generate_speech(
    model: KPipeline,
    text: str,
    voice: str,
    lang: str = 'a',
    device: str = 'cpu',
    speed: float = 1.0
) -> Tuple[Optional[torch.Tensor], Optional[str]]:
    """Generate speech using the Kokoro pipeline in a thread-safe manner

    Args:
        model: KPipeline instance
        text: Text to synthesize
        voice: Voice name (e.g. 'af_bella')
        lang: Language code ('a' for American English, 'b' for British English)
        device: Device to use ('cuda' or 'cpu')
        speed: Speech speed multiplier (default: 1.0)

    Returns:
        Tuple of (audio tensor, phonemes string) or (None, None) on error
    """
    global _pipeline_lock

    try:
        if model is None:
            raise ValueError("Model is None - pipeline not properly initialized")

        # Format voice name and path
        voice_name = voice.replace('.pt', '')
        voice_path = os.path.abspath(os.path.join("voices", f"{voice_name}.pt"))

        # Check if voice file exists
        if not os.path.exists(voice_path):
            raise ValueError(f"Voice file not found: {voice_path}")

        # Thread-safe initialization of model properties and voice loading
        with _pipeline_lock:
            # Initialize voices dictionary if it doesn't exist
            if not hasattr(model, 'voices'):
                model.voices = {}

            # Ensure device is set
            if not hasattr(model, 'device'):
                model.device = device

            # Ensure voice is loaded before generating
            if voice_name not in model.voices:
                print(f"Loading voice {voice_name}...")
                try:
                    model.load_voice(voice_path)
                    if voice_name not in model.voices:
                        raise ValueError("Voice load succeeded but voice not in model.voices dictionary")
                except Exception as e:
                    raise ValueError(f"Failed to load voice {voice_name}: {e}")

        # Generate speech (outside the lock for better concurrency)
        print(f"Generating speech with device: {model.device}")
        generator = model(
            text,
            voice=voice_path,
            speed=speed,
            split_pattern=r'\n+'
        )

        # Get first generated segment and convert numpy array to tensor if needed
        for gs, ps, audio in generator:
            if audio is not None:
                if isinstance(audio, np.ndarray):
                    audio = torch.from_numpy(audio).float()
                return audio, ps

        return None, None
    except (ValueError, FileNotFoundError, RuntimeError, KeyError, AttributeError, TypeError) as e:
        print(f"Error generating speech: {e}")
        return None, None
    except Exception as e:
        print(f"Unexpected error during speech generation: {e}")
        import traceback
        traceback.print_exc()
        return None, None

Phonemizer successfully initialized


In [59]:
def generate(model, text, voicepack, lang, speed=1.0):

    # G2P baseado na linguagem
    if lang in "ab":
        fallback = espeak.EspeakFallback(british=(lang == 'b'))
        g2p = en.G2P(trf=False, british=(lang == 'b'), fallback=fallback, unk='')
        _, tokens = g2p(text)
        ps = ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()

    elif lang == "j":
        g2p = ja.JAG2P()
        ps, _ = g2p(text)

    elif lang == "z":
        g2p = zh.ZHG2P()
        ps, _ = g2p(text)

    else:
        language = LANG_CODES[lang]
        g2p = espeak.EspeakG2P(language=language)
        ps, _ = g2p(text)

    # Limite de segurança
    if len(ps) > 510:
        logger.warning(f"Phoneme string muito longa ({len(ps)} > 510), truncando.")
        ps = ps[:510]

    # Chamada direta do KModel com return_output
    output = model(ps, voicepack[len(ps)-1], speed, return_output=True)
    audio = output.audio

    return audio, ps


In [82]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipeline = build_model('kokoro-v0_19.pth', device)
MODEL = pipeline.model

# Language is determined by the first letter of the VOICE_NAME:
# 🇺🇸 'a' => American English => en-us
# 🇬🇧 'b' => British English => en-gb

VOICE_NAME = [
    'af', # Default voice is a 50-50 mix of Bella & Sarah
    'af_bella', 'af_sarah', 'am_adam', 'am_michael',
    'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis',
    'af_nicole', 'af_sky',
][8]

VOICEPACK = torch.load(f'voices/{VOICE_NAME}.pt', weights_only=True).to(device)
print(f'Loaded voice: {VOICE_NAME}')

Loaded voice: bm_lewis


In [83]:
text = "How could I know? It's an unanswerable question. AGI is Like asking an unborn child if they'll lead a good life. They haven't even been born."

In [84]:
audio, phonemes = generate(MODEL, text, VOICEPACK, lang="p")

[32m2025-06-11 17:38:44.411[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mforward[0m:[36m127[0m - [34m[1mphonemes: ˈow kˈowld ˈi knˈow? ˈits ˈɐ̃ŋ ˌunɐ̃ŋsweɾˈably kˌesʧiˈoŋ. ˌaʒˌeˈi ˈiz lˈikj askˈiŋɡ ˈɐ̃ŋ ũŋbˈoɾən ʃˈiʊd ˈif tˈAl leˈad a ɡˈʊd lˈify. tˈA avˈAŋʧ ˈevAŋ bˈin bˈoɾən. -> input_ids: [156, 57, 65, 16, 53, 156, 57, 65, 54, 46, 16, 156, 51, 16, 53, 56, 156, 57, 65, 6, 16, 156, 51, 62, 61, 16, 156, 70, 17, 112, 16, 157, 63, 56, 70, 17, 112, 61, 65, 47, 125, 156, 43, 44, 54, 67, 16, 53, 157, 47, 61, 133, 51, 156, 57, 112, 4, 16, 157, 43, 147, 157, 47, 156, 51, 16, 156, 51, 68, 16, 54, 156, 51, 53, 52, 16, 43, 61, 53, 156, 51, 112, 92, 16, 156, 70, 17, 112, 16, 63, 17, 112, 44, 156, 57, 125, 83, 56, 16, 131, 156, 51, 135, 46, 16, 156, 51, 48, 16, 62, 156, 24, 54, 16, 54, 47, 156, 43, 46, 16, 43, 16, 92, 156, 135, 46, 16, 54, 156, 51, 48, 67, 4, 16, 62, 156, 24, 16, 43, 64, 156, 24, 112, 133, 16, 156, 47, 64, 24, 112, 16, 44, 156, 51, 56, 16, 44, 156, 57, 125, 83, 56, 4]

In [85]:
# 4️⃣ Display the 24khz audio and print the output phonemes
display(Audio(data=audio, rate=24000, autoplay=True))
print(phonemes)

ˈow kˈowld ˈi knˈow? ˈits ˈɐ̃ŋ ˌunɐ̃ŋsweɾˈably kˌesʧiˈoŋ. ˌaʒˌeˈi ˈiz lˈikj askˈiŋɡ ˈɐ̃ŋ ũŋbˈoɾən ʃˈiʊd ˈif tˈAl leˈad a ɡˈʊd lˈify. tˈA avˈAŋʧ ˈevAŋ bˈin bˈoɾən.


In [80]:
import numpy as np
from scipy.io.wavfile import write

# Supondo que audio ainda é torch.Tensor
# Primeiro: mova para CPU (caso esteja na GPU)
if audio.is_cuda:
    audio = audio.cpu()

# Converta para numpy
audio_np = audio.numpy()

# Normalize e converta para int16
audio_int16 = (audio_np * 32767).astype(np.int16)

# Salve o arquivo WAV
rate = 24000
write("output.wav", rate, audio_int16)

# Reproduza o áudio (no Colab, por exemplo)
from IPython.display import Audio
Audio(filename="output.wav", rate=rate, autoplay=True)

In [None]:
VOICEPACK.shape


# 3️⃣ Call generate, which returns 24khz audio and the phonemes used


torch.Size([511, 1, 256])

## Blending Voices

In [None]:
VOICEPACK_01 = torch.load(f'voices/bf_emma.pt', weights_only=True).to(device)
VOICEPACK_02 = torch.load(f'voices/bf_isabella.pt', weights_only=True).to(device)
VOICEPACK_03 = torch.load(f'voices/bm_lewis.pt', weights_only=True).to(device)

In [None]:
from kokoro import generate

audio, out_ps = generate(MODEL,
                         text,
                         VOICEPACK_01,
                         lang=VOICE_NAME[0])
display(Audio(data=audio, rate=24000, autoplay=True))
print(out_ps)

hˌaʊ kʊd aɪ nˈəʊ? ɪts ɐn ʌnˈansəɹəbəl kwˈɛstʃən. ˌeɪdʒˌiːˈaɪ ɪz lˈaɪk ˈaskɪŋ ɐn ʌnbˈɔːn tʃˈaɪld ɪf ðeɪl lˈiːd ɐ ɡˈʊd lˈaɪf. ðeɪ hˈavənt ˈiːvən bˌiːn bˈɔːn.


In [None]:
audio, out_ps = generate(MODEL,
                         text,
                         VOICEPACK_02,
                         lang=VOICE_NAME[0])
display(Audio(data=audio, rate=24000, autoplay=True))

In [None]:
audio, out_ps = generate(MODEL,
                         text,
                         VOICEPACK_03,
                         lang=VOICE_NAME[0])
display(Audio(data=audio, rate=24000, autoplay=True))

In [None]:
VOICEPACK_01.shape

torch.Size([511, 1, 256])

In [None]:
# Average the tensors
bf_average = (VOICEPACK_01 + VOICEPACK_02) / 2

In [None]:

audio, out_ps = generate(MODEL,
                         text,
                         bf_average,
                         lang=VOICE_NAME[0])

In [None]:
display(Audio(data=audio, rate=24000, autoplay=True))
print(out_ps)

hˌaʊ kʊd aɪ nˈəʊ? ɪts ɐn ʌnˈansəɹəbəl kwˈɛstʃən. ˌeɪdʒˌiːˈaɪ ɪz lˈaɪk ˈaskɪŋ ɐn ʌnbˈɔːn tʃˈaɪld ɪf ðeɪl lˈiːd ɐ ɡˈʊd lˈaɪf. ðeɪ hˈavənt ˈiːvən bˌiːn bˈɔːn.


In [None]:
VOICEPACK_01 = torch.load(f'voices/bf_emma.pt', weights_only=True).to(device)
VOICEPACK_02 = torch.load(f'voices/bf_isabella.pt', weights_only=True).to(device)
VOICEPACK_03 = torch.load(f'voices/bm_lewis.pt', weights_only=True).to(device)

# Weights for the two tensors
weight_a = 0.5
weight_b = 0.5

# Weighted average
weighted_voice = (weight_a * VOICEPACK_01 + weight_b * VOICEPACK_03) / (weight_a + weight_b)

In [None]:


audio, out_ps = generate(MODEL,
                         text,
                         weighted_voice,
                         lang=VOICE_NAME[0])
display(Audio(data=audio, rate=24000, autoplay=True))

In [None]:
t = 0.7  # interpolation factor, 0 <= t <= 1
interpolated_tensor = (1 - t) * VOICEPACK_01 + t * VOICEPACK_03


audio, out_ps = generate(MODEL,
                         text,
                         interpolated_tensor,
                         lang=VOICE_NAME[0])
display(Audio(data=audio, rate=24000, autoplay=True))

In [None]:
t = 0.5  # interpolation factor, 0 <= t <= 1
interpolated_tensor = (1 - t) * VOICEPACK_01 + t * VOICEPACK_02


audio, out_ps = generate(MODEL,
                         text,
                         interpolated_tensor,
                         lang=VOICE_NAME[0])
display(Audio(data=audio, rate=24000, autoplay=True))

In [None]:
import torch

def slerp(a, b, t):
    """
    Perform spherical interpolation (slerp) between two tensors a and b.

    Args:
        a (torch.Tensor): First tensor (assumed to be normalized).
        b (torch.Tensor): Second tensor (assumed to be normalized).
        t (float): Interpolation parameter in [0, 1].

    Returns:
        torch.Tensor: The interpolated tensor.
    """
    # Ensure shapes
    assert a.shape == b.shape, "Tensors must have the same shape"

    # Compute the dot product along the last dimension
    dot = torch.sum(a * b, dim=-1, keepdim=True)
    dot = torch.clamp(dot, -1.0, 1.0)  # Ensure it's within [-1, 1]

    # Compute the angle theta
    theta = torch.acos(dot)

    # Compute the sine of theta
    sin_theta = torch.sin(theta)

    # Small epsilon for checking near-zero sine values
    epsilon = 1e-6

    # If sin_theta is near zero, fallback to linear interpolation
    if sin_theta.max().item() < epsilon:
        return (1 - t) * a + t * b

    # Otherwise, use the spherical interpolation formula
    sin_t_theta = torch.sin(t * theta)
    sin_one_minus_t_theta = torch.sin((1 - t) * theta)

    interpolated = (sin_one_minus_t_theta / sin_theta) * a + (sin_t_theta / sin_theta) * b
    return interpolated


In [None]:
t = 0.5  # interpolation factor
interpolated_tensor = slerp(VOICEPACK_03,VOICEPACK_02,t)


audio, out_ps = generate(MODEL,
                         text,
                         interpolated_tensor,
                         lang=VOICE_NAME[0])
display(Audio(data=audio, rate=24000, autoplay=True))

In [None]:
interpolated_tensor.shape

torch.Size([511, 1, 256])