In [1]:
import copy
import json
from collections import Counter

In [2]:
clitics = [('كن', 'كم'),
           ('هم', 'هن'),
           ('كي', 'ك'),
           ('ه', 'ها')]

all_clitics = ['كم',
               'كن',
               'ك',
               'كي',
               'ه',
               'ها',
               'هم',
               'هن']

clitics_genders = {'كم': 'M',
                   'كن': 'F',
                   'ك':  'M',
                   'كي': 'F',
                   'ه':  'M',
                   'ها': 'F',
                   'هم': 'M',
                   'هن': 'F'}

In [3]:
def read_data(path):
    with open(path, mode='r') as f:
        return f.readlines()

In [4]:
def write_data(path, data):
    with open(path, mode='w') as f:
        for ex in data:
            for i, (token, tag) in enumerate(zip(ex.tokens, ex.tags)):
                f.write(ex.id + ' ' + token + ' ' + tag)
                f.write('\n')
            f.write('\n')

In [5]:
class TokenLevalInfo:
    """
    Simple object to save sentence info
    """
    def __init__(self, tokens, tags, sent_label, id):
        self.tokens = tokens
        self.tags = tags
        self.sent_label = sent_label
        self.id = id
    
    def __repr__(self):
        return str(self.to_json_str())

    def to_json_str(self):
        return json.dumps(self.to_dict(), indent=2, ensure_ascii=False)

    def to_dict(self):
        output = copy.deepcopy(self.__dict__)
        return output

In [6]:
def group_original_and_reinflections(data):
    """
    This function takes advantage of the way I created the data.
    It groups the original and reinflection sentences together.
    
    If the id contains a '.' (.1, .2, .3), we know that this sentence is a
    reinflection, so we simply add it with its sisters sentences

    Args:
        - data: list of strings
    
    Returns:
        - all_org_and_re_ids: list of list. Each sublist contains the original 
                          and reinflection ids (when applicable)
                          
        - all_org_and_re_english: list of list of english sentences.
                                 
        - all_org_and_re_arabic: list of list. Each sublist contains the original 
                          and reinflection sentences (when applicable)
        
         - all_org_and_re_labels: list of list. Each sublist contains the original 
                          and reinflection labels (when applicable)
    
    """
    all_org_and_re_ids = []
    all_org_and_re_english = []
    all_org_and_re_arabic = []
    all_org_and_re_labels = []

    for ex in data:
        id = ex[0].strip()
        english = ex[1].strip()
        arabic = ex[2].strip()
        # replace <s> </s>
        arabic = arabic.replace('<s>','').replace('</s>','')
        label = ex[3].strip()
    
        if '.' not in id:
            all_org_and_re_ids.append([id])
            all_org_and_re_english.append([english])
            all_org_and_re_arabic.append([arabic])
            all_org_and_re_labels.append([label])
        else:
            all_org_and_re_ids[-1].append(id)
            all_org_and_re_english[-1].append(english)
            all_org_and_re_arabic[-1].append(arabic)
            all_org_and_re_labels[-1].append(label)
    
    return all_org_and_re_ids, all_org_and_re_english, all_org_and_re_arabic, all_org_and_re_labels

In [7]:
def tag_input_corpus(dataset_path, split='train'):
    ids = read_data(dataset_path+split+'.ids')
    en = read_data(dataset_path+split+'.en')
    ar = read_data(dataset_path+split+'.arin')
    labels = read_data(dataset_path+split+'.arin.label')
    
    # Grouping
    grouped = group_original_and_reinflections(zip(ids, en, ar, labels))
    
    grouped_ids, grouped_en = grouped[0], grouped[1]
    grouped_ar, grouped_labels = grouped[2], grouped[3]
    
    # Tagging
    tokens, tags, all_data = tag_tokens(grouped_ids, grouped_ar, grouped_labels)
    
    assert len(tokens) == len(tags) == len(all_data)
    
#     import pdb; pdb.set_trace()
    print(len(grouped))
    return tokens, tags, all_data

