In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dag_gflownet.scores.base import BaseScore
from dag_gflownet.scores.priors import UniformPrior
from dag_gflownet.scores.bge_score import BGeScore
import numpy as np
import pandas as pd
import networkx as nx
from tqdm import tqdm

In [3]:
data = pd.read_csv('/home/zehao/causal/jax-dag-gflownet/R_e_data/causal_BH_ell.csv')
data = (data - data.mean()) / data.std()  # Standardize data

In [4]:
bge = BGeScore(data, UniformPrior())

In [5]:
# generate a random DAG of 7 nodes
G = nx.fast_gnp_random_graph(7, 0.5, directed=True)

In [6]:
# show the adjacency matrix
nx.to_numpy_array(G)

array([[0., 1., 1., 0., 1., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 1., 0., 0., 1.],
       [0., 1., 0., 0., 0., 1., 1.],
       [0., 1., 1., 1., 0., 1., 0.],
       [1., 1., 0., 1., 1., 0., 1.],
       [1., 1., 1., 0., 0., 1., 0.]])

In [7]:
bge.score(G)

AttributeError: 'BGeScore' object has no attribute '_cache_local_scores'

In [5]:
correct_num_dags=[1, 1, 3, 25, 543, 29281, 3781503, 1138779265]

In [6]:
import itertools

In [10]:
def generate_all_dags_with_adj(n):
    graphs = []
    for adj_mat in tqdm(itertools.product([0, 1], repeat=n**2), total=2**(n**2)):
        adj_mat = np.array(adj_mat).reshape(n, n)
        G = nx.from_numpy_array(adj_mat, create_using=nx.DiGraph)
        if nx.is_directed_acyclic_graph(G):
            graphs.append(G)

    return graphs

In [65]:
len(generate_all_dags_with_adj(3)), correct_num_dags[3]

100%|██████████| 512/512 [00:00<00:00, 27975.87it/s]


(25, 25)

In [66]:
len(generate_all_dags_with_adj(4)), correct_num_dags[4]

100%|██████████| 65536/65536 [00:01<00:00, 33235.30it/s]


(543, 543)

In [None]:
len(generate_all_dags_with_adj(7))
## it will take way too long if starting from adjacency matrix.

  0%|          | 56732/562949953421312 [00:02<5677483:28:57, 27543.01it/s]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fc487f25010>>
Traceback (most recent call last):
  File "/home/zehao/anaconda3/envs/causal/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 
  0%|          | 364984/562949953421312 [00:13<5971303:46:33, 26187.75it/s]

In [59]:
# let's try to start from edges
from itertools import combinations

def contains_reverse_pairs(tup):
    # Convert each element into a tuple if not already one, then into a set for efficient searching
    seen = set()
    
    for item in tup:
        # Check if the reverse of the current item exists in the set
        if (item[1], item[0]) in seen:
            return True  # A reverse pair exists
        seen.add(item)  # Add the current item to the set for future checks
    
    return False  # No reverse pairs found
            

def generate_all_dags(n):
    """Generate all possible DAGs for n labeled nodes using networkx."""
    G = nx.DiGraph()
    G.add_nodes_from(range(n))  # Add n nodes to the graph
    all_edges = [(i, j) for i in range(n) for j in range(n) if i != j]
    all_dags = np.zeros((correct_num_dags[n], n, n),dtype=np.int8)
    ind=0

    for r in tqdm(range(len(all_edges) + 1)):
        for edges in combinations(all_edges, r):
            if not contains_reverse_pairs(edges):
                G.clear_edges()
                G.add_edges_from(edges)
                if nx.is_directed_acyclic_graph(G):
                    # convert to numpy array
                    adj_mat = nx.to_numpy_array(G)
                    # make sure the datatype is int 
                    #adj_mat = adj_mat.astype(int)
                    
                    all_dags[ind] = adj_mat
                    ind+=1

    return all_dags

In [62]:
n = 3
dags = generate_all_dags(n)
len(dags), correct_num_dags[n]

100%|██████████| 7/7 [00:00<00:00, 2256.04it/s]


(25, 25)

In [63]:
n = 4
dags = generate_all_dags(n)
len(dags), correct_num_dags[n]

100%|██████████| 13/13 [00:00<00:00, 477.38it/s]


(543, 543)

In [64]:
n = 5
dags = generate_all_dags(n)
len(dags), correct_num_dags[n]

100%|██████████| 21/21 [00:01<00:00, 10.93it/s]


(29281, 29281)

In [65]:
n = 6
dags = generate_all_dags(n)
len(dags), correct_num_dags[n]

100%|██████████| 31/31 [17:23<00:00, 33.66s/it]


(3781503, 3781503)