In [1]:
import os
import csv
import json

import numpy as np
# import matplotlib.pyplot as plt

# import sacrebleu
import soundfile as sf

import copy
import yaml
import math
import torch
from tqdm.notebook import tqdm

from IPython.display import display, Audio

def read_logs(path):
    logs = []
    with open(path, "r") as r:
        for l in r.readlines():
            l = l.strip()
            if l != "":
                logs.append(json.loads(l))
    return logs

def write_logs(logs, path):
    with open(path, "w") as w:
        for log in logs:
            w.write(json.dumps(log) + "\n")

def read_wav(wav_path):
    if ':' in wav_path:
        wav_path, offset, duration = wav_path.split(':')
        offset = int(offset)
        duration = int(duration)
    else:
        offset = 0
        duration = -1
    source, rate = sf.read(wav_path, start=offset, frames=duration)
    return source, rate

def read_tsv(tsv_path):
    import csv
    with open(tsv_path) as f:
        reader = csv.DictReader(
            f,
            delimiter="\t",
            quotechar=None,
            doublequote=False,
            lineterminator="\n",
            quoting=csv.QUOTE_NONE,
        )
        samples = [dict(e) for e in reader]
    return samples

def write_tsv(samples, tsv_path):
    with open(tsv_path, "w") as w:
        writer = csv.DictWriter(
            w,
            samples[0].keys(),
            delimiter="\t",
            quotechar=None,
            doublequote=False,
            lineterminator="\n",
            quoting=csv.QUOTE_NONE,
        )
        writer.writeheader()
        writer.writerows(samples)

def play(audio_path):
    display(Audio(read_wav(audio_path)[0], rate=16000))

In [2]:
tsv_path = "/compute/babel-14-5/for_daniel/en-zh/train_nospeaker_traj_30_filtered.tsv"
with open(tsv_path) as f:
    reader = csv.DictReader(
        f,
        delimiter="\t",
        quotechar=None,
        doublequote=False,
        lineterminator="\n",
        quoting=csv.QUOTE_NONE,
    )
    samples = [dict(e) for e in reader]

In [3]:
len (samples)

106515

In [4]:
indices = list(range(len(samples)))
np.random.shuffle(indices)

In [5]:
# sample = samples[indices[9]]
sample = samples[indices[0]]
wav, sr = read_wav(sample['audio'])
trajectory = eval(sample['trajectory'])

# print(sample['src_segments'], '\n')
# print(sample['tgt_text'])

