Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve readability and add comments. #138

Merged
merged 1 commit into from Nov 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
220 changes: 131 additions & 89 deletions gpjax/abstractions.py
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from typing import Callable, Dict, Optional
from typing import Callable, Dict, Optional, Tuple, Any, Union

import jax
import jax.numpy as jnp
Expand All @@ -33,83 +33,18 @@

@dataclass(frozen=True)
class InferenceState:
"""Imutable dataclass for storing optimised parameters and training history."""

params: Dict
history: Float[Array, "n_iters"]

def unpack(self):
return self.params, self.history


def progress_bar_scan(n_iters: int, log_rate: int):
"""Progress bar for Jax.lax scans (adapted from https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/)."""

tqdm_bars = {}
remainder = n_iters % log_rate

def _define_tqdm(args, transform):
tqdm_bars[0] = tqdm(range(n_iters))

def _update_tqdm(args, transform):
loss_val, arg = args
tqdm_bars[0].update(arg)
tqdm_bars[0].set_postfix({"Objective": f"{loss_val: .2f}"})

def _update_progress_bar(loss_val, i):
"""Updates tqdm progress bar of a JAX scan or loop."""
_ = lax.cond(
i == 0,
lambda _: host_callback.id_tap(_define_tqdm, None, result=i),
lambda _: i,
operand=None,
)

_ = lax.cond(
# update tqdm every multiple of `print_rate` except at the end
(i % log_rate == 0) & (i != n_iters - remainder),
lambda _: host_callback.id_tap(
_update_tqdm, (loss_val, log_rate), result=i
),
lambda _: i,
operand=None,
)

_ = lax.cond(
# update tqdm by `remainder`
i == n_iters - remainder,
lambda _: host_callback.id_tap(
_update_tqdm, (loss_val, remainder), result=i
),
lambda _: i,
operand=None,
)

def _close_tqdm(args, transform):
tqdm_bars[0].close()

def close_tqdm(result, i):
return lax.cond(
i == n_iters - 1,
lambda _: host_callback.id_tap(_close_tqdm, None, result=result),
lambda _: result,
operand=None,
)

