from pgmpy.models import FactorGraph
from pgmpy.factors.discrete import DiscreteFactor
from pgmpy.inference import BeliefPropagation

# Define the factor graph
G = FactorGraph()

# Add variable nodes to the graph
variables = ['x0', 'x1', 'x2', 'x3', 'x4']
for var in variables:
    G.add_node(var)

# Define the S-box as a factor. You would replace 'sbox_factor_values' with the actual values.
sbox_factor_values = [0] * 32  # Placeholder for actual S-box values
sbox_factor = DiscreteFactor(variables, [2] * 5, sbox_factor_values)
G.add_factors(sbox_factor)

# Add edges between the variable nodes and the factor. This should reflect the actual S-box connections.
for var in variables:
    G.add_edge(var, sbox_factor)

# Define the XOR factors. You will need to add a factor for each XOR operation in your graph.
# You would repeat this block for each XOR operation with the correct variables involved.
xor_factor_values = [0, 1, 1, 0] * 4  # Represents the truth table for XOR
xor_factor = DiscreteFactor(['x0', 'x1', 'x3'], [2, 2, 2], xor_factor_values)  # Example for x0 XOR x1 = x3
G.add_factors(xor_factor)
G.add_edges_from([('x0', xor_factor), ('x1', xor_factor), ('x3', xor_factor)])

# Initialize belief propagation
bp = BeliefPropagation(G)

# Run belief propagation
bp.calibrate()

# Query marginal for each variable
for var in variables:
    print(f"Marginal for {var}:")
    print(bp.query(variables=[var]))


In [None]:
# The XOR truth table for binary variables A, B, C (A XOR B = C)
xor_truth_table = [
    [0, 0, 0, 1],
    [0, 1, 1, 1],
    [1, 0, 1, 1],
    [1, 1, 0, 1],
]

# Create the XOR factor
xor_factor = DiscreteFactor(['A', 'B', 'C'], [2, 2, 2], xor_truth_table)

# Add the XOR factor to the graph
G.add_factors(xor_factor)
G.add_edges_from([('A', xor_factor), ('B', xor_factor), ('C', xor_factor)])


In [1]:
from pgmpy.models import FactorGraph
from pgmpy.factors.discrete import DiscreteFactor
from pgmpy.inference import BeliefPropagation
import numpy as np

# Initialize the factor graph
G = FactorGraph()

# Add variable nodes for each bit
for i in range(5):
    G.add_node(f"x{i}")
    G.add_node(f"y{i}")

sbox_hex_mapping = {
    0x00: 0x04, 0x01: 0x0b, 0x02: 0x1f, 0x03: 0x14, 0x04: 0x1a, 0x05: 0x15, 0x06: 0x09, 0x07: 0x02,
    0x08: 0x1b, 0x09: 0x05, 0x0a: 0x1e, 0x0b: 0x13, 0x0c: 0x0d, 0x0d: 0x03, 0x0e: 0x06, 0x0f: 0x1c,
    0x10: 0x10, 0x11: 0x11, 0x12: 0x07, 0x13: 0x0e, 0x14: 0x00, 0x15: 0x0d, 0x16: 0x11, 0x17: 0x1e,
    0x18: 0x10, 0x19: 0x0c, 0x1a: 0x01, 0x1b: 0x19, 0x1c: 0x1c, 0x1d: 0x17, 0x1e: 0x1f, 0x1f: 0x17
}

# Convert the hexadecimal S-box mapping to binary tuples
sbox_mapping = {tuple(map(int, format(x, '05b'))): tuple(map(int, format(sbox_hex_mapping[x], '05b')))
                for x in sbox_hex_mapping}

sbox_factor_table = np.zeros((32, 32))

# Create the actual factors for the S-box and add them to the factor graph
for input_val, output_val in sbox_hex_mapping.items():
    # sbox_factor_table[input_val] = [0] * 32
    # if output_val == 0x04:
    #     sbox_factor_table[input_val][output_val] = 1
    # else:
    #     sbox_factor_table[input_val][output_val] = 1
    sbox_factor_table[input_val, output_val] = 1.0

# sbox_factor_values = sbox_factor_table.flatten()
sbox_factor_table /= sbox_factor_table.sum(axis=1, keepdims=True)


# Create a single S-box factor since S-box is a single function of five variables
sbox_factor = DiscreteFactor(['x0', 'x1', 'x2', 'x3', 'x4', 'y0', 'y1', 'y2', 'y3', 'y4'], [2] * 10, sbox_factor_table.flatten())
G.add_factors(sbox_factor)

# Add edges from the variable nodes to the S-box factor
for i in range(5):
    G.add_edge(f"x{i}", sbox_factor)
    G.add_edge(f"y{i}", sbox_factor)



# Convert input to binary
input_value = 0x14
input_bits = list(map(int, format(input_value, '05b')))

# Create and add evidence factors to the graph
for i, bit in enumerate(input_bits):
    evidence_factor = DiscreteFactor([f'x{i}'], [2], [1-bit, bit])
    G.add_factors(evidence_factor)
    G.add_edge(f'x{i}', evidence_factor)

# Initialize belief propagation
bp = BeliefPropagation(G)

# previous_marginals = None
# for _ in range(20):  # Adjust the number of iterations as needed
#     bp.calibrate()
#     current_marginals = [bp.query(variables=[f"y{i}"]).values for i in range(5)]  # Extract numerical values

#     # Check if marginals have converged (you can define a threshold for convergence)
#     if previous_marginals is not None:
#         converged = all(np.allclose(current_marginals[i], previous_marginals[i]) for i in range(5))
#         if converged:
#             print("Convergence achieved.")
#             break
    
#     previous_marginals = current_marginals
bp.calibrate()
# Query and print marginal probabilities
for i in range(5):
    print(bp.query(variables=[f"y{i}"]))

+-------+-----------+
| y0    |   phi(y0) |
| y0(0) |    1.0000 |
+-------+-----------+
| y0(1) |    0.0000 |
+-------+-----------+
+-------+-----------+
| y1    |   phi(y1) |
| y1(0) |    1.0000 |
+-------+-----------+
| y1(1) |    0.0000 |
+-------+-----------+
+-------+-----------+
| y2    |   phi(y2) |
| y2(0) |    1.0000 |
+-------+-----------+
| y2(1) |    0.0000 |
+-------+-----------+
+-------+-----------+
| y3    |   phi(y3) |
| y3(0) |    1.0000 |
+-------+-----------+
| y3(1) |    0.0000 |
+-------+-----------+
+-------+-----------+
| y4    |   phi(y4) |
| y4(0) |    1.0000 |
+-------+-----------+
| y4(1) |    0.0000 |
+-------+-----------+


[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 1. 0. 0. 0. 0.]
