# Unit-to-Unit Translation Inference (EN → KO)

**Pipeline**: EN unit (text) → UTUT Translation → KO unit (predicted) → CodeHiFiGAN Vocoder → Waveform

**Ground Truth**: KO unit (text) & KO WAV for comparison

**Data**: `aihub_a2a_unit` (en/ko unit text) + `aihub_a2a_wav` (ko wav)

In [1]:
import os
import sys
import json
import subprocess
import numpy as np
import torch
import soundfile as sf
import IPython.display as ipd
from collections import defaultdict

# Project paths
ROOT_DIR = "/home/2022113135"
PROJECT_DIR = os.path.join(ROOT_DIR, "gyucheol/NetfLips/av2av-main")
JJS_DIR = os.path.join(ROOT_DIR, "jjs/av2av")

INFERENCE_SCRIPT = os.path.join(PROJECT_DIR, "inference_unit2a.py")

sys.path.insert(0, PROJECT_DIR)
sys.path.insert(0, JJS_DIR)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ============================================================
# Configuration
# ============================================================

# --- Data Paths ---
EN_WAV_DIR = os.path.join(ROOT_DIR, "datasets/aihub_a2a_wav/test/en")    # SOURCE
EN_UNIT_DIR = os.path.join(ROOT_DIR, "datasets/aihub_a2a_unit/test/en")
KO_UNIT_DIR = os.path.join(ROOT_DIR, "datasets/aihub_a2a_unit/test/ko")   # GT
KO_WAV_DIR  = os.path.join(ROOT_DIR, "datasets/aihub_a2a_wav/test/ko")    # GT

# --- UTUT (Unit-to-Unit Translation) ---
UTUT_CHECKPOINT = os.path.join(
    JJS_DIR, "unit2unit/utut_finetune/utut_additional_ckpt/unit_mbart_multilingual_ft/en_ko/checkpoint_best.pt"
)
SRC_LANG = "en"
TGT_LANG = "ko"

# --- Vocoder (CodeHiFiGAN) ---
VOCODER_CHECKPOINT = os.path.join(PROJECT_DIR, "unit2av/checkpoint/zeroth-hubert/g_00500000")
VOCODER_CONFIG     = os.path.join(PROJECT_DIR, "unit2av/checkpoint/zeroth-hubert/config.json")

# --- Speaker Encoder ---
SPEAKER_ENCODER_PATH = os.path.join(PROJECT_DIR, "unit2av/encoder.pt")

# --- Output ---
OUTPUT_DIR = os.path.join(PROJECT_DIR, "output/unit2unit_inference")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Inference settings ---
MAX_SAMPLES_PER_SUBJECT = 1   # max samples to display per subject
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {DEVICE}")
print(f"UTUT checkpoint: {UTUT_CHECKPOINT}")
print(f"Vocoder checkpoint: {VOCODER_CHECKPOINT}")

Device: cuda
UTUT checkpoint: /home/2022113135/jjs/av2av/unit2unit/utut_finetune/utut_additional_ckpt/unit_mbart_multilingual_ft/en_ko/checkpoint_best.pt
Vocoder checkpoint: /home/2022113135/gyucheol/NetfLips/av2av-main/unit2av/checkpoint/zeroth-hubert/g_00500000


## 1. Find All Subjects & Build File Triplets

In [3]:
# Scan EN unit dir and find matching KO unit + KO wav triplets
en_files = sorted([f for f in os.listdir(EN_UNIT_DIR) if f.endswith('.txt')])

triplets = []          # list of dicts
subject_map = defaultdict(list)  # subject -> list of triplet indices

for fname in en_files:
    base = fname[:-4]  # strip .txt
    en_wav_path = os.path.join(EN_WAV_DIR, base + "_en.wav")
    en_unit_path = os.path.join(EN_UNIT_DIR, fname)
    ko_unit_path = os.path.join(KO_UNIT_DIR, fname)           # same filename
    ko_wav_path  = os.path.join(KO_WAV_DIR, base + ".wav")   # .txt -> .wav

    if not os.path.exists(ko_unit_path):
        continue
    if not os.path.exists(ko_wav_path):
        continue

    # Extract subject: e.g. 'et_c_005' from 'et_c_005_002_009_0026'
    parts = base.split('_')
    subject = '_'.join(parts[:3])  # et_c_005, et_k_003, ...

    triplets.append({
        "id": base,
        "subject": subject,
        "en_wav_path": en_wav_path,
        "en_unit_path": en_unit_path,
        "ko_unit_path": ko_unit_path,
        "ko_wav_path": ko_wav_path,
    })
    subject_map[subject].append(len(triplets) - 1)

