In [None]:
"""Mask/Replace Objects/Directions"""
import json
import os
import itertools
from tqdm import tqdm
import random
import regex
import string
import copy
import stanza  # Use StanfordNLP POS Tagger.

SENTENCE_SPLIT_REGEX = regex.compile(r'(\W+)')
pos_tagger = stanza.Pipeline('en', processors='tokenize,mwt,pos')


def check_dir(d):
    if os.path.isdir(d):
        print(d, '\tEXISTS!')
    else:
        os.mkdir(d)
        print(d, '\tCREATED!')


dataset = 'touchdown'
src_data_dir = f'./raw_{dataset}/'
dst_data_dir = f'../../{dataset}/data/'
check_dir(dst_data_dir)
input_filename_pattern = f'{src_data_dir}/%s.json'

phases = ['train', 'dev', 'test']

mask_token = '[MASK]'
direction_groups = [['front', 'forward'], ['left'], ['right'], ['stop'], ['back']]
directions = list(itertools.chain.from_iterable(direction_groups))
numeric_groups = [
    # cardinal
    ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20',
    'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen', 'twenty'], 
    # ordinal
    ['1st', '2nd', '3rd', '4th', '5th', '6th', '7th', '8th', '9th', '10th', '11th', '12th', '13th', '14th', '15th', '16th', '17th', '18th', '19th', '20th', 
    'first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth', 'ninth', 'tenth', 'eleventh', 'twelfth', 'thirteenth', 'fourteenth', 'fifteenth', 'sixteenth', 'seventeenth', 'eighteenth', 'nineteenth', 'twentieth']
]
numerics = list(itertools.chain.from_iterable(numeric_groups))

In [8]:
def load_data(datafile):
    data = []
    with open(datafile) as f:
        for line in f.readlines():
            data.append(json.loads(line))
    return data


def tokenize_sentence(sentence):
    ''' Break sentence into a list of words and punctuation '''
    toks = []
    for word in [s.strip().lower() for s in SENTENCE_SPLIT_REGEX.split(sentence.strip()) if len(s.strip()) > 0]:
        # Break up any words containing punctuation only, e.g. '!?', unless it is multiple full stops e.g. '..'
        if all(c in string.punctuation for c in word) and not all(c in '.' for c in word):
            toks += list(word)
        else:
            toks.append(word)
    return ' '.join(toks)


def load_objects():
    objs = json.load(open(os.path.join(src_data_dir, f'objects.json')))
    return objs


def find_noun_positions(instr):
    stopwords = directions
    noun_positions = []
    instr = tokenize_sentence(instr)
    doc = pos_tagger(instr)
    tokens = []
    idx = 0
    
    for sentence in doc.sentences:
        for word in sentence.words:
            tokens.append(word.text)
            # print(idx, '\t', word.text, '\t', word.pos)
            if word.pos in ['NOUN', 'PROPN'] and word.text not in stopwords:
                noun_positions.append(idx)
            idx += 1
            
    return tokens, noun_positions


def _process_objects(phase, mask_rate, intermediate, objects=None, replace=False, random_mask=False):
    print('Processing %s-%.2f' % (phase, mask_rate))
    rst = []

    for item_ in tqdm(intermediate):
        item = copy.deepcopy(item_)
        token_group = copy.deepcopy(item['tokens'])
        noun_position_group = item['noun_positions']
        rst_instructions = []
        for tokens, noun_positions in zip(token_group, noun_position_group):
            if random_mask:  # controlled trial
                mask_positions = random.sample(list(range(len(tokens))), int(mask_rate * len(noun_positions)))
            else:
                mask_positions = random.sample(noun_positions, int(mask_rate * len(noun_positions)))
            for pos in mask_positions:
                if replace:
                    tokens[pos] = random.sample(objects, 1)[0]  # randomly select an object
                else:  # mask
                    tokens[pos] = mask_token
            masked_instr = ' '.join(tokens)
            rst_instructions.append(masked_instr)
        item['navigation_text'] = rst_instructions[0]
        del item['tokens']
        del item['noun_positions']
        rst.append(item)

    print('#item:\t', len(rst))
    return rst


def _find_replacing_direction(direction):
    """Find a replacement in the other groups"""
    dg = direction_groups
    for idx, group in enumerate(dg):
        if direction in group:
            break
    cors = list(range(len(direction_groups)))
    cors.remove(idx)
    cands = []
    for i in cors:
        cands = cands + dg[i]
    replacement = random.sample(cands, 1)[0]
    return replacement


