In [1]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2

In [2]:
from typing import List
import difflib
import functools
import json
import unicodedata
import copy
from collections import defaultdict, Counter
import re
import numpy as np
from utils import Dataset
from corruption_model_morph import CorruptModel, Rule, WordCorruptModel
import torch

from error_tagger.tagger import Model, inference_single
from error_tagger.utils import ErrorTagDataset
from transformers import pipeline
from camel_tools.utils.charsets import UNICODE_PUNCT_SYMBOL_CHARSET

In [3]:
def build_error_mle_model(edit_corrupt_model):
    """
    Models P(error_tag | morph_feat)
    """
    joint_counts = edit_corrupt_model.counts
    # taking out split and insert counts
    joint_counts = {k: v for k, v in joint_counts.items() if (k[0] != 'SPLIT' and 'INSERT' not in k[0])}
    
    morph_counts = dict()

    for k, v in joint_counts.items():
        morph_counts[k[1]] = v + morph_counts.get(k[1], 0)


    assert sum(joint_counts.values()) == sum(morph_counts.values())

    lookup = defaultdict(list)

    for key in joint_counts:
        prob = joint_counts[key] / morph_counts[key[1]]
        lookup[key[1]].append((key[0], prob))
    
    return lookup



def build_error_mle_model_full(data):
    """
    Models P(error_tag | word, morph_feat)
    """
    model = defaultdict(lambda: defaultdict(lambda: 0))
    word_ana_counts = dict()
    
    for example in data.examples:
        tgt_tokens = example.tgt_tokens
        areta_tags = example.areta_tags
        anas = example.morph_feats
        
        for token, tag, ana in zip(tgt_tokens, areta_tags, anas):
            
            if tag != 'UC' and tag != 'UNK':
                if tag == 'SPLIT' and len(token.split()) == 2:
                    tokens = token.split()
                    ana_str = json.dumps([ana[0]], ensure_ascii=True)
                    model[(tokens[0], ana_str)][tag] += 1
                    word_ana_counts[(tokens[0], ana_str)] = 1 + word_ana_counts.get((tokens[0], ana_str), 0)

                elif tag != 'SPLIT':
                    ana_str = json.dumps(ana, ensure_ascii=True)
                    model[(token, ana_str)][tag] += 1
                    word_ana_counts[(token, ana_str)] = 1 + word_ana_counts.get((token, ana_str), 0)
    
    
    for token, ana in model:
        for tag in model[(token, ana)]:
            model[(token, ana)][tag] /= word_ana_counts[(token, ana)]
    
    
    return model


In [4]:
data = Dataset(raw_data_path='/scratch/ba63/gec/data/alignment/modeling_areta_tags_check/qalb14/'\
               'corruption_data/mix_train.areta.txt',
              morph_feats_path='/scratch/ba63/gec/data/alignment/modeling_areta_tags_check/qalb14/'\
               'corruption_data/mix_train_morph.json')

In [5]:
edit_corrupt_model = CorruptModel.build(data)

In [6]:
word_corrupt_model = WordCorruptModel.build(data)

In [7]:
error_mle_model_full = build_error_mle_model_full(data)
# error_mle_model = build_error_mle_model(edit_corrupt_model)

In [8]:
with open('error_tagger/model_w_morph.config.json') as f:
    model_config = json.load(f)
error_nn_model = Model(**model_config)
error_nn_model.load_state_dict(torch.load('error_tagger/model_w_morph.pt'))
vectorizer = ErrorTagDataset.load_vectorizer('error_tagger/vectorizer.txt')

In [9]:
keys = list(edit_corrupt_model.model.keys())
areta_tags = set([x[0] for x in keys])

total_rules = sum([len(edit_corrupt_model[key]) for key in edit_corrupt_model.model.keys()])
counts_by_key = {key: sum([edit_corrupt_model[key][x] for x in edit_corrupt_model[key]]) for key in keys}
counts_by_key = sorted(counts_by_key.items(), key=lambda x: x[1], reverse=True)


print(f'Edit Corruption Model Stats')
print('--------------------')
print(f'Total number of (error tag, morph_feat) pairs: {len(keys)}')
print(f'Total number of rules: {total_rules}')
print()
print()
print(f'Word Corruption Model Stats')
print('--------------------') 
print(f'Total number of (error tag, morph_feat, word) mappings: {len(word_corrupt_model.model.keys())}')
# print(f'Rules: ')
# counts_by_key

Edit Corruption Model Stats
--------------------
Total number of (error tag, morph_feat) pairs: 13595
Total number of rules: 42055


Word Corruption Model Stats
--------------------
Total number of (error tag, morph_feat, word) mappings: 81892


