In [17]:
def write_data(path, src_sents, tgt_sents, tag_sents):
    with open(path, mode='w') as f:
        for src, tgt, tags in zip(src_sents, tgt_sents, tag_sents):
            for src_token, tgt_token, tag in zip(src, tgt, tags):
                f.write(f'{src_token}\t{tgt_token}\t{tag}')
                f.write('\n')
            f.write('\n')

In [18]:
def read_alignment(path):
    example = []
    examples = []
    with open(path) as f:
        for line in f.readlines():
            line = line.replace('\n', '').split('\t')
            if len(line) > 1:
                s, t, tag = line[0], line[1], line[2]
                example.append((s, t, tag))
            else:
                examples.append(example)
                example = []

        if example:
            examples.append(example)

    return examples

In [64]:
def postprocess_alignment_no_span(src, tgt, tags):
    assert len(src) == len(tgt)

    i, j = 0, 0
    new_src, new_tgt, new_tags = [], [], []

    append_tgt = []
    append_tag = []
    
    # add <bos> and </eos> tokens to the beginning of 
    # src and target
    src = ['<bos>'] + src + ['</eos>']
    tgt = ['<bos>'] + tgt + ['</eos>']
    tags = ['UC'] + tags + ['UC']
    
    
    while i < len(src) and j < len(tgt):
        if src[i] == tgt[j]: # Keep

            if append_tgt: # In case we caught an insert, append to current token
                new_tgt[-1]  = new_tgt[-1] + ' ' + ' '.join(append_tgt)
            
                if new_tags[-1] != 'UC': # update the tag 
                    new_tags[-1]  = new_tags[-1] + '+' + '+'.join(append_tag)
                else:
                    new_tags[-1] = '+'.join(append_tag)
                
                append_tgt = []
                append_tag = []
            
            new_tgt.append(tgt[j])
            new_tags.append(tags[i])
            new_src.append(src[i])
            
        
            i += 1
            j += 1


        elif src[i] != '' and tgt[j] != '': # Replace

            if append_tgt: # In case we caught an insert, append to current token
                new_tgt[-1]  = new_tgt[-1] + ' ' + ' '.join(append_tgt)
                
                if new_tags[-1] != 'UC': # update the tag
                    tag = []
                    for t in new_tags[-1].split('+'):
                        if (not t.startswith('REPLACE') and not t.startswith('INSERT')
                            and not t.startswith('DELETE')):
                            tag.append(f'REPLACE_{t}')
                        else:
                            tag.append(t)
                    
                    tag = '+'.join(tag)
                    new_tags[-1]  = tag + '+' + '+'.join(append_tag)
                
                else:
                    new_tags[-1] = '+'.join(append_tag)
                    
                append_tgt = []
                append_tag = []
                

            new_tgt.append(tgt[j])
        
            tag = '+'.join([f'REPLACE_{t}' for t in tags[i].split('+')])
            new_tags.append(tag)
        
            new_src.append(src[i])
            
            
            i += 1
            j += 1


        elif src[i] == '' and tgt[j] != '': # Track all the inserts
            append_tgt = []
            append_tag = []
        
            while i < len(src) and j < len(tgt) and src[i] == '' and tgt[j] != '':

                append_tgt.append(tgt[j])
                append_tag.append(f'INSERT_{tags[i]}')
            
                j += 1
                i += 1

        else: # Deletions
            new_src.append(src[i])
            new_tgt.append(tgt[i])
            
            new_tags.append('DELETE')
        
            j += 1
            i += 1


    if append_tgt:
        new_tgt[-1] = new_tgt[-1] + ' ' + ' '.join(append_tgt)
        new_tags[-1] = new_tags[-1] + ' ' + ' '.append(append_tag)

    assert len(new_tgt) == len(new_src)
    assert " ".join(new_tgt).split() == " ".join(tgt).split()
    assert " ".join(new_src).split() == " ".join(src).split()
    assert len(new_src) == len(new_tags)
    
    return new_src, new_tgt, new_tags


In [65]:
# def read_alignment(path):
#     example = []
#     examples = []
#     with open(path) as f:
#         for line in f.readlines()[1:]:
#             line = line.replace('\n', '').split('\t')
#             if len(line) > 1:
#                 s, t = line
#                 example.append((s, t))
#             else:
#                 examples.append(example)
#                 example = []

#         if example:
#             examples.append(example)

#     return examples

In [66]:
# def postprocess(src, tgt):
#     assert len(src) == len(tgt)
    