In [8]:
def tag_tokens(grouped_ids, grouped_ar, grouped_labels):
    tokens = []
    tags = []
    tagged_data = []

    for ids, ar_sents, labels in zip(grouped_ids, grouped_ar, grouped_labels): 
        if len(ids) == 1: # BB case, label each token as B
            ar_tokens = [t for t in ar_sents[0].split()]
            tokens.append(ar_tokens)
            tags.append(['B+B']*len(ar_tokens))
            
            tagged_data.append(TokenLevalInfo(tokens=ar_tokens,
                                              tags=['B+B']*len(ar_tokens),
                                              sent_label=labels[0],
                                              id=ids[0]))

        elif len(ids) == 2: # MB/FB or BM/BF cases
            sent_1_tokens = ar_sents[0].split()
            sent_1_label = labels[0]
            sent_2_tokens = ar_sents[1].split()
            sent_2_label = labels[1]
            
            # getting the tags
            sent_1_tags, sent_2_tags = tag_two(sent_1_tokens=sent_1_tokens,
                                               sent_1_label=sent_1_label,
                                               sent_2_tokens=sent_2_tokens,
                                               sent_2_label=sent_2_label)
            
            # adding to tokens and tags
            tokens.append(sent_1_tokens)
            tags.append(sent_1_tags)
    
            tokens.append(sent_2_tokens)
            tags.append(sent_2_tags)
            
            tagged_data.append(TokenLevalInfo(tokens=sent_1_tokens,
                                  tags=sent_1_tags,
                                  sent_label=sent_1_label,
                                  id=ids[0]))

            tagged_data.append(TokenLevalInfo(tokens=sent_2_tokens,
                                  tags=sent_2_tags,
                                  sent_label=sent_2_label,
                                  id=ids[1]))
            
            # checking
            if sent_1_tags.count('B+B') != sent_2_tags.count('B+B'):
                print(ids)
            
            if sent_1_tags.count('1M+B') != sent_2_tags.count('1F+B'):
                print(ids)
            
            if sent_1_tags.count('2M+B') != sent_2_tags.count('2F+B'):
                print(ids)
                
            if sent_1_tags.count('B+M') != sent_2_tags.count('B+F'):
                print(ids)
            
        elif len(ids) == 4: ## MM, FM, MF, FF
            
            sent_1_tokens = ar_sents[0].split()
            sent_1_label = labels[0]
            
            sent_2_tokens = ar_sents[1].split()
            sent_2_label = labels[1]
            
            sent_3_tokens = ar_sents[2].split()
            sent_3_label = labels[2]
            
            sent_4_tokens = ar_sents[3].split()
            sent_4_label = labels[3]
            
            # By construction, we know that sentence 1 and sentence 2
            # differ in 1st person *AND* sentence 1 and sentence 3
            # differ in 2nd person *AND* sentence 2 and 4 differ in 2nd person
            # So we can do multiple passes over the 4 sentences to get their tags
            

#             if ids[0] == "B-11885" or ids[0] == "B-8397":
#                 import pdb; pdb.set_trace()
            sent_1_tags_temp, sent_2_tags_temp = tag_four(sent_1_tokens=sent_1_tokens,
                                                          sent_1_label=sent_1_label,
                                                          sent_1_tags=['B+B'] * len(sent_1_tokens),
                                                          sent_2_tokens=sent_2_tokens,
                                                          sent_2_label=sent_2_label,
                                                          sent_2_tags=['B+B'] * len(sent_2_tokens),
                                                          person=1)
            
            sent_1_gen_map = {i: gen for i, gen in enumerate(sent_1_tags_temp) if gen != 'B+B'}
            sent_2_gen_map = {i: gen for i, gen in enumerate(sent_2_tags_temp) if gen != 'B+B'}
            

            sent_1_tags, sent_3_tags = tag_four(sent_1_tokens=sent_1_tokens,
                                                sent_1_label=sent_1_label,
                                                sent_1_tags=list(sent_1_tags_temp),
                                                sent_2_tokens=sent_3_tokens,
                                                sent_2_label=sent_3_label,
                                                sent_2_tags=list(sent_1_tags_temp),
                                                person=2)

            sent_2_tags, sent_4_tags = tag_four(sent_1_tokens=sent_2_tokens,
                                                sent_1_label=sent_2_label,
                                                sent_1_tags=list(sent_2_tags_temp),
                                                sent_2_tokens=sent_4_tokens,
                                                sent_2_label=sent_4_label,
                                                sent_2_tags=list(sent_2_tags_temp),
                                                person=2)

