# Imports

In [None]:
#!pip install dm-haiku optax
#!pip install jax
#!pip install plotly

In [None]:
import jax
from jax import config
config.update("jax_enable_x64", True)
config.update('jax_default_matmul_precision', 'float32')
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import optax
import haiku as hk
import numpy as np
import jax.numpy as jnp
from typing import Iterable, Iterator, NamedTuple, TypeVar, Any, MutableMapping, Tuple
import time
import math
import datetime
import json
import os
import plotly.graph_objs as go
import plotly.io as pio
import plotly.express as px
pio.renderers.default = 'colab'
from plotly.subplots import make_subplots


In [3]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots

def plot_training(all_metrics):
  fig = make_subplots(rows=2, cols=2, subplot_titles=("Loss", "Accuracy", "L1 Norm", "L2 Norm"), vertical_spacing=0.1)

  color_dict = {'train': 'red', 'eval': 'blue'}
  for i, metric in enumerate(['loss', 'acc']):
    for t in ['train', 'eval']:
      trace = go.Scatter(
        x=[d['step'] for d in all_metrics],
        y=[d[f'{t}_{metric}'] for d in all_metrics],
        mode='lines+markers',
        name=f'{t.capitalize()} {metric.capitalize()}',
        line=dict(color=color_dict[t]),
        yaxis='y1' if metric == 'loss' else 'y2'
      )
      fig.add_trace(trace, row=1, col=i+1)
      if metric == 'loss':
        fig.update_yaxes(type='log', title_text=f'{metric.capitalize()}', row=1, col=i+1)
      else:
        fig.update_yaxes(title_text=f'{metric.capitalize()}', row=1, col=i+1)

  # Plotting L1 and L2 norms
  for i, norm in enumerate(['l1_norm', 'l2_norm']):
    trace = go.Scatter(
      x=[d['step'] for d in all_metrics],
      y=[d[norm] for d in all_metrics],
      mode='lines+markers',
      name=norm.replace('_', ' ').capitalize()
    )
    fig.add_trace(trace, row=2, col=i+1)
    fig.update_yaxes(title_text=norm.replace('_', ' ').capitalize(), row=2, col=i+1)

  fig.update_xaxes(row=1, col=1, range=[0, max([d['step'] for d in all_metrics])])
  fig.update_xaxes(row=1, col=2, range=[0, max([d['step'] for d in all_metrics])])
  fig.update_xaxes(row=2, col=1, range=[0, max([d['step'] for d in all_metrics])])
  fig.update_xaxes(row=2, col=2, range=[0, max([d['step'] for d in all_metrics])])

  fig.update_layout(height=800, hovermode='closest')
  fig.show()


In [4]:
def plot_weights(state):
  key_subkey_array = []
  for key, subdict in state.params.items():
    for subkey, array in subdict.items():
      key_subkey_array.append((key, subkey, array))

  zmin = min([np.min(array) for key, subkey, array in key_subkey_array]).item()
  zmax = max([np.max(array) for key, subkey, array in key_subkey_array]).item()
  zval = max(abs(zmin), zmax)*1

  N = len(key_subkey_array)
  grid_size = math.ceil(math.sqrt(N)) # find the nearest square grid

  fig = make_subplots(rows=grid_size, cols=grid_size, subplot_titles=[f"{key} {subkey}" for key, subkey, array in key_subkey_array], vertical_spacing=.1)

  for idx, (key, subkey, array) in enumerate(key_subkey_array):
      row = idx // grid_size + 1  # Calculate the appropriate row, col placement in grid
      col = idx % grid_size + 1
      trace = go.Heatmap(z=array, zmin=zval*-1, zmax=zval*1, zmid=0, colorscale='RdBu', name=f'{key} {subkey}')
      # trace = go.Heatmap(z=array, zmin=zval*-0.1, zmax=zval*0.1, zmid=0, colorscale='RdBu', name=f'{key} {subkey}')
      fig.add_trace(trace, row=row, col=col)

  fig.update_layout(height=400*grid_size, width=400*grid_size)
  fig.show()

