In [1]:
# Install the library
# %pip install pythae

In [9]:
import torch
import torchvision.datasets as datasets

device = "cuda" if torch.cuda.is_available() else "cpu"

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)

train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.0
eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.0

# train_dataset = train_dataset.float()
# eval_dataset = eval_dataset.float()


In [24]:
from pythae.models import AmortizedDualVAE, AmortizedDualVAEConfig
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines.training import TrainingPipeline


In [33]:
config = BaseTrainerConfig(
    output_dir='my_amortized_dual_vae',
    learning_rate=1e-4,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_epochs=10,  # Increase for more training
    optimizer_cls="AdamW",
    optimizer_params={"weight_decay": 0.05, "betas": (0.91, 0.99)}
)

model_config = AmortizedDualVAEConfig(
    input_dim=(1, 28, 28),
    latent_dim=8,
    polynomial_order=2,
    langevin_steps=10,
    langevin_step_size=1e-2,
    langevin_n_samples=8
)

model = AmortizedDualVAE(
    model_config=model_config
)


In [31]:
# from pythae.trainers.training_callbacks import WandbCallback
# callbacks = []
# wandb_cb = WandbCallback()
# wandb_cb.setup(training_config=config, # training config
#     model_config=model_config, # model config
#     project_name="memvae", # specify your wandb project
# )
# callbacks.append(wandb_cb)

In [34]:
pipeline = TrainingPipeline(
    training_config=config,
    model=model
)


In [35]:
pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset,
    # callbacks=callbacks
)


Preprocessing train data...
Checking train dataset...
Preprocessing eval data...

Checking eval dataset...
Using Base Trainer

Model passed sanity check !
Ready for training.

Created my_amortized_dual_vae/AmortizedDualVAE_training_2025-10-17_16-41-11. 
Training config, checkpoints and final model will be saved here.

Training params:
 - max_epochs: 10
 - per_device_train_batch_size: 64
 - per_device_eval_batch_size: 64
 - checkpoint saving every: None
Optimizer: AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.91, 0.99)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 0.05
)
Scheduler: None

Successfully launched training !



reconstruction loss: 93.14171600341797 score proxy: -0.000112745794467628 dual proxy: -0.17090295255184174 moment loss: 2.6276464462280273 lambda reg: 0.0013085969258099794


Training of epoch 1/10:   0%|          | 0/782 [00:00<?, ?batch/s]

reconstruction loss: 93.68435668945312 score proxy: -0.000772155646700412 dual proxy: -0.19156962633132935 moment loss: 4.51427698135376 lambda reg: 0.0012574406573548913
reconstruction loss: 93.03831481933594 score proxy: -0.0012864717282354832 dual proxy: -0.3680495023727417 moment loss: 5.4881205558776855 lambda reg: 0.0010726551990956068
reconstruction loss: 91.40686798095703 score proxy: -0.0009376737289130688 dual proxy: -0.4516788125038147 moment loss: 4.918226718902588 lambda reg: 0.0008982393774203956
reconstruction loss: 90.43258666992188 score proxy: -0.0006541461916640401 dual proxy: -0.5556646585464478 moment loss: 5.621401309967041 lambda reg: 0.0008056775550357997
reconstruction loss: 89.06132507324219 score proxy: -0.000811696401797235 dual proxy: -0.6698553562164307 moment loss: 5.7462849617004395 lambda reg: 0.0007528146379627287
reconstruction loss: 87.62519836425781 score proxy: -0.00036390850436873734 dual proxy: -0.7069571018218994 moment loss: 5.7181806564331055 

Eval of epoch 1/10:   0%|          | 0/157 [00:00<?, ?batch/s]

reconstruction loss: 26.72528076171875 score proxy: -4.129860099055804e-05 dual proxy: 0.004430482164025307 moment loss: 2.4332871437072754 lambda reg: 6.601163477171212e-05
reconstruction loss: 27.340547561645508 score proxy: 4.759374860441312e-05 dual proxy: -0.002817439381033182 moment loss: 3.921776533126831 lambda reg: 6.549809768330306e-05
reconstruction loss: 26.859291076660156 score proxy: -4.742295641335659e-05 dual proxy: 0.014001462608575821 moment loss: 4.523547172546387 lambda reg: 6.623959052376449e-05
reconstruction loss: 26.337745666503906 score proxy: 0.00012482910824473947 dual proxy: 0.03257916867733002 moment loss: 5.129490375518799 lambda reg: 6.506625504698604e-05
reconstruction loss: 27.82413673400879 score proxy: -9.664117533247918e-05 dual proxy: 0.014322788454592228 moment loss: 4.95318078994751 lambda reg: 6.646159454248846e-05
reconstruction loss: 26.464862823486328 score proxy: -1.0589945304673165e-05 dual proxy: 0.009104687720537186 moment loss: 5.28474521

