In [1]:
import logging
import os
import pickle
import time
from functools import partial

import numpy as np
import arch, brax
import utils
from datasets import get_dataset

import jax
import jax.numpy as jnp
from jax import jit, vmap
from jax.example_libraries import optimizers, stax
from jax.tree_util import tree_map
from utils import get_calibration, jaxRNG

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define all the args
model='sdenet'
output='output_old'
seed=0
stl=False
lr=0.0007
epochs=300
bs=128
test_bs=1000
nsamples=1
w_init=-1.0
b_init=-1.0
p_init=-1.0
pause_every=200
no_drift=False
ou_dw=True
kl_coef=0.001
diff_coef=0.0001
ds='cifar10'
no_xt=True
acc_grad=1
aug=0
remat=False
ema=0.999
meanfield_sdebnn=False
infer_w0=False
w0_prior_std=0.1
disable_test=False
verbose=True
nblocks='2-2-2'
nsteps_list = [160, 140, 100, 60, 20]
block_type=0
fx_dim=64
fx_actfn='softplus'
fw_dims='2-128-2'
fw_actfn='softplus'
lr_sched='constant'
rng_generator = jaxRNG(seed= seed)

train_loader, train_eval_loader, val_loader, test_loader, input_size, train_size = get_dataset(128, 1000, "cifar10")
num_batches = len(train_loader)
print(f"Number of batches: {num_batches}")
train_batches = utils.inf_generator(train_loader)

# SDEBNN specific
mf = partial(brax.MeanField, disable=True) if kl_coef == 0. else brax.MeanField
fw_dims = list(map(int, fw_dims.split("-")))
layers = [mf(arch.Augment( aug))]
nblocks = list(map(int, nblocks.split("-")))
opt_init, opt_update, get_params = optimizers.adam(7e-4)

# Load the checkpoint if it exists
checkpoint_path = os.path.join(output, 'best_model_checkpoint.pkl')
if os.path.exists(checkpoint_path):
    logging.warning("Loading checkpoints...")
    with open(checkpoint_path, "rb") as f:
        checkpoint = pickle.load(f)

    # Extract states from the checkpoint
    start_epoch = checkpoint['epoch']
    best_val_acc = checkpoint['best_val_acc']
    global_step = checkpoint['global_step']
    opt_state = checkpoint['optimizer_state']
    ema_params = checkpoint['ema_state']
    params = checkpoint['model_state']  # Loaded parameters
    print(f"Successfully loaded checkpoints for epoch {start_epoch} with best validation accuracy {best_val_acc}")
else:
    raise SystemExit("No checkpoint found!")
    
params = get_params(opt_state)

Files already downloaded and verified
Files already downloaded and verified




Number of batches: 351
Successfully loaded checkpoints for epoch 1 with best validation accuracy 0.28600001335144043


In [3]:
def _nll(params, batch, rng):
    inputs, targets = batch
    preds, kl, info_dict = _predict(params, inputs, rng=rng, full_output=False)
    nll = -jnp.mean(jnp.sum(preds * targets, axis=1))
    return preds, nll, kl, info_dict


@partial(jit, static_argnums=(3,))
def sep_loss(params, batch, rng, kl_coef):  # no backprop
    preds, nll, kl, _ = _nll(params, batch, rng)
    if kl_coef > 0:
        obj_loss = nll + kl * kl_coef
    else:
        obj_loss = nll
    _sep_loss = {'loss': obj_loss, 'kl': kl, 'nll': nll, 'preds': preds}
    return obj_loss, _sep_loss

@partial(jit, static_argnums=(3,))
def loss(params, batch, rng, kl_coef):  # backprop so checkpoint
    _, nll, kl, _ = jax.checkpoint(_nll)(params, batch, rng)
    if kl_coef > 0:
        return nll + kl * kl_coef
    else:
        return nll

@jit
def predict(params, inputs, rng): 
    return _predict(params, inputs, rng=rng, full_output=True)

@partial(jit, static_argnums=(2,))
def accuracy(params, data, nsamples, rng):
    inputs, targets = data
    target_class = jnp.argmax(targets, axis=1)
    rngs = jax.random.split(rng, nsamples)
    preds, _, info_dic = vmap(predict, in_axes=(None, None, 0))(params, inputs, rngs)
    preds = jnp.stack(preds, axis=0)
    avg_preds = preds.mean(0)
    predicted_class = jnp.argmax(avg_preds, axis=1)
    n_correct = jnp.sum(predicted_class == target_class)
    n_total = inputs.shape[0]
    wts = info_dic['sdebnn_w']
    wts = jnp.stack(wts, axis=0)
    avg_wts = wts.mean(0)
    return n_correct, n_total, avg_preds, avg_wts

