In [2]:
import numpy as np

def logsumexp(a, axis=None):
    a = np.asarray(a)
    m = np.max(a, axis=axis, keepdims=True)
    return (m + np.log(np.sum(np.exp(a - m), axis=axis, keepdims=True))).squeeze()

def build_column_states(n=10):
    # states: 0..2^n-1, bit r is row r+1
    # 用 uint16 足够容纳 n<=16
    return np.arange(1 << n, dtype=np.uint16)

def count_equal_adjacent_in_col(state, n=10):
    diffs = ((state >> 0) ^ (state >> 1)) & ((1 << (n - 1)) - 1)
    return (n - 1) - int(diffs).bit_count()

def exact_joint_top_bottom(beta, n=10, L=10):
    """
    Exact inference on n x L grid (n rows, L columns),
    returns joint table P(x_{1,L}, x_{n,L}) for x in {0,1}.
    """
    states = build_column_states(n)
    S = len(states)  # 2^n

    # ----- unary log potential: log psi_t(x_t)
    eq_adj = np.array([count_equal_adjacent_in_col(s, n) for s in states], dtype=np.int16)
    log_unary = beta * eq_adj.astype(np.float64)   # (S,)

    # ----- pairwise log potential: log psi(a,b) = beta * (n - Hamming(a,b))
    # 关键修复：用 unpackbits 计算 popcount( a XOR states )，不使用 int.bit_count 直接喂 numpy dtype
    log_pair = np.empty((S, S), dtype=np.float64)

    # 为了 unpackbits，我们把 uint16 视作两个 uint8（低字节+高字节），然后按位展开求和
    # popcount(xor) = sum of bits of its bytes
    xor_bytes_view = states.astype(np.uint16)[:, None]  # (S,1) 用于广播时保持形状一致（仅用在 view）
    for i, a in enumerate(states):
        xor = (a ^ states).astype(np.uint16)                      # (S,)
        bits = np.unpackbits(xor[:, None].view(np.uint8), axis=1) # (S,16) bits
        hd = bits.sum(axis=1)                                     # (S,) popcount = Hamming distance
        log_pair[i, :] = beta * (n - hd)

    # ----- forward messages (log domain)
    alpha = np.empty((L, S), dtype=np.float64)
    alpha[0, :] = log_unary

    for t in range(1, L):
        tmp = alpha[t-1, :][:, None] + log_pair  # (S,S): add to each row
        alpha[t, :] = log_unary + logsumexp(tmp, axis=0)

    # last column marginal
    log_marg_last = alpha[L-1, :].copy()
    log_marg_last -= logsumexp(log_marg_last)  # normalize
    marg_last = np.exp(log_marg_last)

    # ----- joint P(x_top, x_bottom)
    joint = np.zeros((2, 2), dtype=np.float64)
    for s, p in zip(states, marg_last):
        x_top = (s >> 0) & 1
        x_bot = (s >> (n - 1)) & 1
        joint[x_top, x_bot] += p

    joint /= joint.sum()
    return joint

def print_joint(joint, title=""):
    if title:
        print(title)
    print("Rows: x_{1,10} (top) = 0,1 ; Cols: x_{10,10} (bottom) = 0,1")
    print(np.array_str(joint, precision=6, suppress_small=False))
    print("Row sums:", joint.sum(axis=1))
    print("Col sums:", joint.sum(axis=0))
    print("Total:", joint.sum())
    print()

if __name__ == "__main__":
    for beta in [4.0, 1.0, 0.01]:
        joint = exact_joint_top_bottom(beta=beta, n=10, L=10)
        print_joint(joint, title=f"Exact inference (Method 1), beta={beta}")


Exact inference (Method 1), beta=4.0
Rows: x_{1,10} (top) = 0,1 ; Cols: x_{10,10} (bottom) = 0,1
[[4.996520e-01 3.479759e-04]
 [3.479759e-04 4.996520e-01]]
Row sums: [0.5 0.5]
Col sums: [0.5 0.5]
Total: 1.0

Exact inference (Method 1), beta=1.0
Rows: x_{1,10} (top) = 0,1 ; Cols: x_{10,10} (bottom) = 0,1
[[0.280447 0.219553]
 [0.219553 0.280447]]
Row sums: [0.5 0.5]
Col sums: [0.5 0.5]
Total: 1.0

Exact inference (Method 1), beta=0.01
Rows: x_{1,10} (top) = 0,1 ; Cols: x_{10,10} (bottom) = 0,1
[[0.25 0.25]
 [0.25 0.25]]
Row sums: [0.5 0.5]
Col sums: [0.5 0.5]
Total: 0.9999999999999999

