In [703]:
import json
import copy
from collections import Counter, defaultdict
from itertools import combinations
import re
import operator
import kenlm
import math

In [704]:
# bin/lmplz -o 5 < train.char.txt > char.arpa
# bin/build_binary char.arpa delme.bin
model = kenlm.LanguageModel('/Users/ba63/Desktop/repos/gender-rewriting/kenlm/build/delme.bin')

In [705]:
def score(words):
    scored_words = [(w, model.score(" ".join(list(w)))) for w in words]
    return scored_words

In [706]:
def read_data(path):
    with open(path, encoding="utf8") as f:
        return [x.strip() for x in f.readlines()]

In [707]:
class TokenizedSent:
    def __init__(self, id, tokens, tags):
        self.id = id
        self.tokens = tokens
        self.tags = tags

    def __repr__(self):
        return str(self.to_json_str())

    def to_json_str(self):
        return json.dumps(self.to_dict(), indent=2, ensure_ascii=False)

    def to_dict(self):
        output = copy.deepcopy(self.__dict__)
        return output


In [708]:
def collate_stentences(raw_data):
    tokenized_sents = []
    id, tokens, tags = None, [], []
    for line in raw_data:
        line = line.split()
        if line:
            id = line[0]
            tokens.append(line[1])
            tags.append(line[2])
        else:
            tokenized_sents.append(TokenizedSent(id, tokens, tags))
            id, tokens, tags = None, [], []
    
    return tokenized_sents

In [709]:
def edit_distance(src_word, trg_word):
    str1 = src_word if len(src_word) <= len(trg_word) else trg_word
    str2 = src_word if str1 == trg_word else trg_word
    
    m = len(str1)
    n = len(str2)
    dp = [[0 for _ in range(n + 1)] for _ in range(m + 1)]
    
    # base cases
    for row in range(m):
        dp[row][n] = m - row
    
    for col in range(n):
        dp[m][col] = n - col
    
    # Bottom up dp
    for i in range(m - 1, -1, -1):
        for j in range(n - 1, -1, -1):
            add = 1 if str1[i] != str2[j] else 0
            dp[i][j] = min(dp[i + 1][j] + 1, dp[i][j + 1] + 1, dp[i + 1][j + 1] + add)
    

    return dp


In [710]:
def backtrack_dp(dp, src_word, trg_word):
    w1 = src_word if len(src_word) <= len(trg_word) else trg_word
    w2 = src_word if w1 == trg_word else trg_word
    
    i, j = 0, 0
    w1_align, w2_align = "", ""
    
    while i < len(w1) and j < len(w2):
        if dp[i][j] == dp[i][j + 1] + 1:
#             print(f'Inserting {w2[j]} in {w1}')
            w1_align += '+'
            w2_align += w2[j]
            j += 1
        
        elif dp[i][j] == dp[i + 1][j] + 1:
#             print(f'Deleting {w1[i]} from {w1}')
            w1_align += w1[i]
            w2_align += '-'
            i += 1
        
        elif dp[i][j] == dp[i + 1][j + 1]:
#             print(f'Copying {w1[i]}')
            w1_align += w1[i]
            w2_align += w2[j]
            i += 1
            j += 1

        elif dp[i][j] == dp[i + 1][j + 1] + 1:
#             print(f'Subbing {w1[i]} with {w2[j]}')
            w1_align += w1[i]
            w2_align += w2[j]
            i += 1
            j += 1
    
    assert len(w1_align) <= len(w2_align)
    
    for k in range(j, len(w2)):
#         print(f'inserting {w2[k]} in {w1}')
        w1_align += '+'
        w2_align += w2[k]
    
    assert len(w1_align) <= len(w2_align)
    
    if w1 == src_word and w2 == trg_word:
        src_align, trg_align = w1_align, w2_align
    
    elif w2 == src_word and w1 == trg_word:
        src_align, trg_align = w2_align, w1_align

    return src_align, trg_align

In [711]:
def get_rule(src_align, trg_align):
    assert len(src_align) == len(trg_align)
    src_pattern = ""
    trg_pattern = ""
    for i in range(len(src_align)):
        if src_align[i] == trg_align[i]:
            src_pattern += 'X'
            trg_pattern += 'X'
        else:
            src_pattern += src_align[i]
            trg_pattern += trg_align[i]
    return (src_pattern, trg_pattern)

