In [1]:
import json
from tqdm import tqdm
import os

import random
random.seed(1249)

if not os.path.exists('./data/custom/a2_exp/'):
    os.makedirs('./data/custom/a2_exp/')

#Set ReaSCAN data path
data_path = './data/ReaSCAN-v1.1/'

data = [json.loads(line) for line in open(data_path + 'ReaSCAN-compositional/train.json', 'r')]
random.shuffle(data)

In [2]:
two_relative_clause_count = 0
one_relative_clause_count = 0 
simple_count = 0

for item in data:
    if 'and' in item['input_command']:
        two_relative_clause_count += 1
    if ('that' in item['input_command']) and ('and' not in item['input_command']):
        one_relative_clause_count += 1
    if 'that' not in item['input_command']:
        simple_count += 1
        
print("Simple Command Percentage: ", (simple_count/len(data))*100)
print("One Rel Clause Command Percentage: ", (one_relative_clause_count/len(data))*100)
print("Two Rel Clause Command Percentage: ", (two_relative_clause_count/len(data))*100)

Simple Command Percentage:  15.593027521575923
One Rel Clause Command Percentage:  36.98941306820919
Two Rel Clause Command Percentage:  47.417559410214885


### Total number of red square distractors

In [3]:
count_all_red_square = 0

for sample in data:
    for col in range(0, 6):
        for row in range(0, 6):
            if sample['situation'][col][row][4:] == [0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]:
                count_all_red_square += 1
                
print('Total number of red square distractors: ', count_all_red_square)

Total number of red square distractors:  390307


### Total number of situations with red square distractors

In [4]:
count_situation_with_red_square = 0

for sample in data:
    sample_rs = False
    for col in range(0, 6):
        for row in range(0, 6):
            if sample['situation'][col][row][4:] == [0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]:
                sample_rs = True
    if sample_rs == True:
        count_situation_with_red_square += 1
        
print('Total number of situations with red square distractors: ', count_situation_with_red_square)
print('Total number of situations: ', len(data))
print('Percentage of train data where situation contains red squares: ', count_situation_with_red_square/len(data))

Total number of situations with red square distractors:  262312
Total number of situations:  539722
Percentage of train data where situation contains red squares:  0.48601316974294173


### Split data based on red square distractor

In [5]:
rs_data = []
non_rs_data = []

for sample in data:
    sample_rs = False
    for col in range(0, 6):
        for row in range(0, 6):
            if sample['situation'][col][row][4:] == [0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]:
                sample_rs = True
    if sample_rs == True:
        rs_data.append(sample)
    else:
        non_rs_data.append(sample)

In [6]:
print(len(rs_data))
print(len(non_rs_data))

262312
277410


#### RS data 112000

In [7]:
two_relative_clause_count = 0
one_relative_clause_count = 0 
simple_count = 0

for item in rs_data:
    if 'and' in item['input_command']:
        two_relative_clause_count += 1
    if ('that' in item['input_command']) and ('and' not in item['input_command']):
        one_relative_clause_count += 1
    if 'that' not in item['input_command']:
        simple_count += 1
        
print("Simple Command Percentage in rs data: ", (simple_count/len(rs_data))*100)
print("One Rel Clause Command Percentage in rs data: ", (one_relative_clause_count/len(rs_data))*100)
print("Two Rel Clause Command Percentage in rs data: ", (two_relative_clause_count/len(rs_data))*100)

Simple Command Percentage in rs data:  19.7566256976425
One Rel Clause Command Percentage in rs data:  44.63806459483363
Two Rel Clause Command Percentage in rs data:  35.605309707523865


In [8]:
random.shuffle(rs_data)

sampled_rs_data = []

two_relative_clause_count = 0
one_relative_clause_count = 0 
simple_count = 0

for item in rs_data:
    if ('and' in item['input_command']) and (two_relative_clause_count < 53104):#53104
        sampled_rs_data.append(item)
        two_relative_clause_count += 1
    if (('that' in item['input_command']) and ('and' not in item['input_command'])) and (one_relative_clause_count < 41426):#41426
        sampled_rs_data.append(item)
        one_relative_clause_count += 1
    if ('that' not in item['input_command']) and (simple_count < 17470):#17470
        sampled_rs_data.append(item)
        simple_count += 1
        
len(sampled_rs_data)

112000

In [9]:
two_relative_clause_count = 0
one_relative_clause_count = 0 
simple_count = 0

for item in sampled_rs_data:
    if 'and' in item['input_command']:
        two_relative_clause_count += 1
    if ('that' in item['input_command']) and ('and' not in item['input_command']):
        one_relative_clause_count += 1
    if 'that' not in item['input_command']:
        simple_count += 1
        
print("Simple Command Percentage in sampled rs data: ", (simple_count/len(sampled_rs_data))*100)
print("One Rel Clause Command Percentage in sampled rs data: ", (one_relative_clause_count/len(sampled_rs_data))*100)
print("Two Rel Clause Command Percentage in sampled rs data: ", (two_relative_clause_count/len(sampled_rs_data))*100)

