In [6]:
import numpy as np
import networkx as nx
import os
import multiprocessing as mp
import math

from pathlib import Path
from tqdm.auto import tqdm

In [2]:
root = Path(os.getenv('SLURM_TMPDIR'))

In [8]:
dags_compressed = np.load(root / 'dags_7_final.npy')

In [9]:
dags_compressed.shape

(1138779265, 7)

In [10]:
def verify_dag(batch):
    is_dag = np.zeros((len(batch),), dtype=np.bool_)

    for i, dag_compressed in enumerate(batch):
        adjacency = np.unpackbits(dag_compressed, count=7 ** 2)
        adjacency = adjacency.reshape(7, 7)
    
        graph = nx.from_numpy_array(adjacency, create_using=nx.DiGraph)
        is_dag[i] = nx.is_directed_acyclic_graph(graph)

    return is_dag

In [11]:
def batch_iterate(array, batch_size=1):
    for i in range(0, len(array), batch_size):
        yield array[i:i + batch_size]

In [13]:
num_dags = dags_compressed.shape[0]
batch_size = 2048 * 8

with mp.Pool(60) as pool:
    is_dag = list(tqdm(pool.imap_unordered(verify_dag,
        batch_iterate(dags_compressed, batch_size)),
    total=math.ceil(num_dags / batch_size)))

  0%|          | 0/69506 [00:00<?, ?it/s]

In [14]:
del dags_compressed

In [15]:
is_dag_2 = np.concatenate(is_dag, axis=0)

In [16]:
is_dag_2.shape

(1138779265,)

In [18]:
np.all(is_dag_2)

True