In [7]:
import os 
import random
import string
import json
import pandas as pd
from collections import Counter

In [8]:
def set_seed(seed_val = 37):
    # setting the seed
    random.seed(seed_val)

set_seed(seed_val=37)

In [9]:
# TODO: clean string
def clean_str(ans):
    try:
        return ans.translate(str.maketrans('', '', string.punctuation)).lower().strip()
    except AttributeError:
        return None

In [10]:
ans_df = pd.read_csv('flan_t5_large_org_ans.csv')
train_path = '../full_data/FairytaleQA/train.json'
train_data = []
with open(train_path, 'r') as infile:
    for line in infile:
        train_data.append(json.loads(line))
train_df = pd.DataFrame(train_data)
print('Train data length:', len(train_df))

Train data length: 8548


In [11]:
ans_df.columns

Index(['story_name', 'content', 'answer', 'question', 'local_or_sum',
       'attribute', 'ex_or_im', 'generated_question', 'generated_answer'],
      dtype='object')

In [12]:
# Filter based on exact match of answers
select_rows = []
for i, row in ans_df.iterrows():
    if clean_str(row['answer']) == clean_str(row['generated_answer']):
        select_rows.append(row.values)

In [13]:
# NOTE: Shuffle rows
random.shuffle(select_rows)

In [14]:
em_ans_df = pd.DataFrame(select_rows, columns=ans_df.columns)
em_ans_df.head()

Unnamed: 0,story_name,content,answer,question,local_or_sum,attribute,ex_or_im,generated_question,generated_answer
0,ola-storbaekkjen,"now once , at the time of the hay harvest , he...",into the hill .,where was ola taken ?,local,setting,explicit,where did the man take ola?,into the hill.
1,the-adventures-of-gilla-na-chreck-an-gour,tom would not bring the flail into the palace ...,frightened .,how did the danes feel when they heard the sto...,local,feeling,explicit,how did ohers feel because they heard about th...,frightened.
2,the-mouse-the-bird-and-the-sausage,when the mouse had made up her fire and drawn ...,go to bed and sleep their fill till the next m...,what do they do after a meal ?,local,action,explicit,what did the bird do after the meal?,go to bed and sleep their fill till the next m...
3,prince-featherhead-and-the-princess-celandine,' why should you not ? ' said the princess tim...,the portrait .,what did the princess ask to be allowed to see ?,local,action,explicit,what did celandine decide was meant for the pr...,the portrait.
4,momotaro-story-of-son-of-peach,on hearing this the old man and his wife were ...,he had come out of a peach .,why did the couple name the boy 'momotaro' ?,local,causal relationship,explicit,why was momotaro the child's name?,he had come out of a peach.


In [15]:
def get_attribute_count(df):
    count = dict(Counter(df['attribute']))
    return count

In [16]:
train_attr_count = get_attribute_count(train_df)
train_attr_count

{'causal relationship': 2368,
 'character': 962,
 'action': 2694,
 'setting': 523,
 'outcome resolution': 811,
 'feeling': 824,
 'prediction': 366}

In [17]:
em_attr_count = get_attribute_count(em_ans_df)
em_attr_count

{'setting': 7461,
 'feeling': 12387,
 'action': 26284,
 'causal relationship': 11947,
 'prediction': 2208,
 'character': 14598,
 'outcome resolution': 4306}

# Complete Data Augment

In [23]:
# NOTE: Append complete data
aug_df = pd.DataFrame(select_rows, columns=ans_df.columns)
aug_df_data = aug_df.to_dict('records')
combine_full_data = train_data + aug_df_data
combine_full_df = pd.DataFrame(combine_full_data)
combine_full_attr_stats = get_attribute_count(combine_full_df)
combine_full_attr_stats

{'causal relationship': 14315,
 'character': 15560,
 'action': 28978,
 'setting': 7984,
 'outcome resolution': 5117,
 'feeling': 13211,
 'prediction': 2574}

In [24]:
with open('train_org_data_full_aug.json', 'w') as outfile:
    for d in combine_full_data:
        json.dump(d, outfile)
        outfile.write('\n')

# Selective Augment

In [25]:
total_count = {}
min_attr, min_count = '', 11111111
for attr_name, attr_count in train_attr_count.items():
    total_count[attr_name] = attr_count + em_attr_count[attr_name]
    if total_count[attr_name] < min_count:
        min_attr, min_count = attr_name, total_count[attr_name]
print(total_count)
print(min_attr, min_count)

{'causal relationship': 14315, 'character': 15560, 'action': 28978, 'setting': 7984, 'outcome resolution': 5117, 'feeling': 13211, 'prediction': 2574}
prediction 2574


In [26]:
max_augment = {}
for attr_name, attr_count in train_attr_count.items():
    max_augment[attr_name] = min_count - attr_count
max_augment

{'causal relationship': 206,
 'character': 1612,
 'action': -120,
 'setting': 2051,
 'outcome resolution': 1763,
 'feeling': 1750,
 'prediction': 2208}

In [27]:
# Selectively augment 
include_rows = []
for i, row in em_ans_df.iterrows():
    if max_augment[row['attribute']] > 0:
        include_rows.append(row.values)
        max_augment[row['attribute']] -= 1

In [28]:
include_df = pd.DataFrame(include_rows, columns=ans_df.columns)
include_df_data = include_df.to_dict('records')
print(len(include_df_data))

9590


In [29]:
combine_data = train_data + include_df_data
print(len(combine_data))

18138


In [30]:
combine_df = pd.DataFrame(combine_data)
combine_attr_stats = get_attribute_count(combine_df)
combine_attr_stats

{'causal relationship': 2574,
 'character': 2574,
 'action': 2694,
 'setting': 2574,
 'outcome resolution': 2574,
 'feeling': 2574,
 'prediction': 2574}

In [32]:
with open('train_org_data_bal_aug.json', 'w') as outfile:
    for d in combine_data:
        json.dump(d, outfile)
        outfile.write('\n')