In [1]:
import json
import copy
from collections import defaultdict
import difflib
import re

In [2]:
def load_data(path):
    src_tokens = []
    tgt_tokens = []
    tags = []
    examples = []

    with open(path, mode='r') as f:
        for line in f.readlines()[1:]:
            line = line.split('\t')
            if len(line) == 3:
                src, tgt, tag = line
                src_tokens.append(src)
                tgt_tokens.append(tgt)
                tags.append(tag.strip())
            else:
                examples.append((src_tokens, tgt_tokens, tags))
                src_tokens = []
                tgt_tokens = []
                tags = []

        if src and tgt and tags:
            examples.append((src_tokens, tgt_tokens, tags))

    return examples

In [7]:
tricky_set = set()

class CorruptFactory:
    def __init__(self, model, counts, examples):
        self.model = model
        self.counts = counts
        self.examples = examples
    
    @classmethod
    def build(cls, data):
        model = defaultdict(lambda: defaultdict(lambda: 0))
        examples = defaultdict(lambda: defaultdict(lambda: list()))
        
        counts = dict()
        
    
        for example in data:
            src_tokens, tgt_tokens, areta_tags = example
            assert len(src_tokens) == len(tgt_tokens) == len(areta_tags)
            
            for i in range(len(src_tokens)):
                tgt_t, src_t, tag = tgt_tokens[i], src_tokens[i], areta_tags[i]
            
                if tag != 'UC' and tag != 'UNK':
                    if '+' not in tag:
                        replace_rules = create_rule(tgt_t, src_t, tag)
                        if replace_rules:
                            for x in replace_rules:
                                model[(x['pattern'], tag)][x['rule']] += 1
                                examples[(x['pattern'], tag)][x['rule']].append({'src': tgt_t, 'tgt': src_t})
                                counts[(x['pattern'], tag)] = 1 + counts.get((x['pattern'], tag), 0)
                                
                                
        
        
        for pattern, tag in model:
            for rule in model[(pattern, tag)]:
                model[(pattern, tag)][rule] /= float(counts[(pattern, tag)])
        

        return cls(model, counts, examples)

    def __len__(self):
        return len(self.model)
    
    def __getitem__(self, tt_tag):

        tgt_t, tag = tt_tag
        
        if (tgt_t, tag) in self.model:
            return dict(self.model[tgt_t, tag])
        
        return None
    
def create_rule(orig, corr, areta_tag, negative_indices=True):
    sm = difflib.SequenceMatcher(None,
                             orig,
                             corr,
                             autojunk=False)
    
    rules = []
    
    if areta_tag == 'SPLIT': # Merge corruptions
        # we will just take the first token as a candidate to be merged with consecutive tokens
        words = orig.split(' ')
        rules.append({'pattern': (words[0]),
                      'rule': ('merge')})
    
    else:
        
        for tag, i1, i2, j1, j2 in sm.get_opcodes():
            i1_orig = i1
            i2_orig = i2

            if negative_indices:
                if i1 > (len(orig) + 1) // 2:
                    i1 = i1 - len(orig) - 1
                    i2 = i2 - len(orig) - 1

                if tag != 'equal':
                    
                    if tag == 'replace':
                        if i1 < 0:
                            rules.append({'pattern': (i1 + 1, None, orig[i1+1:]),
                                          'rule': (tag, corr[j1:j2])})
                        else:
                            rules.append({'pattern': (i1, i2, orig[i1:i2]),
                                          'rule': (tag, corr[j1:j2])})
            

                    elif tag == 'delete':
                        if 'INSERT' not in areta_tag:
                            
                            if i1 < 0:
                                rules.append({'pattern': (i1 + 1, None, orig[i1+1:]),
                                              'rule': (tag,'')})
                            else:
                                rules.append({'pattern': (i1, i2, orig[i1:i2]), 'rule': (tag, '')})
                        
                        else:
                            if i1 < 0:
                                rules.append({'pattern': (orig[i1+1:].strip(), ),
                                              'rule': (tag,'')})
                            else:
                                rules.append({'pattern': (orig[i1:i2].strip(), ), 'rule': (tag, '')})
                        

                    elif tag == 'insert':
                    
                        if 'DELETE' in areta_tag: # insert corruption

                            rules.append({'pattern': None, 'rule': (tag, corr[j1: j2])})

                        elif 'MERGE' in areta_tag: # split corruptions
                            if i1 < 0:
                                
                                rules.append({'pattern': (i1 + 1, None, orig[i1+1: ]),
                                              'rule': (tag, i1, corr[j1: j2])})

            
                            else:
                                
                                rules.append({'pattern': (None, i1, orig[0:i1]), 'rule': (tag, i1, corr[j1: j2])})
                        

                        
                        else:
