# Expérience 5.3 (papier) — *Decoder frozen* + comparaison des familles de postérieurs

Ici, on reproduit l’expérience 5.3 du papier : l’idée est de **geler le décodeur** pour que le vrai posterior $p_\theta(z|x)$ reste le même, et qu’on compare uniquement l’effet de la famille variationnelle $q(z|x)$.

Concrètement, on fait :
1) On entraîne un VAE de base (ici avec $q$ Gaussien), puis on **sauvegarde le décodeur**.  
2) On recharge ce même décodeur (figé) et on entraîne seulement des **encodeurs “small”** avec différentes familles :  
- Gaussien (FFG),  
- Flow,  
- ContextualFlow.  

Ensuite, sur un **subset fixe** (par ex. 100 points de test), on calcule :
- $\log \hat p(x)$ (IWAE, et parfois AIS aussi ; on garde le max),
- $\mathcal{L}[q]$ (ELBO amorti : encodeur),
- $\mathcal{L}[q^*]$ (ELBO “local” : on optimise $q$ pour chaque point).

Et on en déduit les gaps :
- approximation gap : $\log \hat p(x) - \mathcal{L}[q^*]$  
- amortization gap : $\mathcal{L}[q^*] - \mathcal{L}[q]$  
- inference gap : $\log \hat p(x) - \mathcal{L}[q]$

Ce qu’on regarde : est-ce qu’un posterior plus flexible (Flow / ContextualFlow) réduit **seulement** l’approximation gap, ou bien aussi l’amortization gap (donc l’encodeur généralise mieux), tout en gardant exactement le même décodeur.

# Import


In [1]:
import sys, numpy as np, csv
from pathlib import Path
import torch
import torch.nn.functional as F
import torch.optim as optim

base_dir = Path.cwd().parent
sys.path.insert(0, str(base_dir / 'models'))
sys.path.insert(0, str(base_dir / 'models' / 'utils'))

from vae_2 import VAE
from inference_net import standard
from distributions import Gaussian, Flow, ContextualFlow
from optimize_local_q import optimize_local_q_dist
from ais3 import test_ais

# Dataset & Device


In [None]:
data = np.load('binarized_mnist.npz')
train_x, valid_x, test_x = data['train_data'], data['valid_data'], data['test_data']
print(f"Train {train_x.shape}, Valid {valid_x.shape}, Test {test_x.shape}")

device = torch.device(
    'cuda' if torch.cuda.is_available()
    else 'mps' if torch.backends.mps.is_available()
    else 'cpu'
)
print(f"Device actif : {device}")

# Définitions des fonctions utiles 
- de train 
- d'évaluations des gaps

In [None]:
@torch.no_grad()
def estimate_logp_IWAE(model, X, K=5000, batch_size=64):
    # On estime log p(x) via IWAE(K) et on moyenne sur les batches.
    vals = []
    loader = torch.utils.data.DataLoader(torch.from_numpy(X).float(), batch_size=batch_size, shuffle=False)
    for xb in loader:
        xb = xb.to(device)
        val, _, _ = model.forward2(xb, k=K)
        vals.append(val.item())
    return float(np.mean(vals))

def estimate_logp_AIS(model, X, K=50, n_intermediate=500, batch_size=200):
    # On estime log p(x) via AIS, par batches
    vals = []
    model.eval()
    for i in range(0, len(X), batch_size):
        xb = X[i:i+batch_size]
        if not isinstance(xb, np.ndarray):
            xb = xb.detach().cpu().numpy()
        try:
            est = test_ais(model=model, data_x=xb, batch_size=xb.shape[0],
                           display=0, k=K, n_intermediate_dists=n_intermediate)
            if torch.is_tensor(est):
                est = est.item()
            vals.append(float(est))
        except Exception as e:
            print(f"⚠️ AIS batch {i//batch_size}: {e}")
    return float(np.mean(vals)) if vals else np.nan

def estimate_logp(model, X, method="IWAE", K_iwae=5000, K_ais=50, n_intermediate=500):
    # Petit wrapper pratique : on choisit IWAE, AIS, ou les deux.
    if method == "IWAE":
        return estimate_logp_IWAE(model, X, K_iwae)
    elif method == "AIS":
        return estimate_logp_AIS(model, X, K_ais, n_intermediate)
    elif method == "BOTH":
        return {
            "IWAE": estimate_logp_IWAE(model, X, K_iwae),
            "AIS": estimate_logp_AIS(model, X, K_ais, n_intermediate)
        }

@torch.no_grad()
def amortized_elbo(model, X, k=1, batch_size=64):
    # On calcule L[q_phi] (ELBO amorti) : passage forward standard avec warmup=1.
    vals = []
    loader = torch.utils.data.DataLoader(torch.from_numpy(X).float(), batch_size=batch_size, shuffle=False)
    for xb in loader:
        xb = xb.to(device)
        elbo, _, _ = model.forward(xb, k=k, warmup=1.0)
        vals.append(elbo.item())
    return float(np.mean(vals))

