Skip to content

Commit

Permalink
Improve readability and add comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd committed Nov 6, 2022
1 parent 8542439 commit 9ee096a
Show file tree
Hide file tree
Showing 7 changed files with 429 additions and 260 deletions.
220 changes: 131 additions & 89 deletions gpjax/abstractions.py
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from typing import Callable, Dict, Optional
from typing import Callable, Dict, Optional, Tuple, Any, Union

import jax
import jax.numpy as jnp
Expand All @@ -33,83 +33,18 @@

@dataclass(frozen=True)
class InferenceState:
"""Imutable dataclass for storing optimised parameters and training history."""

params: Dict
history: Float[Array, "n_iters"]

def unpack(self):
return self.params, self.history


def progress_bar_scan(n_iters: int, log_rate: int):
"""Progress bar for Jax.lax scans (adapted from https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/)."""

tqdm_bars = {}
remainder = n_iters % log_rate

def _define_tqdm(args, transform):
tqdm_bars[0] = tqdm(range(n_iters))

def _update_tqdm(args, transform):
loss_val, arg = args
tqdm_bars[0].update(arg)
tqdm_bars[0].set_postfix({"Objective": f"{loss_val: .2f}"})

def _update_progress_bar(loss_val, i):
"""Updates tqdm progress bar of a JAX scan or loop."""
_ = lax.cond(
i == 0,
lambda _: host_callback.id_tap(_define_tqdm, None, result=i),
lambda _: i,
operand=None,
)

_ = lax.cond(
# update tqdm every multiple of `print_rate` except at the end
(i % log_rate == 0) & (i != n_iters - remainder),
lambda _: host_callback.id_tap(
_update_tqdm, (loss_val, log_rate), result=i
),
lambda _: i,
operand=None,
)

_ = lax.cond(
# update tqdm by `remainder`
i == n_iters - remainder,
lambda _: host_callback.id_tap(
_update_tqdm, (loss_val, remainder), result=i
),
lambda _: i,
operand=None,
)

def _close_tqdm(args, transform):
tqdm_bars[0].close()

def close_tqdm(result, i):
return lax.cond(
i == n_iters - 1,
lambda _: host_callback.id_tap(_close_tqdm, None, result=result),
lambda _: result,
operand=None,
)