subjects = sorted(subject_map.keys())

print(f"Total triplets found: {len(triplets)}")
print(f"Total subjects: {len(subjects)}")
print(f"\nSubjects (count):")
for s in subjects:
    print(f"  {s}: {len(subject_map[s])} files")

Total triplets found: 9412
Total subjects: 32

Subjects (count):
  et_c_005: 27 files
  et_c_010: 4 files
  et_c_012: 58 files
  et_c_014: 106 files
  et_k_003: 3173 files
  et_k_004: 408 files
  et_m_008: 328 files
  et_m_010: 714 files
  iv_K_018: 11 files
  iv_k_018: 1939 files
  iv_k_019: 381 files
  iv_m_001: 712 files
  md_c_001: 30 files
  md_c_002: 29 files
  md_c_007: 46 files
  md_c_012: 24 files
  md_c_013: 22 files
  md_c_016: 30 files
  md_c_018: 44 files
  md_k_018: 52 files
  md_p_001: 343 files
  md_s_003: 46 files
  md_s_005: 49 files
  md_s_006: 48 files
  md_s_007: 35 files
  md_s_008: 32 files
  md_s_010: 80 files
  md_s_011: 40 files
  md_s_014: 30 files
  md_s_016: 42 files
  md_t_001: 251 files
  md_t_002: 278 files


In [4]:
# Select a subset: up to MAX_SAMPLES_PER_SUBJECT per subject
selected_triplets = []
for s in subjects:
    indices = subject_map[s][:MAX_SAMPLES_PER_SUBJECT]
    for idx in indices:
        selected_triplets.append(triplets[idx])

print(f"Selected {len(selected_triplets)} samples for inference ({MAX_SAMPLES_PER_SUBJECT} per subject)")

Selected 32 samples for inference (1 per subject)


## 2. Load Models

- **UTUT**: loaded in-process (fairseq)
- **Speaker Encoder**: loaded in-process (to extract speaker embedding → `.pt`)
- **Vocoder**: called via `inference_unit2a.py` subprocess

In [5]:
# --- 2-1. Load UTUT Translation Model ---
from unit2unit.inference import load_model as load_utut_model
from util import process_units
from fairseq import utils
from fairseq_cli.generate import get_symbols_to_strip_from_output

print("Loading UTUT translation model...")
utut_task, utut_generator = load_utut_model(
    UTUT_CHECKPOINT, SRC_LANG, TGT_LANG, use_cuda=(DEVICE == "cuda")
)
print("UTUT model loaded.")

Loading UTUT translation model...


2026-02-04 06:51:25 | INFO | fairseq.tasks.translation | [en] dictionary: 1004 types
2026-02-04 06:51:25 | INFO | fairseq.tasks.translation | [ko] dictionary: 1004 types


UTUT model loaded.


In [6]:
# # --- 2-2. Load CodeHiFiGAN Vocoder ---
# from unit2av.model import CodeHiFiGANModel_spk
# from unit2av.utils import AttrDict

# print(f"Loading vocoder config from {VOCODER_CONFIG}...")
# with open(VOCODER_CONFIG) as f:
#     h = AttrDict(json.loads(f.read()))

# print("Initializing CodeHiFiGAN vocoder...")
# vocoder = CodeHiFiGANModel_spk(dict(h)).to(DEVICE)

# state_dict = torch.load(VOCODER_CHECKPOINT, map_location=DEVICE)
# if 'generator' in state_dict:
#     vocoder.load_state_dict(state_dict['generator'])
# else:
#     vocoder.load_state_dict(state_dict)

# vocoder.eval()
# vocoder.remove_weight_norm()
# print("Vocoder loaded.")

In [7]:
# --- 2-3. Load Speaker Encoder ---
from unit2av.model_speaker_encoder import SpeakerEncoder
SPEAKER_ENCODER_PATH="/home/2022113135/gyucheol/NetfLips/av2av-main/unit2av/encoder.pt"

