In [None]:
# 事前準備　はじめに実行

In [1]:
import os
import re
import time
import json
import torch
import commons
import utils
from models import SynthesizerTrn
from text_JP import cleaned_text_to_sequence, symbols
import pyopenjtalk
from text_JP.phonemize import Phonemizer

# This script is designed to be run in a Jupyter Notebook or an environment
# with IPython display capabilities.
from IPython.display import Audio, display

  from pkg_resources import resource_filename


# 1 全体デコード 

In [3]:
# --- 1. Configuration (EDIT THESE PATHS) ---
# ===========================================
# Path to the config file of the trained multi-speaker model
config_path = "./logs/uudb_csj21/config.json" 

# Path to the generator checkpoint of the trained model
# Find the latest G_****.pth file in your model directory
checkpoint_path = "./logs/uudb_csj21/G_3020000.pth" # <-- IMPORTANT: UPDATE THIS PATH

# Text to be synthesized
text_to_synthesize = "最近、インターステラーを見たのですけど、すごく面白かったです。"
text_to_synthesize = "[あ]ちゃんと入ってないんだ 、 [あー]"

# ===========================================


# --- 2. Text Pre-processing Functions ---
def japanese_cleaner_revised(text):
    parts = re.split(r'({cough}|<cough>|\[.*?\]|[、。])', text)
    phoneme_parts = []
    phonemizer = Phonemizer()
    for part in parts:
        if not part or part.isspace():
            continue
        if part.startswith('[') and part.endswith(']') and len(part) > 2:
            content = part[1:-1]
            if not content:
                phoneme_parts.append('[ ]')
            else:
                kana_content = pyopenjtalk.g2p(content, kana=True).replace('ヲ', 'オ')
                phoneme_content = phonemizer(kana_content)
                phoneme_parts.append(f'[ {phoneme_content} ]')
            continue
        if part == '{cough}' or part == '<cough>':
            phoneme_parts.append('<cough>')
            continue
        if part in '、。':
            phoneme_parts.append('sp')
            continue
        kana = pyopenjtalk.g2p(part, kana=True).replace('ヲ', 'オ')
        phonemes = phonemizer(kana)
        phoneme_parts.append(phonemes)
    final_text = ' '.join(phoneme_parts)
    return re.sub(r'\s+', ' ', final_text).strip()

def text_to_sequence_custom(text, hps):
    phonemized_text = japanese_cleaner_revised(text)
    stn_tst = cleaned_text_to_sequence(phonemized_text)
    if hps.data.add_blank:
        stn_tst = commons.intersperse(stn_tst, 0)
    return torch.LongTensor(stn_tst)


# --- 3. Main Synthesis Process ---
if not os.path.exists(config_path):
    print(f"ERROR: Config file not found at {config_path}")
elif not os.path.exists(checkpoint_path):
    print(f"ERROR: Checkpoint file not found at {checkpoint_path}")
    print("Please update the 'checkpoint_path' variable in this script to point to your trained model.")
else:
    # Load configuration
    hps = utils.get_hparams_from_file(config_path)

    # Determine device
    #device = "cuda" if torch.cuda.is_available() else "cpu"
    device = "cpu"

    # Load model
    print("Loading model...")
    net_g = SynthesizerTrn(
        len(symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        n_speakers=hps.data.n_speakers,
        **hps.model).to(device)
    
    # Set model to evaluation mode
    _ = net_g.eval()
    
    # Load checkpoint
    print(f"Loading checkpoint from {checkpoint_path}...")
    _ = utils.load_checkpoint(checkpoint_path, net_g, None)

    # Process text
    print(f"Original text: {text_to_synthesize}")

    stn_tst = text_to_sequence_custom(text_to_synthesize, hps)

    i=375
    # Synthesize for each speaker
    print(f"\n--- Synthesizing for Speaker {i} ---")
    start_time = time.time()
        
    with torch.no_grad():
        x_tst = stn_tst.unsqueeze(0).to(device)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
        sid = torch.LongTensor([i]).to(device)
            
        # Inference
        audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, 
                            noise_scale=0.1, 
                            noise_scale_w=1.0, 
                            length_scale=1.0
                            )[0][0,0].data.cpu().float().numpy()

    end_time = time.time()
    elapsed_time = end_time - start_time
    audio_duration = len(audio) / hps.data.sampling_rate
    rtf = elapsed_time / audio_duration

    print(f"Audio duration: {audio_duration:.2f} seconds")
    print(f"Elapsed time: {elapsed_time:.2f} seconds")
    print(f"Real Time Factor (RTF): {rtf:.4f}")
    display(Audio(audio, rate=hps.data.sampling_rate, normalize=False))

    print("\nSynthesis complete.")