def _progress_bar_scan(func):
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`."""

def wrapper_progress_bar(carry, x):
if type(x) is tuple:
iter_num, *_ = x
else:
iter_num = x
result = func(carry, x)
*_, loss_val = result
_update_progress_bar(loss_val, iter_num)
return close_tqdm(result, iter_num)

return wrapper_progress_bar
def unpack(self) -> Tuple[Dict, Float[Array, "n_iters"]]:
"""Unpack parameters and training history into a tuple.
return _progress_bar_scan
Returns:
Tuple[Dict, Float[Array, "n_iters"]]: Tuple of parameters and training history.
"""
return self.params, self.history


def fit(
Expand Down Expand Up @@ -137,36 +72,41 @@ def fit(

params, trainables, bijectors = parameter_state.unpack()

def loss(params):
# Define optimisation loss function on unconstrained space, with a stop gradient rule for trainables that are set to False
def loss(params: Dict) -> Float[Array, "1"]:
params = trainable_params(params, trainables)
params = constrain(params, bijectors)
return objective(params)

iter_nums = jnp.arange(n_iters)

# Tranform params to unconstrained space:
# Tranform params to unconstrained space
params = unconstrain(params, bijectors)

# Initialise optimiser state
opt_state = optax_optim.init(params)

def step(carry, iter_num):
# Iteration loop numbers to scan over
iter_nums = jnp.arange(n_iters)

# Optimisation step
def step(carry, iter_num: int):
params, opt_state = carry
loss_val, loss_gradient = jax.value_and_grad(loss)(params)
updates, opt_state = optax_optim.update(loss_gradient, opt_state, params)
params = ox.apply_updates(params, updates)
carry = params, opt_state
return carry, loss_val

# Display progress bar if verbose is True
if verbose:
step = progress_bar_scan(n_iters, log_rate)(step)

# Run the optimisation loop
(params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums)

# Tranform params to constrained space:
# Tranform final params to constrained space
params = constrain(params, bijectors)

inf_state = InferenceState(params=params, history=history)

return inf_state
return InferenceState(params=params, history=history)


def fit_batches(
Expand Down Expand Up @@ -200,17 +140,23 @@ def fit_batches(

params, trainables, bijectors = parameter_state.unpack()

def loss(params, batch):
# Define optimisation loss function on unconstrained space, with a stop gradient rule for trainables that are set to False
def loss(params: Dict, batch: Dataset) -> Float[Array, "1"]:
params = trainable_params(params, trainables)
params = constrain(params, bijectors)
return objective(params, batch)

# Tranform params to unconstrained space
params = unconstrain(params, bijectors)

# Initialise optimiser state
opt_state = optax_optim.init(params)

# Mini-batch random keys and iteration loop numbers to scan over
keys = jr.split(key, n_iters)
iter_nums = jnp.arange(n_iters)

# Optimisation step
def step(carry, iter_num__and__key):
iter_num, key = iter_num__and__key
params, opt_state = carry
Expand All @@ -224,19 +170,21 @@ def step(carry, iter_num__and__key):
carry = params, opt_state
return carry, loss_val

# Display progress bar if verbose is True
if verbose:
step = progress_bar_scan(n_iters, log_rate)(step)

# Run the optimisation loop
(params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys))

# Tranform final params to constrained space
params = constrain(params, bijectors)
inf_state = InferenceState(params=params, history=history)

return inf_state
return InferenceState(params=params, history=history)


def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset:
"""Batch the data into mini-batches.
"""Batch the data into mini-batches. Sampling is done with replacement.
Args:
train_data (Dataset): The training dataset.
Expand All @@ -247,6 +195,7 @@ def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset
"""
x, y, n = train_data.X, train_data.y, train_data.n

# Subsample data inidicies with replacement to get the mini-batch
indicies = jr.choice(key, n, (batch_size,), replace=True)

return Dataset(X=x[indicies], y=y[indicies])
Expand Down Expand Up @@ -285,18 +234,23 @@ def fit_natgrads(

params, trainables, bijectors = parameter_state.unpack()

# Tranform params to unconstrained space
params = unconstrain(params, bijectors)

# Initialise optimiser states
hyper_state = hyper_optim.init(params)
moment_state = moment_optim.init(params)

# Build natural and hyperparameter gradient functions
nat_grads_fn, hyper_grads_fn = natural_gradients(
stochastic_vi, train_data, bijectors, trainables
)

# Mini-batch random keys and iteration loop numbers to scan over
keys = jax.random.split(key, n_iters)
iter_nums = jnp.arange(n_iters)

# Optimisation step
def step(carry, iter_num__and__key):
iter_num, key = iter_num__and__key
params, hyper_state, moment_state = carry
Expand All @@ -316,15 +270,103 @@ def step(carry, iter_num__and__key):
carry = params, hyper_state, moment_state
return carry, loss_val

# Display progress bar if verbose is True
if verbose:
step = progress_bar_scan(n_iters, log_rate)(step)

# Run the optimisation loop
(params, _, _), history = jax.lax.scan(
step, (params, hyper_state, moment_state), (iter_nums, keys)
)

# Tranform final params to constrained space
params = constrain(params, bijectors)
inf_state = InferenceState(params=params, history=history)
return inf_state

return InferenceState(params=params, history=history)


def progress_bar_scan(n_iters: int, log_rate: int) -> Callable:
"""Progress bar for Jax.lax scans (adapted from https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/)."""

tqdm_bars = {}
remainder = n_iters % log_rate

def _define_tqdm(args: Any, transform: Any) -> None:
"""Define a tqdm progress bar."""
tqdm_bars[0] = tqdm(range(n_iters))

def _update_tqdm(args: Any, transform: Any) -> None:
"""Update the tqdm progress bar with the latest objective value."""
loss_val, arg = args
tqdm_bars[0].update(arg)
tqdm_bars[0].set_postfix({"Objective": f"{loss_val: .2f}"})

def _close_tqdm(args: Any, transform: Any) -> None:
"""Close the tqdm progress bar."""
tqdm_bars[0].close()

def _callback(cond:bool, func: Callable, arg: Any) -> None:
"""Callback a function for a given argument if a condition is true."""
dummy_result = 0

def _do_callback(_) -> int:
"""Perform the callback."""
return host_callback.id_tap(func, arg, result=dummy_result)

def _not_callback(_) -> int:
"""Do nothing."""
return dummy_result

_ = lax.cond(cond, _do_callback, _not_callback, operand=None)


def _update_progress_bar(loss_val: Float[Array, "1"], iter_num: int) -> None:
"""Updates tqdm progress bar of a JAX scan or loop."""

# Conditions for iteration number
is_first: bool = iter_num == 0
is_multiple: bool = (iter_num % log_rate == 0) & (iter_num != n_iters - remainder)
is_remainder: bool = iter_num == n_iters - remainder
is_last: bool = iter_num == n_iters - 1

# Define progress bar, if first iteration
_callback(is_first, _define_tqdm, None)

# Update progress bar, if multiple of log_rate
_callback(is_multiple, _update_tqdm, (loss_val, log_rate))

# Update progress bar, if remainder
_callback(is_remainder, _update_tqdm, (loss_val, remainder))

# Close progress bar, if last iteration
_callback(is_last, _close_tqdm, None)


def _progress_bar_scan(body_fun: Callable) -> Callable:
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`."""

def wrapper_progress_bar(carry: Any, x: Union[tuple, int]) -> Any:

# Get iteration number
if type(x) is tuple:
iter_num, *_ = x
else:
iter_num = x

# Compute iteration step
result = body_fun(carry, x)

# Get loss value
*_, loss_val = result

# Update progress bar
_update_progress_bar(loss_val, iter_num)

return result

return wrapper_progress_bar

return _progress_bar_scan


__all__ = [
Expand Down
4 changes: 2 additions & 2 deletions gpjax/config.py
Expand Up @@ -21,7 +21,7 @@

__config = None

FillTriangular = dx.Chain([tfb.FillTriangular()]) # TODO: Dan to chain methods.
FillTriangular = dx.Chain([tfb.FillTriangular(), ]) # TODO: Dan to chain methods.
Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x)
Softplus = dx.Lambda(
forward=lambda x: jnp.log(1 + jnp.exp(x)),
Expand Down Expand Up @@ -75,7 +75,7 @@ def add_parameter(param_name: str, bijection: dx.Bijector) -> None:
Args:
param_name (str): The name of the parameter that is to be added.
bijection (tfb.Bijector): The bijection that should be used to unconstrain the parameter's value.
bijection (dx.Bijector): The bijection that should be used to unconstrain the parameter's value.
"""
lookup_name = f"{param_name}_transform"
get_defaults()
Expand Down

0 comments on commit 9ee096a

Please sign in to comment.