#             if ids[0] == 'C-1225':
#                 import pdb; pdb.set_trace()
#                 print
        
            for i in sent_1_gen_map:
                if sent_1_gen_map[i] != sent_1_tags[i]:
#                     import pdb; pdb.set_trace()
                    print(ar_sents)
                    print()
                    
#             for i in sent_2_gen_map:
#                 if sent_2_gen_map[i] != sent_2_tags[i]:
#                     import pdb; pdb.set_trace()
#                     print()
                    
                    

            # adding to tokens and tags
            tokens.append(sent_1_tokens)
            tags.append(sent_1_tags)
    
            tokens.append(sent_2_tokens)
            tags.append(sent_2_tags)
            
            tokens.append(sent_3_tokens)
            tags.append(sent_3_tags)
    
            tokens.append(sent_4_tokens)
            tags.append(sent_4_tags)
            
            
            tagged_data.append(TokenLevalInfo(tokens=sent_1_tokens,
                                  tags=sent_1_tags,
                                  sent_label=sent_1_label,
                                  id=ids[0]))

            tagged_data.append(TokenLevalInfo(tokens=sent_2_tokens,
                                  tags=sent_2_tags,
                                  sent_label=sent_2_label,
                                  id=ids[1]))
            
            tagged_data.append(TokenLevalInfo(tokens=sent_3_tokens,
                      tags=sent_3_tags,
                      sent_label=sent_3_label,
                      id=ids[2]))

            tagged_data.append(TokenLevalInfo(tokens=sent_4_tokens,
                      tags=sent_4_tags,
                      sent_label=sent_4_label,
                      id=ids[3]))
            
            # checking
            # number of B's should match accross all 4 sentences
            if (sent_1_tags.count('B+B') != sent_2_tags.count('B+B') != sent_3_tags.count('B+B') != sent_4_tags.count('B+B')):
                print(ids)
                
            # checking
            first_person_f = (sent_1_tags.count('1F+B') + sent_2_tags.count('1F+B') + 
                              sent_3_tags.count('1F+B') + sent_4_tags.count('1F+B') +
                              sent_1_tags.count('1F+1F') + sent_2_tags.count('1F+1F') + 
                              sent_3_tags.count('1F+1F') + sent_4_tags.count('1F+1F') +
                              sent_1_tags.count('1F+1M') + sent_2_tags.count('1F+1M') + 
                              sent_3_tags.count('1F+1M') + sent_4_tags.count('1F+1M') +
                              sent_1_tags.count('1F+2F') + sent_2_tags.count('1F+2F') + 
                              sent_3_tags.count('1F+2F') + sent_4_tags.count('1F+2F') +
                              sent_1_tags.count('1F+2M') + sent_2_tags.count('1F+2M') + 
                              sent_3_tags.count('1F+2M') + sent_4_tags.count('1F+2M'))
            
            first_person_m = (sent_1_tags.count('1M+B') + sent_2_tags.count('1M+B') + 
                              sent_3_tags.count('1M+B') + sent_4_tags.count('1M+B') +
                              sent_1_tags.count('1M+1F') + sent_2_tags.count('1M+1F') +
                              sent_3_tags.count('1M+1F') + sent_4_tags.count('1M+1F') +
                              sent_1_tags.count('1M+1M') + sent_2_tags.count('1M+1M') +
                              sent_3_tags.count('1M+1M') + sent_4_tags.count('1M+1M') +
                              sent_1_tags.count('1M+2F') + sent_2_tags.count('1M+2F') +
                              sent_3_tags.count('1M+2F') + sent_4_tags.count('1M+2F') +
                              sent_1_tags.count('1M+2M') + sent_2_tags.count('1M+2M') +
                              sent_3_tags.count('1M+2M') + sent_4_tags.count('1M+2M'))

            second_person_f = (sent_1_tags.count('2F+B') + sent_2_tags.count('2F+B') + 
                               sent_3_tags.count('2F+B') + sent_4_tags.count('2F+B') +
                               sent_1_tags.count('2F+1F') + sent_2_tags.count('2F+1F') + 
                               sent_3_tags.count('2F+1F') + sent_4_tags.count('2F+1F') +
                               sent_1_tags.count('2F+1M') + sent_2_tags.count('2F+1M') + 
                               sent_3_tags.count('2F+1M') + sent_4_tags.count('2F+1M') +
                               sent_1_tags.count('2F+2F') + sent_2_tags.count('2F+2F') + 
                               sent_3_tags.count('2F+2F') + sent_4_tags.count('2F+2F') +
                               sent_1_tags.count('2F+2M') + sent_2_tags.count('2F+2M') + 
                               sent_3_tags.count('2F+2M') + sent_4_tags.count('2F+2M'))

            second_person_m = (sent_1_tags.count('2M+B') + sent_2_tags.count('2M+B') + 
                               sent_3_tags.count('2M+B') + sent_4_tags.count('2M+B') +
                               sent_1_tags.count('2M+1F') + sent_2_tags.count('2M+1F') + 
                               sent_3_tags.count('2M+1F') + sent_4_tags.count('2M+1F') +
                               sent_1_tags.count('2M+1M') + sent_2_tags.count('2M+1M') + 
                               sent_3_tags.count('2M+1M') + sent_4_tags.count('2M+1M')+ 
                               sent_1_tags.count('2M+2F') + sent_2_tags.count('2M+2F') + 
                               sent_3_tags.count('2M+2F') + sent_4_tags.count('2M+2F') +
                               sent_1_tags.count('2M+2M') + sent_2_tags.count('2M+2M') + 
                               sent_3_tags.count('2M+2M') + sent_4_tags.count('2M+2M'))
            
            
            clitics_m = (sent_1_tags.count('B+1M') + sent_2_tags.count('B+1M') + 
                         sent_3_tags.count('B+1M') + sent_4_tags.count('B+1M') +
                         sent_1_tags.count('B+2M') + sent_2_tags.count('B+2M') + 
                         sent_3_tags.count('B+2M') + sent_4_tags.count('B+2M'))
            
            clitics_f = (sent_1_tags.count('B+1F') + sent_2_tags.count('B+1F') + 
                         sent_3_tags.count('B+1F') + sent_4_tags.count('B+1F') +
                         sent_1_tags.count('B+2F') + sent_2_tags.count('B+2F') + 
                         sent_3_tags.count('B+2F') + sent_4_tags.count('B+2F'))
            
            # number of 1F and 1M should match accross all 4 sentences
            if first_person_f != first_person_m:
                print(f'First Person Problems: {ids}')
            
            # number of 2F and 1M should match accross all 4 sentences          
            if second_person_f != second_person_m:
                print(f'Second Person Problems: {ids}')
            
            if clitics_m != clitics_f:
                print(f'Clitics Problems: {ids}')
    

    return tokens, tags, tagged_data