Loading model...
Mutli-stream iSTFT VITS
Loading checkpoint from ./logs/uudb_csj21/G_3020000.pth...
Original text: [あ]ちゃんと入ってないんだ 、 [あー]

--- Synthesizing for Speaker 375 ---
Audio duration: 1.81 seconds
Elapsed time: 0.10 seconds
Real Time Factor (RTF): 0.0542



Synthesis complete.


# 5 細切れデコード　スペクトログラムをoverlapで接合　相互相関で調整はしない

# 6 細切れデコード　スペクトログラムを接合　overlapあり　相互相関で調整あり

In [14]:
import os
import time
import torch
import torch.nn.functional as F
import numpy as np
import commons
import utils
from models import SynthesizerTrn
from text_JP import symbols
from scipy.io.wavfile import write

# ==========================================
# 1. 設定
# ==========================================
config_path = "./logs/uudb_csj21/config.json"
checkpoint_path = "./logs/uudb_csj21/G_3020000.pth"
input_txt_path = "./filelists/csj_uudb_test_fine.txt"
output_dir = "output_wavs_batch"

# 生成パラメータ
noise_scale = 0.1
noise_scale_w = 1.0
length_scale = 1.0

# デバイス設定
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# ==========================================
# 2. 共通クラス・関数定義
# ==========================================

class TorchSTFT(torch.nn.Module):
    def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
        super().__init__()
        self.filter_length = filter_length
        self.hop_length = hop_length
        self.win_length = win_length
        # windowのデバイス転送はinverse内で行うためここでは作成のみ
        self.window = torch.hann_window(win_length, periodic=True)

    def inverse(self, magnitude, phase):
        complex_spec = magnitude * torch.exp(phase * 1j)
        # istftは (..., Freq, Time) を期待するため、Batch次元等がある場合は維持される
        inverse_transform = torch.istft(
            complex_spec,
            self.filter_length, self.hop_length, self.win_length, 
            window=self.window.to(complex_spec.device)
        )
        return inverse_transform.unsqueeze(1)

def find_best_frame_shift(ref_spec, target_spec, search_range=5):
    # 4次元(B, S, F, T)対応: 全サブバンド・周波数をまとめて相関をとる
    if ref_spec.dim() == 4:
        b, s, f, t = ref_spec.shape
        ref_spec = ref_spec.reshape(b, s * f, t).contiguous()
        target_spec = target_spec.reshape(b, s * f, t).contiguous()
    
    ref_log = torch.log(ref_spec + 1e-6)
    target_log = torch.log(target_spec + 1e-6)
    ref_log = ref_log - torch.mean(ref_log, dim=-1, keepdim=True)
    target_log = target_log - torch.mean(target_log, dim=-1, keepdim=True)

    pad_target = F.pad(target_log, (search_range, search_range))
    cross_corr = F.conv1d(pad_target, ref_log)
    max_idx = torch.argmax(cross_corr)
    return max_idx.item() - search_range