print("Loading speaker encoder...")
speaker_encoder = SpeakerEncoder(SPEAKER_ENCODER_PATH)
if DEVICE == "cuda":
    speaker_encoder = speaker_encoder.cuda()
print("Speaker encoder loaded.")

  warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")


Loading speaker encoder...
Speaker encoder loaded.


## 3. Inference Functions

In [8]:
def read_unit_text(path):
    """Read a unit text file and return list of int."""
    with open(path) as f:
        units = list(map(int, f.readline().strip().split()))
    return units


def translate_units(en_units, utut_task, utut_generator, use_cuda=True):
    """
    Unit-to-Unit Translation: EN units -> KO units.
    Returns predicted unit string and list of ints.
    """
    # Reduce consecutive duplicates and encode
    reduced = process_units(en_units, reduce=True)
    unit_tensor = utut_task.source_dictionary.encode_line(
        " ".join(map(str, reduced)),
        add_if_not_exist=False,
        append_eos=True,
    ).long()

    # Prepend BOS, append source language tag
    unit_tensor = torch.cat([
        unit_tensor.new([utut_task.source_dictionary.bos()]),
        unit_tensor,
        unit_tensor.new([utut_task.source_dictionary.index(f"[{SRC_LANG}]")])
    ])

    sample = {"net_input": {
        "src_tokens": unit_tensor.view(1, -1),
    }}
    if use_cuda:
        sample = utils.move_to_cuda(sample)

    # Run translation
    pred = utut_task.inference_step(
        utut_generator,
        None,
        sample,
    )[0][0]

    # Decode predicted tokens to unit string
    pred_str = utut_task.target_dictionary.string(
        pred["tokens"].int().cpu(),
        extra_symbols_to_ignore=get_symbols_to_strip_from_output(utut_generator)
    )

    # Convert to list of int
    pred_units = [int(x) for x in pred_str.strip().split() if x.isdigit()]
    return pred_units, pred_str


def save_input_pt(units, ko_wav_path, speaker_encoder, output_pt_path):
    """
    Save predicted units + speaker embedding as .pt file
    for inference_unit2a.py consumption.
    """
    code = torch.LongTensor(units)
    spkr_embed = speaker_encoder.get_embed(ko_wav_path)
    # Before (2D — causes IndexError on transpose(1,2)):
    # spkr_tensor = torch.from_numpy(spkr_embed).float().view(1, -1)     # (1, 256)

    # After (3D — matches what the model expects):
    spkr_tensor = torch.from_numpy(spkr_embed).float().view(1, 1, -1)  # (1, 1, 256)


    torch.save({"code": code, "spkr": spkr_tensor}, output_pt_path)


def run_vocoder(input_pt_path, output_dir, device="cuda"):
    """
    Call inference_unit2a.py via subprocess to synthesize waveform.
    Returns the generated wav path.
    """
    command = [
        "python", INFERENCE_SCRIPT,
        "--checkpoint", VOCODER_CHECKPOINT,
        "--config", VOCODER_CONFIG,
        "--input_file", input_pt_path,
        "--output_folder", output_dir,
        "--device", device,
    ]
    result = subprocess.run(command, capture_output=True, text=True, cwd=PROJECT_DIR)
    if result.returncode != 0:
        print(f"  [ERROR] {result.stderr}")
    return result

## 4. Run Inference: Unit-to-Unit Translation + Waveform Synthesis

In [None]:
results = []
use_cuda = (DEVICE == "cuda")

# Temp dir for .pt files passed to inference_unit2a.py
PT_DIR = os.path.join(OUTPUT_DIR, "pt_inputs")
os.makedirs(PT_DIR, exist_ok=True)