In [9]:
def replace_last(s, replace_what, replace_with):
    head, _sep, tail = s.rpartition(replace_what)
    return head + replace_with + tail

In [10]:
def tag_two(sent_1_tokens, sent_1_label, sent_2_tokens, sent_2_label):
    sent_1_tags = []
    sent_2_tags = []
    
    for token_1, token_2 in zip(sent_1_tokens, sent_2_tokens):
        clitic_found = False
        for clitic_combo in clitics:
            clitic_tag_1 = 'B'
            clitic_tag_2 = 'B'
            # check if the two tokens have clitics or not
            # checking if the two word end with clitics
            if (token_1.endswith(clitic_combo[0]) and token_2.endswith(clitic_combo[1]) or
                token_1.endswith(clitic_combo[1]) and token_2.endswith(clitic_combo[0])):
                
                token_1_clitic = clitic_combo[0] if token_1.endswith(clitic_combo[0]) else clitic_combo[1]
                token_2_clitic = clitic_combo[0] if token_2.endswith(clitic_combo[0]) else clitic_combo[1]

                token_1_base = replace_last(token_1, token_1_clitic, '')
                token_2_base = replace_last(token_2, token_2_clitic, '')

                # getting the base form gender

                token_tag_1, token_tag_2 = compare_two(token_1_base, sent_1_label,
                                                       token_2_base, sent_2_label)
                
                # getting the clitic number
                clitic_tag_number = '1' if sent_1_label[0] != 'B' else '2'
                
                # getting the clitic gender and adding number to it
                clitic_tag_1 = clitic_tag_number + clitics_genders[token_1_clitic]
                clitic_tag_2 = clitic_tag_number + clitics_genders[token_2_clitic]
                
                clitic_found = True
                break

    
        # if they don't end clitics, then just get the base form gender
        # and tag clitics at B
        if clitic_found == False:
            token_tag_1, token_tag_2 = compare_two(token_1, sent_1_label, token_2, sent_2_label)


        sent_1_tags.append(f'{token_tag_1}+{clitic_tag_1}')
        sent_2_tags.append(f'{token_tag_2}+{clitic_tag_2}')
    
    return sent_1_tags, sent_2_tags

