In [187]:
import torch
import math
import functools
from itertools import product
from torch import nn

In [188]:
D = 4
H = 8
R0 = 0.1
batch = 10

In [189]:
def reward(pos: torch.Tensor):
    reward_ = R0
    reward_ += functools.reduce(lambda a, b: a * b, [0.25 < abs(pos[d] / (H - 1) - 0.5) <= 0.5 for d in range(D)])
    reward_ += functools.reduce(lambda a, b: a * b, [0.30 < abs(pos[d] / (H - 1) - 0.5) < 0.4 for d in range(D)])
    return reward_


rewards = torch.zeros(*[H for i in range(D)])
coords = [range(H) for _ in range(D)]

for coord in product(*coords):
    rewards[tuple(coord)] = reward(torch.tensor(coord))

rewards /= rewards.sum()

In [190]:
class GFlowNet(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()

        self.hidded_size = hidden_size
        self.linear_1 = nn.Linear(D * H, self.hidded_size)
        self.linear_2 = nn.Linear(self.hidded_size, D + 2)

        self.activation = nn.LeakyReLU()

    def forward(self, x: torch.tensor) -> torch.tensor:
        pred = self.activation(self.linear_1(x))
        pred = self.linear_2(pred)

        return pred


model1 = GFlowNet(128)
optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.001)

In [191]:
def code(state):
    res = torch.zeros(batch, H * D)
    for k in range(batch):
        for d in range(D):
            res[k][int(state[k][d]) + d * H] = 1.0
    return res

In [192]:
def empirical_loss(samples):
    counter = torch.zeros(*[H for i in range(D)])
    for sample in samples:
        counter[tuple(sample)] += 1
    counter /= counter.sum()
    return (rewards - counter).abs().sum()

In [196]:
def train(model: GFlowNet, optimizer, loss_fn, n_epochs):
    visited = []
    for epoch in range(1, n_epochs + 1):
        optimizer.zero_grad()

        x = loss_fn(model)
        loss = x[0]
        loss.backward()
        optimizer.step()

        for state in x[1]:
            visited.append(state)

        if epoch % 50 == 0: print(f"Epoch {epoch}, loss = {empirical_loss(visited[-20000:])}")

In [197]:
def pred_next(probs, is_sampled, states):
    nxt = torch.distributions.categorical.Categorical(probs).sample()

    correct = True
    for k in range(batch):
        if is_sampled[k] or nxt[k] == D: continue
        if states[k][nxt[k]] == H - 1: correct = False

    if not correct: return pred_next(probs, is_sampled, states)
    return nxt

In [198]:
def detailed_balance_loss(model: GFlowNet):
    is_sampled = torch.zeros(batch)
    states = torch.zeros(batch, D, dtype=torch.int64)
    ways = [[] for i in range(batch)]

    finished = False
    while not finished:
        pred = model(code(states))
        probs = pred[:, 0: -1]
        flow = pred[:, -1]

        finished = True
        for j in range(0, batch):
            if is_sampled[j]: continue
            finished = False
            for i in range(0, D):
                if states[j][i] == H - 1:
                    probs[j][i] = -float("inf")

        probs = torch.nn.Softmax(dim=1)(probs)

        nxt = pred_next(probs, is_sampled, states)

        for k in range(0, batch):
            if is_sampled[k]: continue
            cnt_b = max(1, functools.reduce(lambda x, y: x + (y >= 1), states[k]))
            ways[k].append([flow[k], probs[k][nxt[k]].log(), math.log(1 / cnt_b)])
            if nxt[k] == D:
                ways[k].append([math.log(reward(states[k])), 0, 0])
                is_sampled[k] = True
            else:
                states[k][nxt[k]] += 1

    loss = 0
    for j in range(batch):
        for i in range(0, len(ways[j]) - 1):
            loss += (ways[j][i][0] + ways[j][i][1] - ways[j][i + 1][0] - ways[j][i + 1][2]) ** 2
    loss /= batch
    
    return loss, states

In [199]:
train(model1, optimizer1, detailed_balance_loss, 10000)

Epoch 50, loss = 1.8829225301742554
Epoch 100, loss = 1.836082100868225
Epoch 150, loss = 1.7919085025787354
Epoch 200, loss = 1.7343473434448242
Epoch 250, loss = 1.6825450658798218
Epoch 300, loss = 1.62959885597229
Epoch 350, loss = 1.5723975896835327
Epoch 400, loss = 1.5110070705413818
Epoch 450, loss = 1.4572871923446655
Epoch 500, loss = 1.4095252752304077
Epoch 550, loss = 1.3659405708312988
Epoch 600, loss = 1.3230280876159668
Epoch 650, loss = 1.2846542596817017
Epoch 700, loss = 1.2519886493682861
Epoch 750, loss = 1.2312843799591064
Epoch 800, loss = 1.2120141983032227
Epoch 850, loss = 1.1957565546035767
Epoch 900, loss = 1.1803474426269531
Epoch 950, loss = 1.1628798246383667
Epoch 1000, loss = 1.1514924764633179
Epoch 1050, loss = 1.1411404609680176
Epoch 1100, loss = 1.1320475339889526
Epoch 1150, loss = 1.1205109357833862
Epoch 1200, loss = 1.1133849620819092
Epoch 1250, loss = 1.104933500289917
Epoch 1300, loss = 1.095579981803894
Epoch 1350, loss = 1.0885746479034424