def update_ema(ema_params, params, momentum=0.999):
    return tree_map(lambda e, p: e * momentum + p * (1 - momentum), ema_params, params)

def evaluate(params, data_loader, input_size, nsamples, rng_generator, kl_coef):
    n_total = 0
    n_correct = 0
    nll = 0
    kl = 0
    logits = np.array([])
    wts = np.array([])
    labels = np.array([])
    for inputs, targets in data_loader:
        targets = jax.nn.one_hot(jnp.array(targets), num_classes=10)
        inputs = jnp.array(inputs).reshape((-1,) + (input_size[-1],) + input_size[:2])
        inputs = jnp.transpose(inputs, (0, 2, 3, 1))  # Permute from NCHW to NHWC
        batch_correct, batch_total, _logits, _wts = accuracy(
            params, (inputs, targets), nsamples, rng_generator.next()
        )
        n_correct = n_correct + batch_correct
        _, batch_nll, batch_kl, _ = jit(_nll)(params, (inputs, targets), rng_generator.next())
        if n_total == 0:
            logits = np.array(_logits)
            wts = np.array(_wts)
            labels = np.array(targets)
        else:
            logits = np.concatenate([logits, np.array(_logits)], axis=0)
            wts = np.concatenate([wts, np.array(_wts)], axis=0)
            labels = np.concatenate([labels, targets], axis=0)
        n_total = n_total + batch_total
        nll = nll + batch_nll
        kl = kl + batch_kl
    return n_correct / n_total, jnp.stack(logits, axis=0), labels, nll / n_total, kl / n_total, jnp.stack(wts, axis=0)

inference_times = []
eces = []

print("Starting experiment")
for nsteps in nsteps_list:
    for i, nb in enumerate(nblocks):
        fw = arch.MLP(fw_dims, actfn=fw_actfn, xt=no_xt, ou_dw=ou_dw, nonzero_w=w_init, nonzero_b=b_init, p_scale=p_init)  # weight network is time dependent
        if meanfield_sdebnn:
            layers.extend([mf(brax.SDEBNN(block_type,
                                            fx_dim,
                                            fx_actfn,
                                            fw,
                                            diff_coef=diff_coef,
                                            stl=stl,
                                            xt=no_xt,
                                            nsteps=nsteps,
                                            remat=remat,
                                            w_drift=not no_drift,
                                            stax_api=True,
                                            infer_initial_state=infer_w0,
                                            initial_state_prior_std=w0_prior_std)) for _ in range(nb)
            ])
        else:
            layers.extend([brax.SDEBNN( block_type,
                                        fx_dim,
                                        fx_actfn,
                                        fw,
                                        diff_coef=diff_coef,
                                        stl=stl,
                                        xt=no_xt,
                                        nsteps=nsteps,
                                        remat=remat,
                                        w_drift=not no_drift,
                                        infer_initial_state=infer_w0,
                                        initial_state_prior_std=w0_prior_std) for _ in range(nb)
            ])
        if i < len(nblocks) - 1:
            layers.append(mf(arch.SqueezeDownsample(2)))
    layers.append(mf(stax.serial(stax.Flatten, stax.Dense(10), stax.LogSoftmax)))

    init_random_params, _predict = brax.bnn_serial(*layers)

    # for inputs, targets in tqdm(test_loader): # evaluate already deals with the loop
    start_time = time.time()
    acc, logits, targets, nll, _, _ = evaluate(params, test_loader, input_size, nsamples, rng_generator, kl_coef=kl_coef)

    # Calculate inference time
    inference_time = time.time() - start_time
    inference_times.append(inference_time)

    # Convert logits to probabilities
    probabilities = jax.nn.softmax(logits)
        
    # Calculate ECE
    # TODO: might be better to use utils.ECE or utils.compute_acc_bin
    cal = get_calibration(targets, probabilities) 
    eces.append(cal['ece'])
    print(f"nsteps: {nsteps} - completed")

print("Experiment run sucessfully!")

Starting experiment


2024-01-09 14:04:14.251635: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.46GiB (rounded to 4788077568)requested by op 
2024-01-09 14:04:14.251887: W external/tsl/tsl/framework/bfc_allocator.cc:497] ****________________________________________________________________________________________________
2024-01-09 14:04:14.253011: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4788077520 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   32.62MiB
              constant allocation:     4.0KiB
        maybe_live_out allocation:   80.25MiB
     preallocated temp allocation:    4.46GiB
                 total allocation:    4.57GiB
