# Libs

In [1]:
import os
import matplotlib.pyplot as plt
import pandas as pd
import wandb
from datetime import datetime

import pickle
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# S√©quence d'Entra√Ænement

L'entra√Ænement se fait en 4 phases it√©ratives, align√©es sur le framework LUCIDE pour unifier l'inf√©rence causale via un √©quilibre dynamique entre les distributions. Focus sur la coh√©rence bay√©sienne, en exploitant les GFlowNets pour g√©rer les espaces compositionnels et l'apprentissage distributionnel pour quantifier l'incertitude.

### Phase 1: Environmental Grounding
**Objectif** : Aligner les distributions de pr√©diction et de fr√©quence de l'environnement sur les observations environnementales.
- Apprentissage de la prior environnementale $p^{env}_\theta$ (mod√©lis√©e via GFlowNet, nomm√©e **llm_prior_env_model**) : Aligne sur les fr√©quences observ√©es.
  - Optimisation : $\theta^* = \arg \min_\theta \mathbb{E}_{x \sim p^{env}} \left[ \log \frac{p^{\text{env}}(x)}{p^{env}_\theta(x)}\right]$
- Apprentissage du mod√®le conditionnel Seq-to-Seq $p_\phi^{LLM}(y|x)$ : Entra√Ænement autoregressif standard sur donn√©es environnementales.
  - Optimisation : $\phi^* = \arg \min_\phi \mathbb{E}_{(x,y) \sim p^{\text{env}}} \left[ - \log p_{\phi}^{\text{LLM}}(y|x) \right]$

### Phase 2: Internal Belief Consolidation
**Objectif** : Mettre √† jour le syst√®me de croyances internes en utilisant les distributions apprises (analogue √† une exploration onirique des structures de croyances).
- Apprentissage de la prior interne $p^{internal}_\psi$ (mod√©lis√©e via GFlowNet, int√©gr√©e au prior unifi√© $p_{\theta,\psi}^{\text{prior}} \propto p_\theta^{env} \times p_\psi^{internal}$).
  - Cherche la coh√©rence entre prior et posterior : $p^{env}_\theta(x) \times p^{internal}_\psi(x) \times p_{\phi}^{LLM}(y|x) \propto p^{env}(x|y), \quad \forall x \sim p^{internal}$
  - Optimisation : $\psi^* = \arg \min_\psi \mathbb{E}_{x \sim p^{internal}_\psi} \left[ \left(\log \frac{Z_\psi^{\text{internal}} \times p_\psi^{\text{internal}}(x)}{R(x)}\right)^2 \right]$, o√π $R(x) = p^{\text{env}}_\theta(x) \times p^{\text{internal}}_\psi(x) \times p_\phi^{\text{LLM}} (y|x)$

### Phase 3: Adversarial Exploration
**Objectif** : D√©couvrir les s√©quences qui violent la coh√©rence bay√©sienne ‚Äì identifier les angles morts du raisonnement.
- Apprentissage de la distribution adversariale $p_\omega^{adv}$ (mod√©lis√©e via GFlowNet, nomm√©e LLM_ADVERSARIAL).
  - G√©n√®re des contextes o√π le mod√®le √©choue : $p^{env}_\theta(x) \times p^{internal}_\psi(x) \times p_{\phi}^{LLM}(y|x) \not\propto p^{env}(x|y), \quad \forall x \sim p^{adv}_\omega$
  - Maximise la divergence bay√©sienne : $\omega^* = \arg \max_\omega \mathbb{E}_{x \sim p^{adv}_\omega} \left[ \left(\log \frac{Z_\omega^{\text{adv}} \times p_\omega^{\text{adv}}(x)}{R(x)}\right)^2 \right]$


### Phase 4: Adversarial Correction
**Objectif** : Restaurer la coh√©rence bay√©sienne sur les contextes adversariaux ‚Äì apprendre des erreurs.
- Ajustement du mod√®le g√©n√©ratif Seq-to-Seq sur les exemples adversariaux.
  - Optimisation : $\phi^* = \arg \min_\phi \mathbb{E}_{x \sim p^{\text{adv}}_\omega} \left[ - \log \left( p^{\text{env}}_\theta(x) \times p^{\text{internal}}_\psi(x) \times p_\phi^{\text{LLM}}(y|x) \right) \right]$
- Boucle de renforcement : Int√®gre les ajustements pour maximiser la vraisemblance marginale $p(y)$ sans calculer explicitement le posterior intractable.

Il faut 4 fonctions de loss principales, adapt√©es aux phases :
- $\mathcal{L}_{Phase1}$ (pour $p^{env}_\theta$ et Seq-to-Seq) : Bas√©e sur KL-divergence et cross-entropy.
- $\mathcal{L}_{Phase2}$ (pour $p^{internal}_\psi$) : MSE sur log-ratios pour coh√©rence interne.
- $\mathcal{L}_{Phase3}$ (pour $p^{adv}_\omega$) : Maximisation via MSE invers√©e.
- $\mathcal{L}_{Phase4}$ (pour Seq-to-Seq sur adversariaux) : N√©gative log-likelihood pond√©r√©e.

Inputs : logprobs (B,), Z() pour les GFlowNets.

## Formatage des Entr√©es/Sorties

Adapt√© au dataset d'addition pour validation (a+b=c, a,b in [0,99], train/exclut [40,49], eval sur [40,49]). Vocab : digits 0-9, '+', '='. Utilise GFlowNets pour sampling compositionnel et RL distributionnel pour uncertainty.

# Model I/O Formatting

### üß† LLM PRIOR ENV (GFlowNet)
- **Input:** initial state  $s_0 = \langle \text{BOS} \rangle$
- **Output sequence example:**  $\langle \text{BOS} \rangle \, a\, a\, a \, + \, b\, b\, b \, = \langle \text{EOS} \rangle$
- **Max length:**  `LLM_PRIOR_ENV_MAX_LEN_PRED = 13`

---

### üîÅ Seq-to-Seq Model
- **Encoder input:** cleaned sequence (no BOS) $\text{"aaa + bbb =<EOS>"}$
  ‚Üí `SEQ_2_SEQ_MAX_LEN_ENCODER = 12`
- **Decoder input:** forced BOS token $\langle \text{BOS} \rangle \text{"xxxx"}$
  ‚Üí `SEQ_2_SEQ_MAX_LEN_DECODER = 5`

---

### ‚ö° LLM ADVERSARIAL (GFlowNet)
- **Input:**  $s_0 = \langle \text{BOS} \rangle$
- **Output sequence example:**  $\langle \text{BOS} \rangle \, a\, a\, a \, + \, b\, b\, b \, = \langle \text{EOS} \rangle$
- **Max length:**  
  `LLM_PRIOR_ENV_MAX_LEN_PRED = 13`

---

## üîç Model Outputs

Each model returns a tuple: $(\log p(\text{seq}) \in \mathbb{R}^B,\; \ln Z \in \mathbb{R})$

where:
- **$\log p(\text{seq})$**: sequence log-probability (per batch),
- **$\ln Z$**: learnable normalization scalar (partition energy).

In [2]:
def loss_LLM_PRIOR_ENV(out_LLM_PRIOR_ENV, out_Seq2Seq):
    """
    Compute the Trajectory Balance loss for the LLM PRIOR ENV (GFlowNet) phase.

    Args:
        out_LLM_PRIOR_ENV: tuple(logp_LLM_PRIOR_ENV (B,), lnZ_LLM_PRIOR_ENV scalar)
        out_Seq2Seq: logp_Seq2Seq (B,)

    Returns:
        loss_LLM_PRIOR_ENV: scalar tensor (for backprop on LLM_PRIOR_ENV)
    """
    logp_LLM_PRIOR_ENV, lnZ_LLM_PRIOR_ENV = out_LLM_PRIOR_ENV
    logp_Seq2Seq_detached = out_Seq2Seq.detach()

    # Reward = logp_prior_env + logp_seq2seq (detached)
    reward = logp_LLM_PRIOR_ENV + logp_Seq2Seq_detached

    # Trajectory Balance Loss: (reward - lnZ + logp_prior_env)^2
    loss = ((reward - lnZ_LLM_PRIOR_ENV + logp_LLM_PRIOR_ENV) ** 2).mean()

    return loss