In [712]:
train_data_src = read_data('new_tokens_data/train.arin.tokens')
train_data_trg_mm = read_data('new_tokens_data/train.ar.MM.tokens')
train_data_trg_fm = read_data('new_tokens_data/train.ar.FM.tokens')
train_data_trg_mf = read_data('new_tokens_data/train.ar.MF.tokens')
train_data_trg_ff = read_data('new_tokens_data/train.ar.FF.tokens')


tokenized_sents_src = collate_stentences(train_data_src)
tokenized_sents_trg_mm = collate_stentences(train_data_trg_mm)
tokenized_sents_trg_mf = collate_stentences(train_data_trg_mf)
tokenized_sents_trg_fm = collate_stentences(train_data_trg_fm)
tokenized_sents_trg_ff = collate_stentences(train_data_trg_ff)

assert len(tokenized_sents_src) == len(tokenized_sents_trg_mm)
assert len(tokenized_sents_src) == len(tokenized_sents_trg_fm)
assert len(tokenized_sents_src) == len(tokenized_sents_trg_mf)
assert len(tokenized_sents_src) == len(tokenized_sents_trg_ff)

In [713]:
class Rule:
    def __init__(self, src_word, trg_word,
                 edit_dist, rule):
        
        self.src_word = src_word
        self.trg_word = trg_word
        self.edit_dist = edit_dist
        self.rule = rule

    def __repr__(self):
        return str(self.to_json_str())

    def to_json_str(self):
        return json.dumps(self.to_dict(), indent=2, ensure_ascii=False)

    def to_dict(self):
        output = copy.deepcopy(self.__dict__)
        return output

In [716]:
def generate_rules(src_sents, trg_sents):
    rules_tag_src_tgt = defaultdict(lambda: defaultdict(lambda: 0)) 
    rules_tag_src = dict()
    rules_probs = defaultdict(lambda: defaultdict(lambda: 0)) # P(tgt_pattrn | tgt_tag, src_pattrn)
    src_tag_probs = defaultdict(lambda: defaultdict(lambda: 0)) # P(src_pattrn | src_tag)
    src_tag_counts_probs = dict() # P(src_tag)
                    
    for src_sent, trg_sent in zip(src_sents, trg_sents):
        for token_1, token_2, tag_1, tag_2 in zip(src_sent.tokens, trg_sent.tokens, src_sent.tags, trg_sent.tags):
            if token_1 != token_2:
                assert tag_1 != tag_2

#                 print(f'Token 1: {token_1}')
#                 print(f'Token 2: {token_2}')
                edit_distance_table = edit_distance(src_word=token_1, trg_word=token_2)
#                 print(f'Edit Dist: {edit_distance_table[0][0]}')
#                 print(f'Backtracking:')
                
                src_align, trg_align = backtrack_dp(edit_distance_table, src_word=token_1, trg_word=token_2)
                src_pattern, trg_pattern = get_rule(src_align, trg_align)

#                 print(f'Rule: {(src_pattern, trg_pattern)}')
#                 print(f'Condensed Rule: {(r1, r2)}')
#                 print('================')


                rules_tag_src_tgt[(tag_2, src_pattern)][trg_pattern] += 1
                rules_tag_src[(tag_2, src_pattern)] = 1 + rules_tag_src.get((tag_2, src_pattern), 0)

                src_tag_probs[src_pattern][tag_1] += 1
                src_tag_counts_probs[tag_1] = src_tag_counts_probs.get(tag_1, 0) + 1
        
    
    # turning the counts into log probs
    for tgt_g, src_pttrn in rules_tag_src_tgt:
        for tgt_pttrn in rules_tag_src_tgt[(tgt_g, src_pttrn)]:
            rules_probs[(tgt_g, src_pttrn)][tgt_pttrn] = math.log10(rules_tag_src_tgt[(tgt_g, src_pttrn)][tgt_pttrn] / 
                                                              float(rules_tag_src[(tgt_g, src_pttrn)]))
    
#     count_src_pttrns = sum([v for k, v in src_pattrn_probs.items()])
    for src_pttrn in src_tag_probs:
        for src_tag in src_tag_probs[src_pttrn]:
            src_tag_probs[src_pttrn][src_tag] = math.log10(src_tag_probs[src_pttrn][src_tag] /
                                                              src_tag_counts_probs[src_tag])
        
    return rules_probs, src_tag_probs