def _progress_bar_scan(func):
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`."""

def wrapper_progress_bar(carry, x):
if type(x) is tuple:
iter_num, *_ = x
else:
iter_num = x
result = func(carry, x)
*_, loss_val = result
_update_progress_bar(loss_val, iter_num)
return close_tqdm(result, iter_num)

return wrapper_progress_bar
def unpack(self) -> Tuple[Dict, Float[Array, "n_iters"]]:
"""Unpack parameters and training history into a tuple.

return _progress_bar_scan
Returns:
Tuple[Dict, Float[Array, "n_iters"]]: Tuple of parameters and training history.
"""
return self.params, self.history


def fit(
Expand Down Expand Up @@ -137,36 +72,41 @@ def fit(

params, trainables, bijectors = parameter_state.unpack()

def loss(params):
# Define optimisation loss function on unconstrained space, with a stop gradient rule for trainables that are set to False
def loss(params: Dict) -> Float[Array, "1"]:
params = trainable_params(params, trainables)
params = constrain(params, bijectors)
return objective(params)

iter_nums = jnp.arange(n_iters)

# Tranform params to unconstrained space:
# Tranform params to unconstrained space
params = unconstrain(params, bijectors)

# Initialise optimiser state
opt_state = optax_optim.init(params)

def step(carry, iter_num):
# Iteration loop numbers to scan over
iter_nums = jnp.arange(n_iters)

# Optimisation step
def step(carry, iter_num: int):
params, opt_state = carry
loss_val, loss_gradient = jax.value_and_grad(loss)(params)
updates, opt_state = optax_optim.update(loss_gradient, opt_state, params)
params = ox.apply_updates(params, updates)
carry = params, opt_state
return carry, loss_val

# Display progress bar if verbose is True
if verbose:
step = progress_bar_scan(n_iters, log_rate)(step)

# Run the optimisation loop
(params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums)

# Tranform params to constrained space:
# Tranform final params to constrained space
params = constrain(params, bijectors)

inf_state = InferenceState(params=params, history=history)

return inf_state
return InferenceState(params=params, history=history)


def fit_batches(
Expand Down Expand Up @@ -200,17 +140,23 @@ def fit_batches(

params, trainables, bijectors = parameter_state.unpack()

def loss(params, batch):
# Define optimisation loss function on unconstrained space, with a stop gradient rule for trainables that are set to False
def loss(params: Dict, batch: Dataset) -> Float[Array, "1"]:
params = trainable_params(params, trainables)
params = constrain(params, bijectors)
return objective(params, batch)

# Tranform params to unconstrained space
params = unconstrain(params, bijectors)

# Initialise optimiser state
opt_state = optax_optim.init(params)

# Mini-batch random keys and iteration loop numbers to scan over
keys = jr.split(key, n_iters)
iter_nums = jnp.arange(n_iters)

# Optimisation step
def step(carry, iter_num__and__key):
iter_num, key = iter_num__and__key
params, opt_state = carry
Expand All @@ -224,19 +170,21 @@ def step(carry, iter_num__and__key):
carry = params, opt_state
return carry, loss_val

# Display progress bar if verbose is True
if verbose:
step = progress_bar_scan(n_iters, log_rate)(step)

# Run the optimisation loop
(params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys))

# Tranform final params to constrained space
params = constrain(params, bijectors)
inf_state = InferenceState(params=params, history=history)

return inf_state
return InferenceState(params=params, history=history)


def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset:
"""Batch the data into mini-batches.
"""Batch the data into mini-batches. Sampling is done with replacement.

Args:
train_data (Dataset): The training dataset.
Expand All @@ -247,6 +195,7 @@ def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset
"""
x, y, n = train_data.X, train_data.y, train_data.n

# Subsample data inidicies with replacement to get the mini-batch
indicies = jr.choice(key, n, (batch_size,), replace=True)

return Dataset(X=x[indicies], y=y[indicies])
Expand Down Expand Up @@ -285,18 +234,23 @@ def fit_natgrads(

params, trainables, bijectors = parameter_state.unpack()

# Tranform params to unconstrained space
params = unconstrain(params, bijectors)

# Initialise optimiser states
hyper_state = hyper_optim.init(params)
moment_state = moment_optim.init(params)

# Build natural and hyperparameter gradient functions
nat_grads_fn, hyper_grads_fn = natural_gradients(
stochastic_vi, train_data, bijectors, trainables
)

# Mini-batch random keys and iteration loop numbers to scan over
keys = jax.random.split(key, n_iters)
iter_nums = jnp.arange(n_iters)

# Optimisation step
def step(carry, iter_num__and__key):
iter_num, key = iter_num__and__key
params, hyper_state, moment_state = carry
Expand All @@ -316,15 +270,103 @@ def step(carry, iter_num__and__key):
carry = params, hyper_state, moment_state
return carry, loss_val

# Display progress bar if verbose is True
if verbose:
step = progress_bar_scan(n_iters, log_rate)(step)

# Run the optimisation loop
(params, _, _), history = jax.lax.scan(
step, (params, hyper_state, moment_state), (iter_nums, keys)
)

# Tranform final params to constrained space
params = constrain(params, bijectors)
inf_state = InferenceState(params=params, history=history)
return inf_state

return InferenceState(params=params, history=history)


def progress_bar_scan(n_iters: int, log_rate: int) -> Callable:
"""Progress bar for Jax.lax scans (adapted from https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/)."""

tqdm_bars = {}
remainder = n_iters % log_rate

def _define_tqdm(args: Any, transform: Any) -> None:
"""Define a tqdm progress bar."""
tqdm_bars[0] = tqdm(range(n_iters))

def _update_tqdm(args: Any, transform: Any) -> None:
"""Update the tqdm progress bar with the latest objective value."""
loss_val, arg = args
tqdm_bars[0].update(arg)
tqdm_bars[0].set_postfix({"Objective": f"{loss_val: .2f}"})

def _close_tqdm(args: Any, transform: Any) -> None:
"""Close the tqdm progress bar."""
tqdm_bars[0].close()

def _callback(cond:bool, func: Callable, arg: Any) -> None:
"""Callback a function for a given argument if a condition is true."""
dummy_result = 0

def _do_callback(_) -> int:
"""Perform the callback."""
return host_callback.id_tap(func, arg, result=dummy_result)

def _not_callback(_) -> int:
"""Do nothing."""
return dummy_result

_ = lax.cond(cond, _do_callback, _not_callback, operand=None)


def _update_progress_bar(loss_val: Float[Array, "1"], iter_num: int) -> None:
"""Updates tqdm progress bar of a JAX scan or loop."""

# Conditions for iteration number
is_first: bool = iter_num == 0
is_multiple: bool = (iter_num % log_rate == 0) & (iter_num != n_iters - remainder)
is_remainder: bool = iter_num == n_iters - remainder
is_last: bool = iter_num == n_iters - 1

# Define progress bar, if first iteration
_callback(is_first, _define_tqdm, None)

# Update progress bar, if multiple of log_rate
_callback(is_multiple, _update_tqdm, (loss_val, log_rate))

# Update progress bar, if remainder
_callback(is_remainder, _update_tqdm, (loss_val, remainder))

# Close progress bar, if last iteration
_callback(is_last, _close_tqdm, None)


def _progress_bar_scan(body_fun: Callable) -> Callable:
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`."""

def wrapper_progress_bar(carry: Any, x: Union[tuple, int]) -> Any:

# Get iteration number
if type(x) is tuple:
iter_num, *_ = x
else:
iter_num = x

# Compute iteration step
result = body_fun(carry, x)

# Get loss value
*_, loss_val = result

# Update progress bar
_update_progress_bar(loss_val, iter_num)

return result

return wrapper_progress_bar

return _progress_bar_scan


__all__ = [
Expand Down
4 changes: 2 additions & 2 deletions gpjax/config.py
Expand Up @@ -21,7 +21,7 @@

__config = None

FillTriangular = dx.Chain([tfb.FillTriangular()]) # TODO: Dan to chain methods.
FillTriangular = dx.Chain([tfb.FillTriangular(), ]) # TODO: Dan to chain methods.
Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x)
Softplus = dx.Lambda(
forward=lambda x: jnp.log(1 + jnp.exp(x)),
Expand Down Expand Up @@ -75,7 +75,7 @@ def add_parameter(param_name: str, bijection: dx.Bijector) -> None:

Args:
param_name (str): The name of the parameter that is to be added.
bijection (tfb.Bijector): The bijection that should be used to unconstrain the parameter's value.
bijection (dx.Bijector): The bijection that should be used to unconstrain the parameter's value.
"""
lookup_name = f"{param_name}_transform"
get_defaults()
Expand Down