In [1]:
# create a three layer MLP in torch

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import tqdm

In [2]:
# create a three layer MLP in torch
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size, bias=True)
        self.layer2 = nn.Linear(hidden_size, hidden_size, bias=True)
        self.layer3 = nn.Linear(hidden_size, output_size, bias=True)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.softmax(self.layer3(x))
        return x




In [4]:
def generate_grouped_data(n):
    """Generates n data points with 4 balanced groups,
       (w==x, y==z), (w==x, y!=z), (w!=x, y==z), (w!=x, y!=z)
       For the heirarchical equality task""" 
    
    # Initialize the empty list for data and labels
    data, labels = [], []
    for i in range(n):
        # Define a pattern for each group based on the remainder when i is divided by 4
        group_pattern = i % 4
        
        # Depending on the pattern, generate w, x, y, z
        if group_pattern == 0:  # w==x, y!=z
            w = x = np.random.uniform(-0.5, 0.5)
            y, z = np.random.uniform(-0.5, 0.5, 2)
        elif group_pattern == 1:  # w==x, y==z
            w = x = np.random.uniform(-0.5, 0.5)
            y = z = np.random.uniform(-0.5, 0.5)
        elif group_pattern == 2:  # w!=x, y==z
            w, x = np.random.uniform(-0.5, 0.5, 2)
            y = z = np.random.uniform(-0.5, 0.5)
        elif group_pattern == 3:  # w!=x, y!=z
            w, x, y, z = np.random.uniform(-0.5, 0.5, 4)
        
        # Decide the label based on the given condition
        label = 1 if (w == x) == (y == z) else 0

        # Append the generated data and label to the respective lists
        data.append([w, x, y, z])
        labels.append(label)

    # Convert the lists to numpy arrays for further processing
    data = np.array(data)
    labels = np.array(labels)

    return data, labels


In [5]:
# generate the data
data, labels = generate_grouped_data(100000)

In [6]:
# train the network

# set the random seed
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

# define the hyperparameters
input_size = 4
hidden_size = 16
output_size = 2
num_epochs = 10000
batch_size = 10
learning_rate = 0.01

# define the model
model = MLP(input_size, hidden_size, output_size)

# define the loss function and the optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=1000, verbose=True)

# convert the data and labels to torch tensors
data, labels = generate_grouped_data(100000)
data = torch.from_numpy(data).float()
labels = torch.from_numpy(labels).long()

val_data, val_labels = generate_grouped_data(100)
val_data = torch.from_numpy(val_data).float()
val_labels = torch.from_numpy(val_labels).long()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

data = data.to(device)
labels = labels.to(device)
val_data = val_data.to(device)
val_labels = val_labels.to(device)
model = model.to(device)

# train the model
for epoch in range(num_epochs):
    # forward pass
    outputs = model(data)
    loss = criterion(outputs, labels)
    
    # backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step(loss)
    
    if (epoch+1) % 100 == 0:
        print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

# test the model
with torch.no_grad():
    # forward pass
    outputs = model(val_data)
    _, predicted = torch.max(outputs.data, 1)

    # compute the accuracy
    total = val_labels.size(0)
    correct = (predicted == val_labels).sum().item()
    print('Accuracy of the network on the 100 test data: {} %'.format(100 * correct / total))

Epoch [100/10000], Loss: 0.5622
Epoch [200/10000], Loss: 0.4283
Epoch [300/10000], Loss: 0.3946
Epoch [400/10000], Loss: 0.3785
Epoch [500/10000], Loss: 0.3682
Epoch [600/10000], Loss: 0.3628
Epoch [700/10000], Loss: 0.3797


## Causal DAGS

In [35]:
import networkx as nx
from collections import defaultdict
import random
import matplotlib.pyplot as plt