def locally_optimized_elbo(model, X, n_points=10):
    # On calcule L[q*] : pour quelques points, on ré-optimise localement une q du même type que celle du modèle.
    vals = []
    q_class = type(model.hyper_params['q'])
    for i in range(min(n_points, len(X))):
        x = torch.from_numpy(X[i]).float().view(1, -1).to(device)
        logpost = lambda z: model.logposterior_func2(x=x, z=z)
        q_local = q_class(model.hyper_params).to(device)
        try:
            q_local.load_state_dict(model.hyper_params['q'].state_dict(), strict=False)
        except:
            pass
        vae_star, _ = optimize_local_q_dist(logpost, model.hyper_params, x, q_local)
        vals.append(float(vae_star.item()))
    return float(np.mean(vals))

def train_base_vae_ffg(train_x, max_epochs=500, batch_size=64, z_size=50):
    # On entraîne un VAE "base" avec qFFG, puis on sauvegarde decoder + encoder (utile pour geler le decoder ensuite).
    x_size = train_x.shape[1]
    hyper = {
        'x_size': x_size, 'z_size': z_size, 'act_func': F.elu,
        'encoder_arch': [[x_size,200],[200,200],[200,2*z_size]],
        'decoder_arch': [[z_size,200],[200,200],[200,x_size]],
        'q_dist': standard, 'cuda': int(device.type=='cuda'), 'hnf': 0,
        'context_size': 0
    }
    q = Gaussian(hyper)
    hyper['q'] = q
    model = VAE(hyper).to(device)
    X = torch.from_numpy(train_x).float()
    y = torch.zeros(len(X))
    loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X, y), batch_size=batch_size, shuffle=True)
    opt = optim.Adam(list(model.q_dist.parameters()) + list(model.generator.parameters()), lr=1e-3)

    warmup_T = 100.0
    total_epochs = 0
    for ep in range(1, max_epochs + 1):
        for xb, _ in loader:
            xb = xb.to(device)
            opt.zero_grad()
            warm = min(total_epochs / warmup_T, 1.0)
            elbo, _, _ = model.forward(xb, k=1, warmup=warm)
            (-elbo).backward()
            opt.step()
        total_epochs += 1
        if ep % 10 == 0:
            print(f"[BASE] epoch {ep}/{max_epochs} ELBO={elbo.item():.3f}")

    torch.save(model.generator.state_dict(), "decoder_base_qFFG.pt")
    torch.save(model.q_dist.state_dict(), "encoder_base_qFFG.pt")
    print(" Base model saved")
    return "decoder_base_qFFG.pt", "encoder_base_qFFG.pt", hyper

def train_encoder_only(train_x, frozen_decoder_path, q_family, q_name, base_hyper,
                       max_epochs=500, batch_size=64):
    # Ici on refait uniquement l’encodeur (q_family), en gelant le decoder du modèle base.
    x_size, z_size = base_hyper['x_size'], base_hyper['z_size']
    hyper = dict(base_hyper)
    
    # On adapte l’architecture encoder selon la famille (Gaussian / Flow / ContextualFlow).
    if q_family is ContextualFlow:
        context_size = 128
        hyper['context_size'] = context_size
        output_size = 2 * z_size + context_size
        enc_arch = [[x_size, 100], [100, output_size]]
    elif q_family is Gaussian:
        enc_arch = [[x_size, z_size], [z_size, 2*z_size]]
        hyper['context_size'] = 0
    else: 
        enc_arch = [[x_size, 100], [100, 2*z_size]]
        hyper['context_size'] = 0
        
    hyper['encoder_arch'] = enc_arch
    q = q_family(hyper)
    hyper['q'] = q

    model = VAE(hyper).to(device)

    # On charge le decoder entraîné, puis on le fige pour isoler l’effet de q(z|x).
    model.generator.load_state_dict(torch.load(frozen_decoder_path, map_location=device))
    for p in model.generator.parameters():
        p.requires_grad = False

    # On optimise seulement les paramètres de q_dist (l’encodeur).
    opt = optim.Adam(model.q_dist.parameters(), lr=1e-3)

    X = torch.from_numpy(train_x).float()
    y = torch.zeros(len(X))
    loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X, y), batch_size=batch_size, shuffle=True)

    for ep in range(1, max_epochs + 1):
        for xb, _ in loader:
            xb = xb.to(device)
            opt.zero_grad()
            elbo, _, _ = model.forward(xb, k=1, warmup=1.0)
            (-elbo).backward()
            opt.step()
        if ep % 10 == 0:
            print(f"[{q_name}] epoch {ep}/{max_epochs} ELBO={elbo.item():.3f}")

    # On sauvegarde l’encodeur ré-entraîné.
    enc_path = f"encoder_{q_name}.pt"
    torch.save(model.q_dist.state_dict(), enc_path)
    print(f" Saved {enc_path}")
    return enc_path, hyper