def loss_LLM_ADVERSERIAL(out_LLM_PRIOR_ENV, out_LLM_ADVERSERIAL, out_Seq2Seq):
    """
    Compute the adversarial GFlowNet loss.

    Args:
        out_LLM_PRIOR_ENV: tuple(logp_prior_env (B,), lnZ_prior_env scalar)
        out_LLM_ADVERSERIAL: tuple(logp_adverserial (B,), lnZ_adverserial scalar)
        out_Seq2Seq: tuple(logp_seq2seq (B,), lnZ_seq2seq scalar)

    Returns:
        loss_LLM_ADVERSERIAL: scalar tensor (for backprop on LLM_ADVERSERIAL)
    """
    logp_prior_env, lnZ_prior_env = out_LLM_PRIOR_ENV
    logp_adverserial, lnZ_adverserial = out_LLM_ADVERSERIAL
    logp_seq2seq, _ = out_Seq2Seq

    # D√©tachement des termes exog√®nes
    logp_prior_env_detached = logp_prior_env.detach()
    logp_seq2seq_detached = logp_seq2seq.detach()

    # Adversarial loss = (reward)^(-2)
    loss = (logp_prior_env_detached + logp_seq2seq_detached + logp_adverserial - lnZ_adverserial + logp_adverserial).pow(-2).mean()

    return loss


def loss_LLM_RENFORCEMENT(out_LLM_PRIOR_ENV, out_LLM_ADVERSERIAL, out_Seq2Seq):
    """
    Compute the reinforcement (Seq-to-Seq) loss.

    Args:
        out_LLM_PRIOR_ENV: tuple(logp_prior_env (B,), lnZ_prior_env scalar)
        out_LLM_ADVERSERIAL: tuple(logp_adversarial (B,), lnZ_adversarial scalar)
        out_Seq2Seq: tuple(logp_seq2seq (B,), lnZ_seq2seq (B,))

    Returns:
        loss_LLM_RENFORCEMENT: scalar tensor (for backprop on Seq-to-Seq)
    """
    logp_prior_env, _ = out_LLM_PRIOR_ENV
    logp_adversarial, _ = out_LLM_ADVERSERIAL
    logp_seq2seq, lnZ_seq2seq = out_Seq2Seq

    # D√©tachement des termes exog√®nes
    logp_prior_env_detached = logp_prior_env.detach()
    logp_adversarial_detached = logp_adversarial.detach()

    # Reinforcement loss = (reward)^(2)
    loss = (logp_prior_env_detached + logp_seq2seq + logp_adversarial_detached - lnZ_seq2seq + logp_seq2seq).pow(2).mean()

    return loss

In [3]:
class LLMDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=32, n_layers=1, n_heads=1, dropout=0.1, lnZ_shape=()):
        """
        Transformer Decoder pour LLM_PRIOR_ENV ou LLM_ADVERSERIAL.
        Peut fonctionner en mode on-policy (g√©n√©ration) ou off-policy (log prob).
        """
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, n_heads, dim_feedforward=4*d_model, dropout=dropout, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
        self.output = nn.Linear(d_model, vocab_size)
        self.lnZ = nn.Parameter(torch.zeros(lnZ_shape, requires_grad=True))
        self.d_model = d_model

    # ======================================
    #   UTILITAIRES MASQUAGE AUTOREGRESSIF
    # ======================================
    def _causal_mask(self, T):
        """Masque triangulaire inf√©rieur pour bloquer la fuite future (autoregressif)."""
        return torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)

    # ======================================
    #   FORWARD AUTOREGRESSIF
    # ======================================
    def forward(self, input_ids):
        """
        input_ids: (B, T)
        Sort: logits (B, T, vocab_size)
        """
        B, T = input_ids.shape
        x = self.embed(input_ids) * (self.d_model ** 0.5)
        mask = self._causal_mask(T).to(x.device)
        h = self.decoder(x, x, tgt_mask=mask)
        return self.output(h)

    # ======================================
    #   MODE OFF-POLICY : log p(seq)
    # ======================================
    def log_prob(self, tokens):
        """
        Calcule la log-probabilit√© totale d'une s√©quence donn√©e (off-policy).
        tokens: (B, T)
        Retourne: (logp_total (B,), lnZ)
        """
        B, T = tokens.shape
        logits = self.forward(tokens[:, :-1])                 # (B, T-1, vocab)
        logp = F.log_softmax(logits, dim=-1)
        next_tokens = tokens[:, 1:].unsqueeze(-1)
        token_logp = logp.gather(2, next_tokens).squeeze(-1)  # (B, T-1)
        logp_total = token_logp.sum(dim=1)                    # (B,)
        return logp_total, self.lnZ

    # ======================================
    #   MODE ON-POLICY : g√©n√©ration par sampling
    # ======================================
    def generate(self, batch_size, max_len, bos_token_id, eos_token_id, device):
        """
        G√©n√©ration batch√©e diff√©rentiable sur logp(seq).
        Retourne: tokens (B, T_gen), logp_total (B,), lnZ
        """
        tokens = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
        finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
        logp_total = torch.zeros(batch_size, device=device)

        for _ in range(max_len):
            logits = self.forward(tokens)[:, -1, :]        # (B, vocab)
            logp = F.log_softmax(logits, dim=-1)           # (B, vocab)

            next_tokens = torch.multinomial(logp.exp(), 1) # (B, 1)
            next_logp = logp.gather(1, next_tokens).squeeze(1)

            logp_total += next_logp * (~finished)
            tokens = torch.cat([tokens, next_tokens], dim=1)

            finished |= next_tokens.squeeze(1).eq(eos_token_id)
            if finished.all():
                break

        return tokens, logp_total, self.lnZ
    