In [717]:
rules_probs_mm, src_pattrn_probs_mm = generate_rules(src_sents=tokenized_sents_src,
                                                     trg_sents=tokenized_sents_trg_mm)

rules_probs_mf, src_pattrn_probs_mf = generate_rules(src_sents=tokenized_sents_src,
                                                     trg_sents=tokenized_sents_trg_mf)

rules_probs_fm, src_pattrn_probs_fm = generate_rules(src_sents=tokenized_sents_src,
                                                     trg_sents=tokenized_sents_trg_fm)

rules_probs_ff, src_pattrn_probs_ff = generate_rules(src_sents=tokenized_sents_src,
                                                     trg_sents=tokenized_sents_trg_ff)


In [718]:
print('***** Target MM Rules map *****')
for key in rules_probs_mm:
    trg_gender, src_pattern = key
    print(f'{key} : {dict(rules_probs_mm[key])}')
    print(src_pattrn_probs_mm[src_pattern])
    break

***** Target MM Rules map *****
('1M+B', 'XXXXXة') : {'XXXXX+': -0.13199016723600726, 'XXXXXا': -0.5840007013726386, 'XXXXXأ': -2.8344207036815328}
defaultdict(<function generate_rules.<locals>.<lambda>.<locals>.<lambda> at 0x7fa51b4baaf0>, {'1F+B': -0.7084047232776474, '2F+B': -1.5332922578057235})


In [719]:
rules_count_mm = sum([len(val) for key, val in rules_probs_mm.items()])
rules_count_mf = sum([len(val) for key, val in rules_probs_mf.items()])
rules_count_fm = sum([len(val) for key, val in rules_probs_fm.items()])
rules_count_ff = sum([len(val) for key, val in rules_probs_ff.items()])

In [720]:
print(f'There are {rules_count_mm} in the input to target MM RBR model')
print(f'There are {rules_count_mf} in the input to target FM RBR model')
print(f'There are {rules_count_fm} in the input to target MF RBR model')
print(f'There are {rules_count_ff} in the input to target FF RBR model')

There are 889 in the input to target MM RBR model
There are 890 in the input to target FM RBR model
There are 890 in the input to target MF RBR model
There are 892 in the input to target FF RBR model


In [721]:
str2 = 'أسود'
str1 = 'سوداء'
dp = edit_distance(src_word=str1, trg_word=str2)

src_align, trg_align = backtrack_dp(dp, str1, str2)
print('')
print(src_align)
print(trg_align)


-سوداء
أسود++


In [722]:
src_pattern, trg_pattern= get_rule(src_align, trg_align)
print(src_pattern)
print(trg_pattern)

-XXXاء
أXXX++


In [726]:
def match_rule(word, src_gender, trg_gender, rules_probs, src_pattrn_probs):
    """
    Returns all the rules that match the src token pattern and the target gender
    sorted by their frequency
    """
    matched_rules = []
    for rule in rules_probs:
        target_gender, src_pattern = rule

        pattern = src_pattern.replace('+', '').replace('-','').replace('X', '(.)')
        
        match = re.match(pattern, word)
        
        # matching on the pattern and the target gender
        if match and match[0] == word and trg_gender == target_gender:

            matched_rules.append({'src_word': word,
                                  'src_pattern': src_pattern, 
                                  'trg_gender': target_gender,
                                  'src_regex': pattern,
                                  'targets': dict(rules_probs[rule]),
                                  'src_pattern_prob': src_pattrn_probs[src_pattern][src_gender]})
        
        
    return sorted(matched_rules, key=lambda x: x['src_pattern_prob'], reverse=True)

In [727]:
def generate_token(target_rule, src_word, src_pattern):
    """
    Generates a token given a src word, src pattern and target pattern
    """
    tgt_pattern = target_rule.replace('+', '').replace('-','')
    x_count = 1
    while 'X' in tgt_pattern:
        tgt_pattern = tgt_pattern.replace('X', f'\\{x_count}', 1)
        x_count += 1
    # generate target words
    trg_word = re.sub(src_pattern, tgt_pattern, src_word)
    return trg_word