# Fonction de run de l'expérience

In [None]:
def run_exp_5_3(train_x, eval_x,
                K_logp=10000, n_local=25,
                use_pretrained=False, pretrained_paths=None,
                use_ais=False, K_ais=50, T_ais=200):

    # On fixe un subset d’évaluation pour avoir une comparaison stable et rapide.
    rng = np.random.default_rng(0)
    idx_subset = rng.choice(len(eval_x), size=min(100, len(eval_x)), replace=False)
    subset = eval_x[idx_subset]

    # On prépare un dossier + un CSV pour logger les résultats.
    out_dir = Path("exp53_logs")
    out_dir.mkdir(exist_ok=True)
    csv_path = out_dir / "exp53_results.csv"
    if not csv_path.exists():
        with open(csv_path, "w", newline='') as f:
            writer = csv.writer(f)
            writer.writerow(["family", "logp", "Lq", "Lq*", "approx_gap", "amort_gap", "infer_gap"])

    # On a deux modes :
    # (1) use_pretrained=True : on charge un decoder déjà entraîné + des encodeurs "small" déjà entraînés
    # (2) use_pretrained=False : on entraîne tout (base VAE + encodeurs small)
    if use_pretrained:
        print(" Mode Pretrained activé : Chargement des poids existants...")

        # On récupère les chemins de poids (décodeur + encodeurs).
        dec_path = pretrained_paths['decoder']
        enc_ffg_small = pretrained_paths['enc_ffg']
        enc_flow_small = pretrained_paths['enc_flow']
        enc_cflow_small = pretrained_paths.get('enc_cflow', None)
        
        # On reconstruit les hyperparamètres (architectures) qui doivent matcher les poids sauvegardés.
        # Ici on fait une "base_hyper" commune (le décodeur est le même pour les trois).
        base_hyper = {
            'x_size': eval_x.shape[1], 'z_size': 50, 'act_func': F.elu,
            'decoder_arch': [[50,200],[200,200],[200,eval_x.shape[1]]],
            'q_dist': standard, 'cuda': int(device.type=='cuda'), 'hnf': 0, 'context_size': 0
        }
        
        # On crée ensuite un hyper spécifique à chaque famille pour que l’encodeur ait la bonne forme.
        hyper_ffg = dict(base_hyper); hyper_ffg['encoder_arch'] = [[base_hyper['x_size'], 50], [50, 100]]
        hyper_flow = dict(base_hyper); hyper_flow['encoder_arch'] = [[base_hyper['x_size'], 100], [100, 100]]
        hyper_cflow = dict(base_hyper); hyper_cflow['encoder_arch'] = [[base_hyper['x_size'], 100], [100, 100 + 128]]; hyper_cflow['context_size'] = 128
        
    else:
        # On entraîne un VAE base (qFFG) et on sauvegarde le décodeur.
        # Puis on gèle ce décodeur et on entraîne des encodeurs "small" de familles différentes.
        dec_path, _, base_hyper = train_base_vae_ffg(train_x, max_epochs=100)
        
        enc_ffg_small,  hyper_ffg   = train_encoder_only(train_x, dec_path, Gaussian, 'qFFG_small', base_hyper, max_epochs=80)
        enc_flow_small, hyper_flow  = train_encoder_only(train_x, dec_path, Flow,   'qFlow_small', base_hyper, max_epochs=80)
        enc_cflow_small, hyper_cflow = train_encoder_only(train_x, dec_path, ContextualFlow, 'qCFlow_small', base_hyper, max_epochs=80)

    def build_model(enc_path, hyper):
        # On reconstruit un VAE 
        # puis on charge le decoder gelé et les poids de l’encodeur correspondant.
        hyper = dict(hyper)
        
        # On déduit la famille q(z|x) à partir du nom du checkpoint 
        if "CFlow" in enc_path:
            q_class = ContextualFlow
        elif "FFG" in enc_path:
            q_class = Gaussian
        else:
            q_class = Flow
            
        q = q_class(hyper)
        hyper['q'] = q
        m = VAE(hyper).to(device)

        # On charge le même décodeur pour toutes les familles (expérience "decoder frozen").
        m.generator.load_state_dict(torch.load(dec_path, map_location=device))
        
        # On charge les poids de l’encodeur (qui doivent matcher l’arch de hyper['encoder_arch']).
        m.q_dist.load_state_dict(torch.load(enc_path, map_location=device))
        m.hyper_params = hyper
        return m

    # On construit les trois modèles "small" (même decoder, encodeurs différents).
    model_ffg  = build_model(enc_ffg_small, hyper_ffg)
    model_flow = build_model(enc_flow_small, hyper_flow)
    model_cflow = build_model(enc_cflow_small, hyper_cflow)

    # Estimation de log p(x) : soit IWAE seul, soit IWAE + AIS puis on garde le max
    if use_ais:
        print(" Estimation log p(x) avec AIS + IWAE ...")
        logp_ffg_all  = estimate_logp(model_ffg,  subset, method="BOTH", K_iwae=K_logp, K_ais=K_ais, n_intermediate=T_ais)
        logp_flow_all = estimate_logp(model_flow, subset, method="BOTH", K_iwae=K_logp, K_ais=K_ais, n_intermediate=T_ais)
        logp_cflow_all = estimate_logp(model_cflow, subset, method="BOTH", K_iwae=K_logp, K_ais=K_ais, n_intermediate=T_ais)
        
        logp_ffg  = max(v for v in logp_ffg_all.values()  if v is not None)
        logp_flow = max(v for v in logp_flow_all.values() if v is not None)
        logp_cflow = max(v for v in logp_cflow_all.values() if v is not None)
    else:
        logp_ffg  = estimate_logp(model_ffg,  subset, method="IWAE", K_iwae=K_logp)
        logp_flow = estimate_logp(model_flow, subset, method="IWAE", K_iwae=K_logp)
        logp_cflow = estimate_logp(model_cflow, subset, method="IWAE", K_iwae=K_logp)

    # On calcule L[q] (amorti) et L[q*] (optimisation locale) pour chaque famille.
    Lq_ffg, Lq_flow, Lq_cflow = amortized_elbo(model_ffg, subset), amortized_elbo(model_flow, subset), amortized_elbo(model_cflow, subset)
    Lqstar_ffg, Lqstar_flow, Lqstar_cflow = locally_optimized_elbo(model_ffg, subset, n_points=n_local), locally_optimized_elbo(model_flow, subset, n_points=n_local), locally_optimized_elbo(model_cflow, subset, n_points=n_local)

    # On en déduit les trois gaps : approximation, amortization, total inference.
    def compute_gaps(logp, Lq, Lqs):
        return logp - Lqs, Lqs - Lq, logp - Lq

    gaps_ffg = compute_gaps(logp_ffg, Lq_ffg, Lqstar_ffg)
    gaps_flow = compute_gaps(logp_flow, Lq_flow, Lqstar_flow)
    gaps_cflow = compute_gaps(logp_cflow, Lq_cflow, Lqstar_cflow)

    # On affiche et on log dans le CSV.
    for fam, logp, Lq, Lqs, g in [
        ("qFFG_small", logp_ffg, Lq_ffg, Lqstar_ffg, gaps_ffg),
        ("qFlow_small", logp_flow, Lq_flow, Lqstar_flow, gaps_flow),
        ("qCFlow_small", logp_cflow, Lq_cflow, Lqstar_cflow, gaps_cflow)
    ]:
        A, M, I = g
        print(f"\n{fam}: logp={logp:.2f}, Lq*={Lqs:.2f}, Lq={Lq:.2f}, A={A:.2f}, M={M:.2f}, I={I:.2f}")
        with open(csv_path, "a", newline='') as f:
            writer = csv.writer(f)
            writer.writerow([fam, round(logp, 2), round(Lq, 2), round(Lqs, 2), round(A, 2), round(M, 2), round(I,2)])

    print(f"\n Résultats sauvegardés dans {csv_path}")

    return {"ffg": (logp_ffg, Lq_ffg, Lqstar_ffg, gaps_ffg),
            "flow": (logp_flow, Lq_flow, Lqstar_flow, gaps_flow),
            "cflow": (logp_cflow, Lq_cflow, Lqstar_cflow, gaps_cflow)}

if __name__ == "__main__":
    # Ici on définit les chemins vers les checkpoints qu’on veut charger en mode pretrained.
    my_pretrained_paths = {
        'decoder': 'decoder_base_qFFG.pt',
        'enc_ffg': 'encoder_qFFG_small.pt',      
        'enc_flow': 'encoder_qFlow_small.pt',
        'enc_cflow': 'encoder_qCFlow_small.pt'
    }

    # On lance l’expérience 5.3 : decoder figé + encodeurs small, et on compare les gaps.
    results = run_exp_5_3(
        train_x=train_x,
        eval_x=test_x,
        K_logp=5000,
        n_local=15,
        use_pretrained=True,                 
        pretrained_paths=my_pretrained_paths, 
        use_ais=True, K_ais=20, T_ais=100
    )