In [1]:
import os
import math
import string
import random
import numpy as np
from collections import defaultdict, Counter
import pandas as pd
from IPython.display import Markdown, display

In [2]:
np.random.seed(1234)

In [3]:
def printbold(string):
    display(Markdown('{}{}{}'.format('**', string, '**')))

In [4]:
def sign(x):
    if x < 0:
        return '-'
    elif x > 0:
        return '+'
    else:
        return 'none'

In [5]:
class Setup(object):
    def __init__(self, base_dir='/mnt/task_runtime/'):
        self.base_dir = base_dir 
        
        data_folder_path, results_folder_path = os.path.join(self.base_dir, 'lego-data'), os.path.join(self.base_dir, 'lego-results')
        for folder_path in [data_folder_path, results_folder_path]:
            if not os.path.exists(folder_path):
                os.makedirs(os.path.join(folder_path))
    
    def generate_datapoint(self, length):
        alphabet = list(string.ascii_lowercase)
        
        symbols = random.sample(alphabet, length)
        values = 2 * np.random.randint(2, size=(length, )) - 1
        
        pairs = list(zip(symbols, values))
        symbol_value_map = dict(pairs)
            
        eqs = ['{}={}'.format(symbols[0], values[0])]
        for i in range(1, length):
            eqs.append('{}={}{}'.format(symbols[i], sign(values[i]/values[i-1]), symbols[i-1]))
        np.random.shuffle(eqs)
        input_eqs = '; '.join(eqs)
        
        input_symbols_in_order = []
        for eq in eqs:
            input_symbols_in_order.append(eq.split('=')[0])
        
        input_order_labels = []
        for symbol in input_symbols_in_order: 
            input_order_labels.append('{}={}'.format(symbol, symbol_value_map[symbol]))
        
        resolution_order_labels = []
        for pair in pairs:
            resolution_order_labels.append('='.join([pair[0], str(pair[1])]))  
        return input_eqs, ';'.join(resolution_order_labels), ';'.join(input_order_labels), ';'.join(input_order_labels)+';aux: {}'.format(','.join(symbols))
    
    def train_test_split(self, df, frac=0.8):
        split_point = int(math.floor(frac * len(df)))
        df = df.sample(frac=1.)
        train_df, test_df = df[:split_point], df[split_point:]
        return train_df, test_df
        
    def save(self, dataset, path):
        dataset_str = '\n'.join(dataset)
        with open(path, 'w') as f:
            f.write(dataset_str)
    def get_path(self, obj_type, fname):
        if obj_type not in ['lego-data', 'lego-results']:
            raise ValueError('invalid object type!')
        return os.path.join(self.base_dir, obj_type, fname)
            
    def generate_dataset(self, train_seq_length, test_seq_length, ood_test_seq_length, persist=False, num_train_samples=8000,  num_test_samples=500, num_ood_test_samples=500, train_fname='lego-train.txt', test_fname='lego-test.txt', ood_test_fname='lego-ood-test.txt'):
        dataset, ood_test_dataset = set([]), set([])
        
        for _ in range(num_train_samples + num_test_samples):
            dataset.add(self.generate_datapoint(train_seq_length)) 
        
        for _ in range(num_ood_test_samples):
            ood_test_dataset.add(self.generate_datapoint(ood_test_seq_length))  
            
        df, ood_test_df = pd.DataFrame(dataset, columns=['input', 'resolution-order-labels', 'input-order-labels', 'labels-with-aux']), pd.DataFrame(ood_test_dataset, columns=['input', 'resolution-order-labels', 'input-order-labels', 'labels-with-aux'])  
        train_df, test_df = self.train_test_split(df)
        
        if persist:
            train_df.to_csv(self.get_path('lego-data', train_fname))
            test_df.to_csv(self.get_path('lego-data', test_fname))
            ood_test_df.to_csv(self.get_path('lego-data', ood_test_fname))
            
        return train_df, test_df, ood_test_df   

In [6]:
setup = Setup()

In [7]:
setup.generate_datapoint(5)

('i=-t; t=-g; g=+l; l=1; f=-i',
 'l=1;g=1;t=-1;i=1;f=-1',
 'i=1;t=-1;g=1;l=1;f=-1',
 'i=1;t=-1;g=1;l=1;f=-1;aux: l,g,t,i,f')

In [8]:
train_df, test_df, ood_test_df = setup.generate_dataset(6, 6, 12, persist=True)

In [9]:
train_df.head()

Unnamed: 0,input,resolution-order-labels,input-order-labels,labels-with-aux
5642,t=-q; r=+t; a=+p; q=+s; p=-r; s=1,s=1;q=1;t=-1;r=-1;p=1;a=1,t=-1;r=-1;a=1;q=1;p=1;s=1,"t=-1;r=-1;a=1;q=1;p=1;s=1;aux: s,q,t,r,p,a"
2515,d=-v; r=+d; s=+h; v=+s; w=1; h=-w,w=1;h=-1;s=-1;v=-1;d=1;r=1,d=1;r=1;s=-1;v=-1;w=1;h=-1,"d=1;r=1;s=-1;v=-1;w=1;h=-1;aux: w,h,s,v,d,r"
1815,o=-k; q=-o; d=-1; c=+q; m=-d; k=-m,d=-1;m=1;k=-1;o=1;q=-1;c=-1,o=1;q=-1;d=-1;c=-1;m=1;k=-1,"o=1;q=-1;d=-1;c=-1;m=1;k=-1;aux: d,m,k,o,q,c"
3346,h=-g; g=+p; r=+c; c=-1; p=+l; l=+r,c=-1;r=-1;l=-1;p=-1;g=-1;h=1,h=1;g=-1;r=-1;c=-1;p=-1;l=-1,"h=1;g=-1;r=-1;c=-1;p=-1;l=-1;aux: c,r,l,p,g,h"
4744,l=+w; j=1; w=-f; f=+k; k=+p; p=-j,j=1;p=-1;k=-1;f=-1;w=1;l=1,l=1;j=1;w=1;f=-1;k=-1;p=-1,"l=1;j=1;w=1;f=-1;k=-1;p=-1;aux: j,p,k,f,w,l"