def get_text_from_phonemes(phonemes, hps):
    symbol_to_id = {s: i for i, s in enumerate(symbols)}
    clean_phonemes = phonemes.replace("[", "").replace("]", "").strip()
    phoneme_list = clean_phonemes.split(" ")
    text_norm = []
    for p in phoneme_list:
        if p in symbol_to_id:
            text_norm.append(symbol_to_id[p])
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    return torch.LongTensor(text_norm)

def istft_finalize(net_g, full_complex_spec):
    """
    Multi-stream iSTFTに対応した波形再構成関数
    full_complex_spec shape: [Batch, Subbands, Freq, Time]
    """
    final_spec = torch.abs(full_complex_spec)
    final_phase = torch.angle(full_complex_spec)
    
    stft = TorchSTFT(
        filter_length=net_g.dec.gen_istft_n_fft, 
        hop_length=net_g.dec.gen_istft_hop_size, 
        win_length=net_g.dec.gen_istft_n_fft
    ).to(device)

    # Multi-stream処理の判定
    if hasattr(net_g.dec, 'subbands') and net_g.dec.subbands > 1:
        # [B, S, F, T] -> [B*S, F, T] に変形してiSTFT
        b, s, f, t = final_spec.shape
        spec_reshaped = final_spec.view(b * s, f, t)
        phase_reshaped = final_phase.view(b * s, f, t)
        
        y_mb_hat = stft.inverse(spec_reshaped, phase_reshaped) # -> [B*S, 1, Time_sub]
        y_mb_hat = y_mb_hat.squeeze(1).view(b, s, -1)          # -> [B, S, Time_sub]

        # 合成フィルタ (Synthesis Filter Bank)
        if net_g.ms_istft_vits:
            # 学習済みアップサンプリングフィルタを使用
            y_mb_hat = F.conv_transpose1d(
                y_mb_hat, 
                net_g.dec.updown_filter.to(device) * net_g.dec.subbands, 
                stride=net_g.dec.subbands
            )
            audio_tensor = net_g.dec.multistream_conv_post(y_mb_hat)
        else:
            # PQMFまたは単純加算 (Fallback)
            try:
                from pqmf import PQMF
                pqmf = PQMF(device)
                audio_tensor = pqmf.synthesis(y_mb_hat.unsqueeze(2)) 
            except ImportError:
                 audio_tensor = torch.sum(y_mb_hat, dim=1, keepdim=True)
    else:
        # 通常のiSTFT (Single stream)
        audio_tensor = stft.inverse(final_spec, final_phase)

    return audio_tensor[0, 0].data.cpu().float().numpy()


# ==========================================
# 3. 各合成条件の関数
# ==========================================

# (1) Normal
def synthesize_cond1(net_g, x_tst, x_tst_lengths, sid):
    # inferメソッドは内部で適切にMulti-stream処理を行うためそのまま使用
    audio = net_g.infer(
        x_tst, x_tst_lengths, sid=sid, 
        noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale
    )[0][0,0].data.cpu().float().numpy()
    return audio

# (2) Audio Chunk
def synthesize_cond2(net_g, x_tst, x_tst_lengths, sid, chunk_size=10):
    attn, y_mask, (z, z_p, m_p, logs_p), timings = net_g.infer_z_only(
        x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale
    )
    z = z * y_mask
    g = net_g.emb_g(sid).unsqueeze(-1) if net_g.n_speakers > 0 else None
    
    full_audio_chunks = []
    z_len = z.shape[2]
    # 出力波形の結合が必要
    for step in range(0, z_len, chunk_size):
        z_chunk = z[:, :, step:min(step+chunk_size, z_len)]
        o_chunk, _, _, _ = net_g.dec(z_chunk, g=g)
        full_audio_chunks.append(o_chunk)
    
    return torch.cat(full_audio_chunks, dim=2)[0, 0].data.cpu().float().numpy()

