In [None]:
import sys
data_dir = "/media/george/Data/mustc/en-de"
code_dir = "/home/george/Projects/simulst"
fair_dir = "/home/george/utility/fairseq"
sys.path.insert(0, code_dir)
sys.path.insert(0, fair_dir)
model = "cif_de_sum_ctc0_3_lat0_0"

In [None]:
import examples.simultaneous_translation
from fairseq import (
    checkpoint_utils,
    options,
    quantization_utils,
    tasks,
    utils,
)
from torchinfo import summary
import logging
import os
import matplotlib.pyplot as plt
import torch
from fairseq_cli.generate import get_symbols_to_strip_from_output

In [None]:
checkpoint = f"{code_dir}/exp/checkpoints/{model}/checkpoint_best.pt"
use_cuda = True

overrides = {
    "user_dir": f"{code_dir}/codebase",
    "inference_config_yaml": f"{code_dir}/exp/infer_st.yaml",
    "data": data_dir,
    "gen_subset": "dev_st",
    "batch_size": 1,
    "beam": 1,
    "do_mtl": True,
}

states = checkpoint_utils.load_checkpoint_to_cpu(
    path=checkpoint, arg_overrides=overrides, load_on_all_ranks=False)
cfg = states["cfg"]

In [None]:
logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=os.environ.get("LOGLEVEL", "INFO").upper(),
    stream=sys.stdout,
)
logger = logging.getLogger("fairseq_cli.train")

In [None]:
utils.import_user_module(cfg.common)

# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(cfg.task)
# Build model and criterion
model = task.build_model(cfg.model)
criterion = task.build_criterion(cfg.criterion)
logger.info(summary(model))
logger.info("task: {}".format(task.__class__.__name__))
logger.info("model: {}".format(model.__class__.__name__))
logger.info("criterion: {}".format(criterion.__class__.__name__))

In [None]:
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
model = task.build_model(cfg.model)
model.load_state_dict(
    states["model"], strict=True, model_cfg=cfg.model
)

# Optimize ensemble for generation
if use_cuda:
    model.cuda()
model.prepare_for_inference_(cfg)

In [None]:
task.load_dataset(cfg.dataset.gen_subset, task_cfg=cfg.task)

In [None]:
# Load dataset 
itr = task.get_batch_iterator(
    dataset=task.dataset(cfg.dataset.gen_subset),
    max_tokens=cfg.dataset.max_tokens,
    max_sentences=cfg.dataset.batch_size,
    max_positions=utils.resolve_max_positions(
        task.max_positions(), model.max_positions() #*[m.max_positions() for m in models]
    ),
    ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
    required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
    seed=cfg.common.seed,
    num_shards=cfg.distributed_training.distributed_world_size,
    shard_id=cfg.distributed_training.distributed_rank,
    num_workers=cfg.dataset.num_workers,
    data_buffer_size=cfg.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False)

generator = task.build_generator(
    [model], cfg.generation,
)

# Handle tokenization and BPE
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
tokenizer = task.build_tokenizer(cfg.tokenizer)
bpe = task.build_bpe(cfg.bpe)

def encode_fn(x):
    if tokenizer is not None:
        x = tokenizer.encode(x)
    if bpe is not None:
        x = bpe.encode(x)
    return x

def decode_fn(x):
    if bpe is not None:
        x = bpe.decode(x)
    if tokenizer is not None:
        x = tokenizer.decode(x)
    return x

In [None]:
idx = 480
for sample in itr:
    if use_cuda:
        sample = utils.move_to_cuda(sample) 
    if "net_input" not in sample:
        continue
    if sample['id'].item() == idx:
        break
assert sample['id'].item() == idx
sample_w_asr = sample
sample = {
    "id": sample_w_asr["id"],
    "net_input": {
        "src_tokens": sample_w_asr["net_input"]["src_tokens"],
        "src_lengths": sample_w_asr["net_input"]["src_lengths"],
    }
}
translations = task.inference_step(
    generator, [model], sample
)
utils.post_process_prediction(
    hypo_tokens=translations[0][0]["tokens"].int().cpu(),
    src_str=None,
    alignment=None,
    align_dict=None,
    tgt_dict=tgt_dict,
    remove_bpe=None, #"sentencepiece",
    extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
)[1]