--------------------------------------------------------------------------
Train loss: 36.4484
Eval loss: 31.4163
--------------------------------------------------------------------------


reconstruction loss: 26.480934143066406 score proxy: -3.57438693754375e-05 dual proxy: 0.011638443917036057 moment loss: 5.279701232910156 lambda reg: 6.37742705293931e-05
reconstruction loss: 26.88232421875 score proxy: 5.117100226925686e-05 dual proxy: 0.01504384446889162 moment loss: 5.200772762298584 lambda reg: 6.677021156065166e-05
reconstruction loss: 27.12312889099121 score proxy: 3.0941700970288366e-05 dual proxy: 0.0072875493206083775 moment loss: 4.782873630523682 lambda reg: 6.645204848609865e-05
reconstruction loss: 26.287832260131836 score proxy: 0.0001508752175141126 dual proxy: -0.013411919586360455 moment loss: 4.353382110595703 lambda reg: 6.398333061952144e-05
reconstruction loss: 27.511451721191406 score proxy: -2.4313083031302085e-06 dual proxy: 0.009451234713196754 moment loss: 4.426916122436523 lambda reg: 6.699829827994108e-05
reconstruction loss: 25.359729766845703 score proxy: -5.2527335355989635e-05 dual proxy: 0.007140297908335924 moment loss: 1.997866988182

Training of epoch 2/10:   0%|          | 0/782 [00:00<?, ?batch/s]

reconstruction loss: 26.529163360595703 score proxy: 2.6480835003894754e-05 dual proxy: 0.01279459148645401 moment loss: 2.226038694381714 lambda reg: 6.258337816689163e-05
reconstruction loss: 26.08391571044922 score proxy: -2.3622947992407717e-05 dual proxy: 0.006220195442438126 moment loss: 3.2849295139312744 lambda reg: 5.7498502428643405e-05
reconstruction loss: 29.418712615966797 score proxy: 8.552364306524396e-05 dual proxy: 0.010432351380586624 moment loss: 4.058676242828369 lambda reg: 5.2667011914309114e-05
reconstruction loss: 27.37160873413086 score proxy: -0.00010491296416148543 dual proxy: -0.0036242599599063396 moment loss: 4.663268089294434 lambda reg: 5.1757491746684536e-05
reconstruction loss: 28.669984817504883 score proxy: 0.0001156634752987884 dual proxy: 0.028328174725174904 moment loss: 5.137273788452148 lambda reg: 5.293364665703848e-05
reconstruction loss: 26.061779022216797 score proxy: -9.437845437787473e-05 dual proxy: 0.03022436425089836 moment loss: 4.7022

Eval of epoch 2/10:   0%|          | 0/157 [00:00<?, ?batch/s]

reconstruction loss: 26.72702407836914 score proxy: -6.211789150256664e-05 dual proxy: -0.0057139927521348 moment loss: 2.1639626026153564 lambda reg: 2.834516271832399e-05
reconstruction loss: 27.734886169433594 score proxy: -6.159054464660585e-05 dual proxy: -0.008277131244540215 moment loss: 3.9775052070617676 lambda reg: 2.7793603294412605e-05
reconstruction loss: 24.323549270629883 score proxy: 3.073101106565446e-05 dual proxy: 0.009791623800992966 moment loss: 4.919285774230957 lambda reg: 2.934589792857878e-05
reconstruction loss: 27.11113739013672 score proxy: -9.102065814659e-05 dual proxy: 0.004952840507030487 moment loss: 4.511111259460449 lambda reg: 2.847085488610901e-05
reconstruction loss: 24.999237060546875 score proxy: -4.662578794523142e-05 dual proxy: 0.008698014542460442 moment loss: 4.488208770751953 lambda reg: 2.8904532882734202e-05
reconstruction loss: 25.831575393676758 score proxy: -7.222737622214481e-06 dual proxy: 0.0155781339854002 moment loss: 4.7987561225

--------------------------------------------------------------------------
Train loss: 31.6431
Eval loss: 31.2285
--------------------------------------------------------------------------


reconstruction loss: 26.435665130615234 score proxy: 5.7921452025766484e-06 dual proxy: 0.012445218861103058 moment loss: 5.229084491729736 lambda reg: 2.8733174985973164e-05
reconstruction loss: 25.628829956054688 score proxy: 2.1662948711309582e-05 dual proxy: 0.005076810717582703 moment loss: 4.8964409828186035 lambda reg: 2.887124355765991e-05
reconstruction loss: 25.725074768066406 score proxy: -2.247913471364882e-05 dual proxy: -0.00013095280155539513 moment loss: 4.83314323425293 lambda reg: 2.7826536097563803e-05
reconstruction loss: 24.427234649658203 score proxy: 1.967790740309283e-05 dual proxy: -0.016233135014772415 moment loss: 3.5777900218963623 lambda reg: 2.923531792475842e-05