In [None]:
error = 'INSERT_XM'
set_m = set()
for key in keys:
    if error in key[0]:
        print(key)
        set_m.add(key[1])
        for x in edit_corrupt_model[key]:
            
            print(f'Rule: {x}')
            print(f'Count: {edit_corrupt_model[key][x]}')
            print(f'Example: {edit_corrupt_model.examples[key][x][0]}')
            print()

In [None]:
# error_tags = list(edit_corrupt_model.model.keys())

# error_probs = {error: edit_corrupt_model.counts[error]/total_error_counts for error in error_tags}

In [None]:
def introduce_errors(tokens, morph_feats, error_prop, std_dev,
                     edit_corruption_model,
                     word_corruption_model,
                     error_mle_model,
                     error_nn_model,
                     vectorizer,
                     op_probs=None):
    
    

    num_errors = int(np.round(np.random.normal(error_prop, std_dev) * len(tokens)))
    num_errors = min(max(0, num_errors), len(tokens))  # num_errors \in [0; len(tokens)]
    
    corruptions_tags = []
    
    if num_errors == 0:
        return ' '.join(tokens), corruptions_tags
    
    
    token_ids_to_modify = np.random.choice(len(tokens), num_errors, replace=False)


    new_sentence = []

    fill_mask = pipeline('fill-mask', model='/scratch/ba63/gec/mlm', top_k=1)

    for token_id in range(len(tokens)):
        if token_id not in token_ids_to_modify:
            new_sentence.append(tokens[token_id])
            continue
        

        current_token, current_feat = tokens[token_id], morph_feats[token_id]
        str_feat = json.dumps(current_feat, ensure_ascii=False)
        
        operation = np.random.choice(['replace', 'insert', 'delete'], p=[0.85, 0.14, 0.01])
        
        
        
#         print(f'Correct word: {current_token}')
#         print(f'Morph Feat: {str_feat}')
#         print(f'Sampled error: {operation}')
    
        if operation == 'replace':
            compatible_tags, compatible_rules = get_compatible_tags(current_token, str_feat,
                                                                    edit_corruption_model)


            if len(compatible_tags) > 0:
                
                
                print()
                # getting the mle tags
                mle = error_mle_model.get((current_token, str_feat), None)

                # getting the tags that are only compatible
                mle_tags = []
                if mle != None:
                    mle_tags = [(tag, prob) for tag, prob in mle.items() if tag in compatible_tags]
                    mle_tags = sorted(mle_tags, key=lambda x: x[-1], reverse=True)

                # getting the nn tags
                nn = inference_single(error_nn_model, vectorizer, current_token, current_feat)
                nn_tags = [(tag, prob) for tag, prob in nn['top5'] if tag in compatible_tags]

                
                if len(nn_tags) == 0 and len(mle_tags) == 0:
                    tag = 'OOV'
#                     print("CANNOT CORRUPT")

                elif len(mle_tags) == 0 and len(nn_tags) != 0:
                    tag = max(nn_tags, key=lambda x: x[-1])[0]

                elif len(nn_tags) == 0 and len(mle_tags) != 0:
                    tag = max(mle_tags, key=lambda x: x[-1])[0]

                else:
                    max_mle_tag = max(mle_tags, key=lambda x: x[-1])
                    max_nn_tag = max(nn_tags, key=lambda x: x[-1])
                    tag = max_mle_tag[0] if max_mle_tag[1] > max_nn_tag[1] else  max_nn_tag[0]

                # special case for pnx, fix later to make prettier
                if current_token in UNICODE_PUNCT_SYMBOL_CHARSET:
                    if 'INSERT_PM' in mle_tags and 'REPLACE_PC' in mle_tags:
                        tag = np.random.choice(['INSERT_PM', 'REPLACE_PC'], p=[0.5, 0.5])
                
#                 print(f'MLE tags: {mle_tags}\n')
#                 print(f'NN tags: {nn_tags}\n')
#                 print(f'Corruption tag: {tag}\n')
                corruptions_tags.append(tag)
        
                # corruption

                if tag != 'OOV':
                    global_corruptions = word_corruption_model[(tag, str_feat, current_token)]
                    if len(global_corruptions) != 0:

                        corrupted = max(global_corruptions, key=global_corruptions.get)
#                         print(f'Rule: Global Corruption Model\n')
                        
                    else:
                        # getting the compatible rules
                        rules = compatible_rules[(tag, str_feat)]
                        max_rule = Rule.from_str(max(rules, key=rules.get))
                        corrupted = max_rule.apply(current_token)
#                         print(f'Rule: {max_rule.to_str()}\n')
                    
                    if corrupted is None:
                        assert tag == 'SPLIT'
                        corrupted = current_token
                        new_sentence.append(current_token+'_SPLIT')
                    else:
                        new_sentence.append(corrupted)
                    
                else:
                    corrupted = current_token
                    new_sentence.append(corrupted)
                    