In [728]:
def generate_using_rule(matched_rules, pick_top_overall_rule=False, pick_top_target_rule=False):
    """
    Returns generated token(s) given the matched rules.
    
    Note: we have two forms for rules to pick from:
          1) pick_top_overall_rule: Selects the (src_pattern, trg_gender) rule that
             occured the most in the training data
          2) pick_top_target_rule: Selects the target_pattern that appeared the most
             for a given (src_pattern, trg_gender). Because one (src_pattern, trg_gender)
             could have multiple target patterns
            
          We also do not apply any pattern that appeared only 1 during training to reduce
          noisy outputs.
    """
    generated_tokens = dict()
    if pick_top_overall_rule:
        matched_rule = max(matched_rules, key=lambda x: x['src_pattern_prob'])
        if pick_top_target_rule:
            tgt_rule, tgt_rule_prob = max(matched_rule['targets'].items(), key=lambda x: x[1])

            generated_token = generate_token(tgt_rule, matched_rule['src_word'],
                                             matched_rule['src_regex'])

            log_prob_kenlm = model.score(" ".join(generated_token))

            generated_tokens[generated_token] = log_prob_kenlm + tgt_rule_prob + matched_rule['src_pattern_prob']

        else:
            generated_tokens = []

            for tgt_rule, tgt_rule_prob in matched_rule['targets'].items():

                generated_token = generate_token(tgt_rule, matched_rule['src_word'],
                                                 matched_rule['src_regex'])

                log_prob_kenlm = model.score(" ".join(generated_token))

                generated_tokens[generated_token] = log_prob_kenlm + tgt_rule_prob + matched_rule['src_pattern_prob']

    else:
        for matched_rule in matched_rules:
            if pick_top_target_rule:
                tgt_rule, tgt_rule_prob = max(matched_rule['targets'].items(), key=lambda x: x[1])

                generated_token = generate_token(tgt_rule, matched_rule['src_word'],
                                                 matched_rule['src_regex'])

                log_prob_kenlm = model.score(" ".join(generated_token))

                generated_tokens[generated_token] = log_prob_kenlm + tgt_rule_prob + matched_rule['src_pattern_prob']

            else:
                for tgt_rule, tgt_rule_prob in matched_rule['targets'].items():

                    generated_token = generate_token(tgt_rule, matched_rule['src_word'],
                                                     matched_rule['src_regex'])
                
                    log_prob_kenlm = model.score(" ".join(generated_token))

                    generated_tokens[generated_token] = log_prob_kenlm + tgt_rule_prob + matched_rule['src_pattern_prob']

    return generated_tokens

P(tgt_pattern | src_pattern, tgt_gender)
P(src_pattern)

log P(src_pattern) + log P(tgt_pattern | src_pattern, tgt_gender) + log p_lm(tgt_word)

P(src_pattern | src_gender)?

In [734]:
matched_rules = match_rule(word='صديقة',
                           src_gender='2F+B',
                           trg_gender='2M+B',
                           rules_probs=rules_probs_mm,
                           src_pattrn_probs=src_pattrn_probs_mm)
matched_rules

