# Augmentation by parapharsing

## Init & Load Seed Data

In [None]:
import json, openai
from tqdm import tqdm 

In [None]:
DOMAIN = "drone-planning/"
# DOMAIN = "clean-up/"
# DOMAIN = "pick-and-place/"
with open(DOMAIN + "train_seed.jsonl") as f:
    train_seed = [json.loads(line) for line in f]

In [None]:
eng_seeds = {
    seed['natural']: [] for seed in train_seed
}

## Augmentation Code
prompting GPT-3 seems to work the best in this case

In [None]:
# You need to set your OPENAI API key here
# https://beta.openai.com/account/api-keys
openai.api_key = "TO_BE_SET"

In [None]:
def normalize(sentence):
    # captialize first letter and add period at the end if not present
    if sentence[0].islower():
        sentence = sentence[0].upper() + sentence[1:]
    if sentence[-1] != '.':
        sentence = sentence + '.'
    return sentence

def parse_sentences_from_response(response):
    lines = response.split('\n')
    # assert len(lines) == 5
    assert len(lines) == 10
    lines[0] = "1." + lines[0]
    paraphrases = []
    for idx, line in enumerate(lines):
        assert line.startswith(str(idx+1) + '. ')
        sentence_start_idx = len(str(idx+1) + '. ')
        paraphrases.append(line[sentence_start_idx:])
    for paraphrase in paraphrases:
        if paraphrase[-1] == ' ':
            if paraphrase[-2] == '.':
                paraphrase = paraphrase[:-1]
            else:
                paraphrase = paraphrase[:-2] + '.'
    return paraphrases


PROMPT = """Rephrase the source sentence in 10 different ways. Make the outputs as diverse as possible.

Source: 
SOURCE-TO-BE-PLACED

Outputs:
1."""
def rephrase_a_sentence(sentence):
    response = openai.Completion.create(
        model="text-davinci-002",
        prompt=PROMPT.replace("SOURCE-TO-BE-PLACED", normalize(sentence)),
        temperature=0.7,
        max_tokens=512,
        top_p=1,
        best_of=1,
        frequency_penalty=0.1,
        presence_penalty=0
        )
    output = response['choices'][0]['text']
    try:
        paraphrases = parse_sentences_from_response(output)
    except:
        print("Error in parsing response")
        print(output)
        return output, "ERROR"
    return parse_sentences_from_response(output)

In [None]:
O = rephrase_a_sentence("Go to the red room or go to the green room to finally go to the blue room.")

In [None]:
O

## Run Augmentation

In [None]:
len(eng_seeds)

In [None]:
list(eng_seeds.keys())[0]

In [None]:
def paraphrase_done(eng_seeds):
    for eng_seed, extended in tqdm(eng_seeds.items()):
        if len(extended) == 0:
            return False
    return True

while not paraphrase_done(eng_seeds):
    for eng_seed, extended in tqdm(eng_seeds.items()):
        if len(extended) == 0:
            extended += rephrase_a_sentence(eng_seed)

In [None]:
eng_seeds

### Dump as Training Data

In [None]:
train_seed[0]

In [None]:
with open(DOMAIN + "syn-aug.train.jsonl", 'w') as f:
    for seed in train_seed:
        f.write(json.dumps(seed) + '\n')
        for aug_eng in eng_seeds[seed['natural']]:
                f.write(json.dumps({
                    'natural': aug_eng,
                    'canonical': seed['canonical'],
                    'formula': seed['formula']
                }) + '\n')

In [None]:
with open(DOMAIN + "syn.train.jsonl", 'w') as f:
    for seed in train_seed:
        f.write(json.dumps(seed) + '\n')

### Normalize the natural language form 

In [None]:
if DOMAIN == "clean-up/":
    # in clean up, golden natural language data comes without period at the end, no capitalization in the beginning
    def clean_up_normalize(sentence):
        if sentence[0].isupper():
            sentence = sentence[0].lower() + sentence[1:]
        if sentence[-1] == '.':
            sentence = sentence[:-1]
        return sentence

    buffer = []
    with open(DOMAIN + "syn-aug.train.jsonl", 'r') as f:
        for l in f.readlines():
            buffer.append(json.loads(l))
    
    with open(DOMAIN + "syn-aug.train.jsonl", 'w') as f:
        for dp in buffer:
            f.write(json.dumps({
                'natural': clean_up_normalize(dp['natural']),
                'canonical': dp['canonical'],
                'formula': dp['formula']
            }) + '\n')

if DOMAIN == "pick-and-place/":
    # in pick and place, golden natural language data comes without period at the end, no capitalization in the beginning
    def clean_up_normalize(sentence):
        if sentence[0].isupper():
            sentence = sentence[0].lower() + sentence[1:]
        if sentence[-1] == '.':
            sentence = sentence[:-1]
        return sentence

    buffer = []
    with open(DOMAIN + "syn-aug.train.jsonl", 'r') as f:
        for l in f.readlines():
            buffer.append(json.loads(l))
    
    with open(DOMAIN + "syn-aug.train.jsonl", 'w') as f:
        for dp in buffer:
            f.write(json.dumps({
                'natural': clean_up_normalize(dp['natural']),
                'canonical': dp['canonical'],
                'formula': dp['formula']
            }) + '\n')

In [None]:
if DOMAIN == "drone-planning/":
    # in clean up, golden natural language data comes with a "space + period" at the end, no capitalization in the beginning
    def clean_up_normalize(sentence):
        if sentence[0].isupper():
            sentence = sentence[0].lower() + sentence[1:]
        while sentence[-1] == ' ' or sentence[-1] == '.' or sentence[-1] == '!':
            sentence = sentence[:-1]
        sentence = sentence + '.'
        sentence = sentence.replace('.', ' .')
        sentence = sentence.replace(',', ' ,')
        return sentence

    buffer = []
    # with open(DOMAIN + "syn-aug.train.jsonl", 'r') as f:
    #     for l in f.readlines():
    #         buffer.append(json.loads(l))
    
    # with open(DOMAIN + "syn-aug.train.jsonl", 'w') as f:
    #     for dp in buffer:
    #         f.write(json.dumps({
    #             'natural': clean_up_normalize(dp['natural']),
    #             'canonical': dp['canonical'],
    #             'formula': dp['formula']
    #         }) + '\n')
    with open(DOMAIN + "syn.train.jsonl", 'r') as f:
        for l in f.readlines():
            buffer.append(json.loads(l))
    
    with open(DOMAIN + "syn.train.jsonl", 'w') as f:
        for dp in buffer:
            f.write(json.dumps({
                'natural': clean_up_normalize(dp['natural']),
                'canonical': dp['canonical'],
                'formula': dp['formula']
            }) + '\n')