#                 print(f'Corrupted word: {corrupted}')
                    


#             else:
#                 print(f'No compatible tags for selected words')
            
        elif operation == 'delete':
#             print(f'Corrupted word: {current_token}')
#             print(f'Corruption tag: DELETE')
        
            corruptions_tags.append('DELETE')
        
        elif operation == 'insert':
#             print(f'Corrupted word: [MASK]')
            new_sentence.append(current_token)
            new_sentence.append('[MASK]')
            corruptions_tags.append('INSERT')
        
            
        print('----------------------------------------')
    
    new_sentence = ' '.join(new_sentence)
    
    # applying split errors
    if '_SPLIT' in new_sentence:
        new_sentence = re.sub(r'\_SPLIT\s', '', new_sentence)
    
    # filling the masked tokens
    # we will ignore masked positions where the prediction is a
    # subword
    if '[MASK]' in new_sentence:
        filled_masks = fill_mask(new_sentence)
        if len(filled_masks) > 1:
            masked_tokens = [x[0]['token_str'] for x in filled_masks]
        else:
            masked_tokens = [filled_masks[0]['token_str']]
        mask_idx = 0
        sentence_tokens = new_sentence.split(' ')
        sentence_tokens_ = []

        for i in range(len(sentence_tokens)):
            if sentence_tokens[i] == '[MASK]':
                masked_token = masked_tokens[mask_idx]
                if '##' not in masked_token:
                    sentence_tokens_.append(masked_token)
                mask_idx += 1
            else:
                sentence_tokens_.append(sentence_tokens[i])

        assert '[MASK]' not in sentence_tokens_
        new_sentence = ' '.join(sentence_tokens_)

    
    new_sentence = re.sub(r' +', ' ', new_sentence)
    
    return new_sentence, corruptions_tags



        
def get_compatible_tags(word, morph_feats, model):

    compatible_tags = set()
    compatible_rules = defaultdict(lambda: defaultdict(lambda: 0))
    for key in model.model.keys():
        if key[1] == morph_feats and 'DELETE' not in key[0]:
            for rule in model[key]:
                rule_obj = Rule.from_str(rule)
                if rule_obj.is_applicable(word):
                    compatible_tags.add(key[0])
                    compatible_rules[key][rule] = model[key][rule]

    return list(compatible_tags), compatible_rules

In [None]:
dev_dataset = Dataset(raw_data_path='/scratch/ba63/gec/data/alignment/modeling_areta_tags_check/qalb14/'\
               'corruption_data/qalb14_tune.areta.txt',
              morph_feats_path='/scratch/ba63/gec/data/alignment/modeling_areta_tags_check/qalb14/'\
               'corruption_data/qalb14_tune_morph.json')

In [None]:
testing_sents_idx = []
# we need to handle the more than one token problem!
for i in range(len(dev_dataset.examples)):
    tgt_tokens = dev_dataset.examples[i].tgt_tokens
    if sum([len(x.split()) for x in tgt_tokens]) == len(tgt_tokens) and 40 <= len(tgt_tokens) <= 60 :
        testing_sents_idx.append(i)

In [None]:
# tokens = dev_dataset.examples[12].tgt_tokens
# morph_feats = dev_dataset.examples[12].morph_feats
error_prop=0.20
std=0.2
# print(' '.join(tokens))

In [None]:
total_tags = []
corrupted_sentences = []

for idx in testing_sents_idx[:20]:
    tokens = dev_dataset.examples[idx].tgt_tokens
    morph_feats = dev_dataset.examples[idx].morph_feats

    
    
    
    corrupted, corruption_tags  = introduce_errors(tokens=tokens,
                                                 morph_feats=morph_feats,
                                                 error_prop=error_prop,
                                                 std_dev=std,
                                                 edit_corruption_model=edit_corrupt_model,
                                                 word_corruption_model=word_corrupt_model,
                                                 error_mle_model=error_mle_model_full,
                                                 error_nn_model=error_nn_model,
                                                 vectorizer=vectorizer)

    corrupted_sentences.append((' '.join(tokens), corrupted))
    print('**************************************************')
    print('**************************************************')
    
    print(' '.join(tokens))
    print(corrupted)
    
    print('**************************************************')
    print('**************************************************')
    total_tags.extend(corruption_tags)

In [None]:
# tags_cnt = Counter(total_tags)

In [None]:
# tags_cnt

In [None]:
for sent in corrupted_sentences:
    correct = sent[0]
    corrupted = sent[1]
    
    if correct != corrupted:
        print(correct)
        print(corrupted)
        print()