In [1]:
from camel_tools.utils.charsets import UNICODE_PUNCT_SYMBOL_CHARSET
import string
import re
import copy

In [8]:
def read_lines(path):
    with open(path) as f:
        return [x.strip() for x in f.readlines()]
    
def read_alignment(path):
    example = []
    examples = []
    with open(path) as f:
        for line in f.readlines()[1:]:
            line = line.replace('\n', '').split('\t')
            if len(line) > 1:
                s, t = line
                example.append((s, t))
            else:
                examples.append(example)
                example = []

        if example:
            examples.append(example)

    return examples

In [9]:
def read_gold_m2_edits(path):
    sent_edits = []
    edits = []
    with open(path, mode='r') as f:
        for line in f.readlines():
            line = line.strip()
            if line.startswith('S'):
                continue
            
            if line:
                line = line.split('|||')
                op = line[1]
                if op == 'Delete':
                    sent_edits.append('')
                else:
                    sent_edits.append(line[2])
            else:
                edits.append(sent_edits)
                sent_edits = []
        
    return edits

In [10]:
def create_m2_edits_per_ex(example):
    edits = []
    src_idx = 0
    preds = []
    for i, (s, t) in enumerate(example):

        if s == t:
            src_idx += 1
            continue
        
        if s != '' and t != '':
            edit = f'A {src_idx} {src_idx + len(s.split())}|||Replace|||{t}|||REQUIRED|||-NONE-|||0'
            src_idx += len(s.split())
            preds.append(t)
        
        elif s == '' and t != '':
            edit = f'A {src_idx} {src_idx}|||Insert|||{t}|||REQUIRED|||-NONE-|||0'
            preds.append(t)
            
        elif s != '' and t == '':
            edit = f'A {src_idx} {src_idx + len(s.split())}|||Delete||||||REQUIRED|||-NONE-|||0'
            src_idx += len(s.split())
            preds.append('')
            
        edits.append(edit)

    return "\n".join(edits), preds

def create_m2_edits(examples):
    edits = []
    preds = []
    for i, example in enumerate(examples):
        ex_edits, ex_preds = create_m2_edits_per_ex(example)
        edits.append(ex_edits)
        preds.append(ex_preds)
    
    return edits, preds

In [11]:
def compare_edits(src_sents, tgt_sents, gold_edits, my_edits):
    check = []
    bugs = 0
    with open('alignment_eval_checkkk.txt', mode='w') as f:
        for i in range(len(src_sents)):
            g_edits = gold_edits[i]
            m_edits = my_edits[i]

            if g_edits != m_edits: 
                bugs += 1
                f.write(f'<s>{src_sents[i]}<s>\n')
                f.write(f'<s>{tgt_sents[i]}<s>\n')

                j, k = 0, 0

                while j < len(g_edits) or k < len(m_edits):
                    g = g_edits[j] if j < len(g_edits) else ''
                    m = m_edits[k] if k < len(m_edits) else ''

                    f.write(f'<s>{g}<s>\t<s>{m}<s>\n')
                    
                    j += 1
                    k += 1

                f.write('\n')
    print(bugs)

In [12]:
def recover_tgt(src, m2_edits):
    m2_edits = m2_edits.split('\n')
    tgt = []
    curr_idx = 0
    for edit in m2_edits:

        edit = edit.split('|||')
        span = edit[0].split()[1:]
        start, end = int(span[0]), int(span[1])
        
        tgt += src[curr_idx: start]
        
        if edit[1] != 'Delete':

            tgt.append(edit[2])
        
        curr_idx = end
    
    if curr_idx < len(src):
        tgt += src[curr_idx: ]
    
    return " ".join(tgt)

In [13]:
def fix_m2_edit(m2_edits, tgt_sent):
    m2_edits = m2_edits.split('\n')
    new_m2_edits = []
    curr_idx = 0
    tgt_idx = 0
    
    for i, edit in enumerate(m2_edits):
        edit = edit.split('|||')
        new_edit = copy.copy(edit)
        span = edit[0].split()[1:]
        start, end = int(span[0]), int(span[1])
        tgt_idx += start - curr_idx
        
        # recover the original target token that has to be inserted or replaced
        if edit[1] == 'Insert' or edit[1] == 'Replace':
            new_edit[2] = " ".join(tgt_sent.split()[tgt_idx: tgt_idx + len(edit[2].split())])
            tgt_idx += len(edit[2].split())
        
        curr_idx = end
        new_m2_edits.append('|||'.join(new_edit))
    
    return "\n".join(new_m2_edits)

In [15]:
gold_edits = read_gold_m2_edits('/scratch/ba63/gec/data/QALB-0.9.1-Dec03-2021-SharedTasks/data/2014'\
                                '/tune/QALB-2014-L1-Tune.m2')

src_sents = read_lines('/scratch/ba63/gec/data/QALB-0.9.1-Dec03-2021-SharedTasks/data/2014'\
                        '/tune/QALB-2014-L1-Tune.sent.no_ids.clean')

tgt_sents = read_lines('/scratch/ba63/gec/data/QALB-0.9.1-Dec03-2021-SharedTasks/data/2014'\
                        '/tune/QALB-2014-L1-Tune.cor.no_ids')

alignment = read_alignment('/scratch/ba63/gec/data/alignment/qalb14/qalb14_tune.txt')
my_edits, preds = create_m2_edits(alignment)

In [16]:
fixed_edits = []
for i, (edit, tgt_sent) in enumerate(zip(my_edits, tgt_sents)):
    if edit:
        fixed_edits.append(fix_m2_edit(edit, tgt_sent))
    else:
        fixed_edits.append('')

In [17]:
for i in range(len(fixed_edits)):
    if fixed_edits[i] and recover_tgt(src_sents[i].split(), fixed_edits[i]) != tgt_sents[i]:
        print(i)

In [18]:
compare_edits(src_sents, tgt_sents, gold_edits, preds)

147
