In [8]:
## Standard imports
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import tqdm
import arviz as az

## Relative imports
from astroddpm.runners import Diffuser, config_from_id, get_samples
from astroddpm.analysis.validationMetrics import powerSpectrum
from astroddpm.utils.plot import check_nearest_epoch, plot_losses, check_training_samples, plot_comparaison
from astroddpm.diffusion.dm import DiscreteSBM
from astroddpm.diffusion.stochastic.sde import DiscreteVPSDE, ContinuousSDE, ContinuousVPSDE
from astroddpm.diffusion.stochastic.solver import get_schedule
from astroddpm.diffusion.models.network import ResUNet, FFResUNet
import astroddpm.utils.colormap_custom ## For CMB colormap (do not remove)
from astroddpm.moment.models import SigmaMomentModel, SigmaMomentNetwork

## Imports for inference
from inference.cmb_ps import CMBPS
from inference.utils import unnormalize_phi, normalize_phi, log_prior_phi_sigma, sample_prior_phi, log_likelihood_eps_phi_sigma, get_phi_bounds
from inference.hmc import HMC

## Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
MODEL_ID = 'ContinuousSBM_ContinuousVPSDE_I_BPROJ_bottleneck_32_firstc_10_phi_beta_cosine_betamax_0.5_betamin_0.001'

In [10]:
placeholder_dm = DiscreteSBM(DiscreteVPSDE(1000), ResUNet())
diffuser = Diffuser(placeholder_dm)

No model id found


In [11]:
NEW_MODEL_ID = 'ContinuousSBM_VPSDE_I_FFResUnet_b32_f10_phi_beta_cosine_betamax_0.5_betamin_0.001'

In [12]:
config = config_from_id(MODEL_ID)
config['diffusion_model']['network']['type'] = 'FFResUNet'
config['model_id'] = NEW_MODEL_ID
config['epochs'] = 1000
diffuser.load(config=config, also_ckpt=False)

Loading the diffuser from a config dict.
{'in_c': 1, 'out_c': 1, 'first_c': 10, 'sizes': [256, 128, 64, 32], 'num_blocks': 1, 'n_steps': 1000, 'time_emb_dim': 100, 'dropout': 0, 'attention': [], 'normalisation': 'GN', 'padding_mode': 'circular', 'eps_norm': 1e-05, 'skiprescale': True, 'type': 'FFResUNet', 'discretization': 'continuous', 'embedding_mode': 'fourier', 'has_phi': True, 'phi_shape': 2, 'phi_embed_dim': 100, 'n_ff_min': 6, 'n_ff_max': 8}


In [6]:
diffuser.ckpt_dir, diffuser.model_id

('/mnt/home/dheurtel/ceph/02_checkpoints',
 'ContinuousSBM_VPSDE_I_FFResUnet_b32_f10_phi_beta_cosine_betamax_0.5_betamin_0.001')

In [7]:
diffuser.train()

Epochs provided as attribute, using it. Be warned the model will start training at self.epoch until self.epochs-1.
Successfully saved checkpoint to /mnt/home/dheurtel/ceph/02_checkpoints/ContinuousSBM_VPSDE_I_FFResUnet_b32_f10_phi_beta_cosine_betamax_0.5_betamin_0.001


  0%|          | 0/1000 [00:00<?, ?it/s]

Successfully saved checkpoint to /mnt/home/dheurtel/ceph/02_checkpoints/ContinuousSBM_VPSDE_I_FFResUnet_b32_f10_phi_beta_cosine_betamax_0.5_betamin_0.001


 20%|██        | 200/1000 [10:25<40:43,  3.05s/it]

Successfully saved checkpoint to /mnt/home/dheurtel/ceph/02_checkpoints/ContinuousSBM_VPSDE_I_FFResUnet_b32_f10_phi_beta_cosine_betamax_0.5_betamin_0.001


 40%|████      | 400/1000 [20:55<31:32,  3.15s/it]  

Successfully saved checkpoint to /mnt/home/dheurtel/ceph/02_checkpoints/ContinuousSBM_VPSDE_I_FFResUnet_b32_f10_phi_beta_cosine_betamax_0.5_betamin_0.001


 60%|██████    | 600/1000 [31:21<20:18,  3.05s/it]

Successfully saved checkpoint to /mnt/home/dheurtel/ceph/02_checkpoints/ContinuousSBM_VPSDE_I_FFResUnet_b32_f10_phi_beta_cosine_betamax_0.5_betamin_0.001


 80%|████████  | 800/1000 [41:45<10:02,  3.01s/it]

Successfully saved checkpoint to /mnt/home/dheurtel/ceph/02_checkpoints/ContinuousSBM_VPSDE_I_FFResUnet_b32_f10_phi_beta_cosine_betamax_0.5_betamin_0.001


100%|█████████▉| 999/1000 [52:03<00:03,  3.02s/it]

Training finished. Final sampling and checkpointing.


100%|██████████| 1000/1000 [52:53<00:00,  3.17s/it]

Successfully saved checkpoint to /mnt/home/dheurtel/ceph/02_checkpoints/ContinuousSBM_VPSDE_I_FFResUnet_b32_f10_phi_beta_cosine_betamax_0.5_betamin_0.001





([2.188047409057617,
  2.000959873199463,
  2.029564142227173,
  2.054314613342285,
  2.0126476287841797,
  2.021949291229248,
  2.042091131210327,
  1.9976739883422852,
  1.9847989082336426,
  2.0388343334198,
  2.0471558570861816,
  1.9821927547454834,
  2.0175559520721436,
  1.9419875144958496,
  1.9223124980926514,
  1.9245061874389648,
  1.9238828420639038,
  1.9047985076904297,
  1.8950001001358032,
  1.857552409172058,
  1.8650524616241455,
  1.8705286979675293,
  1.8282530307769775,
  1.8133245706558228,
  1.7840209007263184,
  1.8085107803344727,
  1.8058676719665527,
  1.7267217636108398,
  1.6867895126342773,
  1.6249372959136963,
  1.7230890989303589,
  1.700169563293457,
  1.5882043838500977,
  1.5841604471206665,
  1.5815153121948242,
  1.5649998188018799,
  1.5538280010223389,
  1.563020944595337,
  1.5565659999847412,
  1.4301810264587402,
  1.462592363357544,
  1.512300968170166,
  1.4773964881896973,
  1.4556829929351807,
  1.4252891540527344,
  1.46888267993927,
  1.