In [None]:
utils.post_process_prediction(
    hypo_tokens=translations[0][0]["tokens"].int().cpu(),
    src_str=None,
    alignment=None,
    align_dict=None,
    tgt_dict=tgt_dict,
    remove_bpe="sentencepiece",
    extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
)[1]

In [None]:
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
target = translations[0][0]["tokens"].unsqueeze(0).type_as(src_lengths)
prev_output_tokens = target.roll(1, 1)
# logits, extra = model(src_tokens, src_lengths, prev_output_tokens=prev_output_tokens)
extra = model.encoder(src_tokens, src_lengths)

In [None]:
paths = []
with open("/home/george/Projects/simulst/eval/data/dev.wav_list", "r") as f:
    for l in f:
        paths.append(l.strip())

In [None]:
import torchaudio
import IPython
wav, sr = torchaudio.load(paths[idx])
IPython.display.Audio(wav[:,:], rate=sr)

In [None]:
alpha = extra["alpha"][0].detach()
cif_steps = alpha.cumsum(-1).floor()

In [None]:
asr_checkpoint = f"{code_dir}/exp/checkpoints/ctc_s2s_asr/checkpoint_best.pt"
asr_states = checkpoint_utils.load_checkpoint_to_cpu(
    path=asr_checkpoint, arg_overrides=overrides, load_on_all_ranks=False)
asr_cfg = asr_states["cfg"]
# Setup task, e.g., translation, language modeling, etc.
asr_task = tasks.setup_task(asr_cfg.task)
# Build model and criterion
asr_model = asr_task.build_model(asr_cfg.model)
asr_model.load_state_dict(
    asr_states["model"], strict=True, model_cfg=asr_cfg.model
)

# Optimize ensemble for generation
if use_cuda:
    asr_model.cuda()
asr_model.prepare_for_inference_(asr_cfg)

In [None]:
from codebase.criterion.best_alignment import best_alignment
encoder_out = asr_model.encoder(
    sample['net_input']['src_tokens'],
    sample['net_input']['src_lengths'],
)
ctc_logits = encoder_out["ctc_logits"][0]
encoder_mask = encoder_out["encoder_padding_mask"][0]
asr_target = sample_w_asr['net_input']['src_txt_tokens']
asr_lenths = sample_w_asr['net_input']['src_txt_lengths']

ctc_logits[..., 0] = ctc_logits[..., 0] - 10
states = best_alignment(
    ctc_logits.log_softmax(-1).transpose(0, 1),
    asr_target,
    (~encoder_mask).sum(-1),
    asr_lenths,
    blank=0,
    as_labels=False
)

