In [1]:
import numpy as np

class CPT:
    def __init__(self, array, axes):
        self.array = array
        self.axes = {name: axis for axis, name in enumerate(axes)}

    def other_axes(self, other_axis_name):
        return tuple(axis for name, axis in self.axes.items() if other_axis_name != name)

    def is_joint_prob(self):
        return np.all(np.isclose(np.sum(self.array), 1.0))
    
    def is_cond_prob(self, var_name):
        return np.all(np.isclose(np.sum(self.array, axis=self.axes[var_name]), 1.0))
    
    def __repr__(self) -> str:
        return str((self.array, self.axes))

In [2]:
p_v1_given_h1 = CPT(np.array([
    [0.4, 0.8, 0.9], 
    [0.6, 0.2, 0.1]
]), ["v1", "h1"])
p_v1_given_h1

(array([[0.4, 0.8, 0.9],
       [0.6, 0.2, 0.1]]), {'v1': 0, 'h1': 1})

In [3]:
p_h1 = CPT(np.array([.6, .3, .1]), ['h1'])
p_h1

(array([0.6, 0.3, 0.1]), {'h1': 0})

In [51]:
sorted(p_v1_given_h1.axes.items(), key=lambda x: x[1])[0]

('v1', 0)

In [7]:
def tile_to_shape_along_axis(arr, target_shape, target_axis):
    # get a list of all axes
    raw_axes = list(range(len(target_shape)))
    tile_dimensions = [target_shape[a] for a in raw_axes if a != target_axis]
    if len(arr.shape) == 0:
        # If given a scalar, also tile it in the target dimension (so it's a bunch of 1s)
        tile_dimensions += [target_shape[target_axis]]
    elif len(arr.shape) == 1:
        # If given an array, it should be the same shape as the target axis
        assert arr.shape[0] == target_shape[target_axis]
        tile_dimensions += [1]
    else:
        raise NotImplementedError()
    tiled = np.tile(arr, tile_dimensions)

    # Tiling only adds prefix axes, so rotate this one back into place
    shifted_axes = raw_axes[:target_axis] + [raw_axes[-1]] + raw_axes[target_axis:-1]
    transposed = np.transpose(tiled, shifted_axes)

    # Double-check this code tiled it to the correct shape
    assert transposed.shape == target_shape
    return transposed

def tile_to_other_dist_along_axis_name(tiling_labeled_array: CPT, target_array: CPT):
    assert len(tiling_labeled_array.axes) == 1
    target_axis_label = sorted(tiling_labeled_array.axes.items(), key=lambda x: x[1])[0][0]
    
    return CPT(
        tile_to_shape_along_axis(
            tiling_labeled_array.array,
            target_array.array.shape,
            target_array.axes[target_axis_label]
        ),
        axes = target_array.axes
    )

tiled_p_h1 = tile_to_other_dist_along_axis_name(p_h1, p_v1_given_h1)
tiled_p_h1

(array([[0.6, 0.3, 0.1],
       [0.6, 0.3, 0.1]]), {'v1': 0, 'h1': 1})

In [5]:
p_v1_given_h1.array

array([[0.4, 0.8, 0.9],
       [0.6, 0.2, 0.1]])

In [6]:
p_h1.array

array([0.6, 0.3, 0.1])

In [8]:
tiled_p_h1

(array([[0.6, 0.3, 0.1],
       [0.6, 0.3, 0.1]]), {'v1': 0, 'h1': 1})

In [9]:
p_v1_given_h1.array * p_h1.array

array([[0.24, 0.24, 0.09],
       [0.36, 0.06, 0.01]])

In [10]:
p_v1_given_h1.array * tiled_p_h1.array

array([[0.24, 0.24, 0.09],
       [0.36, 0.06, 0.01]])

In [4]:
np.isclose(np.sum(p_v1_given_h1.array * p_h1.array), 1.0)

True

In [None]:
class Node:
    def __init__(self, ) -> None:
        