In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import csv
import json

import numpy as np
import matplotlib.pyplot as plt

import sacrebleu
import soundfile as sf

import yaml
from tqdm.notebook import tqdm

def read_logs(path):
    logs = []
    with open(path, "r") as r:
        for l in r.readlines():
            l = l.strip()
            if l != "":
                logs.append(json.loads(l))
    return logs

def read_wav(wav_path):
    wav_path, offset, duration = wav_path.split(':')
    offset = int(offset)
    duration = int(duration)
    source, rate = sf.read(wav_path, start=offset, frames=duration)
    return source, rate

def read_tsv(tsv_path):
    import csv
    with open(tsv_path, encoding='utf-8') as f:
        reader = csv.DictReader(
            f,
            delimiter="\t",
            quotechar=None,
            doublequote=False,
            lineterminator="\n",
            quoting=csv.QUOTE_NONE,
        )
        samples = [dict(e) for e in reader]
    return samples

def write_tsv(samples, tsv_path):
    with open(tsv_path, "w", encoding='utf-8') as w:
        writer = csv.DictWriter(
            w,
            samples[0].keys(),
            delimiter="\t",
            quotechar=None,
            doublequote=False,
            lineterminator="\n",
            quoting=csv.QUOTE_NONE,
        )
        writer.writeheader()
        writer.writerows(samples)

def play(audio_path):
    from IPython.display import display, Audio
    display(Audio(read_wav(audio_path)[0], rate=16000))

In [96]:
base_split = 'train'
split = 'train_st_zh_ft_traj_30_filtered_po10k_gpt-4o-mini-2024-07-18_fa_traj'
# base_split = 'dev'
# split = 'dev_st_zh_traj_30_gpt-4o-mini-2024-07-18_fa_traj'
tsv_path = "/compute/babel-14-5/siqiouya/en-zh/{}.tsv".format(split)
with open(tsv_path, encoding='utf-8') as f:
    reader = csv.DictReader(
        f,
        delimiter="\t",
        quotechar=None,
        doublequote=False,
        lineterminator="\n",
        quoting=csv.QUOTE_NONE,
    )
    samples = [dict(e) for e in reader]

In [97]:
src_segment_size=960
latency_multiplier=1
max_llm_cache_size=4000
no_repeat_ngram_lookback=100
no_repeat_ngram_size=3
beam=1
ms=0

In [98]:
n_split = 8
ckpt_dir = "/compute/babel-5-23/siqiouya/runs/8B-traj-s2-v3.3/last.ckpt/"
greedy_predictions = []
logs = []
for i in range(n_split):
    log_path = f"{ckpt_dir}/greedy_train_chunk30_po10k/cache{max_llm_cache_size}_seg{src_segment_size}_beam{beam}_ms{ms}_nrnl{no_repeat_ngram_lookback}_nrns{no_repeat_ngram_size}/{i}/instances.log"
    # log_path = f"{ckpt_dir}/greedy_dev_chunk30/cache{max_llm_cache_size}_seg{src_segment_size}_beam{beam}_ms{ms}_nrnl{no_repeat_ngram_lookback}_nrns{no_repeat_ngram_size}/{i}/instances.log"
    with open(log_path, "r") as f:
        for line in f:
            line = line.strip()
            if line:
                log = json.loads(line)
                logs.append(log)

In [99]:
for sample, log in zip(samples, logs):
    n_frame = int(sample['n_frames'])
    stepsize = int(0.96 * 16000)
    idx = -1
    new_traj = []
    for offset in range(0, n_frame, stepsize):
        text = ""
        while idx + 1 < len(log['delays']) and int(log['delays'][idx + 1]) * 16 < offset + stepsize:
            idx += 1
            text += log['prediction'][idx]
        new_traj.append(text)
    sample['sampling'] = new_traj

In [103]:
print(samples[-1]['trajectory'])

['', '但是', '如果在', '学校表现良好', '以及', '在生活中', '取决于', '远不止你', '', '快速', '且', '轻松学习的能力？', '所以我离开', '了教室，', '然后我去', '了研究生院，以成为一名', '', '心理学家。我', '开始研究', '儿童和成年人', '在各种', '超级具有挑战性的', '环境中，', ' 在每个', '研究中我的问题', '是谁在', '这里成功', '以及', '为什么？我', '的研究团队和我']


In [102]:
play(samples[-1]['audio'])

In [104]:
cnt = 0
samples_f = []
for x in samples:
    if '香肠' in x['tgt_text']:
        cnt += 1
    else:
        samples_f.append(x)
cnt

384

In [105]:
write_path = "/compute/babel-14-5/siqiouya/en-zh/{}_sample.tsv".format(split)
write_tsv(samples_f, write_path)