In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
# CLRS page 590 but I removed 2-3 edge
adjacency_matrix = torch.tensor(
    [
        [0, 1, 0, 0, 1],
        [1, 0, 0, 1, 1],
        [0, 0, 0, 1, 0],
        [0, 1, 1, 0, 1],
        [1, 1, 0, 1, 0],
    ],
    dtype=torch.float32,
)
adjacency_matrix

tensor([[0., 1., 0., 0., 1.],
        [1., 0., 0., 1., 1.],
        [0., 0., 0., 1., 0.],
        [0., 1., 1., 0., 1.],
        [1., 1., 0., 1., 0.]])

In [3]:
# Starting node (one-hot)
start = torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0])

# Target node (one-hot)
target = torch.tensor([0.0, 0.0, 1.0, 0.0, 0.0])

# 3 transition probability vectors (learnable parameters)
step1 = nn.Parameter(torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]))
step2 = nn.Parameter(torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]))
step3 = nn.Parameter(torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]))

optimizer = optim.SGD([step1, step2, step3], lr=0.1)

# Training loop
for epoch in range(1000):
    optimizer.zero_grad()

    # Forward pass: apply transitions
    state = start
    state = adjacency_matrix @ (state * step1)
    state = adjacency_matrix @ (state * step2)
    state = adjacency_matrix @ (state * step3)

    # Loss: MSE to target
    loss = nn.MSELoss()(state, target)

    # Backward pass
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

print(f"\nFinal state ({torch.argmax(step1.data)}): {state}")
print(f"Step 1 ({torch.argmax(step1.data)}) probs: {step1.data}")
print(f"Step 2 ({torch.argmax(step2.data)}) probs: {step2.data}")
print(f"Step 3 ({torch.argmax(step3.data)}) probs: {step3.data}")

Epoch 0, Loss: 0.1944
Epoch 100, Loss: 0.1207
Epoch 200, Loss: 0.0000
Epoch 300, Loss: 0.0000
Epoch 400, Loss: 0.0000
Epoch 500, Loss: 0.0000
Epoch 600, Loss: 0.0000
Epoch 700, Loss: 0.0000
Epoch 800, Loss: 0.0000
Epoch 900, Loss: 0.0000

Final state (0): tensor([-8.9432e-08,  1.7890e-07,  1.0000e+00, -8.9432e-08,  1.7881e-07],
       grad_fn=<MvBackward0>)
Step 1 (0) probs: tensor([0.9725, 0.2000, 0.2000, 0.2000, 0.2000])
Step 2 (1) probs: tensor([0.2000, 0.7019, 0.2000, 0.2000, 0.7019])
Step 3 (3) probs: tensor([-7.3255e-01, -4.3828e-08,  2.0000e-01,  7.3256e-01, -8.7199e-08])