for i, triplet in enumerate(selected_triplets):
    sample_id = triplet["id"]
    print(f"[{i+1}/{len(selected_triplets)}] {sample_id}")

    # 1) Read EN units
    en_units = read_unit_text(triplet["en_unit_path"])

    # 2) Read GT KO units
    gt_ko_units = read_unit_text(triplet["ko_unit_path"])

    # 3) UTUT Translation: EN -> KO (predicted)
    pred_ko_units, pred_ko_str = translate_units(
        en_units, utut_task, utut_generator, use_cuda=use_cuda
    )

    # 4) Save predicted units + speaker embedding as .pt
    #    inference_unit2a.py strips last 13 chars ("_preprocessed") from basename
    pt_path = os.path.join(PT_DIR, f"{sample_id}_preprocessed.pt")
    save_input_pt(pred_ko_units, triplet["ko_wav_path"], speaker_encoder, pt_path)

    # 5) Run vocoder via inference_unit2a.py subprocess
    run_vocoder(pt_path, OUTPUT_DIR, device=DEVICE)

    # 6) Locate generated wav (inference_unit2a.py naming: {base}_{step}step.wav)
    output_wav_path = os.path.join(OUTPUT_DIR, f"{sample_id}_500000step.wav")

    # 7) Save predicted unit text
    output_unit_path = os.path.join(OUTPUT_DIR, f"{sample_id}_pred_unit.txt")
    with open(output_unit_path, 'w') as f:
        f.write(' '.join(map(str, pred_ko_units)))

    results.append({
        **triplet,
        "en_units": en_units,
        "gt_ko_units": gt_ko_units,
        "pred_ko_units": pred_ko_units,
        "synth_wav_path": output_wav_path,
    })

print(f"\nDone. {len(results)} samples processed.")
print(f"Output directory: {OUTPUT_DIR}")

[1/32] et_c_005_002_009_0026
[2/32] et_c_010_002_004_0014


## 5. Display Results

In [None]:
current_subject = None

print("-" * 80)
print("Unit-to-Unit Translation Inference Results (EN -> KO)")
print("-" * 80)

for r in results:
    # Subject header
    if r["subject"] != current_subject:
        current_subject = r["subject"]
        print("\n" + "#" * 80)
        print(f"  SUBJECT: {current_subject}")
        print("#" * 80 + "\n")

    print("=" * 70)
    print(f"  File ID: {r['id']}")
    print(f"  EN units (first 20):  {r['en_units'][:20]} ...")
    print(f"\n  GT KO units (first 20): {r['gt_ko_units'][:20]} ...")
    print(f"\n  Pred KO units (first 20): {r['pred_ko_units'][:20]} ...")
    print(f"  EN unit length: {len(r['en_units'])} | GT KO: {len(r['gt_ko_units'])} | Pred KO: {len(r['pred_ko_units'])}")
    print()

    # Source EN audio
    print(f"  [Src] English Audio: {os.path.basename(r['en_wav_path'])}")
    ipd.display(ipd.Audio(filename=r['en_wav_path']))


    # GT KO audio
    print(f"  [GT] Korean Audio: {os.path.basename(r['ko_wav_path'])}")
    ipd.display(ipd.Audio(r["ko_wav_path"], rate=16000))

    # Synthesized audio
    synth_path = r["synth_wav_path"]
    if os.path.exists(synth_path):
        print(f"  [Pred] Synthesized Audio: {os.path.basename(synth_path)}")
        ipd.display(ipd.Audio(synth_path, rate=16000))
    else:
        print(f"  [Pred] Generated file not found at: {synth_path}")

    print()

--------------------------------------------------------------------------------
Unit-to-Unit Translation Inference Results (EN -> KO)
--------------------------------------------------------------------------------

################################################################################
  SUBJECT: et_c_005
################################################################################

  File ID: et_c_005_002_009_0026
  EN units (first 20):  [501, 501, 501, 501, 501, 501, 991, 991, 501, 501, 501, 501, 501, 501, 501, 501, 501, 501, 501, 501] ...

  GT KO units (first 20): [43, 843, 474, 825, 825, 825, 825, 825, 681, 359, 874, 822, 255, 416, 565, 565, 565, 565, 217, 217] ...

  Pred KO units (first 20): [501, 501, 501, 501, 501, 501, 501, 501, 501, 501, 501, 501, 501, 501, 501, 501, 501, 501, 501, 501] ...
  EN unit length: 334 | GT KO: 140 | Pred KO: 199

  [Src] English Audio: et_c_005_002_009_0026.wav


FileNotFoundError: [Errno 2] No such file or directory: '/home/2022113135/datasets/aihub_a2a_wav/test/en/et_c_005_002_009_0026.wav'

In [None]:
ipd.display(ipd.Audio('/home/2022113135/datasets/aihub_a2a_wav/test/en/et_c_005_002_009_0026_en.wav'))