In [10]:
import math, random, torch, torch.nn as nn
from torch.utils.data import DataLoader

# Vocabulaire

In [None]:
SPECIAL      = ["<pad>", "<bos>", "<eos>"]
BASE_CHARS   = list("0123456789+= ")
VOCAB        = SPECIAL + BASE_CHARS
PAD, BOS, EOS = SPECIAL
char2idx     = {ch: i for i, ch in enumerate(VOCAB)}
idx2char     = {i: ch for ch, i in char2idx.items()}
VOCAB_SIZE   = len(VOCAB)

INPUT_LEN  = 9                   # «99 + 99 » = 9 car
MAX_LEN = INPUT_LEN + 2          # 9 + 2 = 11

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Env

In [26]:
TARGET_EXPR = "2 + 2"                         # 7 caractères (inclut espaces)
def tokens_to_str(tokens):
    return ''.join(idx2char[t] for t in tokens).rstrip()

In [27]:
class AddSeqEnv:
    def __init__(self, max_len=MAX_LEN):
        self.max_len = max_len
    def reset(self):
        self.state = [char2idx[BOS]]
        return self.state
    def step(self, action):
        self.state.append(action)
        done = (action == char2idx[EOS])

        # Abort si longueur dépasse la borne
        if not done and len(self.state) >= self.max_len:
            return self.state, 0.01, True

        if done:
            expr = tokens_to_str(self.state[1:-1])   # sans BOS/EOS
            reward = 1.0 if expr == TARGET_EXPR else 0.01
            return self.state, reward, True

        return self.state, 0.0, False

# Flow Net

In [28]:
# ----------------- Policy réseau = log-flux Fθ(s,a)
class FlowNet(nn.Module):
    def __init__(self, d_model=128, n_heads=4, n_layers=2):
        super().__init__()
        self.emb = nn.Embedding(VOCAB_SIZE, d_model, padding_idx=char2idx[PAD])
        self.pos = nn.Embedding(MAX_LEN + 1, d_model)

        enc_layer = nn.TransformerEncoderLayer(
            d_model, n_heads, dropout=0.1, batch_first=True
        )
        self.tr = nn.TransformerEncoder(enc_layer, n_layers)

        self.fc = nn.Linear(d_model, VOCAB_SIZE)
        self.fc.weight = self.emb.weight          # weight tying
        self.logZ = nn.Parameter(torch.zeros(()))

    @staticmethod
    def causal_mask(sz, device):
        # shape (sz, sz) : True au-dessus de la diagonale
        return torch.triu(torch.full((sz, sz), float('-inf'), device=device), 1)

    def forward(self, prefix):
        """
        prefix : (B, L)  – renvoie log p(a | s) pour chaque token du vocab.
        """
        B, L = prefix.shape
        pos = torch.arange(L, device=prefix.device).unsqueeze(0)
        x = self.emb(prefix) + self.pos(pos)

        mask = self.causal_mask(L, prefix.device)
        x = self.tr(x, mask=mask)                # (B, L, d_model)

        logits = self.fc(x)                      # (B, L, V)
        return torch.log_softmax(logits[:, -1], -1)  # proba du prochain token


In [29]:
# ---------- Trajectory Balance loss (corrigé)
def tb_loss(model, trajectories, rewards):
    losses = []
    for tokens, R in zip(trajectories, rewards):
        logfwd = 0.0
        # parcours des transitions s_k -> a_k (= tokens[k+1])
        for k in range(len(tokens) - 1):
            prefix = torch.tensor([tokens[:k+1]], device=DEVICE)  # s_k
            logp   = model(prefix)                                # log π(a|s)
            logfwd = logfwd + logp[0, tokens[k+1]]
        logbwd = 0.0        # backward uniform (placeholder)
        TB = logfwd + model.logZ - math.log(R) - logbwd
        losses.append(TB**2)
    return torch.stack(losses).mean()


In [30]:
def sample_trajectory(model, env):
    state = env.reset()
    done  = False
    while not done:
        prefix = torch.tensor([state], device=DEVICE)
        logp   = model(prefix)               # (1, V)
        # Catégoriel => choix
        action = torch.distributions.Categorical(logits=logp).sample().item()
        state, reward, done = env.step(action)
    return state, reward

In [31]:
def train_flow(model, steps=2, batch_size=32):
    env     = AddSeqEnv()
    opt     = torch.optim.Adam(model.parameters(), lr=3e-4)
    for step in range(1, steps+1):
        trajs, rewards = [], []
        for _ in range(batch_size):
            t, r = sample_trajectory(model, env)
            trajs.append(t); rewards.append(r)
        loss = tb_loss(model, trajs, rewards)
        opt.zero_grad(); loss.backward(); opt.step()

        if step % 1 == 0:
            print(f"[{step:05d}] loss={loss.item():.4f}  logZ={model.logZ.item():.2f}")


# Generation

In [32]:
def generate(model, start="<bos>", max_len=MAX_LEN):
    model.eval()
    ids = [char2idx[start]] if start != "<bos>" else [char2idx[BOS]]
    for _ in range(max_len):
        prefix = torch.tensor([ids], device=DEVICE)
        logp   = model(prefix)
        tok    = int(logp.argmax(-1))
        ids.append(tok)
        if tok == char2idx[EOS]:
            break
    return ''.join(idx2char[i] for i in ids[1:-1])   # sans BOS/EOS

In [33]:
torch.manual_seed(0)
model = FlowNet().to(DEVICE)
print("Entraînement...")
train_flow(model, steps=3, batch_size=64)   # test rapide

for _ in range(5):
    print(">>", generate(model))

Entraînement...
[00001] loss=21.2071  logZ=-0.00
[00002] loss=20.2666  logZ=-0.00
[00003] loss=21.2019  logZ=-0.00
>> <bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos>
>> <bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos>
>> <bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos>
>> <bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos>
>> <bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos><bos>