In [4]:
class Seq2SeqWithCritic(nn.Module):
    def __init__(self, vocab_size, d_model=32, n_layers=1, n_heads=1, dropout=0.1):
        """
        Transformer Seq2Seq avec un critic (Z-network) qui renvoie lnZ(B,)
        """
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Parameter(torch.zeros(1, 20, d_model))  # sinusoidal-like positional encoding
        self.d_model = d_model

        # --- Encoder ---
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads,
            dim_feedforward=4*d_model, dropout=dropout, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        # --- Decoder ---
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=n_heads,
            dim_feedforward=4*d_model, dropout=dropout, batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)

        # --- Output head ---
        self.output = nn.Linear(d_model, vocab_size)

        # --- Critic head (predicts lnZ_i par s√©quence) ---
        self.critic = nn.Linear(d_model, 1)

    # ====================================================
    # Helper causal mask for autoregressive decoding
    # ====================================================
    def _causal_mask(self, T):
        return torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)

    # ====================================================
    # Forward (teacher-forced decoding)
    # ====================================================
    def forward(self, encoder_input_ids, decoder_input_ids):
        """
        Args:
            encoder_input_ids: (B, T_enc)
            decoder_input_ids: (B, T_dec)
        Returns:
            logits: (B, T_dec, vocab)
            lnZ: (B,)
        """
        B, T_enc = encoder_input_ids.shape
        B, T_dec = decoder_input_ids.shape

        # Embedding + positional encoding
        enc = self.embed(encoder_input_ids) * (self.d_model ** 0.5)
        dec = self.embed(decoder_input_ids) * (self.d_model ** 0.5)
        enc = enc + self.pos_embed[:, :T_enc, :]
        dec = dec + self.pos_embed[:, :T_dec, :]

        # Encoder
        memory = self.encoder(enc)

        # Critic : moyenne spatiale ‚Üí pr√©diction lnZ_i pour chaque s√©quence
        memory_mean = memory.mean(dim=1)               # (B, d_model)
        lnZ = self.critic(memory_mean).squeeze(1)      # (B,)

        # Decoder autoregressif
        mask = self._causal_mask(T_dec).to(dec.device)
        h = self.decoder(dec, memory, tgt_mask=mask)

        logits = self.output(h)  # (B, T_dec, vocab)
        return logits, lnZ

    # ====================================================
    # Mode OFF-POLICY : log-probabilit√© exacte d'une s√©quence
    # ====================================================
    def log_prob(self, encoder_input_ids, full_target_ids):
        """
        Calcule la log-prob totale de la s√©quence cible donn√©e.
        """
        decoder_input_ids = full_target_ids[:, :-1]     # tokens d√©j√† vus (input)
        target_ids = full_target_ids[:, 1:]             # tokens √† pr√©dire (labels)
        
        logits, lnZ = self.forward(encoder_input_ids, decoder_input_ids)
        logp = F.log_softmax(logits, dim=-1)                       # (B, T_dec, vocab)
        next_tokens = target_ids.unsqueeze(-1)                     # (B, T_dec, 1)
        token_logp = logp.gather(2, next_tokens).squeeze(-1)       # (B, T_dec)
        logp_total = token_logp.sum(dim=1)                         # (B,)
        return logp_total, lnZ

    # ====================================================
    # Mode ON-POLICY : g√©n√©ration batch√©e
    # ====================================================
    def generate(self, encoder_input_ids, max_len, bos_token_id, eos_token_id, device):
        B = encoder_input_ids.size(0)
        dec_tokens = torch.full((B, 1), bos_token_id, dtype=torch.long, device=device)
        finished = torch.zeros(B, dtype=torch.bool, device=device)
        logp_total = torch.zeros(B, device=device)

        enc = self.embed(encoder_input_ids) * (self.d_model ** 0.5)
        enc = enc + self.pos_embed[:, :encoder_input_ids.size(1), :]
        memory = self.encoder(enc)

        memory_mean = memory.mean(dim=1)
        lnZ = self.critic(memory_mean).squeeze(1)  # critic d√©pend de l‚Äôentr√©e (B,)

        for _ in range(max_len):
            dec = self.embed(dec_tokens) * (self.d_model ** 0.5)
            dec = dec + self.pos_embed[:, :dec_tokens.size(1), :]
            mask = self._causal_mask(dec_tokens.size(1)).to(dec.device)
            h = self.decoder(dec, memory, tgt_mask=mask)
            logits = self.output(h)[:, -1, :]              # (B, vocab)
            logp = F.log_softmax(logits, dim=-1)
            next_tokens = torch.multinomial(logp.exp(), 1) # (B, 1)
            next_logp = logp.gather(1, next_tokens).squeeze(1)

            logp_total += next_logp * (~finished)
            dec_tokens = torch.cat([dec_tokens, next_tokens], dim=1)
            finished |= next_tokens.squeeze(1).eq(eos_token_id)

            if finished.all():
                break

        return dec_tokens, logp_total, lnZ

In [5]:
class GRUDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_size=64, num_layers=1, dropout=0.1, lnZ_shape=()):
        """
        GRU Decoder-only pour LLM_PRIOR_ENV ou LLM_ADVERSERIAL (√©quivalent lightweight √† LLMDecoder).
        Plug-and-play : m√™mes interfaces forward/log_prob/generate.
        """
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.gru = nn.GRU(
            hidden_size, hidden_size, num_layers=num_layers,
            batch_first=True, dropout=dropout if num_layers > 1 else 0
        )
        self.output = nn.Linear(hidden_size, vocab_size)
        self.lnZ = nn.Parameter(torch.zeros(lnZ_shape, requires_grad=True))
        self.hidden_size = hidden_size

    def forward(self, input_ids):
        """
        input_ids: (B, T)
        Retourne: logits (B, T, vocab_size)
        """
        B, T = input_ids.shape
        x = self.embed(input_ids)
        out, _ = self.gru(x)
        return self.output(out)

    def log_prob(self, tokens):
        """
        Calcule la log-probabilit√© totale d'une s√©quence donn√©e (off-policy).
        tokens: (B, T)
        Retourne: (logp_total (B,), lnZ)
        """
        B, T = tokens.shape
        logits = self.forward(tokens[:, :-1])                 # (B, T-1, vocab)
        logp = F.log_softmax(logits, dim=-1)
        next_tokens = tokens[:, 1:].unsqueeze(-1)
        token_logp = logp.gather(2, next_tokens).squeeze(-1)  # (B, T-1)
        logp_total = token_logp.sum(dim=1)                    # (B,)
        return logp_total, self.lnZ

    def generate(self, batch_size, max_len, bos_token_id, eos_token_id, device):
        """
        G√©n√©ration batch√©e diff√©rentiable sur logp(seq).
        Retourne: tokens (B, T_gen), logp_total (B,), lnZ
        """
        tokens = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
        finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
        logp_total = torch.zeros(batch_size, device=device)
        h = None  # Hidden initialis√© √† z√©ro par GRU

        for _ in range(max_len):
            inp = tokens[:, -1:]  # (B, 1)
            x = self.embed(inp)
            out, h = self.gru(x, h)
            logits = self.output(out[:, -1, :])               # (B, vocab)
            logp = F.log_softmax(logits, dim=-1)

            next_tokens = torch.multinomial(logp.exp(), 1)    # (B, 1)
            next_logp = logp.gather(1, next_tokens).squeeze(1)

            logp_total += next_logp * (~finished)
            tokens = torch.cat([tokens, next_tokens], dim=1)

            finished |= next_tokens.squeeze(1).eq(eos_token_id)
            if finished.all():
                break

        return tokens, logp_total, self.lnZ
    