Peak buffers:
	Buffer 1:
		Size: 1.97GiB
		XLA Label: fusion
		Shape: f32[159,1,3334848]

	Buffer 2:
		Size: 1.97GiB
		Operator: op_name="jit(accuracy)

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4788077520 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   32.62MiB
              constant allocation:     4.0KiB
        maybe_live_out allocation:   80.25MiB
     preallocated temp allocation:    4.46GiB
                 total allocation:    4.57GiB
Peak buffers:
	Buffer 1:
		Size: 1.97GiB
		XLA Label: fusion
		Shape: f32[159,1,3334848]
		==========================

	Buffer 2:
		Size: 1.97GiB
		Operator: op_name="jit(accuracy)/jit(main)/vmap(jit(predict))/jit(_stochastic_integrate)/broadcast_in_dim[shape=(159, 1, 3334848) broadcast_dimensions=()]" source_file="/nfs/home/sergioco/test_code/part_stoch_inf_deep/jax_code/sdeint.py" source_line=135
		XLA Label: fusion
		Shape: f32[159,1,3334848]
		==========================

	Buffer 3:
		Size: 80.21MiB
		Operator: op_name="jit(accuracy)/jit(main)/vmap(jit(predict))/slice[start_indices=(0, 0, 3072000) limit_indices=(1, 160, 3203424) strides=None]" source_file="/nfs/home/sergioco/test_code/part_stoch_inf_deep/jax_code/brax.py" source_line=375
		XLA Label: fusion
		Shape: f32[1,160,131424]
		==========================

	Buffer 4:
		Size: 12.72MiB
		Operator: op_name="jit(accuracy)/jit(main)/vmap(jit(predict))/jit(sample_with_rep)/jit(_normal)/jit(_normal_real)/mul" source_file="/nfs/home/sergioco/test_code/part_stoch_inf_deep/jax_code/brownian.py" source_line=72 deduplicated_name="fusion.215"
		XLA Label: fusion
		Shape: f32[1,3334848]
		==========================

	Buffer 5:
		Size: 11.72MiB
		Entry Parameter Subshape: f32[1000,32,32,3]
		==========================

	Buffer 6:
		Size: 6.36MiB
		Operator: op_name="jit(accuracy)/jit(main)/vmap(jit(predict))/jit(sample_with_rep)/jit(_normal)/jit(_normal_real)/jit(_uniform)/slice[start_indices=(1667424,) limit_indices=(3334848,) strides=None]" source_file="/nfs/home/sergioco/test_code/part_stoch_inf_deep/jax_code/brownian.py" source_line=72 deduplicated_name="fusion.346"
		XLA Label: fusion
		Shape: u32[1667424]
		==========================

	Buffer 7:
		Size: 6.36MiB
		Operator: op_name="jit(accuracy)/jit(main)/vmap(jit(predict))/jit(sample_with_rep)/jit(_normal)/jit(_normal_real)/jit(_uniform)/slice[start_indices=(0,) limit_indices=(1667424,) strides=None]" source_file="/nfs/home/sergioco/test_code/part_stoch_inf_deep/jax_code/brownian.py" source_line=72 deduplicated_name="fusion.347"
		XLA Label: fusion
		Shape: u32[1667424]
		==========================

	Buffer 8:
		Size: 1.00MiB
		Entry Parameter Subshape: f32[2,131424]
		==========================

	Buffer 9:
		Size: 1.00MiB
		Entry Parameter Subshape: f32[131424,2]
		==========================

	Buffer 10:
		Size: 1.00MiB
		Entry Parameter Subshape: f32[2,131424]
		==========================

	Buffer 11:
		Size: 1.00MiB
		Entry Parameter Subshape: f32[131424,2]
		==========================

	Buffer 12:
		Size: 699.9KiB
		Entry Parameter Subshape: f32[2,89592]
		==========================

	Buffer 13:
		Size: 699.9KiB
		Entry Parameter Subshape: f32[89592,2]
		==========================

	Buffer 14:
		Size: 699.9KiB
		Entry Parameter Subshape: f32[2,89592]
		==========================

	Buffer 15:
		Size: 699.9KiB
		Entry Parameter Subshape: f32[89592,2]
		==========================



In [None]:
print(inference_times)

[23.52790117263794, 12.648282527923584, 12.61378526687622, 12.629003524780273, 12.65815019607544]
[12.65815019607544, 12.629003524780273, 12.61378526687622, 12.648282527923584, 23.52790117263794]


In [None]:
print(eces)

[0.11246883371607906, 0.09643606737773742, 0.09320424692634721, 0.09814559212528466, 0.10699682006875459]
