In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "None"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [10]:
import sys
sys.path.append("..")
%load_ext autoreload

import jax.numpy as jnp
from jax import random

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
%autoreload 2
from configs.EI_leaky import get_config

config = get_config()
config.seed = 3

In [12]:
from models import NeuralOperator, ENN, _create_encoder_arch, _create_decoder_arch, _create_epitrain_arch, _create_epiprior_arch
encoder = _create_encoder_arch(config.encoder_arch)
decoder = _create_decoder_arch(config.decoder_arch)
base_net = NeuralOperator(encoder, decoder)
epi_train = _create_epitrain_arch(config.epitrain_arch)
epi_prior = _create_epiprior_arch(config.epiprior_arch)
arch = ENN(base_net, epi_train, epi_prior, config.ensemble_size, config.scale, config.output_activation)
u = jnp.ones(config.input_dim)
y = jnp.ones(config.query_dim)
z = jnp.ones(config.ensemble_size)
key = random.PRNGKey(config.seed)
print(arch.tabulate(key, u, y, z))


[3m                                  ENN Summary                                   [0m
┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath         [0m[1m [0m┃[1m [0m[1mmodule       [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs     [0m[1m [0m┃[1m [0m[1mparams       [0m[1m [0m┃
┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│               │ ENN           │ - [2mfloat32[0m[4]  │ [2mfloat32[0m[16]  │               │
│               │               │ - [2mfloat32[0m[2]  │              │               │
│               │               │ - [2mfloat32[0m[16] │              │               │
├───────────────┼───────────────┼───────────────┼──────────────┼───────────────┤
│ base_net      │ NeuralOperat… │ - [2mfloat32[0m[4]  │ -            │               │
│               │               │ - [2mfloat32[0m[2]  │ [2mfloat32[0m[32]  │               │
│    