Simple Command Percentage in sampled rs data:  15.598214285714285
One Rel Clause Command Percentage in sampled rs data:  36.987500000000004
Two Rel Clause Command Percentage in sampled rs data:  47.41428571428571


#### Non RS data 112000

In [10]:
two_relative_clause_count = 0
one_relative_clause_count = 0 
simple_count = 0

for item in non_rs_data:
    if 'and' in item['input_command']:
        two_relative_clause_count += 1
    if ('that' in item['input_command']) and ('and' not in item['input_command']):
        one_relative_clause_count += 1
    if 'that' not in item['input_command']:
        simple_count += 1
        
print("Simple Command Percentage in non rs data: ", (simple_count/len(non_rs_data))*100)
print("One Rel Clause Command Percentage in non rs data: ", (one_relative_clause_count/len(non_rs_data))*100)
print("Two Rel Clause Command Percentage in non rs data: ", (two_relative_clause_count/len(non_rs_data))*100)

Simple Command Percentage in non rs data:  11.656032587145381
One Rel Clause Command Percentage in non rs data:  29.757038318734004
Two Rel Clause Command Percentage in non rs data:  58.58692909412062


In [11]:
random.shuffle(non_rs_data)

sampled_non_rs_data = []

two_relative_clause_count = 0
one_relative_clause_count = 0 
simple_count = 0

for item in non_rs_data:
    if ('and' in item['input_command']) and (two_relative_clause_count < 53104):#53104
        sampled_non_rs_data.append(item)
        two_relative_clause_count += 1
    if (('that' in item['input_command']) and ('and' not in item['input_command'])) and (one_relative_clause_count < 41426):#41426
        sampled_non_rs_data.append(item)
        one_relative_clause_count += 1
    if ('that' not in item['input_command']) and (simple_count < 17470):#17470
        sampled_non_rs_data.append(item)
        simple_count += 1
        
len(sampled_non_rs_data)

112000

In [12]:
two_relative_clause_count = 0
one_relative_clause_count = 0 
simple_count = 0

for item in sampled_non_rs_data:
    if 'and' in item['input_command']:
        two_relative_clause_count += 1
    if ('that' in item['input_command']) and ('and' not in item['input_command']):
        one_relative_clause_count += 1
    if 'that' not in item['input_command']:
        simple_count += 1
        
print("Simple Command Percentage in sampled non rs data: ", (simple_count/len(sampled_non_rs_data))*100)
print("One Rel Clause Command Percentage in sampled non rs data: ", (one_relative_clause_count/len(sampled_non_rs_data))*100)
print("Two Rel Clause Command Percentage in sampled non rs data: ", (two_relative_clause_count/len(sampled_non_rs_data))*100)

Simple Command Percentage in sampled non rs data:  15.598214285714285
One Rel Clause Command Percentage in sampled non rs data:  36.987500000000004
Two Rel Clause Command Percentage in sampled non rs data:  47.41428571428571


#### RS data 200000

In [13]:
two_relative_clause_count = 0
one_relative_clause_count = 0 
simple_count = 0

for item in rs_data:
    if 'and' in item['input_command']:
        two_relative_clause_count += 1
    if ('that' in item['input_command']) and ('and' not in item['input_command']):
        one_relative_clause_count += 1
    if 'that' not in item['input_command']:
        simple_count += 1
        
print("Simple Command Percentage in rs data: ", (simple_count/len(rs_data))*100)
print("One Rel Clause Command Percentage in rs data: ", (one_relative_clause_count/len(rs_data))*100)
print("Two Rel Clause Command Percentage in rs data: ", (two_relative_clause_count/len(rs_data))*100)

Simple Command Percentage in rs data:  19.7566256976425
One Rel Clause Command Percentage in rs data:  44.63806459483363
Two Rel Clause Command Percentage in rs data:  35.605309707523865


In [14]:
random.shuffle(rs_data)

sampled_rs_data_200000 = []

two_relative_clause_count = 0
one_relative_clause_count = 0 
simple_count = 0

for item in rs_data:
    if ('and' in item['input_command']) and (two_relative_clause_count < 93397):#53104
        sampled_rs_data_200000.append(item)
        two_relative_clause_count += 1
    if (('that' in item['input_command']) and ('and' not in item['input_command'])) and (one_relative_clause_count < 74000):#41426
        sampled_rs_data_200000.append(item)
        one_relative_clause_count += 1
    if ('that' not in item['input_command']) and (simple_count < 32603):#17470
        sampled_rs_data_200000.append(item)
        simple_count += 1
        
len(sampled_rs_data_200000)

200000

In [15]:
two_relative_clause_count = 0
one_relative_clause_count = 0 
simple_count = 0

for item in sampled_rs_data_200000:
    if 'and' in item['input_command']:
        two_relative_clause_count += 1
    if ('that' in item['input_command']) and ('and' not in item['input_command']):
        one_relative_clause_count += 1
    if 'that' not in item['input_command']:
        simple_count += 1
        
print("Simple Command Percentage in sampled rs data: ", (simple_count/len(sampled_rs_data_200000))*100)
print("One Rel Clause Command Percentage in sampled rs data: ", (one_relative_clause_count/len(sampled_rs_data_200000))*100)
print("Two Rel Clause Command Percentage in sampled rs data: ", (two_relative_clause_count/len(sampled_rs_data_200000))*100)