# (3) Spec Fixed Ratio OLA
def synthesize_cond3(net_g, x_tst, x_tst_lengths, sid, z_chunk_size=10, z_hop_size=5):
    attn, y_mask, (z, z_p, m_p, logs_p), timings = net_g.infer_z_only(
        x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale
    )
    z = z * y_mask
    g = net_g.emb_g(sid).unsqueeze(-1) if net_g.n_speakers > 0 else None
    
    z_total = z.shape[2]
    full_spec = None
    ratio = None
    
    for idx in range(0, z_total, z_hop_size):
        z_chunk = z[:, :, idx:min(idx+z_chunk_size, z_total)]
        _, _, spec, phase = net_g.dec(z_chunk, g=g)
        comp_chunk = spec * torch.exp(1j * phase)
        
        if ratio is None: 
            ratio = comp_chunk.shape[-1] / z_chunk.shape[-1] if z_chunk.shape[-1] > 0 else 1.0
        
        if full_spec is None:
            full_spec = comp_chunk
        else:
            hop = int(z_hop_size * ratio)
            overlap = comp_chunk.shape[-1] - hop
            if overlap <= 0:
                full_spec = torch.cat([full_spec, comp_chunk], dim=-1)
            else:
                prev, curr = full_spec[..., -overlap:], comp_chunk[..., :overlap]
                ov_len = min(prev.shape[-1], curr.shape[-1])
                if ov_len > 0:
                    alpha = torch.linspace(0.0, 1.0, ov_len).to(device).view(1, 1, 1, ov_len)
                    merged = prev[..., :ov_len]*(1-alpha) + curr[..., :ov_len]*alpha
                    full_spec = torch.cat([full_spec[..., :-ov_len], merged, comp_chunk[..., ov_len:]], dim=-1)
                else:
                    full_spec = torch.cat([full_spec, comp_chunk], dim=-1)
        if idx + z_chunk_size >= z_total: break
            
    return istft_finalize(net_g, full_spec)

# (4) Spec Corrected Ratio OLA
def synthesize_cond4(net_g, x_tst, x_tst_lengths, sid, z_chunk_size=10, z_hop_size=5, search_range=2):
    attn, y_mask, (z, z_p, m_p, logs_p), timings = net_g.infer_z_only(
        x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale
    )
    z = z * y_mask
    g = net_g.emb_g(sid).unsqueeze(-1) if net_g.n_speakers > 0 else None
    
    z_total = z.shape[2]
    full_spec = None
    ratio = None
    
    for idx in range(0, z_total, z_hop_size):
        z_chunk = z[:, :, idx:min(idx+z_chunk_size, z_total)]
        _, _, spec, phase = net_g.dec(z_chunk, g=g)
        comp_chunk = spec * torch.exp(1j * phase)
        
        if ratio is None: 
            ratio = comp_chunk.shape[-1] / z_chunk.shape[-1] if z_chunk.shape[-1] > 0 else 1.0
        
        if full_spec is None:
            full_spec = comp_chunk
        else:
            hop = int(z_hop_size * ratio)
            overlap = comp_chunk.shape[-1] - hop
            if overlap <= 0:
                full_spec = torch.cat([full_spec, comp_chunk], dim=-1)
                continue
            
            prev = full_spec[..., -overlap:]
            curr = comp_chunk[..., :overlap]
            
            shift = 0
            # オーバーラップサイズが十分にある場合のみ位置合わせ
            if prev.shape[-1] == overlap and curr.shape[-1] == overlap:
                shift = find_best_frame_shift(torch.abs(prev), torch.abs(curr), search_range)
            
            start_off = max(0, min(-shift, search_range*2))
            aligned = comp_chunk[..., start_off:]
            cross_len = min(overlap, aligned.shape[-1])
            
            if cross_len > 0:
                alpha = torch.linspace(0.0, 1.0, cross_len).to(device).view(1, 1, 1, cross_len)
                merged = full_spec[..., -cross_len:]*(1-alpha) + aligned[..., :cross_len]*alpha
                full_spec = torch.cat([full_spec[..., :-cross_len], merged, aligned[..., cross_len:]], dim=-1)
            else:
                full_spec = torch.cat([full_spec, aligned], dim=-1)
        if idx + z_chunk_size >= z_total: break

    return istft_finalize(net_g, full_spec)


