In [3]:
from dataset import SpuriousBooleanSampling
from utils import generate_parity_func
import torch

# Define Function

In [7]:
#To use the dataset. We first need to have two functions: one for spurious and one for core
#Here we use the example of parity function provided in our package
#Here the input to the function "generate_parity_func" is the index of coordinates the parity function is defined on
core_function = generate_parity_func([0,1,2])
spurious_function = generate_parity_func([0,1])
#or you can define your own boolean function. The function should take input as a boolean torch array and return the result as a boolean torch array.
#for example, say we want define a boolean function that give +1 if sum of x is greater than 2 then the function should be defined as
def example(x):
    res = torch.sum(x, dim=1)
    boolean_res = torch.where(res>=2, torch.tensor(1), torch.tensor(-1))
    return boolean_res

# Construct Dataset

In [12]:
# Params:
# core_len: This is the number of variable the core function is defined on.
# spurious_len: This is the number of variable the spurious function is defined on.
# core_func, spurious_func: These are functions defined following the guidance above.
# c: This is corresponding to the \lambda in our paper also named counfounder strength.
# sample_num: This is the number of training epochs drawn for one epoch. For efficiency reason, we implement it this way.
# bypass_bias_check: if this parameter is set to true, we will check whether the spurious function is severly biased and report error if so.
# We have 4 options for sampling_method. They are ["pure", "on_request", "buffer", "auto"]
# "pure" the dataset is constructed when initalized. Which means the whole sample space is iterate through to get labels for 
# the spurious and core function. Thus if core_len + spurious_len is small (less than 30), we should choose "pure" to enhance efficiency.
# "on_request" the dataset is constructed in-place. At each epoch we repetitively draw uniform random batch until we have enough sample that satisfy 
#  the distribution of D_\lambda 
# "buffer": Not Implemented yet
# "auto": if spurious length len less 15 we use pure otherwise we use on_request

#IMPORTANT: The dataset is currently constructed in a way different from the distribution we define in the paper (for efficiency reason). The distribution
#is only identical to the distribution defined in paper if the spurious function is UNBIASED. Note this is true for all the
#spurious function we studied in the paper. There should be an update to make the distribution identical to the distribution as
#we defined in the paper soon, no matter the spurious function is unbiased or not. 
dataset = SpuriousBooleanSampling(core_len=10, spurious_len=10, 
                                  core_func=core_function, spurious_func=example, 
                                  c=0.9, sample_num=10000, batch_size=64,
                                  sampling_method="pure",
                                  bypass_bias_check=False,
                                  device="cpu")

Spurious function has bias ratio: 0.37059998512268066


In [9]:
for x, core_label, group_label, spurious_label in dataset:
    print(x)
    break

tensor([[-1., -1., -1.,  ...,  1., -1., -1.],
        [-1., -1.,  1.,  ...,  1.,  1., -1.],
        [-1.,  1.,  1.,  ..., -1., -1.,  1.],
        ...,
        [-1., -1.,  1.,  ...,  1.,  1., -1.],
        [ 1., -1., -1.,  ...,  1.,  1.,  1.],
        [ 1., -1., -1.,  ...,  1.,  1., -1.]])
