In [None]:
!pip install jax
!pip install dm-haiku

Collecting dm-haiku
  Downloading dm_haiku-0.0.10-py3-none-any.whl (360 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m360.3/360.3 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
Collecting jmp>=0.0.2 (from dm-haiku)
  Downloading jmp-0.0.4-py3-none-any.whl (18 kB)
Installing collected packages: jmp, dm-haiku
Successfully installed dm-haiku-0.0.10 jmp-0.0.4


In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import optax

In [None]:
devices = jax.devices()
if len(devices) > 0:
    jax.devices()[0]  # Use the first available GPU


In [None]:
def init_bayes_linear(input_dim, output_dim, prior, rng):
    weight_mus = jax.random.uniform(rng, (input_dim, output_dim), minval=-0.05, maxval=0.05)
    weight_rhos = jax.random.uniform(rng, (input_dim, output_dim), minval=-2.0, maxval=-1.0)
    bias_mus = jax.random.uniform(rng, (output_dim,), minval=-0.05, maxval=0.05)
    bias_rhos = jax.random.uniform(rng, (output_dim,), minval=-2.0, maxval=-1.0)
    return {
        'weight_mus': weight_mus,
        'weight_rhos': weight_rhos,
        'bias_mus': bias_mus,
        'bias_rhos': bias_rhos,
        'prior': prior,
    }

def bayes_linear_normalq(params, x, sample=True):
    weight_mus = params['weight_mus']
    weight_rhos = params['weight_rhos']
    bias_mus = params['bias_mus']
    bias_rhos = params['bias_rhos']
    prior = params['prior']

    if sample:
        # Sample gaussian noise for each weight and each bias
        weight_epsilons = jax.random.normal(jax.random.PRNGKey(0), weight_mus.shape)
        bias_epsilons = jax.random.normal(jax.random.PRNGKey(1), bias_mus.shape)

        # Calculate the weight and bias stds from the rho parameters
        weight_stds = jnp.log(1 + jnp.exp(weight_rhos))
        bias_stds = jnp.log(1 + jnp.exp(bias_rhos))

        # Calculate samples from the posterior from the sampled noise and mus/stds
        weight_sample = weight_mus + weight_epsilons * weight_stds
        bias_sample = bias_mus + bias_epsilons * bias_stds

        output = jnp.dot(x, weight_sample) + bias_sample

        # Computing the KL loss term
        prior_cov, varpost_cov = prior['sigma'] ** 2, weight_stds ** 2
        KL_loss = 0.5 * (jnp.log(prior_cov / varpost_cov)).sum() - 0.5 * weight_stds.size
        KL_loss = KL_loss + 0.5 * (varpost_cov / prior_cov).sum()
        KL_loss = KL_loss + 0.5 * ((weight_mus - prior['mu']) ** 2 / prior_cov).sum()

        prior_cov, varpost_cov = prior['sigma'] ** 2, bias_stds ** 2
        KL_loss = KL_loss + 0.5 * (jnp.log(prior_cov / varpost_cov)).sum() - 0.5 * bias_stds.size
        KL_loss = KL_loss + 0.5 * (varpost_cov / prior_cov).sum()
        KL_loss = KL_loss + 0.5 * ((bias_mus - prior['mu']) ** 2 / prior_cov).sum()

        return output, KL_loss
    else:
        output = jnp.dot(x, weight_mus) + bias_mus
        KL_loss = 0.0  # Inference without sampling, so KL loss is zero
        return output, KL_loss


In [None]:
prior = {'mu': 0.0, 'sigma': 0.1}
rng = jax.random.PRNGKey(0)

In [None]:
# Define the forward pass for the hidden layer
def hidden_layer(x, params, sample=True):
    output, kl_loss = bayes_linear_normalq(params, x, sample=sample)
    return jax.nn.tanh(output), kl_loss

# Define a neural network with 1 hidden layer
def neural_network(x, params, sample=True):
    # Hidden layer
    hidden_output, hidden_kl_loss = hidden_layer(x, params['hidden_layer'], sample=sample)

    # Output layer
    output, output_kl_loss = bayes_linear_normalq(params['output_layer'], hidden_output, sample=sample)

    # Combine KL losses
    kl_loss = hidden_kl_loss + output_kl_loss

    return output, kl_loss


In [None]:
# Initialize the parameters for the hidden layer
hidden_layer_params = init_bayes_linear(input_dim=1, output_dim=64, prior=prior, rng=rng)

In [None]:
from sklearn.model_selection import train_test_split
# Generate evenly spaced x values
x = np.linspace(0, 1, 51).reshape(-1, 1)

# Calculate y values
y = np.sin(x * 2 * np.pi) + np.random.normal(size=(51, 1), scale=0.1)

# Split the dataset into a training and evaluation seta
X_train, X_eval, y_train, y_eval = train_test_split(x, y, test_size=0.5, random_state=42)



In [None]:
params = {
    'hidden_layer': hidden_layer_params,
    'output_layer': init_bayes_linear(input_dim=64, output_dim=1, prior=prior, rng=rng),
    'log_noise' : jnp.array([0.0])
}
output, kl_loss = neural_network(x, params, sample=True)

In [None]:
import jax.numpy as jnp

def log_gaussian_loss(output, target, sigma, no_dim):
    exponent = -0.5 * jnp.square(target - output) / jnp.square(sigma)
    log_coeff = -no_dim * jnp.log(sigma)

    return -jnp.sum(log_coeff + exponent)


In [None]:
def loss(params, x, y, rng):
  fit_loss = 0.0
  kl_loss = 0.0
  for i in range(10):
    predictions, kl_loss = neural_network(x, params, sample=True)
    gaussian_loss = log_gaussian_loss(predictions, y, jnp.exp(params['log_noise']), 1)
    fit_loss += gaussian_loss
  total_loss = (fit_loss + kl_loss)/(10 * x.shape[0])

  return total_loss, fit_loss, kl_loss


def elbo_loss(params, x, y, rng):
  return loss(params, x, y, rng)[0]

In [None]:
# Initialize the optimizer
opt = optax.sgd(learning_rate=1e-1)
opt_state = opt.init(params)

# Training loop
beta = 1
num_steps = 500

In [None]:
best_params = params
best_loss = 1000

In [None]:
for step in range(num_steps):
    grads = jax.grad(elbo_loss)(params, X_train, y_train, rng)
    # Update the parameters
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    total_loss, fit_loss, kl_loss = loss(params, X_train, y_train, rng)

    if (total_loss < best_loss):
      best_loss = total_loss
      best_params = params
    if step % 10 == 0:
        print(f"Step {step}, Loss: {total_loss:.4f}, Fit Loss: {fit_loss:.4f}, KL Loss: {kl_loss:.4f}, noise: { jnp.exp(params['log_noise'])[0]:.4f}")
        # Print evaluation loss if needed

# Continue with plotting the predictions and uncertainties.


Step 0, Loss: 1.9046, Fit Loss: 172.4726, KL Loss: 303.6797, noise: 1.3111
Step 10, Loss: 1.5681, Fit Loss: 125.5584, KL Loss: 266.4687, noise: 1.0851
Step 20, Loss: 1.1971, Fit Loss: 94.9912, KL Loss: 204.2810, noise: 0.9425
Step 30, Loss: 0.8637, Fit Loss: 72.8573, KL Loss: 143.0558, noise: 0.8477
Step 40, Loss: 0.3254, Fit Loss: 58.2337, KL Loss: 23.1229, noise: 0.7879
Step 50, Loss: 0.2975, Fit Loss: 50.1466, KL Loss: 24.2182, noise: 0.7535
Step 60, Loss: 0.2833, Fit Loss: 44.8599, KL Loss: 25.9667, noise: 0.7334
Step 70, Loss: 0.2747, Fit Loss: 41.2888, KL Loss: 27.3760, noise: 0.7205
Step 80, Loss: 0.2688, Fit Loss: 38.7290, KL Loss: 28.4617, noise: 0.7118
Step 90, Loss: 0.2640, Fit Loss: 36.7094, KL Loss: 29.2813, noise: 0.7053
Step 100, Loss: 0.2591, Fit Loss: 34.8804, KL Loss: 29.9048, noise: 0.6999
Step 110, Loss: 0.2533, Fit Loss: 32.9304, KL Loss: 30.4067, noise: 0.6947
Step 120, Loss: 0.2455, Fit Loss: 30.5155, KL Loss: 30.8632, noise: 0.6887
Step 130, Loss: 0.2342, Fit Lo

In [None]:
# Number of samples to draw from the posterior
num_samples = 100

# Lists to store predicted values and uncertainties
predicted_values = []

for _ in range(num_samples):
    sampled_predictions, _ = neural_network(x, best_params, sample=True)
    predicted_values.append(sampled_predictions)

# Calculate the mean and standard deviation of the predictions
mean_predictions = jnp.mean(jnp.stack(predicted_values), axis=0)
std_predictions = jnp.std(jnp.stack(predicted_values), axis=0)
uncertainity = (best_params['log_noise'] ** 2 + std_predictions ** 2) ** 0.5

In [None]:
import matplotlib.pyplot as plt

# Sort the data for cleaner plotting
sorted_indices = np.argsort(x.flatten())
x_sorted = x[sorted_indices]
y_sorted = y[sorted_indices]
mean_predictions_sorted = mean_predictions[sorted_indices]
std_predictions_sorted = std_predictions[sorted_indices]

# Create a figure and axis
plt.figure(figsize=(12, 6))
ax = plt.gca()

# Plot the true data
ax.scatter(x_sorted, y_sorted, label='True Data', marker="o", alpha=0.7, lw=0.6, color='black')

# Plot the mean predictions
ax.plot(x_sorted, mean_predictions_sorted, label='Mean Predictions', color='blue')

# Fill between one standard deviation of the predictions (1 std)
std_plt = plt.gca().fill_between(
        x_sorted.squeeze(), (mean_predictions_sorted - 1 * uncertainity).squeeze(), (mean_predictions_sorted + 1 * uncertainity).squeeze(),
        color='lightgray', alpha=0.5 / 1, label=f'Uncertainty ({1} std dev)')


# Set axis labels and title
ax.set_xlabel('X Data')
ax.set_ylabel('Y Data')
ax.set_title('Bayesian Neural Network Prediction Plot')

# Add a legend and grid
ax.legend()
ax.grid(True, linestyle='--', alpha=0.5)

# Show the plot
plt.show()