#     i, j = 0, 0
#     new_src, new_tgt = [], []
#     tags = []
    
#     prepend_tgt = []

#     while i < len(src) and j < len(tgt):
#         if src[i] == tgt[j]: # Keep
            
#             if prepend_tgt: 
#                 new_tgt.append(" ".join(prepend_tgt) + ' ' + tgt[j])
#             else:
#                 new_tgt.append(tgt[j])
            

#             new_src.append(src[i])
            
#             i += 1
#             j += 1
#             prepend_tgt = []
        
#         elif src[i] != '' and tgt[j] != '' or src[i] != '' and tgt[j] == '' : # Replace / Deletion
                
#                 if prepend_tgt: 
#                     new_tgt.append(" ".join(prepend_tgt) + ' ' + tgt[j])
#                 else:
#                     new_tgt.append(tgt[j])

#                 new_src.append(src[i])

#                 i += 1
#                 j += 1
#                 prepend_tgt = []
        
#         else: # Insertion
                
#             prepend_tgt = []

#             while i < len(src) and j < len(tgt) and src[i] == '' and tgt[j] != '':

#                 prepend_tgt.append(tgt[j])
#                 j += 1
#                 i += 1
        
#     if prepend_tgt:
#         new_tgt[-1] = new_tgt[-1] + ' ' + ' '.join(prepend_tgt)
    
#     assert len(new_tgt) == len(new_src)
#     assert " ".join(new_tgt).split() == " ".join(tgt).split()
#     assert " ".join(new_src).split() == " ".join(src).split()
    
#     return new_src, new_tgt

In [67]:
# def postprocess(src, tgt):
#     assert len(src) == len(tgt)
    
#     i, j = 0, 0
#     new_src, new_tgt = [], []
#     tags = []
    
#     prepend_src, prepend_tgt = [], []

#     while i < len(src) and j < len(tgt):
#         if src[i] == tgt[j]: # Keep
            
#             if prepend_tgt: 
#                 new_tgt.append(" ".join(prepend_tgt) + ' ' + tgt[j])
#             else:
#                 new_tgt.append(tgt[j])
            
#             if prepend_src:
#                 new_src.append(" ".join(prepend_src) + ' ' +src[i])
            
#             else:
#                 new_src.append(src[i])
            
#             i += 1
#             j += 1
#             prepend_src, prepend_tgt = [], []
        
#         else:
#             if src[i] != '' and tgt[j] != '': # Replace
                
#                 if prepend_tgt: 
#                     new_tgt.append(" ".join(prepend_tgt) + ' ' + tgt[j])
#                 else:
#                     new_tgt.append(tgt[j])

#                 if prepend_src:
#                     new_src.append(" ".join(prepend_src) + ' ' +src[i])

#                 else:
#                     new_src.append(src[i])

#                 i += 1
#                 j += 1
#                 prepend_src, prepend_tgt = [], []
#             else:
#                 prepend_src, prepend_tgt = [], []

#                 while i < len(src) and j < len(tgt) and src[i] == '' and tgt[j] != '':
                    
#                     prepend_tgt.append(tgt[j])
#                     j += 1
#                     i += 1
                
#                 while i < len(src) and j < len(tgt) and src[i] != '' and tgt[j] == '':
#                     prepend_src.append(src[i])
#                     i += 1
#                     j += 1
        
#     if prepend_tgt:
#         new_tgt[-1] = new_tgt[-1] + ' ' + ' '.join(prepend_tgt)
    
#     if prepend_src:
#         new_src[-1] = new_src[-1] + ' ' + ' '.join(prepend_src)
    
    
#     assert len(new_tgt) == len(new_src)
#     assert " ".join(new_tgt).split() == " ".join(tgt).split()
#     assert " ".join(new_src).split() == " ".join(src).split()
    
#     return new_src, new_tgt

In [74]:
alignment = read_alignment('../../arabic_error_type_annotation/qalb15_dev.my_align.txt')

In [75]:
src_sents, tgt_sents, tag_sents = [], [], []
for sent in alignment:
    src, tgt, tags = [x[0] for x in sent],  [x[1] for x in sent],  [x[2] for x in sent]
    src_, tgt_, tags_ = postprocess_alignment_no_span(src, tgt, tags)
    src_sents.append(src_)
    tgt_sents.append(tgt_)
    tag_sents.append(tags_)

In [76]:
write_data('/scratch/ba63/gec/data/alignment/qalb15/qalb15_dev.areta+.txt',
           src_sents, tgt_sents,
           tag_sents)