Training of epoch 3/10:   0%|          | 0/782 [00:00<?, ?batch/s]

reconstruction loss: 26.553213119506836 score proxy: 1.6681116903782822e-05 dual proxy: -0.01727481000125408 moment loss: 2.1193785667419434 lambda reg: 2.9553786589531228e-05
reconstruction loss: 25.45513916015625 score proxy: -7.955031469464302e-05 dual proxy: -0.007473730016499758 moment loss: 3.8104770183563232 lambda reg: 3.347065285197459e-05
reconstruction loss: 27.274757385253906 score proxy: -4.282245572539978e-05 dual proxy: 0.018390391021966934 moment loss: 4.362298488616943 lambda reg: 3.94178023270797e-05
reconstruction loss: 28.239959716796875 score proxy: -1.0668892400644836e-06 dual proxy: 0.012724870815873146 moment loss: 4.794100284576416 lambda reg: 4.3508389353519306e-05
reconstruction loss: 26.088512420654297 score proxy: -9.332156332675368e-05 dual proxy: 0.014197764918208122 moment loss: 5.311352252960205 lambda reg: 4.327090937294997e-05
reconstruction loss: 26.58818817138672 score proxy: -3.6781631933990866e-05 dual proxy: 0.012413008138537407 moment loss: 5.10

Eval of epoch 3/10:   0%|          | 0/157 [00:00<?, ?batch/s]

reconstruction loss: 26.26140785217285 score proxy: -1.3554369616031181e-05 dual proxy: -0.007403998635709286 moment loss: 1.9373762607574463 lambda reg: 1.093659557227511e-05
reconstruction loss: 25.46640968322754 score proxy: -3.662818926386535e-05 dual proxy: -0.005393725819885731 moment loss: 2.965665817260742 lambda reg: 1.0595405001367908e-05
reconstruction loss: 25.02268409729004 score proxy: -6.352445780066773e-06 dual proxy: -0.0008855354972183704 moment loss: 4.044641017913818 lambda reg: 1.0504576493985951e-05
reconstruction loss: 25.27688980102539 score proxy: 5.396497726906091e-06 dual proxy: 0.0017155518289655447 moment loss: 5.312501430511475 lambda reg: 1.0624069545883685e-05
reconstruction loss: 26.548892974853516 score proxy: -2.4778113584034145e-05 dual proxy: 0.007635005749762058 moment loss: 4.94038724899292 lambda reg: 1.0529084647714626e-05
reconstruction loss: 24.81551170349121 score proxy: -1.8547201761975884e-05 dual proxy: 0.01115350890904665 moment loss: 5.7

--------------------------------------------------------------------------
Train loss: 31.6112
Eval loss: 31.3054
--------------------------------------------------------------------------


reconstruction loss: 28.322980880737305 score proxy: -3.13300256493676e-06 dual proxy: -0.007098549045622349 moment loss: 2.5275375843048096 lambda reg: 1.0978023055940866e-05


Training of epoch 4/10:   0%|          | 0/782 [00:00<?, ?batch/s]

reconstruction loss: 26.66313934326172 score proxy: -3.282140096416697e-05 dual proxy: -0.012970849871635437 moment loss: 2.0865635871887207 lambda reg: 1.0524706340220291e-05
reconstruction loss: 27.37034797668457 score proxy: -1.4569994164048694e-05 dual proxy: -0.014489303342998028 moment loss: 3.345989942550659 lambda reg: 1.2864552445535082e-05
reconstruction loss: 26.49928855895996 score proxy: -2.2163634639582597e-05 dual proxy: -0.003822420025244355 moment loss: 4.601412296295166 lambda reg: 1.4331353668239899e-05
reconstruction loss: 26.202594757080078 score proxy: 5.799958216812229e-06 dual proxy: -0.002360212616622448 moment loss: 4.372222900390625 lambda reg: 1.5546102076768875e-05
reconstruction loss: 26.201927185058594 score proxy: 1.6817784853628837e-05 dual proxy: 0.004207373596727848 moment loss: 4.960331916809082 lambda reg: 1.7269117961404845e-05
reconstruction loss: 26.833715438842773 score proxy: -3.814247975242324e-06 dual proxy: 0.00035416148602962494 moment loss

Eval of epoch 4/10:   0%|          | 0/157 [00:00<?, ?batch/s]

