In [82]:
import sys
import os
from collections import defaultdict
import sacrebleu

temp_path = 'ende_{}_wmt14.en-de.transformer.para_ft'
dfs_path = temp_path.format('dfstopk')
beam_path = temp_path.format('beam')
beam_size = 10

In [124]:
def read_split_files(file_path, beam_size):
    all_lines = dict()
    for file in os.listdir(file_path):
        cur_path = os.path.join(beam_path, file)
        with open(cur_path, 'r') as f:
            lines = f.readlines()
        num_suffix = file.split('.')[-1]
        all_lines[int(num_suffix)] = lines
    max_range = max([int(key) for key in all_lines])
    assert max_range + 1 == len(all_lines)

    ret = defaultdict(list)
    for i in range(max_range+1):
        offset = beam_size * i
        for line in all_lines[i]:
            splits = line.strip().split('|||')
            assert len(splits) == 3
            idx, sent, score = splits
            true_idx = offset + int(idx)
            ret[true_idx].append((float(score), sent.strip()))
    return ret

In [84]:
# call delbpe
import subprocess

def call_delbpe(input_dir):
    output_dir = input_dir + '.delbpe'
    print(output_dir)
    subprocess.run(['mkdir', '-p', output_dir])
    for file in os.listdir(input_dir):
        if not file.endswith('delbpe'):
            abs_path = os.path.join(input_dir, file)
            out_path = os.path.join(output_dir, file)
            subprocess.run(['bash', 'delbpe.sh', abs_path])
            subprocess.run(['mv', abs_path+'.delbpe', out_path])
            
    return output_dir

In [112]:
# call delbpe
import subprocess

def call_detok(input_dir):
    output_dir = input_dir + '.detok'
    print(output_dir)
    subprocess.run(['mkdir', '-p', output_dir])
    for file in os.listdir(input_dir):
        if not file.endswith('detok'):
            abs_path = os.path.join(input_dir, file)
            out_path = os.path.join(output_dir, file)
            f1 = open(abs_path, 'r')
            f2 = open(out_path, 'w')
            subprocess.run(['perl', 'detokenizer.perl'], stdin=f1, stdout=f2)
            f1.close()
            f2.close()
#             subprocess.run(['mv', abs_path+'.delbpe', ])
            
    return output_dir

In [85]:
delbpe_beam_path = call_delbpe(beam_path)
delbpe_dfs_path = call_delbpe(dfs_path)

ende_beam_wmt14.en-de.transformer.para_ft.delbpe
ende_dfstopk_wmt14.en-de.transformer.para_ft.delbpe


In [113]:
detok_delbpe_beam_path = call_detok(delbpe_beam_path)
detok_delbpe_dfs_path = call_detok(delbpe_beam_path)

ende_beam_wmt14.en-de.transformer.para_ft.delbpe.detok
ende_beam_wmt14.en-de.transformer.para_ft.delbpe.detok


In [125]:
beam_outputs = read_split_files(detok_delbpe_beam_path, beam_size)
dfs_outputs = read_split_files(detok_delbpe_dfs_path, beam_size)

In [126]:
def read(file):
    with open(file, 'r') as f:
        lines = f.readlines()
    lines = [line.strip() for line in lines]
    return lines
ref_file = './test.de.tok.detok'
refs = read(ref_file)

In [127]:
len(refs)

3003

In [128]:
def score_all_outputs(output_dict, refs):
    scores = []
    for idx in output_dict:
        cur_ref = refs[idx]
        candidates = output_dict[idx]
        candidates = [item[1] for item in candidates]
        cur_scores = [sacrebleu.sentence_bleu(item, cur_ref) for item in candidates]
        scores.append(cur_scores)
    return scores
beam_scores = score_all_outputs(beam_outputs, refs)

In [129]:
all_scores = []
for i in range(len(beam_scores)):
    for item in beam_scores[i]:
        all_scores.append(item.score)
print(sum(all_scores) / len(all_scores))

18.250411330999558


In [130]:
def get_top1(outputs):
    ret = [None] * len(outputs)
    for key in outputs:
        ret[key] = outputs[key][0][1]
    return ret
def write_file(sents, file):
    with open(file, 'w') as f:
        for sent in sents:
            f.write(sent + '\n')
beam_top1s = get_top1(beam_outputs)
write_file(beam_top1s, 'beam.outs')

In [106]:
beam_scores[50][1].score

17.671315921862618

In [107]:
print(len(beam_outputs))
print(len(dfs_outputs))

3003
3003