#                             import pdb; pdb.set_trace()
                            tricky_set.add(areta_tag)
                            if i1 == 0 or i1 == -1:
                                rules.append({'pattern': (i1, None, orig[i1]), 'rule': (tag, i1, corr[j1: j2])})
                            else:
                                rules.append({'pattern': (i1 - 1, None, orig[i1 - 1]), 'rule': (tag, i1, corr[j1: j2])})

#     if 'INSERT' in areta_tag and 'INSERT_PM' not in areta_tag: 
#         import pdb; pdb.set_trace()

    return rules

In [8]:
data = load_data('/scratch/ba63/gec/data/alignment/modeling_areta_tags_check/'\
                'qalb14/qalb14_train.areta+.txt')

In [9]:
model = CorruptFactory.build(data)
len(model)

7182

In [73]:
types = set([x[1] for x in list(model.model.keys())])
types

{'DELETE',
 'INSERT_PM',
 'INSERT_PM^*',
 'INSERT_PM^2',
 'INSERT_XM',
 'INSERT_XM^*',
 'INSERT_XM^2',
 'MERGE',
 'REPLACE_M',
 'REPLACE_MI',
 'REPLACE_MT',
 'REPLACE_O',
 'REPLACE_OA',
 'REPLACE_OC',
 'REPLACE_OD',
 'REPLACE_OH',
 'REPLACE_OM',
 'REPLACE_OR',
 'REPLACE_OT',
 'REPLACE_OW',
 'REPLACE_PC',
 'REPLACE_PM',
 'REPLACE_PT',
 'REPLACE_S',
 'REPLACE_SF',
 'REPLACE_SW',
 'REPLACE_X',
 'REPLACE_XC',
 'REPLACE_XF',
 'REPLACE_XG',
 'REPLACE_XM',
 'REPLACE_XN',
 'REPLACE_XT',
 'SPLIT'}

In [23]:
model.examples[ ((1, 2, 'ل'), 'REPLACE_XM')]

defaultdict(<function __main__.CorruptFactory.build.<locals>.<lambda>.<locals>.<lambda>()>,
            {('delete', ''): [{'src': 'فلماذا', 'tgt': 'فماذا'},
              {'src': 'ولتداول', 'tgt': 'وتداول'},
              {'src': 'فلتنتظر', 'tgt': 'فتنتظر'},
              {'src': 'ولكن', 'tgt': 'وكن'},
              {'src': 'ولنفس', 'tgt': 'ونفس'},
              {'src': 'وليدافعوا', 'tgt': 'ويدافعون'},
              {'src': 'وليغيروا', 'tgt': 'ويغيروا'}],
             ('replace', 'ا'): [{'src': 'وللمصريين', 'tgt': 'والمصريين'},
              {'src': 'وللجنوبيين', 'tgt': 'والجنوبيين'}]})

In [25]:
model[(('من',), 'INSERT_XM')]

{('delete', ''): 1.0}

In [24]:
for key in model.model.keys():
    if 'INSERT_XM' in key:
        print(f'Pattern {key}')
        for op in model[key]:
            print(f'Operation: {op}')
        print()
        

Pattern (('من',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('هناك',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('الإرهابيون',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('القرن',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('أن',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('في',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('بين',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('هذا',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('سنوات',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('منه',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('هل',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('بشار',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('بها',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('إما',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('على',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('ليس',), 'INSERT_XM')
Operation: ('delete', '')

Pattern (('الذي',), 'INSERT_XM')
Operation: ('d

In [36]:
bla = 0
for key in model.model.keys():
    if model.counts[key] < 3:
        

In [37]:
bla

9557

In [106]:
def create_rule(orig, corr, negative_indices=True):
    sm = difflib.SequenceMatcher(None,
                             orig,
                             corr,
                             autojunk=False)
    
    for tag, i1, i2, j1, j2 in sm.get_opcodes():
        i1_orig = i1
        i2_orig = i2

        if negative_indices:
            if i1 >= (len(orig) + 1) // 2:
                i1 = i1 - len(orig) - 1
                i2 = i2 - len(orig) - 1

            if tag == 'replace':
                print(f'Applicable: {[i1, i2, orig[i1:i2]]}')
                print([tag, i1, i2, corr[j1:j2]])
            
            elif tag == 'insert':
                print([tag, i1, corr[j1:j2]])
            
            elif tag == 'delete':
                print([tag, i1, i2])

In [107]:
create_rule('إلى', 
           'الى')

Applicable: [0, 1, 'إ']
['replace', 0, 1, 'ا']


In [None]:
'الصينية .'