reconstruction loss: 26.27301025390625 score proxy: -1.1420432201703079e-05 dual proxy: -0.01051557157188654 moment loss: 1.971543788909912 lambda reg: 7.977339009812567e-06
reconstruction loss: 27.19491195678711 score proxy: -4.316469130571932e-06 dual proxy: 0.0038287825882434845 moment loss: 3.907768964767456 lambda reg: 7.816109246050473e-06
reconstruction loss: 26.303537368774414 score proxy: -2.3053609766066074e-05 dual proxy: -0.0027138220611959696 moment loss: 4.215075492858887 lambda reg: 7.615371941938065e-06
reconstruction loss: 26.91199493408203 score proxy: 6.015654889779398e-06 dual proxy: 0.0009189122938551009 moment loss: 5.06514835357666 lambda reg: 8.303016329591628e-06
reconstruction loss: 26.154203414916992 score proxy: -1.262138357560616e-05 dual proxy: -0.0011317485477775335 moment loss: 4.781320095062256 lambda reg: 7.966422344907187e-06
reconstruction loss: 26.920095443725586 score proxy: 8.51242475619074e-06 dual proxy: -0.0039006881415843964 moment loss: 4.342

--------------------------------------------------------------------------
Train loss: 31.5649
Eval loss: 31.2829
--------------------------------------------------------------------------


reconstruction loss: 27.357458114624023 score proxy: 3.1530249543720856e-05 dual proxy: -0.0014289298560470343 moment loss: 4.656118869781494 lambda reg: 7.879948498157319e-06
reconstruction loss: 24.307485580444336 score proxy: -4.151628672843799e-05 dual proxy: -0.001836359966546297 moment loss: 4.736867904663086 lambda reg: 8.329161573783495e-06
reconstruction loss: 25.240596771240234 score proxy: -1.777807301550638e-05 dual proxy: -0.0018546388018876314 moment loss: 4.515100955963135 lambda reg: 7.656108209630474e-06
reconstruction loss: 28.272159576416016 score proxy: -3.0809640065854182e-06 dual proxy: 0.005001667886972427 moment loss: 5.276897430419922 lambda reg: 7.398536581604276e-06
reconstruction loss: 26.203786849975586 score proxy: 1.4944475879019592e-05 dual proxy: 0.008089179173111916 moment loss: 5.532924652099609 lambda reg: 7.493090834032046e-06
reconstruction loss: 31.09671401977539 score proxy: 1.0626584298734087e-05 dual proxy: -0.004788042977452278 moment loss: 2.

Training of epoch 5/10:   0%|          | 0/782 [00:00<?, ?batch/s]

reconstruction loss: 25.345855712890625 score proxy: -3.3710810384945944e-05 dual proxy: -0.013275747187435627 moment loss: 2.2889344692230225 lambda reg: 7.457216270267963e-06
reconstruction loss: 24.84992218017578 score proxy: -9.942469660018105e-06 dual proxy: -0.013643983751535416 moment loss: 3.712496519088745 lambda reg: 9.290000889450312e-06
reconstruction loss: 26.58391571044922 score proxy: 1.8729868315858766e-05 dual proxy: -0.0008593806996941566 moment loss: 4.367465019226074 lambda reg: 1.1154391359013971e-05
reconstruction loss: 25.55609130859375 score proxy: -1.2326156138442457e-05 dual proxy: -0.00592803256586194 moment loss: 5.312291145324707 lambda reg: 1.1119423106720205e-05
reconstruction loss: 28.087589263916016 score proxy: 1.3865515029465314e-05 dual proxy: -0.005084450356662273 moment loss: 4.906107425689697 lambda reg: 1.1848812391690444e-05
reconstruction loss: 26.280662536621094 score proxy: -2.936564442279632e-06 dual proxy: -0.00316503643989563 moment loss: 

Eval of epoch 5/10:   0%|          | 0/157 [00:00<?, ?batch/s]

reconstruction loss: 26.747135162353516 score proxy: 9.668183338362724e-07 dual proxy: 0.002564698690548539 moment loss: 2.042811393737793 lambda reg: 5.036296442995081e-06
reconstruction loss: 25.911991119384766 score proxy: 5.841138772666454e-06 dual proxy: 0.00210060877725482 moment loss: 4.166752815246582 lambda reg: 5.209642495174194e-06
reconstruction loss: 24.614898681640625 score proxy: -4.27657323598396e-06 dual proxy: 0.0029642328154295683 moment loss: 4.156556606292725 lambda reg: 5.319880528986687e-06
reconstruction loss: 26.58132553100586 score proxy: -4.081178303749766e-06 dual proxy: 0.006780048832297325 moment loss: 4.815718650817871 lambda reg: 5.298570613376796e-06
reconstruction loss: 26.74335479736328 score proxy: -7.67774872656446e-06 dual proxy: 0.006631805561482906 moment loss: 4.995423316955566 lambda reg: 5.232922831055475e-06
reconstruction loss: 27.389328002929688 score proxy: -1.7100941249736934e-06 dual proxy: 0.003936287015676498 moment loss: 4.70192241668