In [None]:
import numpy as np
import matplotlib
matplotlib.style.use('seaborn')
def print_segments(wav, states, labels, steps, preds, align):
    padding = 5000
    yscale = 0.6
    texttop = 0.75
    textbot = -0.75
    font = 18
    def ratio(x):
        return (x + 1) * 4 * 10 * sr // 1000 - 1

    # transcriptions
    wav = wav * yscale
    states = states.squeeze(0)
    labels = labels.squeeze(0)
    T = states.size(-1)

    blanks = states % 2 == 0    
    tgt_idx = states.div(2, rounding_mode='floor')
    tgt_idx = tgt_idx.masked_fill(blanks, -1)
    next_id = tgt_idx.roll(-1, dims=0)
    prev_id = tgt_idx.roll(1, dims=0)

    l_bound = torch.arange(T)[(tgt_idx != prev_id) & (~blanks)].tolist()
    r_bound = torch.arange(T)[(tgt_idx != next_id) & (~blanks)].tolist()

    fig = plt.figure(figsize=(12, 8*yscale), dpi=100)
    ax = fig.add_subplot(111)
    ax.plot(wav.squeeze().cpu().numpy() * texttop)
    offset = texttop*yscale
    ax.hlines(offset, colors="black", xmin=0, xmax=ratio(T)+padding)
    for l, r, idx in zip(l_bound, r_bound, labels):
        if idx == src_dict.eos():
            break
        x0 = ratio(l)
        x1 = ratio(r)
        w = src_dict.string([idx], "sentencepiece")
        ax.axvspan(x0, x1, alpha=0.1, color="red", ymin=(textbot+1)/2) #, ymax=1 - (1-texttop) / 2)
        ax.annotate(w, (x0, offset+0.05), ha="left", fontsize=font)
        
    # translation
    steps = steps.squeeze(0)
    preds = preds.squeeze(0)
    next_steps = steps.roll(-1, dims=0)
    r_bound_trans = torch.arange(T)[(steps != next_steps)].tolist()
    offset = textbot*yscale
    ax.hlines(offset, colors="black", xmin=0, xmax=ratio(T)+padding)
    l_bound_trans = [0] + r_bound_trans
    for l, r, idx in zip(l_bound_trans, r_bound_trans, preds):
        if idx == tgt_dict.eos():
            break
        x0 = ratio(l)
        x1 = ratio(r)
        w = src_dict.string([idx], "sentencepiece")
        ax.annotate(w, ((x0 + x1) / 2, offset-0.1), ha="center", fontsize=font)
        ax.vlines(x1, colors="black", ymin=-1, ymax=offset)
        
    next_j = [a[1] for a in align[1:]] + [align[-1][1]]
    for (i, j), n_j in zip(align, next_j):
        src = ratio((l_bound[i] + r_bound[i]) / 2)
        tgt = ratio((l_bound_trans[j] + r_bound_trans[j]) / 2)
        ax.plot((src, tgt), (texttop*yscale, textbot*yscale), ':m' if j <= n_j else '-m')
        

    xticks = ax.get_xticks()
    plt.xticks(xticks, (xticks * 1000 / sr).astype(int), fontsize=font*0.9)
    ax.set_xlabel("Time (ms)", fontsize=font*0.9)
    ax.set_yticks([])
    ax.set_ylim(-yscale, yscale)
    ax.set_xlim(0, wav.size(-1)+padding)
    fig.savefig("policy.pdf", bbox_inches='tight', pad_inches=0.05)


align = [
    (0, 0), (1, 1), (2, 3), (3, 4), (4, 7), (5, 5), (6, 6), (7, 8),
    (8, 9), (9, 9), (10, 10), (11, 11), (12, 12), (13, 13), (14, 14)
]
print_segments(wav, states, asr_target, cif_steps, target, align)

In [None]:
utils.post_process_prediction(
    hypo_tokens=asr_target[0].cpu(),
    src_str=None,
    alignment=None,
    align_dict=None,
    tgt_dict=tgt_dict,
    remove_bpe=None, #"sentencepiece",
    extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
)[1]

In [None]:
IPython.display.Audio(wav[:,15359:19199], rate=sr)

In [None]:
utils.post_process_prediction(
    hypo_tokens=translations[0][0]["tokens"].int().cpu(),
    src_str=None,
    alignment=None,
    align_dict=None,
    tgt_dict=tgt_dict,
    remove_bpe=None, #"sentencepiece",
    extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
)[1]

In [None]:

ctc_logits = extra["ctc_logits"][0]
encoder_mask = extra["encoder_padding_mask"][0]
states = best_alignment(
    ctc_logits.log_softmax(-1).transpose(0, 1),
    target,
    (~encoder_mask).sum(-1),
    target.ne(1).sum(-1),
    blank=0,
    as_labels=False
)

In [None]:
steps

In [None]:
states // 2

In [None]:
fig = plt.figure(figsize=(12, 8), dpi=100)
ax = fig.add_subplot(111)
ax.plot(states.squeeze(0).cpu().numpy() // 2)

In [None]:
print_boundary(wav, states.div(2, rounding_mode='floor'), (states % 2 == 1))

In [None]:
import pandas as pd
import json
data = {}
with open("mustc_de-results/cif_de_sum_ctc0_3_lat0_0/instances.log", "r") as f:
    for i, line in enumerate(f):
        d = json.loads(line)
        data[i] = {
            "prediction": d["prediction"],
            "reference": d["reference"],
            "bleu": d["metric"]["sentence_bleu"],
            "AL": d["metric"]["latency"]["AL"],
            "reference_length": d["reference_length"],
        }

In [None]:
pd.DataFrame.from_dict(data, orient='index').sort_values("bleu")