In [11]:
def tag_four(sent_1_tokens, sent_1_label, sent_1_tags, sent_2_tokens, sent_2_label, sent_2_tags, person=1):


    
    for i, (token_1, token_2) in enumerate(zip(sent_1_tokens, sent_2_tokens)):
        if token_1 != token_2:
            
            token_tag_1 = sent_1_tags[i].split('+')[0]
            token_tag_2 = sent_2_tags[i].split('+')[0]
                
                
            clitic_tag_1 = sent_1_tags[i].split('+')[1]
            clitic_tag_2 = sent_2_tags[i].split('+')[1]

            if sent_1_tags[i] == 'B+B' and sent_2_tags[i] == 'B+B':
                clitic_found = False
                for clitic_combo in clitics:

                    # check if the two tokens have clitics or not
                    # checking if the two word end with clitics
                    if (token_1.endswith(clitic_combo[0]) and token_2.endswith(clitic_combo[1]) or
                        token_1.endswith(clitic_combo[1]) and token_2.endswith(clitic_combo[0])):

                        token_1_clitic = clitic_combo[0] if token_1.endswith(clitic_combo[0]) else clitic_combo[1]
                        token_2_clitic = clitic_combo[0] if token_2.endswith(clitic_combo[0]) else clitic_combo[1]

                        token_1_base = replace_last(token_1, token_1_clitic, '')
                        token_2_base = replace_last(token_2, token_2_clitic, '')

                        # getting the base form gender
                        if token_1_base != token_2_base:
                            token_tag_1, token_tag_2 = compare_four(token_1_base, sent_1_label,
                                                                    token_2_base, sent_2_label,
                                                                    person=person)


                        # getting the clitic gender
                        clitic_tag_1 = str(person) + clitics_genders[token_1_clitic]
                        clitic_tag_2 = str(person) + clitics_genders[token_2_clitic]
                        clitic_found = True
                        break


                # if they don't end clitics, then just get the base form gender
                # and tag clitics at B
                if clitic_found == False:
                    token_tag_1, token_tag_2 = compare_four(token_1, sent_1_label, token_2, sent_2_label,
                                                            person=person)



                sent_1_tags[i] = f'{token_tag_1}+{clitic_tag_1}'
                sent_2_tags[i] = f'{token_tag_2}+{clitic_tag_2}'
            
            
            elif sent_1_tags[i] != 'B+B' or sent_2_tags[i] != 'B+B':
#                 import pdb; pdb.set_trace()
                # if the token has been labeled, then the change must have occured
                # either at the base or the clitic or due to a spelling error
                
                # get the part that changed
                clitic_change = True if clitic_tag_1 == clitic_tag_2 == 'B' else False
                base_change = True if token_tag_1 == token_tag_2 == 'B' else False
                clitic_found = False 
            
                assert clitic_change != base_change
                
                if clitic_change:
                    for clitic_combo in clitics:

                        # check if the two tokens have clitics or not
                        # checking if the two word end with clitics
                        if (token_1.endswith(clitic_combo[0]) and token_2.endswith(clitic_combo[1]) or
                            token_1.endswith(clitic_combo[1]) and token_2.endswith(clitic_combo[0])):
                            clitic_found = True
                            token_1_clitic = clitic_combo[0] if token_1.endswith(clitic_combo[0]) else clitic_combo[1]
                            token_2_clitic = clitic_combo[0] if token_2.endswith(clitic_combo[0]) else clitic_combo[1]
                            
                            clitic_tag_1 = str(person) + clitics_genders[token_1_clitic]
                            clitic_tag_2 = str(person) + clitics_genders[token_2_clitic]
                            clitic_found = True
                            break
                
                
                elif base_change:
                    for clitic in all_clitics:
                        if token_1.endswith(clitic) and token_2.endswith(clitic):
                            token_1_base = replace_last(token_1, clitic, '')
                            token_2_base = replace_last(token_2, clitic, '')
                            token_tag_1, token_tag_2 = compare_four(token_1, sent_1_label, token_2, sent_2_label,
                                                                    person=person)
                
                # no clitic has been found and the change is not in the base,
                # then the change must be due to a spelling error
                if clitic_found == False and base_change == False:
                    import pdb; pdb.set_trace()
                sent_1_tags[i] = f'{token_tag_1}+{clitic_tag_1}'
                sent_2_tags[i] = f'{token_tag_2}+{clitic_tag_2}'
    
    return sent_1_tags, sent_2_tags

