In [15]:
import pickle
import jax
from jax.example_libraries import stax
from jax import grad, jit, vmap
import jax.numpy as jnp
import numpy as np
from tqdm import tqdm
from pathlib import Path
from datasets import get_dataset
import time
from functools import partial

from utils import get_calibration, jaxRNG
from sdebnn_classification import evaluate
import brax
import arch

In [9]:
def _nll(params, batch, rng):
    inputs, targets = batch
    preds, kl, info_dict = _predict(params, inputs, rng=rng, full_output=False)
    nll = -jnp.mean(jnp.sum(preds * targets, axis=1))
    return preds, nll, kl, info_dict

def predict(params, inputs, rng): 
    return _predict(params, inputs, rng=rng, full_output=True)

@partial(jit, static_argnums=(2,))
def accuracy(params, data, nsamples, rng):
    inputs, targets = data
    target_class = jnp.argmax(targets, axis=1)
    rngs = jax.random.split(rng, nsamples)
    preds, _, info_dic = vmap(predict, in_axes=(None, None, 0))(params, inputs, rngs)
    preds = jnp.stack(preds, axis=0)
    avg_preds = preds.mean(0)
    predicted_class = jnp.argmax(avg_preds, axis=1)
    n_correct = jnp.sum(predicted_class == target_class)
    n_total = inputs.shape[0]
    wts = info_dic['sdebnn_w']
    wts = jnp.stack(wts, axis=0)
    avg_wts = wts.mean(0)
    return n_correct, n_total, avg_preds, avg_wts

def evaluate(params, data_loader, input_size, nsamples, rng_generator):
    n_total = 0
    n_correct = 0
    nll = 0
    kl = 0
    logits = np.array([])
    wts = np.array([])
    labels = np.array([])
    for inputs, targets in data_loader:
        targets = jax.nn.one_hot(jnp.array(targets), num_classes=10)
        inputs = jnp.array(inputs).reshape((-1,) + (input_size[-1],) + input_size[:2])
        inputs = jnp.transpose(inputs, (0, 2, 3, 1))  # Permute from NCHW to NHWC
        batch_correct, batch_total, _logits, _wts = accuracy(
            params, (inputs, targets), nsamples, rng_generator.next(),
        ) # _logits (nbatch, nclass)
        n_correct = n_correct + batch_correct
        _, batch_nll, batch_kl, _ = jit(_nll)(params, (inputs, targets), rng_generator.next())
        if n_total == 0:
            logits = np.array(_logits)
            wts = np.array(_wts)
            labels = np.array(targets)
        else:
            logits = np.concatenate([logits, np.array(_logits)], axis=0)
            wts = np.concatenate([wts, np.array(_wts)], axis=0)
            labels = np.concatenate([labels, targets], axis=0)
        n_total = n_total + batch_total
        nll = nll + batch_nll
        kl = kl + batch_kl
    return n_correct / n_total, jnp.stack(logits, axis=0), labels, nll / n_total, kl / n_total, jnp.stack(wts, axis=0)

In [17]:
def _nll(params, batch, rng):
    inputs, targets = batch
    print("Inputs shape in nll", inputs.shape)
    preds, kl, info_dict = _predict(params, inputs, rng=rng, full_output=False)
    nll = -jnp.mean(jnp.sum(preds * targets, axis=1))
    return preds, nll, kl, info_dict

def predict(params, inputs, rng):
    print("Entering predict function")
    print("Inputs shape in predict:", inputs.shape)
    print("RNG state in predict:", rng)
    print("data type of inputs:", type(inputs))

    # Assuming _predict is a function called within predict
    try:
        result = _predict(params, inputs, rng=rng, full_output=True)
    except Exception as e:
        print("Error encountered in _predict:", e)
        print("Inputs:", inputs)
        raise e

    return result

@partial(jit, static_argnums=(2,))
def accuracy(params, data, nsamples, rng):
    inputs, targets = data
    target_class = jnp.argmax(targets, axis=1)
    rngs = jax.random.split(rng, nsamples)

    print('Inputs shape in accuracy', inputs.shape)
    print('rngs shape in accuracy', rngs.shape)

    # Select a single input and a single RNG state
    single_input = inputs[0]
    single_rng_state = rngs[0]

    # Call the predict function directly
    single_pred = predict(params, single_input, single_rng_state)

    # Print the output to check if it works correctly
    print('Single Prediction', single_pred)

    # Directly pass inputs as an array to vmap
    preds, _, info_dic = vmap(predict, in_axes=(None, 0, 0))(params, inputs, rngs)
    print("Shape after vmap in accuracy:", preds.shape)
    preds = jnp.stack(preds, axis=0)
    avg_preds = preds.mean(0)
    predicted_class = jnp.argmax(avg_preds, axis=1)
    n_correct = jnp.sum(predicted_class == target_class)
    n_total = inputs.shape[0]
    wts = info_dic['sdebnn_w']
    wts = jnp.stack(wts, axis=0)
    avg_wts = wts.mean(0)
    return n_correct, n_total, avg_preds, avg_wts