Simple Command Percentage in sampled rs data:  16.3015
One Rel Clause Command Percentage in sampled rs data:  37.0
Two Rel Clause Command Percentage in sampled rs data:  46.698499999999996


#### Non RS data 200000

In [16]:
two_relative_clause_count = 0
one_relative_clause_count = 0 
simple_count = 0

for item in non_rs_data:
    if 'and' in item['input_command']:
        two_relative_clause_count += 1
    if ('that' in item['input_command']) and ('and' not in item['input_command']):
        one_relative_clause_count += 1
    if 'that' not in item['input_command']:
        simple_count += 1
        
print("Simple Command Percentage in non rs data: ", (simple_count/len(non_rs_data))*100)
print("One Rel Clause Command Percentage in non rs data: ", (one_relative_clause_count/len(non_rs_data))*100)
print("Two Rel Clause Command Percentage in non rs data: ", (two_relative_clause_count/len(non_rs_data))*100)

Simple Command Percentage in non rs data:  11.656032587145381
One Rel Clause Command Percentage in non rs data:  29.757038318734004
Two Rel Clause Command Percentage in non rs data:  58.58692909412062


In [17]:
random.shuffle(non_rs_data)

sampled_non_rs_data_200000 = []

two_relative_clause_count = 0
one_relative_clause_count = 0 
simple_count = 0

for item in non_rs_data:
    if ('and' in item['input_command']) and (two_relative_clause_count < 94945):#53104
        sampled_non_rs_data_200000.append(item)
        two_relative_clause_count += 1
    if (('that' in item['input_command']) and ('and' not in item['input_command'])) and (one_relative_clause_count < 73935):#41426
        sampled_non_rs_data_200000.append(item)
        one_relative_clause_count += 1
    if ('that' not in item['input_command']) and (simple_count < 31120):#17470
        sampled_non_rs_data_200000.append(item)
        simple_count += 1
        
len(sampled_non_rs_data_200000)

200000

In [18]:
two_relative_clause_count = 0
one_relative_clause_count = 0 
simple_count = 0

for item in sampled_non_rs_data_200000:
    if 'and' in item['input_command']:
        two_relative_clause_count += 1
    if ('that' in item['input_command']) and ('and' not in item['input_command']):
        one_relative_clause_count += 1
    if 'that' not in item['input_command']:
        simple_count += 1
        
print("Simple Command Percentage in sampled non rs data: ", (simple_count/len(sampled_non_rs_data_200000))*100)
print("One Rel Clause Command Percentage in sampled non rs data: ", (one_relative_clause_count/len(sampled_non_rs_data_200000))*100)
print("Two Rel Clause Command Percentage in sampled non rs data: ", (two_relative_clause_count/len(sampled_non_rs_data_200000))*100)

Simple Command Percentage in sampled non rs data:  15.559999999999999
One Rel Clause Command Percentage in sampled non rs data:  36.9675
Two Rel Clause Command Percentage in sampled non rs data:  47.472500000000004


#### Random 200000

In [19]:
data_200000 = random.sample(data, 200000)

In [20]:
two_relative_clause_count = 0
one_relative_clause_count = 0 
simple_count = 0

for item in data_200000:
    if 'and' in item['input_command']:
        two_relative_clause_count += 1
    if ('that' in item['input_command']) and ('and' not in item['input_command']):
        one_relative_clause_count += 1
    if 'that' not in item['input_command']:
        simple_count += 1
        
print("Simple Command Percentage in sampled non rs data: ", (simple_count/len(data_200000))*100)
print("One Rel Clause Command Percentage in sampled non rs data: ", (one_relative_clause_count/len(data_200000))*100)
print("Two Rel Clause Command Percentage in sampled non rs data: ", (two_relative_clause_count/len(data_200000))*100)

Simple Command Percentage in sampled non rs data:  15.6095
One Rel Clause Command Percentage in sampled non rs data:  36.8375
Two Rel Clause Command Percentage in sampled non rs data:  47.553


### Write data

In [21]:
with open('./data/custom/a2_exp/sampled_rs_data_112000.json', 'w') as f:
    for line in sampled_rs_data:
        f.write(json.dumps(line) + '\n')

In [22]:
with open('./data/custom/a2_exp/sampled_non_rs_data_112000.json', 'w') as f:
    for line in sampled_non_rs_data:
        f.write(json.dumps(line) + '\n')

In [21]:
with open('./data/custom/a2_exp/sampled_rs_data_200000.json', 'w') as f:
    for line in sampled_rs_data_200000:
        f.write(json.dumps(line) + '\n')

In [22]:
with open('./data/custom/a2_exp/sampled_non_rs_data_200000.json', 'w') as f:
    for line in sampled_non_rs_data_200000:
        f.write(json.dumps(line) + '\n')

In [23]:
with open('./data/custom/a2_exp/train_random_200000.json', 'w') as f:
    for line in data_200000:
        f.write(json.dumps(line) + '\n')