--------------------------------------------------------------------------
Train loss: 31.5465
Eval loss: 31.2383
--------------------------------------------------------------------------


reconstruction loss: 25.007827758789062 score proxy: 1.9507895103743067e-06 dual proxy: -0.006207426078617573 moment loss: 5.149723529815674 lambda reg: 5.222523668635404e-06
reconstruction loss: 25.17641830444336 score proxy: -1.961943053174764e-05 dual proxy: 0.014619926922023296 moment loss: 2.5774168968200684 lambda reg: 5.4893589549465105e-06


Training of epoch 6/10:   0%|          | 0/782 [00:00<?, ?batch/s]

reconstruction loss: 26.723033905029297 score proxy: -6.999808874752489e-07 dual proxy: 0.00498330220580101 moment loss: 2.165546417236328 lambda reg: 5.393564151745522e-06
reconstruction loss: 26.36151885986328 score proxy: 5.54777807337814e-06 dual proxy: 0.004787840880453587 moment loss: 3.9384305477142334 lambda reg: 5.944173153693555e-06
reconstruction loss: 26.246692657470703 score proxy: -7.661253221158404e-06 dual proxy: 0.004938116297125816 moment loss: 4.803707122802734 lambda reg: 6.856347226857906e-06
reconstruction loss: 26.350303649902344 score proxy: -4.846876436204184e-06 dual proxy: 0.0016620296519249678 moment loss: 5.138463020324707 lambda reg: 7.192907560238382e-06
reconstruction loss: 25.50567626953125 score proxy: -1.2370698641461786e-05 dual proxy: 0.00521549116820097 moment loss: 4.645705223083496 lambda reg: 7.177790394052863e-06
reconstruction loss: 25.74224853515625 score proxy: -2.859526102838572e-05 dual proxy: 0.00437310291454196 moment loss: 3.86525487899

Eval of epoch 6/10:   0%|          | 0/157 [00:00<?, ?batch/s]

reconstruction loss: 26.333980560302734 score proxy: -1.2239504030731041e-05 dual proxy: 0.00814099982380867 moment loss: 2.3707661628723145 lambda reg: 5.490234798344318e-06
reconstruction loss: 26.294862747192383 score proxy: 8.607113159087021e-06 dual proxy: 0.006824172101914883 moment loss: 4.235520839691162 lambda reg: 5.572454938373994e-06
reconstruction loss: 26.014095306396484 score proxy: -1.2694278666458558e-05 dual proxy: 0.00568865891546011 moment loss: 4.659155368804932 lambda reg: 5.483068434841698e-06
reconstruction loss: 24.780132293701172 score proxy: -5.25534278494888e-06 dual proxy: 0.005757277365773916 moment loss: 5.009825229644775 lambda reg: 5.479860647028545e-06
reconstruction loss: 26.026992797851562 score proxy: 1.9903479824279202e-06 dual proxy: 0.00847805105149746 moment loss: 5.57415771484375 lambda reg: 5.498313839780167e-06
reconstruction loss: 25.836008071899414 score proxy: 3.9402020775014535e-06 dual proxy: -0.0005209196824580431 moment loss: 5.9963111

--------------------------------------------------------------------------
Train loss: 31.5567
Eval loss: 31.2265
--------------------------------------------------------------------------


reconstruction loss: 26.005199432373047 score proxy: 8.243132469942793e-06 dual proxy: 0.0056790513917803764 moment loss: 5.17700719833374 lambda reg: 5.3721514632343315e-06
reconstruction loss: 27.1304931640625 score proxy: -3.839140390482498e-06 dual proxy: 0.0035927612334489822 moment loss: 4.139556884765625 lambda reg: 5.609108029602794e-06
reconstruction loss: 27.293405532836914 score proxy: -4.457274826563662e-06 dual proxy: -0.0012595714069902897 moment loss: 4.694136142730713 lambda reg: 5.5701539167785086e-06
reconstruction loss: 26.557907104492188 score proxy: 1.1804687346739229e-05 dual proxy: 0.0014802154619246721 moment loss: 1.834517478942871 lambda reg: 5.5604855333513115e-06


Training of epoch 7/10:   0%|          | 0/782 [00:00<?, ?batch/s]

reconstruction loss: 26.852054595947266 score proxy: -3.081417844441603e-06 dual proxy: 0.00868238601833582 moment loss: 2.3194546699523926 lambda reg: 5.526802397071151e-06
reconstruction loss: 27.35951042175293 score proxy: 5.05825710206409e-06 dual proxy: 0.0015806821174919605 moment loss: 4.056830406188965 lambda reg: 6.034052148606861e-06
reconstruction loss: 26.087078094482422 score proxy: -1.2685934052569792e-05 dual proxy: -0.0014056931249797344 moment loss: 4.833179473876953 lambda reg: 6.2896779127186164e-06
reconstruction loss: 26.212491989135742 score proxy: 1.2188186701678205e-05 dual proxy: -0.002080123871564865 moment loss: 5.071850299835205 lambda reg: 6.4045943872770295e-06
reconstruction loss: 28.166868209838867 score proxy: -8.848629477142822e-06 dual proxy: 0.0002819958608597517 moment loss: 5.031605243682861 lambda reg: 6.386704626493156e-06
reconstruction loss: 27.51740264892578 score proxy: 6.32230603514472e-06 dual proxy: 0.0021312166936695576 moment loss: 4.997

