In [57]:
from time import time

import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import optax
import torch
import torch.nn as nn
from flax import linen
from jax.lib import xla_bridge
from turbanet import TurbaTrainState
import altair as alt
import pandas as pd


# Inputs

In [58]:
# GENERAL INPUTS
GPU = False

# NETWORK SHAPE INPUTS
hidden_size = 10
num_layers = 4

# TRAINING INPUTS
swarm_size = 100
epochs = 10000
lr = 1e-3
dataset_size = 100

In [59]:
# Set numpy/torch/flax seeds
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x1ed4a7341f0>

# Create Data

In [60]:
# Create random data
def make_spirals(n_samples, noise_std=0.0, rotations=1.0):
    ts = jnp.linspace(0, 1, n_samples)
    rs = ts**0.5
    thetas = rs * rotations * 2 * np.pi
    signs = np.random.randint(0, 2, (n_samples,)) * 2 - 1
    labels = (signs > 0).astype(int)

    xs = rs * signs * jnp.cos(thetas) + np.random.randn(n_samples) * noise_std
    ys = rs * signs * jnp.sin(thetas) + np.random.randn(n_samples) * noise_std
    points = jnp.stack([xs, ys], axis=1)
    return points, labels

In [61]:
def pred_grid():
    grid_size = 50
    width = 1.5
    x0s, x1s = jnp.meshgrid(
        np.linspace(-width, width, grid_size), np.linspace(-width, width, grid_size)
    )
    xs = np.stack([x0s, x1s]).transpose().reshape((-1, 2))

    return xs

In [62]:
def true_plot(xs, y):
    df = pd.DataFrame({"x": xs[:, 0], "y": xs[:, 1], "label": y})

    spirals_x_axis = alt.X("x", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))
    spirals_y_axis = alt.Y("y", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))

    spiral_chart = (
        alt.Chart(df, width=350, height=300)
        .mark_circle(stroke="white", size=80, opacity=1)
        .encode(x=spirals_x_axis, y=spirals_y_axis, color=alt.Color("label:N"))
    )

    return spiral_chart

In [63]:
def prediction_plot(xs, y):
    data = {"x": xs[:, 0], "y": xs[:, 1], "pred": np.exp(y)[:, 1]}
    df = pd.DataFrame(data)
    spirals_x_axis = alt.X("x", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))
    spirals_y_axis = alt.Y("y", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))
    pred_chart = (
        alt.Chart(df, width=350, height=300, title="Predictions from MLP")
        .mark_square(size=50, opacity=1)
        .encode(
            x=spirals_x_axis,
            y=spirals_y_axis,
            color=alt.Color("pred", scale=alt.Scale(scheme="blueorange")),
        )
    )

    return pred_chart

In [64]:
points, labels = make_spirals(dataset_size, noise_std=0.05)

In [65]:
spiral_chart = true_plot(points, labels)
spiral_chart


# Torch