def _process_directions(phase, mask_rate, intermediate, replace=False, random_mask=False):
    print('Processing %s-%.2f' % (phase, mask_rate))
    rst = []

    for item_ in tqdm(intermediate):
        item = copy.deepcopy(item_)
        token_group = copy.deepcopy(item['tokens'])
        rst_instructions = []
        for tokens in token_group:
            direction_positions = []
            for pos, token in enumerate(tokens):
                if token in directions:
                    direction_positions.append(pos)
            if random_mask:  # controlled trial
                mask_positions = random.sample(list(range(len(tokens))), int(mask_rate * len(direction_positions)))
            else:
                mask_positions = random.sample(direction_positions, int(mask_rate * len(direction_positions)))
            
            for pos in mask_positions:
                token = tokens[pos]
                if replace:
                    replacement = _find_replacing_direction(token)
                    tokens[pos] = replacement
                else:  # mask
                    tokens[pos] = mask_token
            masked_instr = ' '.join(tokens)
            rst_instructions.append(masked_instr)

        item['navigation_text'] = rst_instructions[0]
        del item['tokens']
        del item['noun_positions']
        rst.append(item)

    print('#item:\t', len(rst))
        
    return rst


def _find_replacing_numeric(numeric):
    """Find a replacement in the same group"""
    for group in numeric_groups:
        if numeric in group:
            while True:
                replacement = random.sample(group, 1)[0]
                if replacement != numeric:
                    break
            break
    return replacement


def _process_numerics(phase, mask_rate, intermediate, replace=False, random_mask=False):
    print('Processing %s-%.2f' % (phase, mask_rate))
    rst = []

    for item_ in tqdm(intermediate):
        item = copy.deepcopy(item_)
        token_group = copy.deepcopy(item['tokens'])
        rst_instructions = []
        for tokens in token_group:
            numeric_positions = []
            for pos, token in enumerate(tokens):
                if token.lower() in numerics:
                    numeric_positions.append(pos)
            if random_mask:  # controlled trial
                mask_positions = random.sample(list(range(len(tokens))), int(mask_rate * len(numeric_positions)))
            else:
                mask_positions = random.sample(numeric_positions, int(mask_rate * len(numeric_positions)))
            
            for pos in mask_positions:
                token = tokens[pos]
                if replace:
                    replacement = _find_replacing_numeric(token.lower())
                    tokens[pos] = replacement
                else:  # mask
                    tokens[pos] = mask_token
            masked_instr = ' '.join(tokens)
            rst_instructions.append(masked_instr)

        item['navigation_text'] = rst_instructions[0]
        del item['tokens']
        rst.append(item)

    print('#item:\t', len(rst))
        
    return rst


def _create_label(setting):
    """
    Create label which is the abbrev of the setting.
    'replace_object' -> 'ro'
    """
    return ''.join(w[0] for w in setting.split('_'))


def _save_result(rst, phase, mask_rate, setting, repeat_idx):
    label = _create_label(setting)
    cur_dst_dir = os.path.join(dst_data_dir, setting, f'{label}{mask_rate:.2f}_{repeat_idx}')
    check_dir(os.path.join(dst_data_dir, setting))
    check_dir(cur_dst_dir)
    with open(os.path.join(cur_dst_dir, f'{phase}.json'), 'w') as fout:
        for item in rst:
            json.dump(item, fout)
            fout.write('\n')