class GRUSeq2SeqWithCritic(nn.Module):
    def __init__(self, vocab_size, hidden_size=64, num_layers=1, dropout=0.1):
        """
        GRU Seq2Seq avec critic (√©quivalent lightweight √† Seq2SeqWithCritic).
        Plug-and-play : m√™mes interfaces forward/log_prob/generate.
        Pas de positional encoding (r√©current g√®re l'ordre).
        """
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.encoder_gru = nn.GRU(
            hidden_size, hidden_size, num_layers=num_layers,
            batch_first=True, dropout=dropout if num_layers > 1 else 0
        )
        self.decoder_gru = nn.GRU(
            hidden_size, hidden_size, num_layers=num_layers,
            batch_first=True, dropout=dropout if num_layers > 1 else 0
        )
        self.output = nn.Linear(hidden_size, vocab_size)
        self.critic = nn.Linear(hidden_size, 1)
        self.hidden_size = hidden_size

    def forward(self, encoder_input_ids, decoder_input_ids):
        """
        Args:
            encoder_input_ids: (B, T_enc)
            decoder_input_ids: (B, T_dec)
        Returns:
            logits: (B, T_dec, vocab)
            lnZ: (B,)
        """
        B, T_enc = encoder_input_ids.shape
        B, T_dec = decoder_input_ids.shape

        # Encoder
        enc_x = self.embed(encoder_input_ids)
        enc_out, enc_h = self.encoder_gru(enc_x)
        memory_mean = enc_out.mean(dim=1)                     # (B, hidden)
        lnZ = self.critic(memory_mean).squeeze(1)             # (B,)

        # Decoder (teacher-forced)
        dec_x = self.embed(decoder_input_ids)
        dec_out, _ = self.decoder_gru(dec_x, enc_h)
        logits = self.output(dec_out)                         # (B, T_dec, vocab)
        return logits, lnZ

    def log_prob(self, encoder_input_ids, full_target_ids):
        """
        Calcule la log-prob totale de la s√©quence cible donn√©e.
        """
        decoder_input_ids = full_target_ids[:, :-1]
        target_ids = full_target_ids[:, 1:]
        
        logits, lnZ = self.forward(encoder_input_ids, decoder_input_ids)
        logp = F.log_softmax(logits, dim=-1)
        next_tokens = target_ids.unsqueeze(-1)
        token_logp = logp.gather(2, next_tokens).squeeze(-1)
        logp_total = token_logp.sum(dim=1)
        return logp_total, lnZ

    def generate(self, encoder_input_ids, max_len, bos_token_id, eos_token_id, device):
        """
        G√©n√©ration batch√©e autoregressive.
        Retourne: tokens (B, T_gen), logp_total (B,), lnZ (B,)
        """
        B = encoder_input_ids.size(0)
        dec_tokens = torch.full((B, 1), bos_token_id, dtype=torch.long, device=device)
        finished = torch.zeros(B, dtype=torch.bool, device=device)
        logp_total = torch.zeros(B, device=device)

        # Encoder une fois
        enc_x = self.embed(encoder_input_ids)
        enc_out, enc_h = self.encoder_gru(enc_x)
        memory_mean = enc_out.mean(dim=1)
        lnZ = self.critic(memory_mean).squeeze(1)

        dec_h = enc_h  # Init decoder avec encoder hidden

        for _ in range(max_len):
            inp = dec_tokens[:, -1:]  # (B, 1)
            x = self.embed(inp)
            dec_out, dec_h = self.decoder_gru(x, dec_h)
            logits = self.output(dec_out[:, -1, :])           # (B, vocab)
            logp = F.log_softmax(logits, dim=-1)
            next_tokens = torch.multinomial(logp.exp(), 1)
            next_logp = logp.gather(1, next_tokens).squeeze(1)

            logp_total += next_logp * (~finished)
            dec_tokens = torch.cat([dec_tokens, next_tokens], dim=1)
            finished |= next_tokens.squeeze(1).eq(eos_token_id)

            if finished.all():
                break

        return dec_tokens, logp_total, lnZ

In [6]:
class AdditionDataset(Dataset):
    def __init__(self, file_path="addition_dataset_train.pkl"):
        with open(file_path, "rb") as f:
            self.data = pickle.load(f)["data"]

    def encode(self, text, max_len, add_bos_eos=False):
        ids = [char2idx[ch] for ch in text]
        if add_bos_eos:
            ids = [char2idx[BOS]] + ids + [char2idx[EOS]]
        ids = ids[:max_len]
        ids += [char2idx[PAD]] * (max_len - len(ids))
        return torch.tensor(ids, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        inp, tgt = self.data[idx]
        x = self.encode(inp, MAX_LEN_INP, add_bos_eos=True)
        y = self.encode(tgt, MAX_LEN_OUT, add_bos_eos=True)
        return x, y

In [7]:
def Bayesian_information_offpolicy(seq2seq_model, llm_model, encoder_input_ids, decoder_input_ids, target_ids):
    """
    Calcule KL(P_seq2seq || P_llm) et entropies normalis√©es (off-policy).
    Args:
        encoder_input_ids: (B, T_enc)
        decoder_input_ids: (B, T_dec)
        target_ids: (B, T_dec)
    Returns:
        kl: ()
        H_norm: ()
    """
    # log-probs (B,)
    logp_seq2seq, _ = seq2seq_model.log_prob(encoder_input_ids, target_ids)
    logp_llm, _ = llm_model.log_prob(decoder_input_ids)

    # KL et entropies
    kl = (logp_seq2seq - logp_llm).mean()          # ()
    H_seq2seq = -(logp_seq2seq).mean()             # ()
    H_llm = -(logp_llm).mean()                     # ()

    # Normalisation
    T_norm = torch.log(torch.tensor(decoder_input_ids.size(1), dtype=torch.float32, device=logp_seq2seq.device))
    H_norm = (H_seq2seq + H_llm) / (2 * T_norm)    # ()

    return kl, H_norm

In [8]:
def Bayesian_information_onpolicy(seq2seq_model, llm_model, bos_token_id, eos_token_id, max_len, device, B=16):
    """
    KL(P_seq2seq || P_llm) et entropies normalis√©es (on-policy)
    """
    # G√©n√©ration par Seq2Seq
    with torch.no_grad():
        encoder_input_ids = torch.randint(3, 50, (B, 8), device=device)  # dummy input encoder
    seq2seq_tokens, logp_seq2seq, _ = seq2seq_model.generate(
        encoder_input_ids, max_len, bos_token_id, eos_token_id, device
    )

    # √âvaluation du LLM sur les s√©quences g√©n√©r√©es
    logp_llm, _ = llm_model.log_prob(seq2seq_tokens)

    # KL divergence
    kl = (logp_seq2seq - logp_llm).mean()

    # Entropies brutes
    H_seq2seq = -(logp_seq2seq).mean()
    H_llm = -(logp_llm).mean()

    # Normalisation par log(longueur moyenne des s√©quences g√©n√©r√©es)
    avg_len = seq2seq_tokens.size(1)
    T_norm = torch.log(torch.tensor(avg_len, dtype=torch.float32, device=logp_seq2seq.device))
    H_norm = (H_seq2seq + H_llm) / (2 * T_norm)

    return kl, H_norm

In [9]:
SPECIAL = ["<pad>", "<bos>", "<eos>"] # pad for padding, bos for beginning of sequence, eos for end of sequence
BASE_CHARS = list("0123456789+= ")
PAD, BOS, EOS = SPECIAL
VOCAB = SPECIAL + BASE_CHARS
VOCAB_SIZE = len(VOCAB)

char2idx = {ch: i for i, ch in enumerate(VOCAB)}
idx2char = {i: ch for ch, i in char2idx.items()}

MAX_LEN_INP = 13   # max length input ("aaa + bbb =<EOS>") for llm and 
MAX_LEN_OUT = 6    # max length output ("xxx")
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

BATCH_SIZE = 128
EPOCHS = 1
LR = 1e-4

In [10]:
train_dataset = AdditionDataset("addition_dataset_train.pkl")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [11]:
# Models
llm_prior = GRUDecoder(vocab_size=VOCAB_SIZE).to(DEVICE)
seq2seq = GRUSeq2SeqWithCritic(vocab_size=VOCAB_SIZE).to(DEVICE)
seq2seq.eval()  # frozen for phase 1

optimizer = optim.Adam(llm_prior.parameters(), lr=LR)

# Loss (Trajectory Balance)
def loss_LLM_PRIOR_ENV(out_LLM_PRIOR_ENV, out_Seq2Seq):
    logp_LLM_PRIOR_ENV, lnZ_LLM_PRIOR_ENV = out_LLM_PRIOR_ENV
    logp_Seq2Seq_detached = out_Seq2Seq.detach()
    reward = logp_LLM_PRIOR_ENV + logp_Seq2Seq_detached
    loss = ((reward - lnZ_LLM_PRIOR_ENV + logp_LLM_PRIOR_ENV) ** 2).mean()
    return loss


# Training loop
for epoch in range(EPOCHS):
    llm_prior.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for batch in pbar:
        x_in, y_out = [b.to(DEVICE) for b in batch]

        # Compute log-probabilities
        logp_prior, lnZ_prior = llm_prior.log_prob(x_in)  # (B,), scalar
        with torch.no_grad():
            logp_seq2seq, _ = seq2seq.log_prob(
                encoder_input_ids=x_in,
                full_target_ids=x_in[:, 1:]
            )

        # Loss + backprop
        loss = loss_LLM_PRIOR_ENV((logp_prior, lnZ_prior), logp_seq2seq)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(llm_prior.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item()
        pbar.set_postfix({"loss": running_loss / (pbar.n + 1)})

    print(f"‚úÖ Epoch {epoch+1}/{EPOCHS} ‚Äî mean loss: {running_loss / len(train_loader):.6f}")


Epoch 1/1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 78/78 [00:01<00:00, 43.87it/s, loss=9.11e+3]

‚úÖ Epoch 1/1 ‚Äî mean loss: 8527.744022





In [12]:
# ======================================
# PHASE 2: LLM ADVERSARIAL TRAINING
# ======================================

llm_adv = LLMDecoder(vocab_size=VOCAB_SIZE).to(DEVICE)
optimizer_adv = optim.Adam(llm_adv.parameters(), lr=LR)

# Adversarial loss
def loss_LLM_ADVERSARIAL(out_LLM_PRIOR_ENV, out_LLM_ADVERSARIAL, out_Seq2Seq):
    logp_prior_env, lnZ_prior_env = out_LLM_PRIOR_ENV
    logp_adversarial, lnZ_adversarial = out_LLM_ADVERSARIAL
    logp_seq2seq, _ = out_Seq2Seq

    logp_prior_env_det = logp_prior_env.detach()
    logp_seq2seq_det = logp_seq2seq.detach()

    loss = (
        logp_prior_env_det
        + logp_seq2seq_det
        + logp_adversarial
        - lnZ_adversarial
        + logp_adversarial
    ).pow(-2).mean()

    return loss


for epoch in range(EPOCHS):
    llm_adv.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"[Adversarial] Epoch {epoch+1}/{EPOCHS}")

    for batch in pbar:
        x_in, y_out = [b.to(DEVICE) for b in batch]

        # Forward: prior, adversarial, seq2seq
        logp_prior, lnZ_prior = llm_prior.log_prob(x_in)
        logp_adv, lnZ_adv = llm_adv.log_prob(x_in)

        with torch.no_grad():
            logp_seq2seq, _ = seq2seq.log_prob(
                encoder_input_ids=x_in,
                full_target_ids=x_in[:, 1:]
            )

        loss = loss_LLM_ADVERSARIAL(
            (logp_prior, lnZ_prior),
            (logp_adv, lnZ_adv),
            (logp_seq2seq, None)
        )

        optimizer_adv.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(llm_adv.parameters(), 1.0)
        optimizer_adv.step()

        running_loss += loss.item()
        pbar.set_postfix({"loss": running_loss / (pbar.n + 1)})

    print(f"‚úÖ [Adversarial] Epoch {epoch+1} ‚Äî mean loss: {running_loss / len(train_loader):.6f}")

torch.save(llm_adv.state_dict(), "llm_adversarial_offpolicy.pt")
print("üíæ Saved model to llm_adversarial_offpolicy.pt")

[Adversarial] Epoch 1/1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 78/78 [00:01<00:00, 40.71it/s, loss=6.05e-5] 

‚úÖ [Adversarial] Epoch 1 ‚Äî mean loss: 0.000058
üíæ Saved model to llm_adversarial_offpolicy.pt





In [13]:
# ======================================
# PHASE 3: SEQ-TO-SEQ REINFORCEMENT TRAINING
# ======================================

optimizer_seq2seq = optim.Adam(seq2seq.parameters(), lr=LR)

def loss_LLM_REINFORCEMENT(out_LLM_PRIOR_ENV, out_LLM_ADVERSARIAL, out_Seq2Seq):
    logp_prior_env, _ = out_LLM_PRIOR_ENV
    logp_adversarial, _ = out_LLM_ADVERSARIAL
    logp_seq2seq, lnZ_seq2seq = out_Seq2Seq

    logp_prior_env_det = logp_prior_env.detach()
    logp_adversarial_det = logp_adversarial.detach()

    loss = (
        logp_prior_env_det
        + logp_seq2seq
        + logp_adversarial_det
        - lnZ_seq2seq
        + logp_seq2seq
    ).pow(2).mean()

    return loss


for epoch in range(EPOCHS):
    seq2seq.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"[Reinforce] Epoch {epoch+1}/{EPOCHS}")

    for batch in pbar:
        x_in, y_out = [b.to(DEVICE) for b in batch]

        # Forward passes
        logp_prior, lnZ_prior = llm_prior.log_prob(x_in)
        logp_adv, lnZ_adv = llm_adv.log_prob(x_in)
        logp_seq2seq, lnZ_seq2seq = seq2seq.log_prob(
            encoder_input_ids=x_in[:, :-1],
            full_target_ids=x_in[:, 1:]
        )

        loss = loss_LLM_REINFORCEMENT(
            (logp_prior, lnZ_prior),
            (logp_adv, lnZ_adv),
            (logp_seq2seq, lnZ_seq2seq)
        )

        optimizer_seq2seq.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(seq2seq.parameters(), 1.0)
        optimizer_seq2seq.step()

        running_loss += loss.item()
        pbar.set_postfix({"loss": running_loss / (pbar.n + 1)})

    print(f"‚úÖ [Reinforce] Epoch {epoch+1} ‚Äî mean loss: {running_loss / len(train_loader):.6f}")