Eval of epoch 7/10:   0%|          | 0/157 [00:00<?, ?batch/s]

reconstruction loss: 25.764842987060547 score proxy: -1.5424739103764296e-05 dual proxy: 0.007843002676963806 moment loss: 2.3256428241729736 lambda reg: 6.443955953727709e-06
reconstruction loss: 25.843843460083008 score proxy: -4.811472081200918e-06 dual proxy: 0.004687848035246134 moment loss: 4.112798690795898 lambda reg: 6.4768100855872035e-06
reconstruction loss: 25.466550827026367 score proxy: 8.686483852216043e-06 dual proxy: 0.00527163315564394 moment loss: 4.461951732635498 lambda reg: 6.555872914759675e-06
reconstruction loss: 27.520065307617188 score proxy: -2.5200231448252453e-06 dual proxy: 0.007857829332351685 moment loss: 4.806238651275635 lambda reg: 6.523425327031873e-06
reconstruction loss: 26.844135284423828 score proxy: -3.564557118806988e-06 dual proxy: 0.009951358661055565 moment loss: 3.846027135848999 lambda reg: 6.6313073148194235e-06
reconstruction loss: 24.633726119995117 score proxy: -8.084482033154927e-06 dual proxy: 0.005603360943496227 moment loss: 4.569

--------------------------------------------------------------------------
Train loss: 31.5124
Eval loss: 31.1343
--------------------------------------------------------------------------


reconstruction loss: 26.00545883178711 score proxy: 2.9156919936212944e-06 dual proxy: -0.0012520723976194859 moment loss: 5.285318851470947 lambda reg: 6.500196377601242e-06
reconstruction loss: 27.66976547241211 score proxy: -1.1588880624913145e-05 dual proxy: 0.004859759472310543 moment loss: 4.899773120880127 lambda reg: 6.598357686016243e-06
reconstruction loss: 28.43149185180664 score proxy: -9.579743345966563e-05 dual proxy: -0.0025673075579106808 moment loss: 2.624108076095581 lambda reg: 6.57100872558658e-06


Training of epoch 8/10:   0%|          | 0/782 [00:00<?, ?batch/s]

reconstruction loss: 25.9505615234375 score proxy: -2.3330860585701885e-06 dual proxy: 0.003760577877983451 moment loss: 1.973718523979187 lambda reg: 6.487244718300644e-06
reconstruction loss: 25.951190948486328 score proxy: -3.922277301171562e-06 dual proxy: -0.005155536346137524 moment loss: 3.483469009399414 lambda reg: 6.782064701837953e-06
reconstruction loss: 26.670454025268555 score proxy: -1.0084911991725676e-05 dual proxy: 0.0023264954797923565 moment loss: 4.378018856048584 lambda reg: 7.255631317093503e-06
reconstruction loss: 27.187503814697266 score proxy: 9.449113349546678e-06 dual proxy: 0.013890884816646576 moment loss: 4.929195404052734 lambda reg: 7.570403340650955e-06
reconstruction loss: 27.796409606933594 score proxy: 1.1367009392415639e-05 dual proxy: 0.0032957715447992086 moment loss: 4.855687618255615 lambda reg: 7.427695891237818e-06
reconstruction loss: 27.377696990966797 score proxy: 4.326539510657312e-06 dual proxy: 0.0007703949231654406 moment loss: 4.8578

Eval of epoch 8/10:   0%|          | 0/157 [00:00<?, ?batch/s]

reconstruction loss: 25.08160972595215 score proxy: -6.712952654197579e-06 dual proxy: 0.007541295140981674 moment loss: 2.2489712238311768 lambda reg: 5.8183600231132004e-06
reconstruction loss: 25.41676139831543 score proxy: -2.1600192212645197e-06 dual proxy: 0.009503725916147232 moment loss: 3.9715182781219482 lambda reg: 5.717127351090312e-06
reconstruction loss: 27.45226287841797 score proxy: -1.6797044111172e-07 dual proxy: 0.011289101094007492 moment loss: 4.553957939147949 lambda reg: 5.487054295372218e-06
reconstruction loss: 27.01565933227539 score proxy: -1.1375362191756722e-05 dual proxy: 0.013195658102631569 moment loss: 4.613250255584717 lambda reg: 5.683355084329378e-06
reconstruction loss: 27.87017250061035 score proxy: -1.0065498827316333e-05 dual proxy: 0.009434057399630547 moment loss: 4.690145969390869 lambda reg: 5.646522822644329e-06
reconstruction loss: 26.676908493041992 score proxy: -1.0780110642372165e-05 dual proxy: 0.002698172815144062 moment loss: 4.726359