class DeterministicDAG:
    def __init__(self):
        self.G = nx.DiGraph()
        self.funcs = defaultdict(lambda: lambda x: x)
        self.validators = defaultdict(lambda: lambda x: True)
        self.samplers = defaultdict(lambda: None)

    def add_node(self, node, func=None, sampler=None, validator=None):
        self.G.add_node(node)
        if func is not None:
            self.funcs[node] = func
        if sampler is not None:
            self.samplers[node] = sampler
            if validator is None:
                valid_values = {sampler() for _ in range(1000)}
                self.validators[node] = lambda x: x in valid_values
        elif validator is not None:
            self.validators[node] = validator

    def add_edge(self, node1, node2):
        if node1 not in self.G or node2 not in self.G:
            raise ValueError(f"Nodes must exist before an edge can be created between them.")
        self.G.add_edge(node1, node2)

    def set_value(self, node, value):
        if not self.validators[node](value):
            raise ValueError(f"Invalid value {value} for node {node}. Value does not pass the validator.")
        self.G.nodes[node]['value'] = value

    def compute_value(self, node):
        if node in self.funcs:
            parents = list(self.G.predecessors(node))
            parent_values = [self.get_value(p) for p in parents] if parents else None
            if parent_values is None:
                self.set_value(node, None)
            elif len(parent_values) == 1:
                self.set_value(node, self.funcs[node](parent_values[0]))
            else:
                self.set_value(node, self.funcs[node](*parent_values))

    def set_inputs(self, **inputs):
        input_nodes = self.get_roots()
        for node, value in inputs.items():
            if node not in input_nodes:
                raise ValueError(f"Node {node} is not an input node.")
            self.G.nodes[node]['value'] = value

    def compute_all_values(self):
        [self.compute_value(node) for node in nx.topological_sort(self.G) if 'value' not in self.G.nodes[node]]

    def intervene(self, node, value):
        if not self.validators[node](value):
            raise ValueError(f"Invalid value {value} for node {node}")
        self.G.nodes[node]['value'] = value
        children = list(nx.descendants(self.G, node))
        for child in children:
            if 'value' in self.G.nodes[child]:
                del self.G.nodes[child]['value']
        self.compute_all_values()

    def set_inputs(self, **inputs):
        input_nodes = self.get_roots()
        for node, value in inputs.items():
            if node not in input_nodes:
                raise ValueError(f"Node {node} is not an input node.")
            self.G.nodes[node]['value'] = value

    def intervene_and_set_inputs(self, intervention_node, intervention_value, **inputs):
        self.set_inputs(**inputs)
        self.intervene(intervention_node, intervention_value)

    def sample_intervention(self, node):
        sampler = self.samplers[node]
        if sampler is None:
            raise ValueError(f"No sampler defined for node {node}")
        value = sampler()
        self.intervene(node, value)
        return value
    
    def get_roots(self):
        return [node for node, degree in self.G.in_degree() if degree == 0]

    def visualize(self):
        nx.draw_networkx(self.G, with_labels=True)
        plt.show()

In [91]:
class DeterministicDAG:
    def __init__(self):
        self.G = nx.DiGraph()
        self.funcs = defaultdict(lambda: lambda x: x)
        self.validators = defaultdict(lambda: lambda x: True)
        self.samplers = defaultdict(lambda: None)

    def add_node(self, node, func=None, sampler=None, validator=None):
        if sampler is None or validator is None:
            raise ValueError("Both a validator and a sampler must be provided when adding a node.")
        self.G.add_node(node)
        self.G.nodes[node]['intervened'] = False
        if func is not None:
            self.funcs[node] = func
        self.samplers[node] = sampler
        self.validators[node] = validator

    def add_edge(self, node1, node2):
        if node1 not in self.G or node2 not in self.G:
            raise ValueError(f"Nodes must exist before an edge can be created between them.")
        self.G.add_edge(node1, node2)

    def set_value(self, node, value):
        if not self.validators[node](value):
            raise ValueError(f"Invalid value {value} for node {node}. Value does not pass the validator.")
        self.G.nodes[node]['value'] = value
        self.G.nodes[node]['intervened'] = False

    def get_value(self, node):
        return self.G.nodes[node]['value']

    def compute_node(self, node):
        if self.G.nodes[node]['intervened']:
            return  # Do not recalculate the value of a node that has been intervened on
        if node in self.funcs:
            parents = list(self.G.predecessors(node))
            parent_values = [self.get_value(p) for p in parents] if parents else None
            if parent_values is None:
                self.set_value(node, None)
            elif len(parent_values) == 1:
                self.set_value(node, self.funcs[node](parent_values[0]))
            else:
                self.set_value(node, self.funcs[node](*parent_values))

    def run_inputs(self):
        for node in nx.topological_sort(self.G):
            if 'value' not in self.G.nodes[node]:
                self.compute_node(node)

    def intervene(self, node, value):
        if not self.validators[node](value):
            raise ValueError(f"Invalid value {value} for node {node}")
        self.G.nodes[node]['value'] = value
        self.G.nodes[node]['intervened'] = True  # Indicate that this node has been intervened on
        children = list(nx.descendants(self.G, node))
        for child in children:
            if 'value' in self.G.nodes[child]:
                del self.G.nodes[child]['value']
            self.G.nodes[child]['intervened'] = False  # Reset the 'intervened' attribute of the children nodes

    def run_with_intervention(self, intervention_node, intervention_value, **inputs):
        self.set_inputs(**inputs)
        self.intervene(intervention_node, intervention_value)
        self.run_inputs()

    def set_inputs(self, **inputs):
        input_nodes = self.get_roots()
        for node, value in inputs.items():
            if node not in input_nodes:
                raise ValueError(f"Node {node} is not an input node.")
            self.set_value(node, value)
        # Reset the 'intervened' attribute of all nodes
        for node in self.G.nodes:
            self.G.nodes[node]['intervened'] = False

    def reset_interventions(self):
        for node in self.G.nodes:
            self.G.nodes[node]['intervened'] = False

    def get_roots(self):
        return [node for node, degree in self.G.in_degree() if degree == 0]

    def visualize(self):
        nx.draw_networkx(self.G, with_labels=True)
        plt.show()