In [5]:
def plot_weight_checkpoints(hidden_w):
  import plotly.graph_objects as go

  fig = go.Figure()
  checkpoints, rows, cols = hidden_w.shape

  for row in range(rows):
    for col in range(cols):
      y_values = hidden_w[:, row, col]
      color = "rgb(000,255,0)" if row < 3 else "rgba(0, 0, 0, 0.1)"
      fig.add_trace(go.Scatter(
        y=y_values,
        mode='lines',
        showlegend=False,
        line=dict(color=color),
      ))

  fig.update_layout(
      title="hidden_w over training",
      xaxis_title="Checkpoint",
      yaxis_title="Value",
  )

  fig.show()

In [6]:
class TrainingState(NamedTuple):
  """Container for the training state."""
  params: hk.Params
  opt_state: optax.OptState
  rng: jax.Array
  step: jax.Array


In [7]:
class NpEncoder(json.JSONEncoder):
  """Save NP as json."""

  def default(self, o):
    if isinstance(o, np.integer):
      return int(o)
    if isinstance(o, np.floating):
      return float(o)
    if isinstance(o, np.ndarray):
      return o.tolist()

    if isinstance(o, jnp.integer):
      return int(o)
    if isinstance(o, jnp.floating):
      return float(o)
    if isinstance(o, jnp.ndarray):
      return o.tolist()

    return super(NpEncoder, self).default(o)


# Hyper parameters