--------------------------------------------------------------------------
Train loss: 31.4758
Eval loss: 31.0908
--------------------------------------------------------------------------


reconstruction loss: 22.733020782470703 score proxy: 4.419525794219226e-06 dual proxy: 0.0022649418096989393 moment loss: 1.8370532989501953 lambda reg: 5.5640944083279464e-06


Training of epoch 9/10:   0%|          | 0/782 [00:00<?, ?batch/s]

reconstruction loss: 25.684734344482422 score proxy: -7.197700142569374e-06 dual proxy: 0.002251278143376112 moment loss: 1.977719783782959 lambda reg: 5.867972504347563e-06
reconstruction loss: 27.168392181396484 score proxy: -3.2414179713669e-06 dual proxy: 0.0029999790713191032 moment loss: 3.727226734161377 lambda reg: 5.687196789949667e-06
reconstruction loss: 25.420654296875 score proxy: -6.604038389923517e-06 dual proxy: 0.005930798128247261 moment loss: 4.303477764129639 lambda reg: 5.318843250279315e-06
reconstruction loss: 27.415966033935547 score proxy: 7.066590114845894e-06 dual proxy: 0.0024220002815127373 moment loss: 4.992206573486328 lambda reg: 4.600084594130749e-06
reconstruction loss: 25.538982391357422 score proxy: -4.78834124351124e-07 dual proxy: 0.0064475396648049355 moment loss: 5.118744850158691 lambda reg: 3.675189191199024e-06
reconstruction loss: 26.35834503173828 score proxy: 1.9457461348793004e-06 dual proxy: 0.0038365228101611137 moment loss: 4.0046458244

Eval of epoch 9/10:   0%|          | 0/157 [00:00<?, ?batch/s]

reconstruction loss: 26.241703033447266 score proxy: 4.0619605101710476e-07 dual proxy: -0.006716425530612469 moment loss: 2.202383041381836 lambda reg: 2.3918892111396417e-06
reconstruction loss: 26.48577117919922 score proxy: 1.3113378827256383e-06 dual proxy: 0.0009362236596643925 moment loss: 4.088375568389893 lambda reg: 2.499274614820024e-06
reconstruction loss: 26.603755950927734 score proxy: -4.4418663946999004e-07 dual proxy: 0.0013523148372769356 moment loss: 4.536779403686523 lambda reg: 2.393736622252618e-06
reconstruction loss: 24.693313598632812 score proxy: -4.9492896323499735e-06 dual proxy: -0.0013585290871560574 moment loss: 4.456467628479004 lambda reg: 2.3182631139206933e-06
reconstruction loss: 26.782176971435547 score proxy: -2.342345396755263e-06 dual proxy: 0.0017905613640323281 moment loss: 4.824227333068848 lambda reg: 2.3053271434037015e-06
reconstruction loss: 27.623746871948242 score proxy: 2.4566627416788833e-06 dual proxy: -0.00017156358808279037 moment l

--------------------------------------------------------------------------
Train loss: 31.5084
Eval loss: 31.1956
--------------------------------------------------------------------------


reconstruction loss: 25.929763793945312 score proxy: 5.7285164984932635e-06 dual proxy: 0.004106852691620588 moment loss: 5.754253387451172 lambda reg: 2.409507487755036e-06
reconstruction loss: 24.632678985595703 score proxy: 2.076840064546559e-05 dual proxy: -0.0014897034270688891 moment loss: 2.157636880874634 lambda reg: 2.506281134628807e-06


Training of epoch 10/10:   0%|          | 0/782 [00:00<?, ?batch/s]

reconstruction loss: 26.514766693115234 score proxy: 1.3055999659172812e-07 dual proxy: -0.009132102131843567 moment loss: 1.7902250289916992 lambda reg: 2.4854155071807327e-06
reconstruction loss: 26.544496536254883 score proxy: -2.963933866340085e-06 dual proxy: -0.007030144799500704 moment loss: 3.2827632427215576 lambda reg: 2.720144038903527e-06
reconstruction loss: 28.194480895996094 score proxy: 8.373542186745908e-06 dual proxy: -0.0011652801185846329 moment loss: 5.116941928863525 lambda reg: 3.3289409202552633e-06
reconstruction loss: 24.66786003112793 score proxy: -8.91418312676251e-06 dual proxy: -0.00109903234988451 moment loss: 5.336143493652344 lambda reg: 3.809722329606302e-06
reconstruction loss: 27.045833587646484 score proxy: 4.5312426664168015e-06 dual proxy: 0.005578506737947464 moment loss: 5.518813610076904 lambda reg: 4.333797733124811e-06
reconstruction loss: 26.828197479248047 score proxy: -2.6764396352518816e-06 dual proxy: 0.004255939275026321 moment loss: 4.

