In [69]:
from rashomon.extract_pools import lattice_edges
from rashomon.hasse import enumerate_policies, is_policies_sorted

import time
import numpy as np


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [60]:
M = 3
R = 4

policies = enumerate_policies(M, R)

In [70]:
is_policies_sorted(policies)

True

Original version

In [76]:
num_iter = 10

start_time = time.time()
for _ in range(num_iter):
    edges = lattice_edges(policies)
end_time = time.time()

print(f"Average time over {num_iter} iterations: {(end_time - start_time) / num_iter} seconds")

Average time over 10 iterations: 0.01595907211303711 seconds


New version with dictionary

In [11]:
def lattice_edges2(policies: list) -> list[tuple[int, int]]:
    """
    Enumerate the Hasse adjacencies
    """
    policy_map = {tuple(pol): i for i, pol in enumerate(policies)}
    edges = set()
    for i, pol in enumerate(policies):
        pol_tuple = tuple(pol)
        for dim in range(len(pol_tuple)):
            for delta in [-1, 1]:
                neighbor = list(pol_tuple)
                neighbor[dim] += delta
                neighbor_tuple = tuple(neighbor)
                if neighbor_tuple in policy_map:
                    j = policy_map[neighbor_tuple]
                    edges.add((min(i, j), max(i, j)))

    return list(edges)

In [62]:
num_iter = 1

start_time = time.time()
for _ in range(num_iter):
    edges2 = lattice_edges2(policies)
end_time = time.time()

print(f"Average time over {num_iter} iterations: {(end_time - start_time) / num_iter} seconds")

Average time over 1 iterations: 0.0010111331939697266 seconds


Version 3 utilizing ordering of policies

In [74]:
num_iter = 10

start_time = time.time()
for _ in range(num_iter):
    edges3 = lattice_edges(policies, sorted=True, M=M, R=R)
end_time = time.time()

print(f"Average time over {num_iter} iterations: {(end_time - start_time) / num_iter} seconds")

Average time over 10 iterations: 0.0 seconds


In [75]:
set(edges) == set(edges3)

True

Version 4 without using policies

In [71]:
def lattice_edges4( M: int, R: np.ndarray | int) -> list[tuple[int,int]]:
    if isinstance(R, int):
        R = np.array([R] * M)
    
    # Compute total number of policies
    total_policies = np.prod(R)
    # for val in R:
    #     total_policies *= val

    # Precompute offsets
    offsets = [1] * M
    for i in reversed(range(M - 1)):
        offsets[i] = offsets[i + 1] * R[i + 1]

    edges = []
    for i in range(total_policies):
        for d in range(M):
            neighbor = i + offsets[d]
            # First check that neighbor is within total range
            if neighbor < total_policies:
                # Check if dimension d is not at its boundary
                block_index = i // offsets[d]
                # block_base is the start of that dimension's 'block'
                block_base = (block_index // R[d]) * R[d]
                # If block_index - block_base < R[d] - 1, we haven't reached the boundary
                if (block_index - block_base) < (R[d] - 1):
                    edges.append((i, neighbor))

    return edges

In [72]:
num_iter = 1

start_time = time.time()
for _ in range(num_iter):
    edges4 = lattice_edges4(M, R)
end_time = time.time()

print(f"Average time over {num_iter} iterations: {(end_time - start_time) / num_iter} seconds")

Average time over 1 iterations: 0.2124629020690918 seconds


In [73]:
set(edges) == set(edges4)

True