[Hidden Progress in Deep Learning](https://arxiv.org/pdf/2207.08799.pdf) // section 3.1, page 6

In [8]:
hyper = {
  'task': 'sparse_parity',
  'n': 30,
  'k': 3,
  'train_size': 900,
  'test_size': 1000,

  'hidden_size': int(32),
  'loss_fn': 'cross_entropy', # ['hinge', 'cross_entropy'],

  'optimizer': 'adam', # ['sgd', 'adam', 'adamw]
  'regularization': 'l1', # ['l1', 'l2', 'none],
  'weight_decay' : 2e-05,

  # 'optimizer': 'adamw', # ['sgd', 'adam', 'adamw]
  # 'regularization': 'l2', # ['l1', 'l2', 'none],
  # 'weight_decay' : .1,

  'w_init_scale': 2,
  'learning_rate': 0.003,
  'warm_up_steps': 1,
  'b1': .99,
  'b2': .98,

  # 'batch_size': 512,
  'log_every': 16,
  'save_every': 16,
  'max_steps': 4000,
  'seed': 5,
  'sweep_slug': 'sparse_parity'
}

# hyper['log_every'] = int(hyper['max_steps']/250)
# hyper['save_every'] = int(hyper['max_steps']/250)

# Model creation

In [9]:
unique_binary_strings = set()
while len(unique_binary_strings) < hyper['train_size'] + hyper['test_size']:
  binary_string = tuple(np.random.randint(2, size=hyper['n']))
  unique_binary_strings.add(binary_string)

inputs = np.array(list(unique_binary_strings), dtype=np.float32)
outputs = np.sum(inputs[:, :hyper['k']], axis=-1) % 2

# inputs = np.where(inputs==0, -1, inputs) # should we remap to 0?
# add a column of ones
ones_column = np.ones((inputs.shape[0], 1), dtype=np.float32)
inputs = np.concatenate((inputs, ones_column), axis=1)

indices = np.random.permutation(len(inputs))
split_idx = int(hyper['train_size'])
train_batch = inputs[indices[:split_idx]], outputs[indices[:split_idx]]
eval_batch = inputs[indices[split_idx:]], outputs[indices[split_idx:]]

In [10]:
  def l1_regularizer(weight_decay):
    def init_fn(state):
      return state

    def update_fn(updates, state, params=None):
      updates = jax.tree_map(
          lambda g, p: g + weight_decay * jnp.sign(p), updates, params
      )
      return updates, state

    return optax.GradientTransformation(init_fn, update_fn)

In [11]:
def l2_regularizer(weight_decay):
  def init_fn(state):
    return state

  def update_fn(updates, state, params=None):
    updates = jax.tree_map(
      lambda g, p: g + weight_decay * p, updates, params
    )
    return updates, state

  return optax.GradientTransformation(init_fn, update_fn)

In [12]:
learning_rate = optax.linear_schedule(0, hyper['learning_rate'], hyper['warm_up_steps'])

if hyper['optimizer'] == 'sgd':
  optimiser_base = optax.sgd(learning_rate)
elif hyper['optimizer'] == 'adam':
  optimiser_base = optax.adam(
      learning_rate, b1=hyper['b1'], b2=hyper['b2']
  )
elif hyper['optimizer'] == 'adamw':
  optimiser_base = optax.adamw(
      learning_rate=learning_rate,
      weight_decay=hyper['weight_decay'],
      b1=hyper['b1'],
      b2=hyper['b2'],
  )
else:
  raise ValueError(f"Unknown optimizer {hyper['optimizer']}.")

if hyper['regularization'] == 'l1':
  if hyper['optimizer'] == 'adamw':
    raise ValueError('Only supporting l1 with adam or sgd optimizer.')
  optimiser = optax.chain(
      l1_regularizer(hyper['weight_decay']),
      optimiser_base,
  )
elif hyper['regularization'] == 'l2':
  if hyper['optimizer'] == 'sgd':
    optimiser = optax.chain(
      l2_regularizer(hyper['weight_decay']),
      optimiser_base,
    )
  elif hyper['optimizer'] == 'adam':
    raise ValueError('Only supporting l2 with adamw optimizer.')
  elif hyper['optimizer'] == 'adamw':
    optimiser = optimiser_base
elif hyper['regularization'] == 'none':
  optimiser = optimiser_base
else:
  raise ValueError(f"Unknown regularization {hyper['regularization']}.")



In [13]:
def forward(inputs):
  w_init = hk.initializers.VarianceScaling(scale=hyper['w_init_scale'], mode='fan_in', distribution='truncated_normal')

  net = hk.Sequential([
    hk.Linear(hyper['hidden_size'], name='hidden', w_init=w_init, with_bias=False),
    jax.nn.relu,
    hk.Linear(1, name='out', with_bias=False),
  ])
  return net(inputs)[:, 0]

@hk.transform
def acc_fn(batch):
  inputs, targets = batch
  outputs = forward(inputs)

  return jnp.mean(jnp.sign(outputs*2 - 1) == jnp.sign(targets*2 - 1))

@hk.transform
def loss_fn(batch):
  inputs, targets = batch
  outputs = forward(inputs)

  if hyper['loss_fn'] == 'cross_entropy':
    epsilon = 1e-7
    outputs_sigmoid = jnp.clip(jax.nn.sigmoid(outputs), epsilon, 1. - epsilon)
    loss = -targets*jnp.log(outputs_sigmoid) - (1 - targets)*jnp.log(1 - outputs_sigmoid)

  if hyper['loss_fn'] == 'hinge':
    targets = 2*targets - 1
    loss = jnp.maximum(0, 1 - targets*outputs)

  return jnp.mean(loss)

@jax.jit
def update(state, batch):
  rng, new_rng = jax.random.split(state.rng)
  loss_and_grad_fn = jax.value_and_grad(loss_fn.apply)
  loss, gradients = loss_and_grad_fn(state.params, rng, batch)

  updates, new_opt_state = optimiser.update(gradients, state.opt_state, state.params)
  new_params = optax.apply_updates(state.params, updates)

  new_state = TrainingState(
    params=new_params,
    opt_state=new_opt_state,
    rng=new_rng,
    step=state.step + 1,
  )
  metrics = {'step': state.step, 'train_loss': loss}
  return new_state, metrics

@jax.jit
def init(rng, batch):
  rng, init_rng = jax.random.split(rng)
  initial_params = loss_fn.init(init_rng, batch)
  return TrainingState(
    params=initial_params,
    opt_state=optimiser.init(initial_params),
    rng=rng,
    step=np.array(0),
  )

In [14]:
# initialise model parameters
state = init(jax.random.PRNGKey(hyper['seed']), train_batch)

all_metrics = []
saved_checkpoints = []
prev_time = time.time()

# Training

In [15]:
for step in range(hyper['max_steps']):
  state, metrics = update(state, train_batch)
  if step % hyper['save_every'] == 0:
    saved_checkpoints.append({'step': step, 'state': state})
  if step % hyper['log_every'] == 0:
    steps_per_sec = hyper['log_every'] / (time.time() - prev_time+0.0001)
    prev_time = time.time()

    l1_norm = 0
    l2_norm = 0
    for param in jax.tree_util.tree_leaves(state.params):
      l1_norm += jnp.sum(jnp.abs(param))
      l2_norm += jnp.sum(jnp.square(param))
    l2_norm = jnp.sqrt(l2_norm)

    metrics |= {
      'train_acc': acc_fn.apply(state.params, state.rng, train_batch),
      'eval_acc': acc_fn.apply(state.params, state.rng, eval_batch),
      'eval_loss': loss_fn.apply(state.params, state.rng, eval_batch),
      'l1_norm': l1_norm,
      'l2_norm': l2_norm,
      'steps_per_sec': steps_per_sec,
    }
    all_metrics.append(metrics)

    print({k: (v.item() if hasattr(v, 'item') else v) for k, v in metrics.items()})



jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).