torch.save(seq2seq.state_dict(), "seq2seq_reinforce.pt")
print("üíæ Saved model to seq2seq_reinforce.pt")

[Reinforce] Epoch 1/1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 78/78 [00:01<00:00, 39.61it/s, loss=1.53e+4]


‚úÖ [Reinforce] Epoch 1 ‚Äî mean loss: 14912.064328
üíæ Saved model to seq2seq_reinforce.pt


In [14]:
# ======================================================
# CHECKPOINT MANAGER
# ======================================================

def save_checkpoint(epoch, llm_prior, llm_adv, seq2seq, opt_prior, opt_adv, opt_seq2seq, metrics, path="checkpoints"):
    os.makedirs(path, exist_ok=True)
    ckpt = {
        "epoch": epoch,
        "llm_prior": llm_prior.state_dict(),
        "llm_adv": llm_adv.state_dict(),
        "seq2seq": seq2seq.state_dict(),
        "opt_prior": opt_prior.state_dict(),
        "opt_adv": opt_adv.state_dict(),
        "opt_seq2seq": opt_seq2seq.state_dict(),
        "metrics": metrics
    }
    torch.save(ckpt, os.path.join(path, f"epoch_{epoch}.pt"))
    # Save metrics log
    with open(os.path.join(path, "log.json"), "a") as f:
        json.dump({"epoch": epoch, **metrics}, f)
        f.write("\n")
    print(f"üíæ Saved checkpoint for epoch {epoch}")

def load_latest_checkpoint(llm_prior, llm_adv, seq2seq, opt_prior, opt_adv, opt_seq2seq, path="checkpoints"):
    if not os.path.exists(path):
        return 0
    ckpts = [f for f in os.listdir(path) if f.endswith(".pt")]
    if not ckpts:
        return 0
    ckpts.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
    last_ckpt = os.path.join(path, ckpts[-1])
    data = torch.load(last_ckpt, map_location=DEVICE)
    llm_prior.load_state_dict(data["llm_prior"])
    llm_adv.load_state_dict(data["llm_adv"])
    seq2seq.load_state_dict(data["seq2seq"])
    opt_prior.load_state_dict(data["opt_prior"])
    opt_adv.load_state_dict(data["opt_adv"])
    opt_seq2seq.load_state_dict(data["opt_seq2seq"])
    print(f"üîÅ Resumed from checkpoint: {ckpts[-1]}")
    return data["epoch"] + 1