def evaluate(params, data_loader, input_size, nsamples, rng_generator):
    n_total = 0
    n_correct = 0
    nll = 0
    kl = 0
    logits = np.array([])
    wts = np.array([])
    labels = np.array([])
    for inputs, targets in data_loader:
        targets = jax.nn.one_hot(jnp.array(targets), num_classes=10)
        inputs = jnp.array(inputs).reshape((-1,) + (input_size[-1],) + input_size[:2])
        inputs = jnp.transpose(inputs, (0, 2, 3, 1))  # Permute from NCHW to NHWC
        
        print("Inputs shape in evaluate", inputs.shape) # Should be ((1000, 32, 32, 3),)
        batch_correct, batch_total, _logits, _wts = accuracy(
            params, (inputs, targets), nsamples, rng_generator.next(),
        ) # _logits (nbatch, nclass)
        n_correct = n_correct + batch_correct
        _, batch_nll, batch_kl, _ = jit(_nll)(params, (inputs, targets), rng_generator.next())
        if n_total == 0:
            logits = np.array(_logits)
            wts = np.array(_wts)
            labels = np.array(targets)
        else:
            logits = np.concatenate([logits, np.array(_logits)], axis=0)
            wts = np.concatenate([wts, np.array(_wts)], axis=0)
            labels = np.concatenate([labels, targets], axis=0)
        n_total = n_total + batch_total
        nll = nll + batch_nll
        kl = kl + batch_kl
    return n_correct / n_total, jnp.stack(logits, axis=0), labels, nll / n_total, kl / n_total, jnp.stack(wts, axis=0)

# Build model
dt_list = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
fw_dims = list(map(int, "2-128-2".split("-")))
aug = 0
nsamples = 1000
rng_generator = jaxRNG(0)
kl_coef = 1e-3

mf = partial(brax.MeanField, disable=True) if kl_coef == 0. else brax.MeanField
layers = [mf(arch.Augment(aug))]
nblocks = list(map(int, "2-2-2".split("-")))
_, _, _, test_loader, input_size, _ = get_dataset(128, 1000, "cifar10")

inference_times = []
eces = []

for dt in dt_list:
    inference_times_dt = []
    eces_dt = []
    for i, nb in enumerate(nblocks):
        fw = arch.MLP(fw_dims, actfn="softplus", xt=False, ou_dw=False, nonzero_w=-1., nonzero_b=-1., p_scale=-1.)  # weight network is time dependent
        layers.extend([brax.SDEBNN(fx_block_type=0,
                                    fx_dim=64, # Shouldn't this be 128?
                                    fx_actfn="softplus",
                                    fw=fw,
                                    diff_coef=1e-4,
                                    stl=False,
                                    xt=False,
                                    nsteps=20,
                                    dt=dt,
                                    remat=False,
                                    w_drift=True,
                                    infer_initial_state=False,
                                    initial_state_prior_std=0.1) for _ in range(nb)
            ])
        if i < len(nblocks) - 1:
            layers.append(mf(arch.SqueezeDownsample(2)))
    layers.append(mf(stax.serial(stax.Flatten, stax.Dense(10), stax.LogSoftmax)))

    _, _predict = brax.bnn_serial(*layers)

    checkpoint_path = 'output/best_model_checkpoint.pkl'
    with open(checkpoint_path, 'rb') as f:
        checkpoint = pickle.load(f)

    params = checkpoint['model_state']

    # for inputs, targets in tqdm(test_loader): # evaluate already deasl with the loop
    start_time = time.time()
    _, logits, targets, _, _, _ = evaluate(params, test_loader, input_size, nsamples, rng_generator)

    # Calculate inference time
    inference_time = time.time() - start_time
    inference_times.append(inference_time)

    # Convert logits to probabilities
    probabilities = jax.nn.softmax(logits)
        
    # Calculate ECE
    cal = get_calibration(targets, probabilities)
    eces.append(cal['ece'])

Files already downloaded and verified
Files already downloaded and verified
Inputs shape in evaluate (1000, 32, 32, 3)
Inputs shape in accuracy (1000, 32, 32, 3)
rngs shape in accuracy (1000, 2)
Entering predict function
Inputs shape in predict: (32, 32, 3)
RNG state in predict: Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=1/0)>
data type of inputs: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
Error encountered in _predict: tuple index out of range
Inputs: Traced<ShapedArray(float32[32,32,3])>with<DynamicJaxprTrace(level=1/0)>


IndexError: tuple index out of range