In [180]:
class TorchModel(nn.Module):
    def __init__(self, hidden_size: int, num_layers: int):
        super(TorchModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.stack = nn.Sequential(
            nn.Linear(2, hidden_size),
            nn.ReLU(),
            *(nn.Linear(hidden_size, hidden_size), nn.ReLU()) * (num_layers - 1),
            nn.Linear(hidden_size, 2),
            nn.LogSoftmax(dim=1),
        )

    def forward(self, x):
        return self.stack(x)


## Create torch model

In [181]:
torch_models = [TorchModel(hidden_size, num_layers) for _ in range(swarm_size)]
torch_optimizers = [
    torch.optim.Adam(torch_model.parameters(), lr=lr) for torch_model in torch_models
]

In [182]:
# Set torch to use GPU if available
device = torch.device("cpu")
if GPU and torch.cuda.is_available():
    device = torch.device("cuda")
    for torch_model in torch_models:
        torch_model.to(device)

In [183]:
# Convert to torch tensors
X_train_torch = torch.from_numpy(np.array(points)).float()
y_train_torch = torch.from_numpy(np.array(labels))

In [184]:
# Move to GPU if available
if GPU:
    X_train_torch = X_train_torch.to(device)
    y_train_torch = y_train_torch.to(device)

## Train torch model

In [185]:
start = time()
for torch_model, torch_optimizer in zip(torch_models, torch_optimizers):
    torch_model.train()
    for epoch in range(epochs):
        torch_optimizer.zero_grad()
        y_pred = torch_model(X_train_torch)
        loss = torch.nn.functional.cross_entropy(y_pred, y_train_torch)
        loss.backward()
        torch_optimizer.step()

        if epoch % 100 == 0:
            print(f"Epoch {epoch} loss: {loss.item()}")

print(f"torch time: {time() - start}")


Epoch 0 loss: 0.6887051463127136
Epoch 100 loss: 0.5867904424667358
Epoch 200 loss: 0.4025283455848694
Epoch 300 loss: 0.3540712296962738
Epoch 400 loss: 0.33100658655166626
Epoch 500 loss: 0.2858526110649109
Epoch 600 loss: 0.1611620932817459
Epoch 700 loss: 0.11111099272966385
Epoch 800 loss: 0.09106907248497009
Epoch 900 loss: 0.061246857047080994
Epoch 1000 loss: 0.04531754553318024
Epoch 1100 loss: 0.034055355936288834
Epoch 1200 loss: 0.024046994745731354
Epoch 1300 loss: 0.011068102903664112
Epoch 1400 loss: 0.00511524872854352
Epoch 1500 loss: 0.0029998261015862226
Epoch 1600 loss: 0.0019837573636323214
Epoch 1700 loss: 0.0012865954777225852
Epoch 1800 loss: 0.0009132098639383912
Epoch 1900 loss: 0.0007132573518902063
Epoch 2000 loss: 0.0005803103558719158
Epoch 2100 loss: 0.00047619157703593373
Epoch 2200 loss: 0.0003966972872149199
Epoch 2300 loss: 0.00033582383184693754
Epoch 2400 loss: 0.00028986940742470324
Epoch 2500 loss: 0.000251501303864643
Epoch 2600 loss: 0.000220759

## Visualize torch predictions

In [186]:
xs = pred_grid()

In [187]:
# Take average of predictions across all models
y = np.zeros((len(torch_models), xs.shape[0], 2))
for idx, model in enumerate(torch_models):
    model.eval()
    with torch.no_grad():
        y_pred = model(torch.from_numpy(xs).float().to(device))
        y[idx] = y_pred.cpu().numpy()

y = np.mean(y, axis=0)

In [188]:
# Plot predictions
pred_chart = prediction_plot(xs, y)
chart = pred_chart + spiral_chart
chart

# Turba

In [189]:
def cross_entropy_turba(params, input, output, apply_fn):
    log_probs = apply_fn({"params": params}, input)
    labels = jax.nn.one_hot(output, log_probs.shape[1])
    loss = -jnp.mean(jnp.sum(labels * log_probs, axis=1))
    return loss, labels

In [190]:
class JaxModel(linen.Module):
    hidden_layers: int = 1
    hidden_dim: int = 32

    @linen.compact
    def __call__(self, x):
        for layer in range(self.hidden_layers):
            x = linen.Dense(self.hidden_dim)(x)
            x = linen.relu(x)
        x = linen.Dense(2)(x)
        x = linen.log_softmax(x)
        return x


## Create Turba model

In [191]:
turba_model = JaxModel(hidden_layers=num_layers, hidden_dim=hidden_size)
turba_state = TurbaTrainState.swarm(turba_model, swarm_size, 2, learning_rate=lr)

In [192]:
# Set Turba to use GPU if available
if GPU and xla_bridge.get_backend().platform != "gpu":
    raise RuntimeError("GPU support not available for Turba.")

## Train Turba model

In [193]:
# Convert to jnp arrays
X_train_turba = jnp.array(
    np.expand_dims(points, axis=0).repeat(swarm_size, axis=0), dtype=jnp.float32
)
y_train_turba = jnp.array(np.expand_dims(labels, axis=0).repeat(swarm_size, axis=0))

In [194]:
start = time()
for epoch in range(epochs):
    turba_state, loss, _ = turba_state.train(X_train_turba, y_train_turba, cross_entropy_turba)

    if epoch % 100 == 0:
        print(f"Epoch {epoch} loss: {loss}")

print(f"turba time: {time() - start}")

Epoch 0 loss: [0.67958844 0.7082458  0.6731628  0.70743454 0.7004765  0.7207977
 0.683908   0.6559397  0.6970501  0.70192975 0.7222053  0.6991862
 0.69488704 0.6659269  0.6926319  0.70674866 0.6979809  0.7010445
 0.6965111  0.7517106  0.68101907 0.73060495 0.68738884 0.6700931
 0.7136264  0.6672694  0.7059455  0.69047624 0.6861186  0.69844496
 0.6898758  0.7208336  0.70387554 0.7011206  0.68034416 0.67961454
 0.65429574 0.7173012  0.7234945  0.82181656 0.7025114  0.665722
 0.7032417  0.68895733 0.6827858  0.69404256 0.720831   0.6792034
 0.69312423 0.7473461  0.679561   0.69430846 0.69439054 0.679148
 0.69200844 0.71456593 0.6543778  0.7239205  0.7248822  0.70320815
 0.6419729  0.654889   0.67512137 0.6937935  0.690005   0.69489944
 0.71887356 0.69426847 0.6841061  0.6844227  0.66184133 0.7150691
 0.67832303 0.6845207  0.6737011  0.6832448  0.7223361  0.66230226
 0.7093251  0.7091563  0.6802206  0.6846037  0.7245321  0.691593
 0.6862986  0.6680152  0.69452107 0.67552334 0.7092732  0.74

## Visualize turba predictions

In [195]:
xs = pred_grid()
xs = jnp.array(np.expand_dims(xs, axis=0).repeat(swarm_size, axis=0), dtype=jnp.float32)

In [196]:
# Take average of predictions across all models
y = turba_state.predict(xs)
y = np.mean(y, axis=0)

In [197]:
# Convert to numpy
xs = np.array(xs[0])
y = np.array(y)

In [198]:
# Plot predictions
pred_chart = prediction_plot(xs, y)
chart = pred_chart + spiral_chart
chart

# Jax

In [199]:
# Instantiate model functions
classifier_fns = JaxModel(hidden_layers=num_layers, hidden_dim=hidden_size)

In [200]:
# Define optimizer
optimizer = optax.adam(learning_rate=lr)

In [201]:
# Define loss function
def cross_entropy(logprobs, labels):
    one_hot_labels = jax.nn.one_hot(labels, logprobs.shape[1])
    return -jnp.mean(jnp.sum(one_hot_labels * logprobs, axis=-1))

In [202]:
# Create wrapper loss function and gradient
def loss_fn(params, batch):
    logits = classifier_fns.apply({"params": params}, batch[0])
    loss = jnp.mean(cross_entropy(logits, batch[1]))
    return loss


loss_and_grad_fn = jax.value_and_grad(loss_fn)

In [203]:
# Define initialization function
def init_fn(input_shape, seed):
    rng = jr.PRNGKey(seed)
    dummy_input = jnp.ones((1, input_shape))
    params = classifier_fns.init(rng, dummy_input)["params"]
    opt_state = optimizer.init(params)
    return params, opt_state

In [204]:
# Define training function
@jax.jit
def train_step_fn(params, opt_state, batch):
    loss, grad = loss_and_grad_fn(params, batch)
    updates, opt_state = optimizer.update(grad, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

In [205]:
# Define prediction function
@jax.jit
def predict_fn(params, x):
    x = jnp.array(x)
    return classifier_fns.apply({"params": params}, x)

In [206]:
# Create vectorized functions
parallel_init_fn = jax.vmap(init_fn, in_axes=(None, 0))
parallel_train_step_fn = jax.vmap(train_step_fn, in_axes=(0, 0, None))
parallel_predict_fn = jax.vmap(predict_fn, in_axes=(0, None))

## Create Jax Model

In [207]:
# Initialize models
seeds = jnp.linspace(0, swarm_size - 1, swarm_size).astype(int)
model_states, opt_states = parallel_init_fn(2, seeds)

## Train Jax model

In [208]:
# Training loop
start = time()
for i in range(epochs):
    model_states, opt_states, _ = parallel_train_step_fn(
        model_states, opt_states, (points, labels)
    )

print(f"jax time: {time() - start}")

jax time: 32.80638933181763


## Visualize jax predictions

In [209]:
# Create prediction grid
xs = jnp.array(pred_grid())

In [210]:
# Take average of predictions across all models
y = parallel_predict_fn(model_states, xs)
y = jnp.mean(y, axis=0)

In [211]:
# Convert to numpy
xs = np.array(xs)
y = np.array(y)

In [212]:
# Plot predictions
pred_chart = prediction_plot(xs, y)
chart = pred_chart + spiral_chart
chart