In [15]:
def entropy_from_logp(logp, normalize=True):
    """
    Approximate the (normalized) entropy from log-probabilities of sequences.
    
    Args:
        logp: torch.Tensor (B,) 
            Log-probabilities of each sequence in the batch (log p(seq))
        normalize: bool
            If True, normalize entropy by log(batch_size)
    
    Returns:
        entropy: torch.Tensor (scalar)
            Estimated (normalized) entropy value
    """
    # Convert to probability (avoid overflow by clamping)
    p = torch.exp(logp - logp.max())  # numerical stability
    p = p / p.sum(dim=0, keepdim=True)  # normalize to form a distribution
    
    # Entropy = - Œ£ p * log p
    H = -torch.sum(p * torch.log(p + 1e-12))
    
    # Normalize by log(batch_size) (so entropy ‚àà [0, 1])
    if normalize:
        H = H / torch.log(torch.tensor(float(len(logp)), device=logp.device) + 1e-9)
    
    return H

In [16]:
import wandb

SAVE_EVERY = 2
# Initialize Weights & Biases
wandb.init(
    project="GFlowNet-Seq2Seq-Bayesian-Coherence",
    name=f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{DEVICE}",
    config={
        "epochs": EPOCHS,
        "lr": 3e-4,
        "batch_size": train_loader.batch_size,
        "device": DEVICE,
        "save_every": SAVE_EVERY
    }
)

