In [11]:
import jax
import jax.numpy as jnp
from jax import random, jit
from typing import Any
from jax.tree_util import tree_map, tree_leaves
import jax
import haiku as hk
import jax.numpy as jnp
from jax import random
import numpy as np

#@jit
def get_noise(params, key):
    noise = {}
    keys = list(random.split(key, num=len(tree_leaves(params))))
    for p in params:

        noise[p] = {}
        noise[p]["w"] = jnp.array(random.normal(keys.pop(0), (list(params[p]["w"].shape))))
        noise[p]["b"] = jnp.array(random.normal(keys.pop(0), (list(params[p]["b"].shape))))

    return noise

class SGLDOptimizer:
    def __init__(self, step_size, noise_scale, key=random.PRNGKey(0)):
        self.step_size = step_size
        self.noise_scale = noise_scale
        self.key = key

    def clip_grads(self, grads, max_norm):
        total_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in tree_leaves(grads)))
        clip_coef = max_norm / (total_norm + 1e-6)  # add small epsilon to prevent division by zero
        clipped_grads = tree_map(lambda g: jnp.where(total_norm < max_norm, g, g * clip_coef), grads)
        return clipped_grads

    def step(self, params, grads):
        keys = random.split(self.key, num=len(tree_leaves(params)))
        noise = get_noise(params, self.key)
        #ns = [random.normal(keys[i], shapes[i]) for i in range(len(shapes))]
        #noise = jax.tree_map(lambda key, shape: random.normal(key, shape), keys, shapes)
        #noise = jax.tree_map(lambda n: n, ns)
        self.key, _ = random.split(self.key)
        return jax.tree_map(lambda p, g, n: p - self.step_size * g + self.noise_scale * n, params, grads, noise)

    def update(self, params, grads):
        clipped_grads = self.clip_grads(grads, max_norm=0.001)
        return self.step(params, grads)

    def update_key(self, key):
        self.key = random.PRNGKey(key)

    def init(self, s):
        return None

In [39]:
from experimental.moons_t import moon_task
import optax
from experimental.breastcancer_t import breast_task
from experimental.forest_t import forest_task

all = []
for step in [4]:
    opts = [SGLDOptimizer(step_size=1, noise_scale=0.5)]
    val_accs = []
    test_accs = []

    for key in [1]:# [42,43,44, 45, 46]:
        temp = []
        test = []
        for idx, opt in enumerate(opts):
            task = forest_task(opt, state=key)
            opt.update_key(key)
            test_acc, vals = task.train(50)
            temp.append(vals)
            test.append(test_acc)

        val_accs.append(temp)
        test_accs.append(test)
        all.append(test)
print(all)

[[-0.16930796 -0.5336697   0.50510466 -0.3733526  -0.34466195  0.8229179
  -0.914063    0.27780005 -0.3927096   0.05986269  0.54167914 -0.113103
   0.50406206 -0.07532346 -0.17175382  0.00976585  0.33459872  0.08726917
   0.83390874 -0.28129542]
 [ 0.02745692  0.5155683  -0.17874728 -0.51364464 -0.9502729   1.2443926
  -0.7738298   0.0233749   0.19563253 -0.26782173  0.6324736   0.02083994
  -0.6476148  -0.6033127  -0.6130086   0.6671368   0.11945832 -0.5239995
   0.29456615 -0.15642594]
 [-0.5066736  -0.66881573  0.6241565  -0.22436388 -0.94914     1.1596916
   0.11991     0.0171384  -0.40513635 -0.07098374 -0.32773316 -0.7045973
   0.55826277 -0.85131466 -0.76413023 -1.5970472  -0.47501498  0.06981955
   0.5514524  -0.5285338 ]
 [-0.45557302 -1.0361345   0.16729447  0.20838869 -0.20008455  0.4492564
  -0.15803963 -0.28258362  0.6466403  -0.505051    0.4839645   0.33382368
   0.60986924  0.5531687  -0.20233342 -0.66703445 -1.3096002  -0.5247959
   0.48440304  0.33828533]
 [-0.02162497

In [2]:
from experimental.moons_t import moon_task
import optax
from experimental.breastcancer_t import breast_task

opts = [optax.sgd(0.005), optax.adam(0.0025), optax.noisy_sgd(0.005, 0.001, 0.75), optax.adamw(0.001)]
opts = [SGLDOptimizer(step_size=1, noise_scale=0.05)]
val_accs = []
test_accs = []

for key in [42,43,44, 45, 46]:
    temp = []
    test = []
    for idx, opt in enumerate(opts):
        task = moon_task(state=key)
        opt.update_key(key)
        test_acc, vals = task.train(opt)
        temp.append(vals)
        test.append(test_acc)

    val_accs.append(temp)
    test_accs.append(test)
print(test_accs)

[[Array(0.866, dtype=float32)], [Array(0.6228, dtype=float32)], [Array(0.5712, dtype=float32)], [Array(0.6192, dtype=float32)], [Array(0.72679996, dtype=float32)]]


In [None]:
import matplotlib.pyplot as plt
import numpy as np

vals_np = np.array(val_accs)


# Compute the column-wise mean
column_means = vals_np.mean(axis=0)
std = vals_np.std(axis=0)

# Create plot
plt.figure(figsize=(6, 6))

# Assuming the x-axis represents the epochs
epochs = np.arange(1, 10001)  # or whatever the number of epochs/steps is

# Assuming the optimizers are as follows
optimizers = ['SGD', 'Adam', 'SGLD', 'AdamW']

for idx, optimizer in enumerate(optimizers):
    plt.plot(epochs, column_means[idx], label=optimizer)
    #plt.fill_between(epochs, column_means[idx] - std[idx], column_means[idx] + std[idx], alpha=0.2)

# for idx, optimizer in enumerate(optimizers):
#     plt.plot(epochs[1000:], column_means[idx][1000:], label=optimizer)
#     plt.fill_between(epochs[1000:], column_means[idx][1000:] - std[idx][1000:], column_means[idx][1000:] + std[idx][1000:], alpha=0.2)

plt.xlabel('Epoch')
plt.ylabel('Average Performance')
plt.title('Circles task optimiser validation performance')
plt.legend()

plt.show()

In [5]:
tests_np = np.array(test_accs)


# Compute the column-wise mean
column_means = tests_np.mean(axis=0)
std = tests_np.std(axis=0)
print(test_accs)
print(column_means)
print(std)

[[Array(0.49319997, dtype=float32)], [Array(0.49319997, dtype=float32)], [Array(0.5068, dtype=float32)], [Array(0.4116, dtype=float32)], [Array(0.49319997, dtype=float32)]]
[0.4796]
[0.03440558]