# ==========================================
# 4. メインループ
# ==========================================
os.makedirs(output_dir, exist_ok=True)
print(f"Output directory created: {output_dir}")

hps = utils.get_hparams_from_file(config_path)
net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    n_speakers=hps.data.n_speakers,
    **hps.model).to(device)
_ = net_g.eval()
utils.load_checkpoint(checkpoint_path, net_g, None)
print("Model loaded.")

with open(input_txt_path, "r", encoding="utf-8") as f:
    lines = f.readlines()

print(f"Start processing {len(lines)} lines...")

for i, line in enumerate(lines):
    line = line.strip()
    if not line: continue
    parts = line.split("|")
    if len(parts) < 3: continue
    
    original_fname = os.path.basename(parts[0]).replace(".wav", "")
    spk_id = int(parts[1])
    phonemes = parts[2]
    
    stn_tst = get_text_from_phonemes(phonemes, hps)
    with torch.no_grad():
        x_tst = stn_tst.to(device).unsqueeze(0)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
        sid = torch.LongTensor([spk_id]).to(device)

        try:
            audio = synthesize_cond1(net_g, x_tst, x_tst_lengths, sid)
            write(os.path.join(output_dir, f"{original_fname}_cond1.wav"), hps.data.sampling_rate, audio)
            
            audio = synthesize_cond2(net_g, x_tst, x_tst_lengths, sid)
            write(os.path.join(output_dir, f"{original_fname}_cond2.wav"), hps.data.sampling_rate, audio)
            
            audio = synthesize_cond3(net_g, x_tst, x_tst_lengths, sid)
            write(os.path.join(output_dir, f"{original_fname}_cond3.wav"), hps.data.sampling_rate, audio)
            
            audio = synthesize_cond4(net_g, x_tst, x_tst_lengths, sid)
            write(os.path.join(output_dir, f"{original_fname}_cond4.wav"), hps.data.sampling_rate, audio)
            
            print(f"[{i+1}/{len(lines)}] Saved: {original_fname}")
        except Exception as e:
            print(f"Error processing {original_fname}: {e}")
            import traceback
            traceback.print_exc()

print("All tasks finished.")

Using device: cuda
Output directory created: output_wavs_batch
Mutli-stream iSTFT VITS
Model loaded.
Start processing 16 lines...
[1/16] Saved: FJK_C051_118
[2/16] Saved: FJK_C051_170
[3/16] Saved: FKC_C031_002
[4/16] Saved: FMS_C051_072
[5/16] Saved: FMT_C041_134
[6/16] Saved: FMT_C041_259
[7/16] Saved: FTH_C004_044
[8/16] Saved: FTH_C005_152
[9/16] Saved: FTS_C002_107
[10/16] Saved: FTS_C002_175
[11/16] Saved: FTS_C004_126
[12/16] Saved: FTS_C006_050
[13/16] Saved: FTS_C007_137
[14/16] Saved: FUE_C033_134
[15/16] Saved: FYH_C042_090
[16/16] Saved: FYH_C043_064
All tasks finished.


