In [18]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import csv
import json

import numpy as np
import matplotlib.pyplot as plt

import sacrebleu
import soundfile as sf

import yaml
from tqdm.notebook import tqdm

from IPython.display import display, Audio

def read_logs(path):
    logs = []
    with open(path, "r") as r:
        for l in r.readlines():
            l = l.strip()
            if l != "":
                logs.append(json.loads(l))
    return logs

def read_wav(wav_path):
    if ':' in wav_path:
        wav_path, offset, duration = wav_path.split(':')
        offset = int(offset)
        duration = int(duration)
    else:
        offset = 0
        duration = -1
    source, rate = sf.read(wav_path, start=offset, frames=duration)
    return source, rate

def read_tsv(tsv_path):
    import csv
    with open(tsv_path) as f:
        reader = csv.DictReader(
            f,
            delimiter="\t",
            quotechar=None,
            doublequote=False,
            lineterminator="\n",
            quoting=csv.QUOTE_NONE,
        )
        samples = [dict(e) for e in reader]
    return samples

def write_tsv(samples, tsv_path):
    with open(tsv_path, "w") as w:
        writer = csv.DictWriter(
            w,
            samples[0].keys(),
            delimiter="\t",
            quotechar=None,
            doublequote=False,
            lineterminator="\n",
            quoting=csv.QUOTE_NONE,
        )
        writer.writeheader()
        writer.writerows(samples)

def play(audio_path):
    from IPython.display import display, Audio
    display(Audio(read_wav(audio_path)[0], rate=16000))

# instance logs

In [None]:
!ls /compute/babel-5-23/siqiouya/runs/8B-traj-s2-v3.0/last.ckpt/simul-results-full/

In [28]:
multiplier = 1
logs = read_logs("/compute/babel-5-23/siqiouya/runs/8B-traj-s2-v3.0/last.ckpt/simul-results-full/cache4000_seg{}_beam1_ms0/instances.log".format(multiplier * 960))

In [31]:
new_trajs = []
for log in logs:
    n_frame = int(log['source_length'] * 16)
    stepsize = int(0.96 * 16000) * multiplier
    idx = -1
    new_traj = []
    for offset in range(0, n_frame, stepsize):
        text = ""
        while idx + 1 < len(log['delays']) and int(log['delays'][idx + 1]) * 16 < offset + stepsize:
            idx += 1
            text += log['prediction'][idx]
        new_traj.append(text)
    new_trajs.append(new_traj)

In [None]:
idx = 1
wav, sr = read_wav(logs[idx]['source'][0])
trajectory = new_trajs[idx]

step = int(sr * 0.96) * multiplier
for i, action in zip(range(0, len(wav), step), trajectory):
    display(Audio(wav[i : i + step], rate=sr, autoplay=False))
    print(i // step, "[T_START]", action, "[T_END]")