## Notebook for generating the dataset for toy example

In this notebook we generate the toy example data for sanity checking our methods. The data is $(x, y) \in \mathbb{R}^3 \times \{0, 1\}$ for any training pair in the dataset. We train on datasets with both explanations $(x_1, x_2)$ and $(x_2, x_3)$ (and any permutation thereof) and test on seperate test sets that only include one of the explanations. 

In [32]:
import os
import torch 

Generating the training examples, currently we only use the simple classification function: $\text{sign}(x_1 + x_2) > 0 \Rightarrow 1$.

In [8]:
def gen_data(func, train_size: int = 10000):
    training_examples = torch.randn((train_size, 3))
    training_examples[:, 2] = training_examples[:, 0]
    training_labels = func(training_examples).unsqueeze(dim=-1).int()

    return training_examples, training_labels

def simple_decs_func(training_examples: torch.Tensor):
    return (training_examples[:, 0] + training_examples[:, 1]) > 0

training_simple = gen_data(simple_decs_func)
training_simple[0][:5], training_simple[1][:5]

(tensor([[-0.5617,  0.0576, -0.5617],
         [ 0.0563, -0.6701,  0.0563],
         [ 0.1022,  0.9445,  0.1022],
         [-1.4113, -1.4094, -1.4113],
         [ 0.8871, -0.0460,  0.8871]]),
 tensor([[0],
         [0],
         [1],
         [0],
         [1]], dtype=torch.int32))

In [9]:
def gen_test_sets(func, test_size: int = 1000):
    expl_1_training = torch.randn((test_size, 3)) # explanation of (x1, x2)
    expl_1_labels = func(expl_1_training).unsqueeze(dim=-1).int()
    expl_1_training[:, 2] = torch.randn_like(expl_1_training[:, 2]) * 2 + 5 # rand out x3 so only x1, x2 can be used
    
    expl_2_training = torch.randn((test_size, 3)) # explanation of (x2, x3)
    expl_2_labels = func(expl_2_training).unsqueeze(dim=-1).int()
    expl_2_training[:, 2] = expl_2_training[:, 0]
    expl_2_training[:, 0] = torch.randn_like(expl_1_training[:, 0]) * 2 + 5 # rand out x1 so only x3, x2 can be used
    
    return (expl_1_training, expl_1_labels), (expl_2_training, expl_2_labels)

test_sets = gen_test_sets(simple_decs_func)

In [10]:
test_sets[0][0][:5], test_sets[0][1][:5]

(tensor([[ 0.5008,  0.5345,  3.5692],
         [ 0.5979, -0.0629,  4.6133],
         [-1.4677, -0.8215,  6.9923],
         [ 0.1728,  2.3955,  1.4230],
         [ 1.1203, -0.4027,  1.1848]]),
 tensor([[1],
         [1],
         [0],
         [1],
         [1]], dtype=torch.int32))

In [11]:
# use .pt as its more optimized for torch tensors
BASE_DIR = '../data/toy_example'
save_path_simple = os.path.join(BASE_DIR, 'simple_func_dataset.pt')
torch.save({'training_x': training_simple[0],
            'training_y': training_simple[1],
            'test1_x': test_sets[0][0],
            'test1_y': test_sets[0][1],
            'test2_x': test_sets[1][0],
            'test2_y': test_sets[1][1],
            }, save_path_simple)

We can also try a more complicated function: $\text{sin}(x_1 + x_2) > 0 \Rightarrow 0$.

In [13]:
def sine_decs(training_examples):
    return torch.sin(training_examples[:, 0] + training_examples[:, 2]) > 0

training_sine = gen_data(sine_decs)
test_sets = gen_test_sets(sine_decs)

save_path_sine = os.path.join(BASE_DIR, 'sine_func_dataset.pt')
torch.save({'training_x': training_sine[0],
            'training_y': training_sine[1],
            'test1_x': test_sets[0][0],
            'test1_y': test_sets[0][1],
            'test2_x': test_sets[1][0],
            'test2_y': test_sets[1][1],
            }, save_path_sine)

In [14]:
test_sets[1][0], test_sets[1][1][:5]

(tensor([[ 4.2298, -1.3915, -1.0504],
         [ 5.3804, -0.6354, -0.9392],
         [ 4.9189,  1.0054, -0.7379],
         ...,
         [ 9.0359, -0.2256,  0.0445],
         [ 3.7391, -0.3364,  1.1681],
         [ 5.3135,  0.0984, -0.6612]]),
 tensor([[1],
         [0],
         [0],
         [0],
         [0]], dtype=torch.int32))

In [5]:
import torch
from torch import nn, Tensor
from torch.nn.functional import softmax, gumbel_softmax
from sparse_generalization.layers.bern_mha import MultiHeadAttentionBern

# Assuming your class MultiHeadAttentionBern is already defined and imported

# Hyperparameters
batch_size = 2
seq_len = 4
embed_size = 8
num_heads = 1
dropout = 0.0
temp = 0.5
hard = True

# Dummy input
x = torch.randn(batch_size, seq_len, embed_size, requires_grad=True)

# Instantiate the layer
mha_bern = MultiHeadAttentionBern(embed_size=embed_size,
                                  num_heads=num_heads,
                                  dropout=dropout,
                                  temp=temp,
                                  hard=hard)

# Forward pass
output, adjacency = mha_bern(x, x, x)

print("Output shape:", output.shape)           # Expected: (batch_size, seq_len, embed_size)
print("Adjacency shape:", adjacency.shape)     # Expected: (batch_size, num_heads, seq_len, seq_len)

adjacency


Output shape: torch.Size([2, 4, 8])
Adjacency shape: torch.Size([2, 4, 4])


tensor([[[1., 0., 1., 0.],
         [0., 1., 1., 0.],
         [1., 1., 0., 1.],
         [1., 1., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 0., 0., 0.],
         [0., 0., 0., 1.],
         [1., 0., 1., 0.]]], grad_fn=<MeanBackward1>)