{'step': 0, 'train_loss': 0.7091265320777893, 'train_acc': 0.4788888990879059, 'eval_acc': 0.4870000183582306, 'eval_loss': 0.7143385410308838, 'l1_norm': 209.67837524414062, 'l2_norm': 8.033172607421875, 'steps_per_sec': 63.430322033878674}
{'step': 16, 'train_loss': 0.6830173134803772, 'train_acc': 0.4933333396911621, 'eval_acc': 0.4870000183582306, 'eval_loss': 0.7022015452384949, 'l1_norm': 207.8489532470703, 'l2_norm': 7.976490020751953, 'steps_per_sec': 29.86813074407524}
{'step': 32, 'train_loss': 0.6706527471542358, 'train_acc': 0.4888888895511627, 'eval_acc': 0.5, 'eval_loss': 0.6998621225357056, 'l1_norm': 208.2666473388672, 'l2_norm': 8.014400482177734, 'steps_per_sec': 160000.0}
{'step': 48, 'train_loss': 0.6606603264808655, 'train_acc': 0.5488889217376709, 'eval_acc': 0.49900001287460327, 'eval_loss': 0.7070659399032593, 'l1_norm': 211.1100616455078, 'l2_norm': 8.132970809936523, 'steps_per_sec': 1010.4104238275379}
{'step': 64, 'train_loss': 0.6497780084609985, 'train_acc

# Plot training

In [26]:
plot_training(all_metrics)

In [25]:
plot_weights(state)

In [18]:
out_w = np.asarray([d['state'].params['out']['w'] for d in saved_checkpoints])
hidden_w = np.asarray([d['state'].params['hidden']['w'] for d in saved_checkpoints])

In [19]:
plot_weight_checkpoints(hidden_w)

# Export model

In [20]:
workdir = 'sparse_parity/'

In [21]:
def save_model(hyper, all_metrics, saved_checkpoints, train_batch):
  """Save model checkpoints to CNS."""
  sweep_str = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
  ckpt_dir = workdir + 'sweeps/' + hyper['sweep_slug'] + '/' + sweep_str + '/'

  os.makedirs(os.path.dirname(ckpt_dir))

  with open(os.path.join(ckpt_dir, 'hyper.json'), 'w') as f:
    f.write(json.dumps(hyper))

  with open(os.path.join(ckpt_dir, 'metrics.json'), 'w') as f:
    f.write(json.dumps(all_metrics, cls=NpEncoder))

  with open(os.path.join(ckpt_dir, 'train_batch.npy'), 'wb') as f:
    np.save(f, train_batch[0])

  with open(os.path.join(ckpt_dir, f'out_t_w.npy'), 'wb') as f:
    np.save(f, np.asarray(out_w.transpose([0, 2, 1])))

  key_subkey_array = []
  for key, subdict in saved_checkpoints[0]['state'].params.items():
    for subkey, array in subdict.items():
      key_subkey_array.append((key, subkey))

  for key, subkey in key_subkey_array:
    slug = (key + '_' +subkey).replace('~', '')
    array = [d['state'].params[key][subkey] for d in saved_checkpoints]
    with open(os.path.join(ckpt_dir, f'{slug}.npy'), 'wb') as f:
      np.save(f, np.asarray(array))

  print(ckpt_dir)