In [12]:
def compare_one(token, sent_label, person):
    token_tag = None
    if sent_label == 'MM' and person == '1':
        token_tag = "1M"
    elif sent_label == 'MM' and person == '2':
        token_tag = "2M" 
    
    elif sent_label == 'FM' and person == '1':
        token_tag = "1F"
    elif sent_label == 'FM' and person == '2':
        token_tag = "2M"
    
    elif sent_label == 'MF' and person == '1':
        token_tag = "1M"
    elif sent_label == 'MF' and person == '2':
        token_tag = "2F"
        
    elif sent_label == 'FF' and person == '1':
        token_tag = "1F"
    elif sent_label == 'FF' and person == '2':
        token_tag = "2F"
    
    return token_tag

In [13]:
def compare_two(token_1, sent_1_label, token_2, sent_2_label):
    token_1_tag = None
    token_2_tag = None
    if token_1 == token_2:
        token_1_tag = 'B'
        token_2_tag = 'B'
    else:
        if sent_1_label == 'FB' and sent_2_label == 'MB':
            token_1_tag = '1F'
            token_2_tag = '1M'
        elif sent_1_label == 'MB' and sent_2_label == 'FB':
            token_1_tag = '1M'
            token_2_tag = '1F'
        if sent_1_label == 'BF' and sent_2_label == 'BM':
            token_1_tag = '2F'
            token_2_tag = '2M'
        elif sent_1_label == 'BM' and sent_2_label == 'BF':
            token_1_tag = '2M'
            token_2_tag = '2F'
        
    return token_1_tag, token_2_tag

In [14]:
def compare_four(token_1, sent_1_label, token_2, sent_2_label, person=1):

    token_1_tag = None
    token_2_tag = None

    # MM case
    if sent_1_label == 'MM' and sent_2_label == 'MF':
        if person == 1:
            token_1_tag = '1M'
            token_2_tag = '1M'  
        elif person == 2:
            token_1_tag ='2M'
            token_2_tag ='2F'

    elif sent_1_label == 'MF' and sent_2_label == 'MM':
        if person == 1:
            token_1_tag ='1M'
            token_2_tag ='1M'
        elif person == 2:
            token_1_tag ='2F'
            token_2_tag ='2M'

    elif sent_1_label == 'MM' and sent_2_label == 'FM':
        if person == 1:
            token_1_tag ='1M'
            token_2_tag ='1F'
        elif person == 2:
            token_1_tag ='2M'
            token_2_tag ='2M'

    elif sent_1_label == 'FM' and sent_2_label == 'MM':
        if person == 1:
            token_1_tag ='1F'
            token_2_tag ='1M'
        elif person == 2:
            token_1_tag ='2M'
            token_2_tag ='2M'

    elif sent_1_label == 'MM' and sent_2_label == 'FF':
        if person == 1:
            token_1_tag ='1M'
            token_2_tag ='1F'
        elif person == 2:
            token_1_tag ='2M'
            token_2_tag ='2F'

    elif sent_1_label == 'FF' and sent_2_label == 'MM':
        if person == 1:
            token_1_tag ='1F'
            token_2_tag ='1M'
        elif person == 2:
            token_1_tag ='2F'
            token_2_tag ='2M'

    # MF case
    elif sent_1_label == 'MF' and sent_2_label == 'FM':
        if person == 1:
            token_1_tag ='1M'
            token_2_tag ='1F'
        elif person == 2:
            token_1_tag ='2F'
            token_2_tag ='2M'

    elif sent_1_label == 'FM' and sent_2_label == 'MF':
        if person == 1:
            token_1_tag ='1F'
            token_2_tag ='1M'
        elif person == 2:
            token_1_tag ='2M'
            token_2_tag ='2F'

    elif sent_1_label == 'MF' and sent_2_label == 'FF':
        if person == 1:
            token_1_tag ='1M'
            token_2_tag ='1F'
        elif person == 2:
            token_1_tag ='2F'
            token_2_tag ='2F'

    elif sent_1_label == 'FF' and sent_2_label == 'MF':
        if person == 1:
            token_1_tag ='1F'
            token_2_tag ='1M'
        elif person == 2:
            token_1_tag ='2F'
            token_2_tag ='2F'

    # FM case
    elif sent_1_label == 'FM' and sent_2_label == 'FF':
        if person == 1:
            token_1_tag ='1F'
            token_2_tag ='1F'
        elif person == 2:
            token_1_tag ='2M'
            token_2_tag ='2F'

    elif sent_1_label == 'FF' and sent_2_label == 'FM':
        if person == 1:
            token_1_tag ='1F'
            token_2_tag ='1F'
        elif person == 2:
            token_1_tag ='2F'
            token_2_tag ='2M'
            
    return token_1_tag, token_2_tag