[34m[1mwandb[0m: Currently logged in as: [33marthurmaffre[0m ([33marthurmaffre-alone[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [17]:
def log_and_plot_wandb(metrics, epoch):
    wandb.log({
        "Loss/L_prior": metrics["L_prior"],
        "Loss/L_adv": metrics["L_adv"],
        "Loss/L_reinf": metrics["L_reinf"],
        "Entropy/H_adv": metrics["H_adv"],
        "Entropy/H_seq": metrics["H_seq"],
        # "KL/KL_prior_posterior": metrics["KL_prior_posterior"],
        "epoch": epoch
    })

In [18]:
def kl_prior_posterior(train_loader, llm_prior, seq2seq, device):
    """
    Approximate KL divergence between prior p_prior and posterior p_posterior ‚àù p_prior * p_seq2seq.
    """
    kl_values = []
    for batch in train_loader:
        x_in, _ = [b.to(device) for b in batch]
        with torch.no_grad():
            logp_prior, _ = llm_prior.log_prob(x_in)
            logp_seq2seq, _ = seq2seq.log_prob(
                encoder_input_ids=x_in,
                decoder_input_ids=x_in[:, :-1],
                target_ids=x_in[:, 1:]
            )
        # Posterior (unnormalized)
        logp_post = logp_prior + logp_seq2seq
        # KL(p_post || p_prior) ‚âà E[log p_post - log p_prior] = E[log p_seq2seq]
        kl = (logp_post - logp_prior).mean().item()
        kl_values.append(kl)
    return float(torch.tensor(kl_values).mean())

In [19]:
# ======================================================
# TOKEN DECODER UTILITY
# ======================================================

def decode(tokens):
    """
    Convertit un tenseur de tokens en texte lisible.
    Exemple : [<bos>, '1', '2', '+', '3', '=', <eos>] ‚Üí '12+3='
    """
    texts = []
    for seq in tokens:
        text = "".join(
            idx2char[int(i)]
            for i in seq
            if idx2char[int(i)] not in ["<pad>"]
        )
        texts.append(
            text.replace("<bos>", "").replace("<eos>", "").strip()
        )
    return texts


def train_epoch_three_stages(
    train_loader,
    llm_prior, llm_adv, seq2seq,
    opt_seq2seq_phase2, opt_adv, opt_seq2seq_phase4, opt_prior_phase1,
    device,
    dataset=None
):
    """
    Effectue un cycle complet d'entra√Ænement :
    1. Flow Matching (Prior)
    2. Seq2Seq Reinforcement (on-policy)
    3. Adversarial Divergence
    4. Off-Policy Seq2Seq (pond√©r√© par le Prior)
    """

    metrics = {}
    
    # ======================================================
    # 4Ô∏è‚É£ PHASE OFF-POLICY ‚Äî Seq2Seq training on real data weighted by Prior
    # ======================================================
    seq2seq.train()
    llm_prior.eval()
    llm_adv.eval()

    total_loss, total_H = 0.0, 0.0

    total_loss = 0.0
    num_passes = 1  # üëà combien de fois tu veux repasser sur le dataset

    for pass_id in range(num_passes):
        pbar = tqdm(train_loader, desc=f"Phase 4 ‚Äî OFF-POLICY pass {pass_id+1}/{num_passes}")
        for batch in pbar:
            x_in, y_out = [b.to(device) for b in batch]

            # Compute prior probability on inputs (off-policy weighting)
            with torch.no_grad():
                logp_prior, _ = llm_prior.log_prob(x_in)

            # Seq2Seq log-prob under its own model
            logp_seq2seq, _ = seq2seq.log_prob(
                encoder_input_ids=x_in,
                full_target_ids=y_out
            )

            # Weighted loss (off-policy correction)
            loss_4 = - (logp_seq2seq).mean()  # weighted NLL

            opt_seq2seq_phase4.zero_grad()
            loss_4.backward()
            torch.nn.utils.clip_grad_norm_(seq2seq.parameters(), 1.0)
            opt_seq2seq_phase4.step()

            total_loss += loss_4.item()
            pbar.set_postfix({"L_offpolicy": total_loss / (pbar.n + 1)})

    metrics["L_offpolicy"] = total_loss / len(train_loader)
    wandb.log({"L_offpolicy": metrics["L_offpolicy"]})


    # ======================================================
    # 1Ô∏è‚É£ PHASE PRIOR ENV ‚Äî Flow Matching (with adversarial regularization)
    # ======================================================
    llm_prior.train()
    seq2seq.eval()
    llm_adv.eval()

    total_loss, total_H = 0.0, 0.0
    pbar = tqdm(train_loader, desc="Phase 1 ‚Äî LLM PRIOR ENV (flow matching + adv)")

    alpha = 1  # importance du seq2seq reward
    beta_entropy = 1e-3  # r√©gularisation d'entropie

    for batch in pbar:
        x_in, _ = [b.to(device) for b in batch]
        if x_in.shape != torch.Size([128]):
            continue

        # Compute log probs under the prior
        _, logp_prior, lnZ_prior = llm_prior.generate(batch_size=train_loader.batch_size,
            max_len=MAX_LEN_INP,
            bos_token_id=char2idx["<bos>"],
            eos_token_id=char2idx["<eos>"],
            device=device)

        with torch.no_grad():
            # Reward from seq2seq (proxy for p*(y|x))
            logp_seq2seq, _ = seq2seq.log_prob(
                encoder_input_ids=x_in,
                full_target_ids=x_in[:, :-1]
            )
            # Pressure from adversarial model
            logp_adv, _ = llm_adv.log_prob(x_in)

        # --- Flow matching target with adversarial term ---
        energy = -logp_prior -logp_seq2seq.detach()
        loss_1 = (logp_prior + energy + lnZ_prior).pow(2).mean()

        # --- Loss ---

        # --- Optimization ---
        opt_prior_phase1.zero_grad()
        loss_1.backward()
        opt_prior_phase1.step()

        total_loss += loss_1.item()

        pbar.set_postfix({
            "L_prior": total_loss / (pbar.n + 1)
        })

    metrics["L_prior"] = total_loss / len(train_loader)
    metrics["H_prior"] = total_H / len(train_loader)
    wandb.log({
        "L_prior": metrics["L_prior"],
        "H_prior": metrics["H_prior"],
    })



    # ======================================================
    # 2Ô∏è‚É£ PHASE REINFORCEMENT ‚Äî Weighted Flow Matching
    # ======================================================
    seq2seq.train()
    llm_prior.eval()
    llm_adv.eval()

    total_loss, total_H = 0.0, 0.0
    num_batches = len(train_loader)
    pbar = tqdm(range(num_batches), desc="Phase 2 ‚Äî SEQ2SEQ REINFORCE (weighted)")

    beta_entropy = 1e-3

    for _ in pbar:
        tokens_adv, logp_adv, _ = llm_adv.generate(
            batch_size=train_loader.batch_size,
            max_len=MAX_LEN_INP,
            bos_token_id=char2idx["<bos>"],
            eos_token_id=char2idx["<eos>"],
            device=device
        )

        seq_tokens, logp_seq2seq, lnZ_seq2seq = seq2seq.generate(
            encoder_input_ids=tokens_adv,
            max_len=MAX_LEN_OUT,
            bos_token_id=char2idx["<bos>"],
            eos_token_id=char2idx["<eos>"],
            device=device
        )

        with torch.no_grad():
            logp_prior, _ = llm_prior.log_prob(tokens_adv)

        # --- importance weighting ---

        # --- flow-matching loss ---
        
        #weight = logp_prior.detach() - logp_adv.detach()
        reward = logp_prior.detach() + logp_seq2seq

        loss_2 = - reward.mean()

        opt_seq2seq_phase2.zero_grad()
        loss_2.backward()
        torch.nn.utils.clip_grad_norm_(seq2seq.parameters(), 1.0)
        opt_seq2seq_phase2.step()

        total_loss += loss_2.item()
        pbar.set_postfix({
            "L_reinf": total_loss / (pbar.n + 1),
            "H_seq": total_H / (pbar.n + 1)
        })

    metrics["L_reinf"] = total_loss / num_batches
    metrics["H_seq"] = total_H / num_batches
    wandb.log({"L_reinf": metrics["L_reinf"], "H_seq": metrics["H_seq"]})
    
    
    # ======================================================
    # 3Ô∏è‚É£ PHASE ADVERSARIAL ‚Äî Divergence Amplification
    # ======================================================
    llm_adv.train()
    llm_prior.eval()
    seq2seq.eval()

    total_loss, total_H = 0.0, 0.0
    num_batches = len(train_loader)
    pbar = tqdm(range(num_batches), desc="Phase 3 ‚Äî LLM ADVERSARIAL (divergence)")

    for _ in pbar:
        tokens_adv, logp_adv, lnZ_adv = llm_adv.generate(
            batch_size=train_loader.batch_size,
            max_len=MAX_LEN_INP,
            bos_token_id=char2idx["<bos>"],
            eos_token_id=char2idx["<eos>"],
            device=device
        )

        with torch.no_grad():
            logp_prior, lnZ_prior = llm_prior.log_prob(tokens_adv)
            logp_seq2seq, _ = seq2seq.log_prob(
                encoder_input_ids=tokens_adv,
                full_target_ids=tokens_adv[:, 1:]
            )

        # Weighting (importance correction)
        weight = torch.exp(logp_prior.detach() - logp_adv.detach())
        weight = weight / weight.sum()
        # Divergence penalty (breaks coherence)
        divergence_term = - logp_seq2seq.detach() - logp_prior.detach()

        loss_3 = (weight * (lnZ_adv + logp_adv - divergence_term).pow(2)).mean()

        opt_adv.zero_grad()
        loss_3.backward()
        torch.nn.utils.clip_grad_norm_(llm_adv.parameters(), 1.0)
        opt_adv.step()

        total_loss += loss_3.item()
        total_H += entropy_from_logp(logp_adv).item()
        pbar.set_postfix({
            "L_adv": total_loss / (pbar.n + 1),
            "H_adv": total_H / (pbar.n + 1)
        })

    metrics["L_adv"] = total_loss / num_batches
    metrics["H_adv"] = total_H / num_batches
    wandb.log({"L_adv": metrics["L_adv"], "H_adv": metrics["H_adv"]})


    

    # ======================================================
    # 4Ô∏è‚É£ KL Divergence between Prior and Posterior
    # ======================================================
    # kl_post = kl_prior_posterior(train_loader, llm_prior, seq2seq, device)
    # metrics["KL_prior_posterior"] = kl_post
    # wandb.log({"KL_prior_posterior": kl_post})
    # print(f"üåå KL(Prior ‚Äñ Posterior): {kl_post:.6f}")


    # ======================================================
    # 5Ô∏è‚É£ GENERATION EXAMPLE (display + log to W&B)
    # ======================================================
    with torch.no_grad():
        if dataset is not None:
            x, y = dataset[torch.randint(0, len(dataset), (1,)).item()]
            x, y = x.unsqueeze(0).to(device), y.unsqueeze(0).to(device)

            prior_tokens, _, _ = llm_prior.generate(
                batch_size=1,
                max_len=MAX_LEN_INP,
                bos_token_id=char2idx["<bos>"],
                eos_token_id=char2idx["<eos>"],
                device=device
            )
            seq_tokens, _, _ = seq2seq.generate(
                encoder_input_ids=x[:, 1:],
                max_len=MAX_LEN_OUT,
                bos_token_id=char2idx["<bos>"],
                eos_token_id=char2idx["<eos>"],
                device=device
            )
            adv_tokens, _, _ = llm_adv.generate(
                batch_size=1,
                max_len=MAX_LEN_INP,
                bos_token_id=char2idx["<bos>"],
                eos_token_id=char2idx["<eos>"],
                device=device
            )

            # D√©codage en texte
            input_str, target_str = decode(x)[0], decode(y)[0]
            prior_str, seq_str, adv_str = decode(prior_tokens)[0], decode(seq_tokens)[0], decode(adv_tokens)[0]

            print("\nüß† Example Generation:")
            print(f"Input     : {input_str}")
            print(f"Target    : {target_str}")
            print(f"Prior Gen : {prior_str}")
            print(f"Seq2Seq   : {seq_str}")
            print(f"Adv Gen   : {adv_str}")

            wandb.log({
                "Generations/Sample": wandb.Table(
                    columns=["Input", "Target", "Prior", "Seq2Seq", "Adversarial"],
                    data=[[input_str, target_str, prior_str, seq_str, adv_str]],
                )
            })

    return metrics

In [20]:
# ======================================================
# CONFIG
# ======================================================
EPOCHS = 1
SAVE_EVERY = 2
CHECKPOINT_PATH = "checkpoints"
NEW_MODEL = True  # ‚¨ÖÔ∏è change √† False pour reprendre le training
DEVICE = "mps"
SPECIAL = ["<pad>", "<bos>", "<eos>"]
BASE_CHARS = list("0123456789+=")
PAD, BOS, EOS = SPECIAL
VOCAB = SPECIAL + BASE_CHARS
VOCAB_SIZE = len(VOCAB)

# ======================================================
# INITIALIZE / RESET MODELS
# ======================================================
def init_models_and_optimizers(vocab_size):
    llm_prior = LLMDecoder(vocab_size).to(DEVICE)
    llm_adv = LLMDecoder(vocab_size).to(DEVICE)
    seq2seq = Seq2SeqWithCritic(vocab_size).to(DEVICE)

    # --- Optimizers s√©par√©s par phase ---
    opt_prior_phase1 = torch.optim.Adam(llm_prior.parameters(), lr=4e-5)
    opt_adv = torch.optim.Adam(llm_adv.parameters(), lr=4e-5)
    opt_seq2seq_phase2 = torch.optim.Adam(seq2seq.parameters(), lr=4e-5)
    opt_seq2seq_phase4 = torch.optim.Adam(seq2seq.parameters(), lr=4e-5)  # off-policy

    return (
        llm_prior, llm_adv, seq2seq,
        opt_prior_phase1, opt_adv,
        opt_seq2seq_phase2, opt_seq2seq_phase4
    )



In [21]:

# ======================================================
# MAIN TRAIN LOOP
# ======================================================

# (Re)init models
(
    llm_prior, llm_adv, seq2seq,
    opt_prior_phase1, opt_adv,
    opt_seq2seq_phase2, opt_seq2seq_phase4
) = init_models_and_optimizers(VOCAB_SIZE)

if NEW_MODEL:
    print("üöÄ Starting fresh training ‚Äî resetting models and optimizers.")
    start_epoch = 0
    if os.path.exists(CHECKPOINT_PATH):
        for f in os.listdir(CHECKPOINT_PATH):
            os.remove(os.path.join(CHECKPOINT_PATH, f))
else:
    print("üîÅ Attempting to resume from last checkpoint...")
    start_epoch = load_latest_checkpoint(
        llm_prior, llm_adv, seq2seq,
        opt_prior_phase1, opt_adv,
        opt_seq2seq_phase2, opt_seq2seq_phase4,
        path=CHECKPOINT_PATH
    )

for epoch in range(start_epoch, EPOCHS):
    metrics = train_epoch_three_stages(
        train_loader=train_loader,
        llm_prior=llm_prior, llm_adv=llm_adv, seq2seq=seq2seq, opt_adv=opt_adv,
        opt_seq2seq_phase2=opt_seq2seq_phase2, opt_seq2seq_phase4=opt_seq2seq_phase4,
        device=DEVICE, opt_prior_phase1=opt_prior_phase1,
        dataset=train_dataset
    )

    log_and_plot_wandb(metrics, epoch + 1)

    print(
        f"\nEpoch {epoch+1}/{EPOCHS} ‚Äî "
        f"L_prior={metrics['L_prior']:.4f} | "
        f"L_adv={metrics['L_adv']:.4f} | "
        f"L_reinf={metrics['L_reinf']:.4f} | "
        f"L_offpolicy={metrics['L_offpolicy']:.4f} | "
        f"H_adv={metrics['H_adv']:.4f} | "
        f"H_seq={metrics['H_seq']:.4f}"
    )

    if (epoch + 1) % SAVE_EVERY == -1:
        save_checkpoint(
            epoch,
            llm_prior, llm_adv, seq2seq,
            opt_prior_phase1, opt_adv,
            opt_seq2seq_phase2, opt_seq2seq_phase4,
            metrics
        )



üöÄ Starting fresh training ‚Äî resetting models and optimizers.


Phase 4 ‚Äî OFF-POLICY pass 1/1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 78/78 [00:03<00:00, 25.03it/s, L_offpolicy=14.5]
Phase 1 ‚Äî LLM PRIOR ENV (flow matching + adv): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 78/78 [00:00<00:00, 340.40it/s]
Phase 2 ‚Äî SEQ2SEQ REINFORCE (weighted): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 78/78 [00:28<00:00,  2.75it/s, L_reinf=49.3, H_seq=0]
Phase 3 ‚Äî LLM ADVERSARIAL (divergence): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 78/78 [00:17<00:00,  4.47it/s, L_adv=86.4, H_adv=0.423]



üß† Example Generation:
Input     : 69+26=
Target    : 95
Prior Gen : 
Seq2Seq   : +836
Adv Gen   : 9882469

Epoch 1/1 ‚Äî L_prior=0.0000 | L_adv=86.4387 | L_reinf=49.2809 | L_offpolicy=13.8989 | H_adv=0.4227 | H_seq=0.0000


In [22]:
def decode(tokens):
    """Convert tensor of token ids -> readable string."""
    texts = []
    for seq in tokens:
        text = "".join(idx2char[int(i)] for i in seq if idx2char[int(i)] not in ["<pad>"])
        texts.append(text.replace("<bos>", "").replace("<eos>", ""))
    return texts

@torch.no_grad()
def show_generation_examples(llm_prior, llm_adv, seq2seq, dataset, device, n=5):
    llm_prior.eval()
    llm_adv.eval()
    seq2seq.eval()

    print("\n==============================")
    print("üß†  GENERATION EXAMPLES")
    print("==============================")

    for i in range(n):
        x, y = dataset[i]
        x, y = x.unsqueeze(0).to(device), y.unsqueeze(0).to(device)

        # -----------------------------------------
        # LLM PRIOR ENV generation
        # -----------------------------------------
        prior_tokens, logp_prior, lnZ_prior = llm_prior.generate(
            batch_size=1,
            max_len=MAX_LEN_INP,
            bos_token_id=char2idx["<bos>"],
            eos_token_id=char2idx["<eos>"],
            device=device
        )

        # -----------------------------------------
        # Seq2Seq generation (teacher-forced encoder)
        # -----------------------------------------
        encoder_input_ids = x[:, 1:]  # remove <bos> for encoder
        seq_tokens, logp_seq, lnZ_seq = seq2seq.generate(
            encoder_input_ids,
            max_len=MAX_LEN_OUT,
            bos_token_id=char2idx["<bos>"],
            eos_token_id=char2idx["<eos>"],
            device=device
        )

        # -----------------------------------------
        # Optional: LLM ADVERSARIAL generation
        # -----------------------------------------
        adv_tokens, logp_adv, lnZ_adv = llm_adv.generate(
            batch_size=1,
            max_len=MAX_LEN_INP,
            bos_token_id=char2idx["<bos>"],
            eos_token_id=char2idx["<eos>"],
            device=device
        )

        print(f"\n--- Example {i+1} ---")
        print(f"Input      : {decode(x)[0]}")
        print(f"Target     : {decode(y)[0]}")
        print(f"Prior Gen  : {decode(prior_tokens)[0]}")
        print(f"Seq2Seq Gen: {decode(seq_tokens)[0]}")
        print(f"Adv Gen    : {decode(adv_tokens)[0]}")
        print(f"logp_prior={logp_prior.item():.2f}, logp_seq={logp_seq.item():.2f}, logp_adv={logp_adv.item():.2f}")
        print(f"lnZ_prior={lnZ_prior.item():.2f}, lnZ_seq={lnZ_seq.item():.2f}, lnZ_adv={lnZ_adv.item():.2f}")


show_generation_examples(llm_prior, llm_adv, seq2seq, dataset=train_dataset, device=DEVICE, n=5)


üß†  GENERATION EXAMPLES

--- Example 1 ---
Input      : 37+71=
Target     : 108
Prior Gen  : 
Seq2Seq Gen: 815310
Adv Gen    : =32+89=1
logp_prior=-1.54, logp_seq=-13.89, logp_adv=-34.36
lnZ_prior=0.00, lnZ_seq=-0.35, lnZ_adv=0.00

--- Example 2 ---
Input      : 66+72=
Target     : 138
Prior Gen  : 
Seq2Seq Gen: 905600
Adv Gen    : 9=3+
logp_prior=-6.68, logp_seq=-15.87, logp_adv=-19.06
lnZ_prior=0.00, lnZ_seq=-0.32, lnZ_adv=0.00

--- Example 3 ---
Input      : 72+61=
Target     : 133
Prior Gen  : 3
Seq2Seq Gen: 4+31=
Adv Gen    : 01843639188
logp_prior=-6.36, logp_seq=-14.48, logp_adv=-32.09
lnZ_prior=0.00, lnZ_seq=-0.32, lnZ_adv=0.00

--- Example 4 ---
Input      : 7+60=
Target     : 67
Prior Gen  : 758+805+549
Seq2Seq Gen: 8354
Adv Gen    : 
logp_prior=-35.36, logp_seq=-13.54, logp_adv=-3.48
lnZ_prior=0.00, lnZ_seq=-0.26, lnZ_adv=0.00

--- Example 5 ---
Input      : 37+79=
Target     : 116
Prior Gen  : 42++1313=0
Seq2Seq Gen: =213+6
Adv Gen    : =
logp_prior=-34.39, logp_seq=-16.