def process_object_and_direction_tokens(phase, data, mask_rates, objects=None, repeat_idx=0):
    print('Processing %s' % phase)
    random.seed(repeat_idx)
    
    intermediate_dir = os.path.join(src_data_dir, 'intermediate')
    check_dir(intermediate_dir)
    intermediate_filename = os.path.join(intermediate_dir, f'{phase}.json')
    if os.path.exists(intermediate_filename):
        intermediate = json.load(open(intermediate_filename, 'r'))
        print('intermediate loaded from %s' % intermediate_filename)
    else:
        print('generating intermediate from scratch.')
        intermediate = []
        for item in tqdm(data):
            instructions = [item['navigation_text']]
            item['tokens'] = []
            item['noun_positions'] = []
            for instr in instructions:
                tokens, noun_positions = find_noun_positions(instr)
                item['tokens'].append(tokens)
                item['noun_positions'].append(noun_positions)
            intermediate.append(item)

        with open(intermediate_filename , 'w') as fout:
            json.dump(intermediate, fout)
    
    # Process by mask_rate
    for mask_rate in mask_rates:
        # OBJECT
        rst = _process_objects(phase, mask_rate, intermediate)
        _save_result(rst, phase, mask_rate, setting='mask_object', repeat_idx=repeat_idx)
        
        rst = _process_objects(phase, mask_rate, intermediate, objects, replace=True)
        _save_result(rst, phase, mask_rate, setting='replace_object', repeat_idx=repeat_idx)
        
        rst = _process_objects(phase, mask_rate, intermediate, random_mask=True)
        _save_result(rst, phase, mask_rate, setting='random_mask_for_object', repeat_idx=repeat_idx)  # controlled trial
        
        # DIRECTION
        rst = _process_directions(phase, mask_rate, intermediate)
        _save_result(rst, phase, mask_rate, setting='mask_direction', repeat_idx=repeat_idx)
                
        rst = _process_directions(phase, mask_rate, intermediate, replace=True)
        _save_result(rst, phase, mask_rate, setting='replace_direction', repeat_idx=repeat_idx)
        
        rst = _process_directions(phase, mask_rate, intermediate, random_mask=True)
        _save_result(rst, phase, mask_rate, setting='random_mask_for_direction', repeat_idx=repeat_idx)  # controlled trial


def process_numeric_tokens(phase, data, mask_rates, repeat_idx=0):
    print('Processing %s' % phase)
    random.seed(repeat_idx)
    
    intermediate_dir = os.path.join(src_data_dir, 'intermediate')
    check_dir(intermediate_dir)
    intermediate_filename = os.path.join(intermediate_dir, f'{phase}_numeric.json')
    if os.path.exists(intermediate_filename):
        intermediate = json.load(open(intermediate_filename, 'r'))
        print('intermediate loaded from %s' % intermediate_filename)
    else:
        print('generating intermediate for the subset containing numeric tokens from scratch.')
        intermediate = []
        subset = []
        numeric_vocab = set(numerics)
        for item in tqdm(data):
            instructions = [item['navigation_text']]
            item['tokens'] = []
            contain_numeric = False
            for instr in instructions:
                instr = tokenize_sentence(instr)
                item['tokens'].append(instr.split())
                if numeric_vocab & set(instr.lower().split()):
                    contain_numeric = True
            if contain_numeric:
                subset.append(item)
                intermediate.append(item)

        with open(intermediate_filename , 'w') as fout:
            json.dump(intermediate, fout)
            print('#item:\t', len(intermediate))

        # store the subset that contains numeric tokens in the instruction
        cur_dst_dir = os.path.join(dst_data_dir, 'numeric_default')
        check_dir(cur_dst_dir)
        with open(os.path.join(cur_dst_dir, f'{phase}.json'), 'w') as fout:
            for item in subset:
                json.dump(item, fout)
                fout.write('\n')

    print('#input_item:\t', len(intermediate))

    # Process by mask_rate
    for mask_rate in mask_rates:
        rst = _process_numerics(phase, mask_rate, intermediate)
        _save_result(rst, phase, mask_rate, setting='mask_numeric', repeat_idx=repeat_idx)
        
        rst = _process_numerics(phase, mask_rate, intermediate, replace=True)
        _save_result(rst, phase, mask_rate, setting='replace_numeric', repeat_idx=repeat_idx)
        
        rst = _process_numerics(phase, mask_rate, intermediate, random_mask=True)
        _save_result(rst, phase, mask_rate, setting='random_mask_for_numeric', repeat_idx=repeat_idx)  # controlled trial

In [None]:
"""Process Numeric Tokens"""
for phase in phases:
    data = load_data(input_filename_pattern % phase)
    mask_rates = [1.] # [0.2 * i for i in range(6)]
    for i in range(5):  # repeat 5 times
        process_numeric_tokens(phase, data, mask_rates, repeat_idx=i)

In [None]:
"""Process Object Tokens and Direction Tokens"""
for phase in phases:
    objects = load_objects()
    data = load_data(input_filename_pattern % phase)
    mask_rates = [1.] # [0.2 * i for i in range(6)]
    for i in range(5):  # repeat 5 times
        process_object_and_direction_tokens(phase, data, mask_rates, objects, repeat_idx=i)