<a href="https://colab.research.google.com/github/Gr3gP/Misc-Projects/blob/main/Parallel_NN_with_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Parallelizing neural networks with JAX

http://willwhitney.com/parallel-training-jax.html

In [1]:
!pip install --upgrade -q git+https://github.com/google/flax.git

  Building wheel for flax (setup.py) ... [?25l[?25hdone


In [2]:
import jax
from jax import numpy as jnp, random as jr
from flax import linen as nn, optim
import numpy as np
import pandas as pd
import altair as alt
import time

##JAX Basics

In [3]:
#scaler function and its gradient
def x_squared(x):
    return x**2
g = jax.grad(x_squared)
x = 2.
print(f'x_squared({x}): {x_squared(x)}')
print(f'd/dx x_squared({x}): {g(x)}')

#use vmap to vectorize the functions
v_x_squared = jax.vmap(x_squared)
v_g = jax.vmap(g)
xs = jnp.linspace(-5, 5, 100)
ys = v_x_squared(xs)
gs = v_g(xs)

x_squared(2.0): 4.0
d/dx x_squared(2.0): 4.0


In [4]:
#stick results into a dataframe and plot
df = pd.concat([
    pd.DataFrame({'x': xs, 'y': ys, 'kind': 'value'}),  #the values f(x)
    pd.DataFrame({'x': xs, 'y': gs, 'kind': 'grad'}),  #the gradients g(x)
])
chart = alt.Chart(df).mark_line(size=3).encode(x='x', y='y', color='kind')
chart

##Two Spirals dataset

In [5]:
def make_spirals(n_samples, noise_std=0., rotations=1.):
    ts = jnp.linspace(0, 1, n_samples)
    rs = ts ** 0.5
    thetas = rs * rotations * 2 * np.pi
    signs = np.random.randint(0, 2, (n_samples,)) * 2 - 1
    labels = (signs > 0).astype(int)

    xs = rs * signs * jnp.cos(thetas) + np.random.randn(n_samples) * noise_std
    ys = rs * signs * jnp.sin(thetas) + np.random.randn(n_samples) * noise_std
    points = jnp.stack([xs, ys], axis=1)
    return points, labels

In [6]:
points, labels = make_spirals(100, noise_std=0.05)
df = pd.DataFrame({'x': points[:, 0], 'y': points[:, 1], 'label': labels})

spirals_x_axis = alt.X('x', scale=alt.Scale(domain=[-1.5, 1.5], nice=False))
spirals_y_axis = alt.Y('y', scale=alt.Scale(domain=[-1.5, 1.5], nice=False))

spiral_chart = alt.Chart(df, width=350, height=300).mark_circle(stroke='white', size=80, opacity=1).encode(
    x=spirals_x_axis, y=spirals_y_axis,
    color=alt.Color('label:N'))
spiral_chart.save('two_spirals.html')
spiral_chart

##A Simple Classifier

In [7]:
class MLPclassifier(nn.Module):
    hidden_layers: int = 2
    hidden_dim: int = 512
    n_classes: int = 2

    @nn.compact
    def __call__(self, x):
        for layer in range(self.hidden_layers):
            x = nn.Dense(self.hidden_dim)(x)
            x = nn.relu(x)
        x = nn.Dense(self.n_classes)(x)
        x = nn.log_softmax(x)
        return x

##Helper Fuinctions for initializing and training

In [8]:
# Somewhat confusingly, instantiating a Flax module gives you an object
# which contains functions, NOT state
classifier_fns = MLPclassifier()

def cross_entropy(logprobs, labels):
    one_hot_labels = jax.nn.one_hot(labels, logprobs.shape[1])
    return -jnp.mean(jnp.sum(one_hot_labels * logprobs, axis=-1))

def loss_fn(params, batch):
    logits = classifier_fns.apply({'params': params}, batch[0])
    loss = jnp.mean(cross_entropy(logits, batch[1]))
    return loss

loss_and_grad_fn = jax.value_and_grad(loss_fn)

##API for Classifier

In [9]:
def init_fn(input_shape, seed):
    rng = jr.PRNGKey(jnp.array(seed, int))
    dummy_input = jnp.ones((1, *input_shape))
    params = classifier_fns.init(rng, dummy_input)['params']
    optimizer_def = optim.Adam(learning_rate=1e-3)
    optimizer = optimizer_def.create(params)
    return optimizer

@jax.jit  # jit makes it go brrr
def train_step_fn(optimizer, batch):
    loss = loss_fn(optimizer.target, batch)
    loss, grad = loss_and_grad_fn(optimizer.target, batch)
    optimizer = optimizer.apply_gradient(grad)
    return optimizer, loss

@jax.jit  # jit makes it go brrr
def predict_fn(optimizer, x):
    x = jnp.array(x)
    return classifier_fns.apply({'params': optimizer.target}, x)

##Running the Network

In [10]:
model_state = init_fn(input_shape=(2,), seed=0)
for i in range(100):
    model_state, loss = train_step_fn(model_state, (points, labels))
print(loss)

0.011014481


In [11]:
def all_preds(model_state):
    grid_size = 50
    width = 1.5
    x0s, x1s = jnp.meshgrid(jnp.linspace(-width, width, grid_size),
                            jnp.linspace(-width, width, grid_size))
    xs = jnp.stack([x0s, x1s]).transpose().reshape((-1,2))
    preds = predict_fn(model_state, xs)
    return xs, preds

xs, preds = all_preds(model_state)

In [12]:
data = {'x': xs[:, 0], 'y': xs[:,1], 'pred': jnp.exp(preds)[:,1]}
df = pd.DataFrame(data)
pred_chart = alt.Chart(df, width=240, height=240, title='Predictions from MLP').mark_square(size=50, opacity=1).encode(
    x=spirals_x_axis,
    y=spirals_y_axis,
    color=alt.Color('pred', scale=alt.Scale(scheme='blueorange')),
)
chart = pred_chart + spiral_chart
chart.save('mlp_pred.html')
chart

##Parallelizing Training

In [13]:
parallel_init_fn = jax.vmap(init_fn, in_axes=(None, 0))
parallel_train_step_fn = jax.vmap(train_step_fn, in_axes=(0, None))

N = 100
seeds = jnp.linspace(0, N -1, N)


model_states = parallel_init_fn((2,), seeds)
for i in range(100):
    model_states, losses = parallel_train_step_fn(model_states, (points, labels))
print(losses)

[0.01101448 0.00828065 0.01413962 0.01402676 0.0120001  0.01741043
 0.00512169 0.00512169 0.0082025  0.00954386 0.01202601 0.0164557
 0.01076173 0.01197245 0.01197245 0.00991478 0.01067753 0.02455184
 0.00819156 0.01156905 0.01042922 0.0081734  0.01310009 0.01529821
 0.01352658 0.01352658 0.00870252 0.01645845 0.01645845 0.01423884
 0.01187097 0.01187097 0.02138541 0.01279157 0.01054461 0.01624773
 0.0122004  0.01287798 0.01047612 0.01744691 0.01001278 0.0117713
 0.01385813 0.00918852 0.0077262  0.00844093 0.01741947 0.02381307
 0.00805076 0.0104053  0.0104053  0.01342507 0.00727525 0.01403918
 0.00982792 0.01512396 0.01512396 0.01597394 0.00459991 0.01817561
 0.01237269 0.0066689  0.0066689  0.00900659 0.00766002 0.01004525
 0.01449491 0.01713726 0.01207718 0.0067038  0.00864453 0.0096061
 0.00977924 0.01745882 0.01134077 0.00939371 0.00814393 0.01109007
 0.01262139 0.00896234 0.01081684 0.01036016 0.00874946 0.01414874
 0.00842146 0.0081682  0.01597267 0.0085169  0.0096699  0.0098504

Plotting each networks predictions

In [14]:
parallel_all_preds = jax.vmap(all_preds)
xs, batched_preds = parallel_all_preds(model_states)
xs = xs[0]

In [15]:
charts = []
for preds in batched_preds:
    data = {'x': xs[:, 0], 'y': xs[:, 1], 'pred': jnp.exp(preds)[:, 1]}
    df = pd.DataFrame(data)
    single_chart = alt.Chart(df, width=240, height=240).mark_square(size=50, opacity=1).encode(
        x=spirals_x_axis, 
        y=spirals_y_axis,
        color=alt.Color('pred', scale=alt.Scale(scheme='blueorange')),
    )
    chart = single_chart + spiral_chart
    charts.append(chart)
chart = alt.hconcat(*charts[:2])
chart.save('multi_mlp_pred.html')

# look how super similar these are!
chart

##Bootstrapped ensembles

In [16]:
def get_first_seed(dataset_index):
    return jr.split(jr.PRNGKey(dataset_index))[0,0]

get_first_seed(0)

DeviceArray(4146024105, dtype=uint32)

In [17]:
@jax.jit
def get_example(data_x, data_y, dataset_index, i):
    '''gets example i from the bootstrapped dataset with index dataset_index.'''
    first_seed = get_first_seed(dataset_index)
    dataset_size = data_x.shape[0]

    #only use dataset_size distinct seeds
    #this ensures out bootstrap=sampled dataset indludes exactly
    #dataset_size points
    i = i % dataset_size

    point_seed = first_seed + i
    point_index = jr.randint(jr.PRNGKey(point_seed), shape=(),
                             minval=0, maxval=dataset_size)
    x_i = jax.lax.dynamic_index_in_dim(data_x, point_index,
                                       keepdims=False)
    y_i = jax.lax.dynamic_index_in_dim(data_y, point_index,
                                       keepdims=False)
    return x_i, y_i

get_example(points, labels, 0, 0)

(DeviceArray([ 0.6852453, -0.6219504], dtype=float32),
 DeviceArray(1, dtype=int32))

In [18]:
def bootstrap_multi_iterator(dataset, dataset_indices):
    """Creates an iterator which, at each step, returns a batch of batches.

    The kth batch is sampled from the bootstrapped resample of `dataset`
    with seed `seeds[k]`."""
    
    batch_size = 32
    dataset_indices = jnp.array(dataset_indices)
    data_x, data_y = dataset
    dataset_size = len(data_x)
    
    get_example_from_dataset = jax.partial(get_example, data_x, data_y)
    
    # for sampling a batch of data from one dataset
    get_batch = jax.vmap(get_example_from_dataset, in_axes=(None, 0))
    # for sampling a batch of data from _each_ dataset
    get_multibatch = jax.vmap(get_batch, in_axes=(0, None))

    def iterate_multibatch():
        """Construct an iterator which runs forever, at each step returning
        a batch of batches."""
        i = 0
        while True:
            indices = jnp.arange(i, i + batch_size, dtype=jnp.int32)
            yield get_multibatch(dataset_indices, indices)
            i += batch_size

    loader_iter = iterate_multibatch()
    return loader_iter

In [19]:
# same as before
parallel_init_fn = jax.vmap(init_fn, in_axes=(None, 0))
# vmap over both inputs now
bootstrap_train_step_fn = jax.vmap(train_step_fn, in_axes=(0, 0))

# make seeds 0 to N-1, which we use for initializing the network and bootstrapping
N = 100
seeds = jnp.linspace(0, N - 1, N).astype(jnp.int32)

model_states = parallel_init_fn((2,), seeds)
data_iterator = bootstrap_multi_iterator((points, labels), dataset_indices=seeds)
for i in range(100):
    x_batch, y_batch = next(data_iterator)
    model_states, losses = bootstrap_train_step_fn(model_states, (x_batch, y_batch))
print(losses)

[0.17217566 0.07021429 0.22784956 0.15737695 0.28838077 0.16491456
 0.12017149 0.12017149 0.04706799 0.07939906 0.19745559 0.09936184
 0.14735948 0.15642619 0.15642619 0.2030158  0.16241509 0.2560954
 0.1071447  0.11495585 0.15770862 0.14941296 0.11878432 0.13803345
 0.1446739  0.1446739  0.13501322 0.18519846 0.18519846 0.17757086
 0.19745876 0.19745876 0.12189123 0.12803112 0.16965379 0.25821382
 0.1898676  0.22086965 0.13562468 0.12363767 0.16972274 0.13059716
 0.15614039 0.11135425 0.13036603 0.16665283 0.12970185 0.19811992
 0.08771882 0.06234949 0.06234949 0.11147568 0.10417018 0.13134733
 0.09851369 0.11364654 0.11364654 0.06340489 0.20267308 0.10697231
 0.13302371 0.10823582 0.10823582 0.13198532 0.05534911 0.15741289
 0.13875891 0.21109484 0.159263   0.13070336 0.15960759 0.15980692
 0.10037327 0.19786653 0.16268903 0.20332766 0.17046006 0.14609909
 0.18178934 0.14673467 0.12522048 0.1099593  0.10829762 0.11537246
 0.12960209 0.19647539 0.13700926 0.21169646 0.21124282 0.12766

In [20]:
xs, batched_preds = parallel_all_preds(model_states)
xs = xs[0]
charts = []
for preds in batched_preds:
    data = {'x': xs[:, 0], 'y': xs[:, 1], 'pred': jnp.exp(preds)[:, 1]}
    df = pd.DataFrame(data)
    bootstrap_pred_chart = alt.Chart(df, width=240, height=240).mark_square(size=50, opacity=1).encode(
        x=spirals_x_axis,
        y=spirals_y_axis,
        color=alt.Color('pred', scale=alt.Scale(scheme='blueorange')),
    )
    chart = bootstrap_pred_chart + spiral_chart
    charts.append(chart)
chart = alt.hconcat(*charts[:2])
chart.save('cootstrap_mlp_pred.html')

#more varied
chart

In [22]:
# ensemble predictions across our models
batched_probs = jnp.exp(batched_preds)
bootstrapped_probs = jnp.mean(batched_probs, axis=0)

data = {'x': xs[:, 0], 'y': xs[:, 1], 'pred': bootstrapped_probs[:, 1]}
df = pd.DataFrame(data)
ensemble_chart = alt.Chart(df, width=240, height=240, title="Predictions from bootstrap").mark_square(size=50, opacity=1).encode(
    x=spirals_x_axis, 
    y=spirals_y_axis,
    color=alt.Color('pred', scale=alt.Scale(scheme='blueorange', domain=[0, 1])),
)
chart = ensemble_chart + spiral_chart
chart.save('ensemble_mlp_pred.html')
chart

##Model Comparison with a single model trained on whole dataset

In [24]:
bootstrap_compare_chart = (pred_chart + spiral_chart) | (ensemble_chart + spiral_chart)
bootstrap_compare_chart.save('bootstrap_compare_pred.html')
bootstrap_compare_chart