[{'src_word': 'صديقة',
  'src_pattern': 'XXXXة',
  'trg_gender': '2M+B',
  'src_regex': '(.)(.)(.)(.)ة',
  'targets': {'XXXX+': -0.10905672430265556,
   'XXXXا': -0.6618986929604364,
   'XXXXه': -2.5483894181329183,
   'XXXXي': -2.8494194137968996},
  'src_pattern_prob': -1.363300740620943},
 {'src_word': 'صديقة',
  'src_pattern': 'XXيXX',
  'trg_gender': '2M+B',
  'src_regex': '(.)(.)ي(.)(.)',
  'targets': {'XX+XX': -0.01639041618816937, 'XXاXX': -1.4313637641589874},
  'src_pattern_prob': -1.9362583502445982},
 {'src_word': 'صديقة',
  'src_pattern': 'XXXX+X',
  'trg_gender': '2M+B',
  'src_regex': '(.)(.)(.)(.)(.)',
  'targets': {'XXXXوX': -0.025235823416073402,
   'XXXXتX': -1.4913616938342726,
   'XXXXيX': -1.6163004304425725},
  'src_pattern_prob': -2.1192984692556074},
 {'src_word': 'صديقة',
  'src_pattern': 'XXXXX+',
  'trg_gender': '2M+B',
  'src_regex': '(.)(.)(.)(.)(.)',
  'targets': {'XXXXXة': -0.36797678529459443,
   'XXXXXم': -0.5440680443502757,
   'XXXXXي': -0.5440680443

In [735]:
if matched_rules:
    generated_tokens = generate_using_rule(matched_rules, pick_top_overall_rule=False,
                                           pick_top_target_rule=False)

    
sorted(generated_tokens.items(), key=lambda x: x[1], reverse=True)

[('صديق', -6.961509896411636),
 ('صدقة', -8.274114735518461),
 ('صديقه', -9.37475912085103),
 ('صدق', -9.378628014220087),
 ('صديقا', -9.412446995989827),
 ('صديقي', -9.524039028776974),
 ('صديقوة', -10.249052229378224),
 ('صديقية', -11.195691444253844),
 ('صديقتة', -11.202252474949011),
 ('صداقة', -11.468139387352805),
 ('صديوقة', -14.85277304825775),
 ('صديقنوة', -16.066170929564326),
 ('صديقةة', -16.707800247049253),
 ('صدأيقة', -16.917578143983352),
 ('صديقةي', -17.155691547303665),
 ('صدييوقة', -17.275850696473587),
 ('صدييقوة', -17.372599838866083),
 ('صديقةم', -17.556175632386672),
 ('صديقةين', -18.492648361815302),
 ('صديقةوا', -18.614831207884638)]

In [736]:
if matched_rules:
    generated_tokens = generate_using_rule(matched_rules, pick_top_overall_rule=False,
                                           pick_top_target_rule=True)

    
sorted(generated_tokens.items(), key=lambda x: x[1], reverse=True)

[('صديق', -6.961509896411636),
 ('صدقة', -8.274114735518461),
 ('صدق', -9.378628014220087),
 ('صديقوة', -10.249052229378224),
 ('صديوقة', -14.85277304825775),
 ('صديقنوة', -16.066170929564326),
 ('صديقةة', -16.707800247049253),
 ('صدأيقة', -16.917578143983352),
 ('صدييوقة', -17.275850696473587),
 ('صدييقوة', -17.372599838866083),
 ('صديقةوا', -18.614831207884638)]

#### Getting stats on the frequency of the rules
#### How many rules appeared only once, twice, 3 times, etc.?

In [38]:
def get_stats(rules):
    # how many rules appeard once, twice, 3 times,...?
    rules_by_count = defaultdict(list)
    for key, val in rules.items():
        for rule, count in val.items():
            rules_by_count[count].append((key, rule))
    
    rules_by_count_lengths = {k: len(v) for k, v in rules_by_count.items()}
    rules_by_count_lengths_sorted = sorted(rules_by_count_lengths, key=rules_by_count_lengths.get, reverse=True)
    for k in rules_by_count_lengths_sorted:
        print(f'{rules_by_count_lengths[k]} rule(s) appeared {k} times')

In [39]:
get_stats(new_rules_mm)

404 rule(s) appeared 1 times
120 rule(s) appeared 2 times
60 rule(s) appeared 4 times
59 rule(s) appeared 3 times
41 rule(s) appeared 5 times
13 rule(s) appeared 12 times
13 rule(s) appeared 8 times
13 rule(s) appeared 11 times
12 rule(s) appeared 9 times
10 rule(s) appeared 7 times
9 rule(s) appeared 13 times
7 rule(s) appeared 14 times
7 rule(s) appeared 6 times
6 rule(s) appeared 10 times
6 rule(s) appeared 18 times
5 rule(s) appeared 20 times
4 rule(s) appeared 17 times
4 rule(s) appeared 26 times
4 rule(s) appeared 15 times
3 rule(s) appeared 37 times
3 rule(s) appeared 44 times
2 rule(s) appeared 51 times
2 rule(s) appeared 28 times
2 rule(s) appeared 57 times
2 rule(s) appeared 19 times
2 rule(s) appeared 143 times
2 rule(s) appeared 22 times
2 rule(s) appeared 63 times
2 rule(s) appeared 93 times
2 rule(s) appeared 21 times
2 rule(s) appeared 29 times
1 rule(s) appeared 504 times
1 rule(s) appeared 178 times
1 rule(s) appeared 850 times
1 rule(s) appeared 373 times
1 rule(s) ap