In [None]:
def synthesize_cond5(net_g, x_tst, x_tst_lengths, sid, z_chunk_size=15, z_hop_size=5, search_range=2):
    # 1. 潜在表現 z の生成（共通処理）
    attn, y_mask, (z, z_p, m_p, logs_p), timings = net_g.infer_z_only(
        x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale
    )
    z = z * y_mask
    g = net_g.emb_g(sid).unsqueeze(-1) if net_g.n_speakers > 0 else None
    
    z_total = z.shape[2]
    full_spec = None
    ratio = None
    
    # 前回の「位置合わせ済み・クロスフェード未適用」のRaw末尾を保持する変数
    prev_raw_tail = None
    
    # z領域でのオーバーラップ長 (例: 15 - 5 = 10)
    z_overlap_len = z_chunk_size - z_hop_size
    
    for idx in range(0, z_total, z_hop_size):
        # zの切り出し
        z_end = min(idx + z_chunk_size, z_total)
        z_chunk = z[:, :, idx:z_end]
        
        # デコード -> スペクトログラム
        _, _, spec, phase = net_g.dec(z_chunk, g=g)
        comp_chunk = spec * torch.exp(1j * phase)
        
        # 倍率推定 (初回のみ、または毎回更新)
        # VITSの構造上、z 1フレームあたりのスペクトログラムフレーム数はほぼ一定です
        if ratio is None: 
            ratio = comp_chunk.shape[-1] / z_chunk.shape[-1] if z_chunk.shape[-1] > 0 else 1.0
        
        # スペクトログラム上でのオーバーラップ長を計算
        # z_overlap_len に対応する長さを算出
        spec_overlap_len = int(z_overlap_len * ratio)
        
        if full_spec is None:
            full_spec = comp_chunk
            
            # 初回終了時：次回の比較用に「今回の末尾」を保存
            # 保存する長さは、次回のオーバーラップに必要な長さ分
            if comp_chunk.shape[-1] >= spec_overlap_len:
                prev_raw_tail = comp_chunk[..., -spec_overlap_len:]
            else:
                prev_raw_tail = comp_chunk
        else:
            # --- 相互相関による位置合わせ ---
            
            # 1. 比較用データの準備
            # Reference: 前回のセグメントの末尾 (z_overlap相当分)
            # 結合済み(full_spec)から取るのではなく、保存しておいたRawデータ(prev_raw_tail)を使用
            prev_ref = prev_raw_tail
            
            # Target: 今回のセグメントの先頭 (z_overlap相当分)
            curr_ref = comp_chunk[..., :spec_overlap_len]
            
            shift = 0
            # サイズが十分にある場合のみ相関計算
            if prev_ref is not None and prev_ref.shape[-1] == spec_overlap_len and curr_ref.shape[-1] == spec_overlap_len:
                # 相互相関でズレを検出
                shift = find_best_frame_shift(torch.abs(prev_ref), torch.abs(curr_ref), search_range)
            
            # 2. シフト適用 (位置合わせ)
            # shift > 0: 今回の波形を遅らせる(左を削る)
            # shift < 0: 今回の波形を早める(右にずらす=左にpaddingだが、ここでは0クリップ)
            start_off = max(0, min(-shift, search_range * 2))
            aligned = comp_chunk[..., start_off:]
            
            # 3. クロスフェード結合
            # 実際に重なる長さ (計算上のオーバーラップ長と、alignedの長さの小さい方)
            cross_len = min(spec_overlap_len, aligned.shape[-1])
            
            if cross_len > 0:
                # クロスフェード係数
                alpha = torch.linspace(0.0, 1.0, cross_len).to(device).view(1, 1, 1, cross_len)
                
                # full_specの末尾(Mix済み) と alignedの先頭(Raw) を混ぜる
                merged = full_spec[..., -cross_len:] * (1 - alpha) + aligned[..., :cross_len] * alpha
                
                full_spec = torch.cat([full_spec[..., :-cross_len], merged, aligned[..., cross_len:]], dim=-1)
            else:
                # 重ならない場合は単純結合
                full_spec = torch.cat([full_spec, aligned], dim=-1)

            # --- 次回のために末尾(Raw)を保存 ---
            # 今回「位置合わせ(Shift)済み」かつ「クロスフェードされていない」部分の末尾を取得
            # ここが次回の Reference になります
            if aligned.shape[-1] >= spec_overlap_len:
                prev_raw_tail = aligned[..., -spec_overlap_len:]
            else:
                prev_raw_tail = aligned

        if idx + z_chunk_size >= z_total: break

    return istft_finalize(net_g, full_spec)