step = int(sr * 0.96)
for i, action in zip(range(0, len(wav), step), trajectory):
    display(Audio(wav[i : i + step], rate=sr, autoplay=False))
    print(i // step, "[T_START]", action, "[T_END]")

0 [T_START]  [T_END]


1 [T_START] 据疾病 [T_END]


2 [T_START] 防控中心（CDC）的 [T_END]


3 [T_START] 统计数据显示， [T_END]


4 [T_START]  在 [T_END]


5 [T_START] 2000年出生的新生儿人群中 [T_END]


6 [T_START] ， （这些孩子 [T_END]


7 [T_START] 现在大多7-8岁 [T_END]


8 [T_START] 左右） 每 [T_END]


9 [T_START] 三个白种人， [T_END]


10 [T_START]  每两个 [T_END]


11 [T_START] 非洲裔美国 [T_END]


12 [T_START] 人和西班牙 [T_END]


13 [T_START] 裔美国人中 就会有一个孩子患上 [T_END]


14 [T_START] 糖尿病 [T_END]


15 [T_START] 。  [T_END]


16 [T_START] 如果 [T_END]


17 [T_START] 这还不足以引起警惕的话， [T_END]


18 [T_START] 疾病防控中心进一步 [T_END]


19 [T_START] 表明， [T_END]


20 [T_START]  糖尿病多 [T_END]


21 [T_START] 出现在这些孩子们高中毕业之前。  [T_END]


22 [T_START] 这 [T_END]


23 [T_START] 就意味着 [T_END]


24 [T_START]  [T_END]


25 [T_START] 40%或45%的 [T_END]


26 [T_START]  [T_END]


27 [T_START]  学龄儿童 [T_END]


28 [T_START]  [T_END]


29 [T_START]  [T_END]


In [111]:
def create_demo_video(sample, output_path, fps=30):
    import cv2
    import numpy as np
    from PIL import Image, ImageDraw, ImageFont
    
    # Read audio and get parameters
    wav, sr = read_wav(sample['audio'])
    trajectory = eval(sample['trajectory']) if isinstance(sample['trajectory'], str) else sample['trajectory']
    
    # Calculate video parameters
    step = int(sr * 0.96)  # Audio chunk size (0.96s)
    total_frames = len(wav) * fps // sr
    frame_per_chunk = int(0.96 * fps)  # Frames per 0.96s chunk
    
    # Initialize video writer with MPEG-4 codec
    temp_output = output_path.replace('.mp4', '_temp.avi')  # Use .avi for temp file
    fourcc = cv2.VideoWriter_fourcc(*'XVID')  # Using MPEG-4 codec
    video = cv2.VideoWriter(temp_output, fourcc, fps, (1280, 720))
    
    # Create font (assuming you have a font file)
    font_path = "Microsoft_YaHei_Bold.ttf"
    font_size = 32
    font = ImageFont.truetype(font_path, font_size)
    
    try:
        # Generate frames
        for frame_idx in tqdm(range(total_frames), desc="frame_index"):
            # Create blank frame
            img = np.zeros((720, 1280, 3), dtype=np.uint8)
            
            # Convert to PIL Image for text rendering
            pil_img = Image.fromarray(img)
            draw = ImageDraw.Draw(pil_img)
            
            # Calculate current chunk index and position
            chunk_idx = int(frame_idx / fps * 16000 / step)
            
            # Get accumulated text up to current chunk
            text = " ".join(trajectory[:chunk_idx])
            
            # Word wrap and draw text
            words = text.split()
            lines = []
            current_line = []
            
            for word in words:
                test_line = " ".join(current_line + [word])
                bbox = draw.textbbox((0, 0), test_line, font=font)
                w = bbox[2] - bbox[0]
                if w <= 1200:  # Max width with margin
                    current_line.append(word)
                else:
                    lines.append(" ".join(current_line))
                    current_line = [word]
            if current_line:
                lines.append(" ".join(current_line))
            
            # Draw text lines
            y = 100  # Top position (changed from 600)
            for line in lines[-4:]:  # Show last 3 lines
                bbox = draw.textbbox((0, 0), line, font=font)
                w = bbox[2] - bbox[0]
                h = bbox[3] - bbox[1]
                x = (1280 - w) // 2  # Center text
                draw.text((x, y), line, font=font, fill=(255, 255, 255))
                y += h + 10  # Add to y instead of subtracting (changed from y -= h + 10)

            # Convert back to OpenCV format and write
            frame = np.array(pil_img)
            video.write(frame)
        
        video.release()
        
        # Use ffmpeg to convert to MP4 and add audio
        wav, sr = read_wav(sample['audio'])
        sf.write('temp.wav', wav, sr)

        import subprocess
        cmd = [
            'ffmpeg', '-y',
            '-i', temp_output,  # Video input
            '-i', 'temp.wav',  # Audio input
            '-c:v', 'libx264',  # Convert to H.264
            '-preset', 'medium',  # Encoding preset
            '-crf', '23',  # Quality setting
            '-c:a', 'aac',  # AAC audio codec
            '-strict', 'experimental',
            output_path
        ]
        subprocess.run(cmd)
        
    finally:
        # Cleanup
        video.release()
        if os.path.exists(temp_output):
            os.remove(temp_output)

In [60]:
create_demo_video(samples[63944], 'demo.mp4', fps=60)

frame_index:   0%|          | 0/1728 [00:00<?, ?it/s]

ffmpeg version 5.1.6 Copyright (c) 2000-2024 the FFmpeg developers
  built with gcc 11 (GCC)
  configuration: --prefix=/usr --bindir=/usr/bin --datadir=/usr/share/ffmpeg --docdir=/usr/share/doc/ffmpeg --incdir=/usr/include/ffmpeg --libdir=/usr/lib64 --mandir=/usr/share/man --arch=x86_64 --optflags='-O2 -flto=auto -ffat-lto-objects -fexceptions -g -grecord-gcc-switches -pipe -Wall -Werror=format-security -Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -specs=/usr/lib/rpm/redhat/redhat-hardened-cc1 -fstack-protector-strong -specs=/usr/lib/rpm/redhat/redhat-annobin-cc1 -m64 -march=x86-64-v2 -mtune=generic -fasynchronous-unwind-tables -fstack-clash-protection -fcf-protection' --extra-ldflags='-Wl,-z,relro -Wl,--as-needed -Wl,-z,now -specs=/usr/lib/rpm/redhat/redhat-hardened-ld -specs=/usr/lib/rpm/redhat/redhat-annobin-cc1 ' --extra-cflags=' -I/usr/include/rav1e' --enable-libopencore-amrnb --enable-libopencore-amrwb --enable-libvo-amrwbenc --enable-version3 --enable-bzlib --disable-crysta

In [None]:
log_zh = read_logs("/compute/babel-5-23/siqiouya/runs/en-zh/8B-traj-s2-v3.6/last.ckpt/simul-results-full-betterfilterbadwords/cache1000_seg1920_beam4_ms0_nrnl100_nrns5/instances.log")[0]
log_es = read_logs("/compute/babel-5-23/siqiouya/runs/en-es/8B-traj-s2-v3.6/last.ckpt/simul-results-full-betterfilterbadwords/cache1000_seg1920_beam4_ms0_nrnl100_nrns5/instances.log")[0]
log_de = read_logs("/compute/babel-5-23/siqiouya/runs/en-de/8B-traj-s2-v3.6/last.ckpt/simul-results-full-betterfilterbadwords/cache1000_seg1920_beam4_ms0_nrnl100_nrns5/instances.log")[0]

In [99]:
def convert_log_to_sample(log, unit='char'):
    sample = {}
    n_frame = int(log['source_length'] * 16)
    stepsize = int(0.96 * 16000)
    idx = -1
    new_traj = []

    tokens = log['prediction'] if unit == 'char' else log['prediction'].split(' ')

    for offset in range(0, n_frame, stepsize):
        text = []
        while idx + 1 < len(log['delays']) and int(log['delays'][idx + 1]) * 16 < offset + stepsize:
            idx += 1
            text.append(tokens[idx])
        text = ' '.join(text) if unit == 'word' else ''.join(text)
        new_traj.append(text)
    sample['audio'] = log['source'][0]
    sample['trajectory'] = new_traj
    return sample

In [106]:
sample_zh = convert_log_to_sample(log_zh, unit='char')
sample_es = convert_log_to_sample(log_es, unit='word')
sample_de = convert_log_to_sample(log_de, unit='word')

In [107]:
# sample_es['audio'] += ':0:460800'

In [112]:
create_demo_video(sample_zh, output_path='demo_infinisst_zh.mp4')
create_demo_video(sample_es, output_path='demo_infinisst_es.mp4')
create_demo_video(sample_de, output_path='demo_infinisst_de.mp4')

frame_index:   0%|          | 0/864 [00:00<?, ?it/s]

ffmpeg version 5.1.6 Copyright (c) 2000-2024 the FFmpeg developers
  built with gcc 11 (GCC)
  configuration: --prefix=/usr --bindir=/usr/bin --datadir=/usr/share/ffmpeg --docdir=/usr/share/doc/ffmpeg --incdir=/usr/include/ffmpeg --libdir=/usr/lib64 --mandir=/usr/share/man --arch=x86_64 --optflags='-O2 -flto=auto -ffat-lto-objects -fexceptions -g -grecord-gcc-switches -pipe -Wall -Werror=format-security -Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -specs=/usr/lib/rpm/redhat/redhat-hardened-cc1 -fstack-protector-strong -specs=/usr/lib/rpm/redhat/redhat-annobin-cc1 -m64 -march=x86-64-v2 -mtune=generic -fasynchronous-unwind-tables -fstack-clash-protection -fcf-protection' --extra-ldflags='-Wl,-z,relro -Wl,--as-needed -Wl,-z,now -specs=/usr/lib/rpm/redhat/redhat-hardened-ld -specs=/usr/lib/rpm/redhat/redhat-annobin-cc1 ' --extra-cflags=' -I/usr/include/rav1e' --enable-libopencore-amrnb --enable-libopencore-amrwb --enable-libvo-amrwbenc --enable-version3 --enable-bzlib --disable-crysta