Eval of epoch 10/10:   0%|          | 0/157 [00:00<?, ?batch/s]

reconstruction loss: 25.970117568969727 score proxy: -2.3279735614778474e-06 dual proxy: -0.01529279351234436 moment loss: 2.1358211040496826 lambda reg: 3.422851250434178e-06
reconstruction loss: 25.945186614990234 score proxy: -4.243329101427662e-07 dual proxy: -0.006598319858312607 moment loss: 3.736896276473999 lambda reg: 3.424039960009395e-06
reconstruction loss: 26.985748291015625 score proxy: -4.4009109956277825e-07 dual proxy: 0.001371920108795166 moment loss: 5.398155212402344 lambda reg: 3.4414817946526455e-06
reconstruction loss: 25.757396697998047 score proxy: -2.7419619073043577e-06 dual proxy: 0.0019487671088427305 moment loss: 4.640598297119141 lambda reg: 3.402826678211568e-06
reconstruction loss: 26.168270111083984 score proxy: 1.675361886555038e-06 dual proxy: 0.004609321244060993 moment loss: 4.440212249755859 lambda reg: 3.3394280762877315e-06
reconstruction loss: 26.62261962890625 score proxy: 3.162555913149845e-06 dual proxy: 0.009946536272764206 moment loss: 5.1

--------------------------------------------------------------------------
Train loss: 31.4805
Eval loss: 31.135
--------------------------------------------------------------------------
Training ended!
Saved final model in my_amortized_dual_vae/AmortizedDualVAE_training_2025-10-17_16-41-11/final_model


reconstruction loss: 26.129722595214844 score proxy: -4.155600890953792e-06 dual proxy: -0.0016548571875318885 moment loss: 4.974730968475342 lambda reg: 3.456297690718202e-06
reconstruction loss: 26.6221866607666 score proxy: 8.957892987382365e-07 dual proxy: 0.0015907464548945427 moment loss: 5.213537693023682 lambda reg: 3.4146976304327836e-06
reconstruction loss: 25.19781494140625 score proxy: 1.3084020338283153e-06 dual proxy: 0.0005945172742940485 moment loss: 4.88966703414917 lambda reg: 3.3404808164050337e-06
reconstruction loss: 26.106449127197266 score proxy: -3.5196117096347734e-06 dual proxy: 0.004585757851600647 moment loss: 4.813777923583984 lambda reg: 3.4072797916451236e-06
reconstruction loss: 25.267841339111328 score proxy: -4.99122461405932e-06 dual proxy: -0.007853000424802303 moment loss: 1.92812979221344 lambda reg: 3.536919848556863e-06


In [36]:
import os
from pythae.models import AutoModel


In [37]:
last_training = sorted(os.listdir(config.output_dir))[-1]
trained_model = AutoModel.load_from_folder(os.path.join(config.output_dir, last_training, 'final_model'))
trained_model = trained_model.to(device)
trained_model.eval()


NameError: Cannot reload automatically the model configuration... The model name in the `model_config.json may be corrupted. Got `AmortizedDualVAEConfig`

## Sampling with the energy-based latent prior


In [None]:
import matplotlib.pyplot as plt

def sample_amortized_dual_vae(model, num_samples=25, reference_batch=None):
    device = next(model.parameters()).device
    model.eval()
    with torch.no_grad():
        if reference_batch is None:
            idx = torch.randint(0, train_dataset.shape[0], (num_samples,))
            reference_batch = train_dataset[idx]
        else:
            reference_batch = reference_batch[:num_samples]
        reference_batch = reference_batch.to(device)
        moments = model.encoder(reference_batch)["embedding"]
        lam = model.lambda_net(moments)
        latent_samples = model.sampler.sample(lam, model.basis)[:, -1, :]
        decoded = model.decoder(latent_samples)["reconstruction"]
    return decoded.cpu()


In [None]:
generated = sample_amortized_dual_vae(trained_model, num_samples=25)


In [None]:
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(generated[i*5 + j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)


## Visualizing reconstructions


In [None]:
reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()


In [None]:
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)


In [None]:
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(eval_dataset[i*5 + j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)


## Visualizing interpolations


In [None]:
interpolations = trained_model.interpolate(
    eval_dataset[:5].to(device),
    eval_dataset[5:10].to(device),
    granularity=10
).detach().cpu()


In [None]:
fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))

for i in range(5):
    for j in range(10):
        axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)
