## Notebook for generating the dataset for toy example

In this notebook we generate the toy example data for sanity checking our methods. The data is $(x, y) \in \mathbb{R}^3 \times \{0, 1\}$ for any training pair in the dataset. We train on datasets with both explanations $(x_1, x_2)$ and $(x_2, x_3)$ (and any permutation thereof) and test on seperate test sets that only include one of the explanations. 

In [1]:
import os
import torch 
BASE_DIR = '../data/toy_example'

Generating the training examples, currently we only use the simple classification function: $\text{sign}(x_1 + x_2) > 0 \Rightarrow 1$.

In [18]:
def gen_data(func, train_size: int = 10000):
    training_examples = torch.randn((train_size, 3))
    training_examples[:, 2] = training_examples[:, 0]
    training_labels = func(training_examples).unsqueeze(dim=-1).int()

    return training_examples, training_labels

def simple_decs_func(training_examples: torch.Tensor):
    return (training_examples[:, 0] + training_examples[:, 1]) > 0

training_simple = gen_data(simple_decs_func)
training_simple[0][:5], training_simple[1][:5]

(tensor([[-1.1507, -1.1341, -1.1507],
         [ 0.6236, -0.8684,  0.6236],
         [-0.9229,  1.8587, -0.9229],
         [-1.1226, -1.3639, -1.1226],
         [-0.2069, -1.6822, -0.2069]]),
 tensor([[0],
         [0],
         [1],
         [0],
         [0]], dtype=torch.int32))

In [22]:
def gen_test_sets(func, test_size: int = 1000, inter: bool = False):
    expl_1_training = torch.randn((test_size, 3)) # explanation of (x1, x2)
    expl_1_training[:, 2] = expl_1_training[:, 0]
    expl_1_labels = func(expl_1_training).unsqueeze(dim=-1).int()
    if inter:
        expl_1_training[:, 2] = 20
    else:
        expl_1_training[:, 2] = torch.randn_like(expl_1_training[:, 2]) * 2 + 5 # rand out x3 so only x1, x2 can be used
    
    expl_2_training = torch.randn((test_size, 3)) # explanation of (x2, x3)
    expl_2_training[:, 2] = expl_2_training[:, 0]
    expl_2_labels = func(expl_2_training).unsqueeze(dim=-1).int()
    if inter:
        expl_2_training[:, 0] = 20
    else:
        expl_2_training[:, 0] = torch.randn_like(expl_1_training[:, 0]) * 2 + 5 # rand out x1 so only x3, x2 can be used
    
    return (expl_1_training, expl_1_labels), (expl_2_training, expl_2_labels)

test_sets = gen_test_sets(simple_decs_func, inter=True)

In [23]:
test_sets[0][0][:5], test_sets[0][1][:5]

(tensor([[-0.8198, -0.1908, 20.0000],
         [-0.2870, -1.2452, 20.0000],
         [ 0.6514,  1.1787, 20.0000],
         [ 1.2639, -0.4919, 20.0000],
         [ 1.6882, -0.2670, 20.0000]]),
 tensor([[0],
         [0],
         [1],
         [1],
         [1]], dtype=torch.int32))

In [24]:
# use .pt as its more optimized for torch tensors
save_path_simple = os.path.join(BASE_DIR, 'simple_func_dataset_int.pt')
torch.save({'training_x': training_simple[0],
            'training_y': training_simple[1],
            'test1_x': test_sets[0][0],
            'test1_y': test_sets[0][1],
            'test2_x': test_sets[1][0],
            'test2_y': test_sets[1][1],
            }, save_path_simple)

We can also try a more complicated function: $\text{sin}(x_1 + x_2) > 0 \Rightarrow 0$.

In [26]:
def sine_decs(training_examples):
    return torch.sin(training_examples[:, 0] + training_examples[:, 2]) > 0

training_sine = gen_data(sine_decs)
test_sets = gen_test_sets(sine_decs, inter=True)

save_path_sine = os.path.join(BASE_DIR, 'sine_func_dataset_int.pt')
torch.save({'training_x': training_sine[0],
            'training_y': training_sine[1],
            'test1_x': test_sets[0][0],
            'test1_y': test_sets[0][1],
            'test2_x': test_sets[1][0],
            'test2_y': test_sets[1][1],
            }, save_path_sine)

In [27]:
test_sets[1][0], test_sets[1][1][:5]

(tensor([[20.0000,  0.2269, -0.2570],
         [20.0000, -0.5824, -0.8815],
         [20.0000, -1.6443, -1.5357],
         ...,
         [20.0000,  0.8063,  0.2168],
         [20.0000,  0.6466,  1.0724],
         [20.0000, -0.4007,  0.7072]]),
 tensor([[0],
         [0],
         [0],
         [0],
         [1]], dtype=torch.int32))

We can also try a more RND-like approach. The goal of this is to make it so that the function that our methods have to learn is much more complicated than the ones above. 

In [29]:
import torch.nn as nn

def nn_func(training_examples: torch.Tensor):
    net = nn.Sequential(
        nn.Linear(2, 512), 
        nn.ReLU(),
        nn.Linear(512, 512), 
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.ReLU(), 
        nn.Linear(512, 1), 
        nn.ReLU(),
        nn.Sigmoid(),
    )
    return (net(training_examples[:, 0:2]) > 0.5).squeeze()

training_nn = gen_data(nn_func)
test_sets = gen_test_sets(nn_func, inter=True)

save_path_sine = os.path.join(BASE_DIR, 'nn_func_dataset_int.pt')
torch.save({'training_x': training_nn[0],
            'training_y': training_nn[1],
            'test1_x': test_sets[0][0],
            'test1_y': test_sets[0][1],
            'test2_x': test_sets[1][0],
            'test2_y': test_sets[1][1],
            }, save_path_sine)