In [15]:
def tag_balanced_corpora(tagged_input):
    target_mm_tagged = []
    target_fm_tagged = []
    target_mf_tagged = []
    target_ff_tagged = []
    
    for ex in tagged_input:
        if ex.sent_label == 'BB':
            target_mm_tagged.append(ex)
            target_fm_tagged.append(ex)
            target_mf_tagged.append(ex)
            target_ff_tagged.append(ex)
        
        elif ex.sent_label == 'MB':
            target_mm_tagged.append(ex)
            target_mm_tagged.append(ex)
            target_mf_tagged.append(ex)
            target_mf_tagged.append(ex)
        
        elif ex.sent_label == 'FB':
            target_ff_tagged.append(ex)
            target_ff_tagged.append(ex)
            target_fm_tagged.append(ex)
            target_fm_tagged.append(ex)
            
        elif ex.sent_label == 'BM':
            target_mm_tagged.append(ex)
            target_mm_tagged.append(ex)
            target_fm_tagged.append(ex)
            target_fm_tagged.append(ex)
        
        elif ex.sent_label == 'BF':
            target_ff_tagged.append(ex)
            target_ff_tagged.append(ex)
            target_mf_tagged.append(ex)
            target_mf_tagged.append(ex)
            
        elif ex.sent_label == 'MM':
            target_mm_tagged.append(ex)
            target_mm_tagged.append(ex)
            target_mm_tagged.append(ex)
            target_mm_tagged.append(ex)
            
        elif ex.sent_label == 'FM':
            target_fm_tagged.append(ex)
            target_fm_tagged.append(ex)
            target_fm_tagged.append(ex)
            target_fm_tagged.append(ex)
            
        elif ex.sent_label == 'MF':
            target_mf_tagged.append(ex)
            target_mf_tagged.append(ex)
            target_mf_tagged.append(ex)
            target_mf_tagged.append(ex)
        
        elif ex.sent_label == 'FF':
            target_ff_tagged.append(ex)
            target_ff_tagged.append(ex)
            target_ff_tagged.append(ex)
            target_ff_tagged.append(ex)
    
    assert len(target_mm_tagged) == len(target_fm_tagged) == len(target_mf_tagged) \
            == len(target_ff_tagged) == len(tagged_input)
        
        
    return target_mm_tagged, target_fm_tagged, target_mf_tagged, target_ff_tagged

In [16]:
split = 'train'
tokens, tags, tagged_all_data = tag_input_corpus('/Users/ba63/Desktop/repos/apgcv2.0-internal/'\
                                                    'Arabic-parallel-gender-corpus-v-2.0/'+split+'/', split=split)

write_data(path='new_tokens_data/' + split +'.arin.tokens.new',
           data=tagged_all_data)

target_mm_tagged, target_fm_tagged, target_mf_tagged, target_ff_tagged = tag_balanced_corpora(tagged_all_data)

write_data(path='new_tokens_data/' + split +'.ar.MM.tokens.new',
           data=target_mm_tagged)

write_data(path='new_tokens_data/' + split +'.ar.FM.tokens.new',
           data=target_fm_tagged)

write_data(path='new_tokens_data/' + split +'.ar.MF.tokens.new',
           data=target_mf_tagged)

write_data(path='new_tokens_data/' + split +'.ar.FF.tokens.new',
          data=target_ff_tagged)