In [10]:
test_df.head()

Unnamed: 0,input,resolution-order-labels,input-order-labels,labels-with-aux
7340,n=-1; o=-n; t=+o; e=-i; i=-m; m=+t,n=-1;o=1;t=1;m=1;i=-1;e=1,n=-1;o=1;t=1;e=1;i=-1;m=1,"n=-1;o=1;t=1;e=1;i=-1;m=1;aux: n,o,t,m,i,e"
5631,l=-j; c=+y; q=+c; y=-1; j=+q; d=-l,y=-1;c=-1;q=-1;j=-1;l=1;d=-1,l=1;c=-1;q=-1;y=-1;j=-1;d=-1,"l=1;c=-1;q=-1;y=-1;j=-1;d=-1;aux: y,c,q,j,l,d"
2093,m=-z; p=+h; z=-1; e=-g; h=+e; g=+m,z=-1;m=1;g=1;e=-1;h=-1;p=-1,m=1;p=-1;z=-1;e=-1;h=-1;g=1,"m=1;p=-1;z=-1;e=-1;h=-1;g=1;aux: z,m,g,e,h,p"
5872,s=+r; r=-g; h=+n; l=-h; g=-l; n=-1,n=-1;h=-1;l=1;g=-1;r=1;s=1,s=1;r=1;h=-1;l=1;g=-1;n=-1,"s=1;r=1;h=-1;l=1;g=-1;n=-1;aux: n,h,l,g,r,s"
7522,a=+k; b=-a; s=+l; k=+s; p=+b; l=-1,l=-1;s=-1;k=-1;a=-1;b=1;p=1,a=-1;b=1;s=-1;k=-1;p=1;l=-1,"a=-1;b=1;s=-1;k=-1;p=1;l=-1;aux: l,s,k,a,b,p"


In [11]:
ood_test_df.head()

Unnamed: 0,input,resolution-order-labels,input-order-labels,labels-with-aux
0,g=+e; j=-d; d=-m; e=1; m=+g; a=-t; t=-o; w=+n;...,e=1;g=1;m=1;d=-1;j=1;c=-1;q=-1;o=1;t=-1;a=1;n=...,g=1;j=1;d=-1;e=1;m=1;a=1;t=-1;w=-1;c=-1;o=1;n=...,g=1;j=1;d=-1;e=1;m=1;a=1;t=-1;w=-1;c=-1;o=1;n=...
1,q=+g; p=-n; w=+i; h=-k; o=+p; i=-y; k=-1; g=+o...,k=-1;h=1;y=-1;i=1;w=1;n=1;p=-1;o=-1;g=-1;q=-1;...,q=-1;p=-1;w=1;h=1;o=-1;i=1;k=-1;g=-1;a=1;c=-1;...,q=-1;p=-1;w=1;h=1;o=-1;i=1;k=-1;g=-1;a=1;c=-1;...
2,y=+j; e=-l; g=-q; q=-x; x=-h; l=-y; j=+i; h=+n...,n=-1;h=-1;x=1;q=-1;g=1;b=-1;t=1;i=-1;j=-1;y=-1...,y=-1;e=-1;g=1;q=-1;x=1;l=1;j=-1;h=-1;t=1;n=-1;...,y=-1;e=-1;g=1;q=-1;x=1;l=1;j=-1;h=-1;t=1;n=-1;...
3,u=-n; g=-w; w=+z; r=+l; a=-g; n=+a; p=+c; l=+v...,v=-1;l=-1;r=-1;c=1;p=1;z=1;w=1;g=-1;a=1;n=1;u=...,u=-1;g=-1;w=1;r=-1;a=1;n=1;p=1;l=-1;v=-1;z=1;c...,u=-1;g=-1;w=1;r=-1;a=1;n=1;p=1;l=-1;v=-1;z=1;c...
4,o=+h; s=+a; v=+b; c=-w; b=-1; e=-o; a=+v; f=+i...,b=-1;v=-1;a=-1;s=-1;i=1;f=1;z=1;w=-1;c=1;h=-1;...,o=-1;s=-1;v=-1;c=1;b=-1;e=1;a=-1;f=1;i=1;h=-1;...,o=-1;s=-1;v=-1;c=1;b=-1;e=1;a=-1;f=1;i=1;h=-1;...


In [12]:
train_df.shape[0], test_df.shape[0], ood_test_df.shape[0]

(6800, 1700, 500)

In [14]:
ood_test_df.iloc[0, :]['input']

'g=+e; j=-d; d=-m; e=1; m=+g; a=-t; t=-o; w=+n; c=-j; o=-q; n=-a; q=+c'

In [15]:
ood_test_df.iloc[0, :]['resolution-order-labels']

'e=1;g=1;m=1;d=-1;j=1;c=-1;q=-1;o=1;t=-1;a=1;n=-1;w=-1'