def float_validator(value):
    return -0.5 <= value <= 0.5

def float_sampler():
    return random.uniform(-0.5, 0.5)

def bool_validator(value):
    return value in {True, False}

def bool_sampler():
    return random.choice([True, False])

def copy_func(x):
    return x

def compare_func(value1, value2):
    return value1 == value2

dag = DeterministicDAG()

# Define the nodes
dag.add_node('x1', validator=float_validator, sampler=float_sampler)
dag.add_node('x2', validator=float_validator, sampler=float_sampler)
dag.add_node('x3', validator=float_validator, sampler=float_sampler)
dag.add_node('x4', validator=float_validator, sampler=float_sampler)
dag.add_node('c1', copy_func, validator=float_validator, sampler=float_sampler)
dag.add_node('c2', copy_func, validator=float_validator, sampler=float_sampler)
dag.add_node('c3', copy_func, validator=float_validator, sampler=float_sampler)
dag.add_node('c4', copy_func, validator=float_validator, sampler=float_sampler)
dag.add_node('b1', compare_func, validator=bool_validator, sampler=bool_sampler)
dag.add_node('b2', compare_func, validator=bool_validator, sampler=bool_sampler)
dag.add_node('y', compare_func, validator=bool_validator, sampler=bool_sampler)

# Define the edges
edges = [('x1', 'c1'), ('x2', 'c2'), ('x3', 'c3'), ('x4', 'c4'), 
         ('c1', 'b1'), ('c2', 'b1'), ('c3', 'b2'), ('c4', 'b2'), 
         ('b1', 'y'), ('b2', 'y')]
for edge in edges:
    dag.add_edge(*edge)


dag.set_inputs(x1=0.1, x2=0.2, x3=0.3, x4=0.4)

dag.run_inputs()

# print out all values
for node in dag.G.nodes:
    print(f"{node}: {dag.get_value(node)}")


import random
num_iterations = 100

for _ in range(num_iterations):
    # Select a node at random and perform a random intervention
    node = random.choice(list(dag.G.nodes))
    value = dag.samplers[node]()

    # reset all interventions
    dag.reset_interventions()

    dag.intervene(node, value)

    # Set input values
    input_values = {node: dag.samplers[node]() for node in dag.get_roots()}
    dag.set_inputs(**input_values)

    # Run the DAG
    dag.run_inputs()

    # Get the values of all nodes
    node_values = {node: dag.get_value(node) for node in dag.G.nodes}
    print(node_values)



x1: 0.1
x2: 0.2
x3: 0.3
x4: 0.4
c1: 0.1
c2: 0.2
c3: 0.3
c4: 0.4
b1: False
b2: False
y: True
{'x1': -0.016712011731899734, 'x2': -0.29636009562963006, 'x3': -0.4981568393843341, 'x4': 0.198991711803439, 'c1': 0.1, 'c2': 0.2, 'c3': 0.3, 'c4': 0.4, 'b1': False, 'b2': False, 'y': True}
{'x1': -0.4922233505641358, 'x2': -0.2014398789818792, 'x3': 0.2686342595428415, 'x4': 0.1289203785446209, 'c1': 0.1, 'c2': 0.2, 'c3': 0.3, 'c4': 0.4, 'b1': False, 'b2': False, 'y': True}
{'x1': -0.07579092900235929, 'x2': 0.25135249995764286, 'x3': -0.4074469530718122, 'x4': -0.002320403157995421, 'c1': 0.1, 'c2': 0.2, 'c3': 0.3, 'c4': 0.4, 'b1': True, 'b2': False, 'y': False}
{'x1': -0.09469489544564202, 'x2': 0.13206560152780367, 'x3': -0.4783806195793181, 'x4': -0.22992159512849142, 'c1': 0.1, 'c2': 0.2, 'c3': 0.3, 'c4': -0.22992159512849142, 'b1': True, 'b2': False, 'y': False}
{'x1': -0.4997467901067748, 'x2': -0.1015713920734681, 'x3': 0.3905618085087551, 'x4': 0.20996471789654325, 'c1': 0.1, 'c2': 0.

In [57]:
dag.validators['x1'].__code__

<code object <lambda> at 0x00000194193BD7C0, file "C:\Users\adamimos\AppData\Local\Temp\ipykernel_18600\4211659493.py", line 10>

In [30]:
dag.get_roots()

['x1', 'x2', 'x3', 'x4']

In [146]:
import platform
import matplotlib
import networkx as nx

print("Python version:", platform.python_version())
print("Matplotlib version:", matplotlib.__version__)
print("NetworkX version:", nx.__version__)


Python version: 3.9.16
Matplotlib version: 3.7.1
NetworkX version: 2.8.4