Counter([x for s in tags for x in s])

['انا ( ستانلي ) ، مرشدكم الشجاع إلى عالما رائعا من الارتجال', 'انا ( ستانلي ) ، مرشدتكم الشجاعة إلى عالما رائعا من الارتجال', 'انا ( ستانلي ) ، مرشدكن الشجاع إلى عالما رائعا من الارتجال', 'انا ( ستانلي ) ، مرشدتكن الشجاعة إلى عالما رائعا من الارتجال']

['و أنا هذا الفتى الذي تريده ؟ هذه غلطة ( كلاهان )', 'و أنا هذه الفتاة التي تريدها ؟ هذه غلطة ( كلاهان )', 'و أنا هذا الفتى الذي تريدينه ؟ هذه غلطة ( كلاهان )', 'و أنا هذه الفتاة التي تريدينها ؟ هذه غلطة ( كلاهان )']

4


Counter({'B+B': 385693,
         '1F+B': 3490,
         '1M+B': 3490,
         '2M+B': 16320,
         '2F+B': 16320,
         'B+2M': 1042,
         'B+2F': 1042,
         '2M+2M': 22,
         '2F+2F': 22,
         '1M+1M': 9,
         '1F+1F': 9,
         'B+1M': 28,
         'B+1F': 28,
         '1M+2M': 1,
         '1F+2M': 1,
         '1M+2F': 1,
         '1F+2F': 1,
         '2M+1M': 1,
         '2M+1F': 1,
         '2F+1M': 1,
         '2F+1F': 1})

In [99]:
split = 'dev'
tokens, tags, tagged_all_data = tag_input_corpus('/Users/ba63/Desktop/repos/apgcv2.0-internal/'\
                                                    'Arabic-parallel-gender-corpus-v-2.0/'+split+'/', split=split)

write_data(path='new_tokens_data/' + split +'.arin.tokens.new',
           data=tagged_all_data)

target_mm_tagged, target_fm_tagged, target_mf_tagged, target_ff_tagged = tag_balanced_corpora(tagged_all_data)

write_data(path='new_tokens_data/' + split +'.ar.MM.tokens.new',
           data=target_mm_tagged)

write_data(path='new_tokens_data/' + split +'.ar.FM.tokens.new',
           data=target_fm_tagged)

write_data(path='new_tokens_data/' + split +'.ar.MF.tokens.new',
           data=target_mf_tagged)

write_data(path='new_tokens_data/' + split +'.ar.FF.tokens.new',
          data=target_ff_tagged)

Counter([x for s in tags for x in s])

4


Counter({'B+B': 44629,
         '2M+B': 1787,
         '2F+B': 1787,
         '1M+B': 422,
         '1F+B': 422,
         'B+2F': 98,
         'B+2M': 98,
         'B+1F': 5,
         'B+1M': 5,
         '2M+2M': 2,
         '2F+2F': 2})

In [100]:
split = 'test'
tokens, tags, tagged_all_data = tag_input_corpus('/Users/ba63/Desktop/repos/apgcv2.0-internal/'\
                                                    'Arabic-parallel-gender-corpus-v-2.0/'+split+'/', split=split)

write_data(path='new_tokens_data/' + split +'.arin.tokens.new',
           data=tagged_all_data)

target_mm_tagged, target_fm_tagged, target_mf_tagged, target_ff_tagged = tag_balanced_corpora(tagged_all_data)

write_data(path='new_tokens_data/' + split +'.ar.MM.tokens.new',
           data=target_mm_tagged)

write_data(path='new_tokens_data/' + split +'.ar.FM.tokens.new',
           data=target_fm_tagged)

write_data(path='new_tokens_data/' + split +'.ar.MF.tokens.new',
           data=target_mf_tagged)

write_data(path='new_tokens_data/' + split +'.ar.FF.tokens.new',
          data=target_ff_tagged)

Counter([x for s in tags for x in s])

4


Counter({'B+B': 108411,
         '2M+B': 4548,
         '2F+B': 4548,
         'B+2M': 279,
         'B+2F': 279,
         '1M+B': 958,
         '1F+B': 958,
         '2M+2M': 8,
         '2F+2F': 8,
         'B+1F': 10,
         'B+1M': 10,
         '1F+1F': 1,
         '1M+1M': 1})