In [84]:
from time import time

import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import optax
import torch
import torch.nn as nn
from flax import linen
from jax.lib import xla_bridge
from turbanet import TurbaTrainState
import altair as alt
import pandas as pd


# Inputs

In [85]:
# GENERAL INPUTS
GPU = False

# NETWORK SHAPE INPUTS
hidden_size = 10
num_layers = 4

# TRAINING INPUTS
swarm_size = 10
epochs = 10000
lr = 1e-3
dataset_size = 100

In [86]:
# Set numpy/torch/flax seeds
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x1df1a5b8110>

# Create Data

In [87]:
# Create random data
def make_spirals(n_samples, noise_std=0.0, rotations=1.0):
    ts = jnp.linspace(0, 1, n_samples)
    rs = ts**0.5
    thetas = rs * rotations * 2 * np.pi
    signs = np.random.randint(0, 2, (n_samples,)) * 2 - 1
    labels = (signs > 0).astype(int)

    xs = rs * signs * jnp.cos(thetas) + np.random.randn(n_samples) * noise_std
    ys = rs * signs * jnp.sin(thetas) + np.random.randn(n_samples) * noise_std
    points = jnp.stack([xs, ys], axis=1)
    return points, labels

In [88]:
def pred_grid():
    grid_size = 50
    width = 1.5
    x0s, x1s = jnp.meshgrid(
        np.linspace(-width, width, grid_size), np.linspace(-width, width, grid_size)
    )
    xs = np.stack([x0s, x1s]).transpose().reshape((-1, 2))

    return xs

In [89]:
def true_plot(xs, y):
    df = pd.DataFrame({"x": xs[:, 0], "y": xs[:, 1], "label": y})

    spirals_x_axis = alt.X("x", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))
    spirals_y_axis = alt.Y("y", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))

    spiral_chart = (
        alt.Chart(df, width=350, height=300)
        .mark_circle(stroke="white", size=80, opacity=1)
        .encode(x=spirals_x_axis, y=spirals_y_axis, color=alt.Color("label:N"))
    )

    return spiral_chart

In [90]:
def prediction_plot(xs, y):
    data = {"x": xs[:, 0], "y": xs[:, 1], "pred": np.exp(y)[:, 1]}
    df = pd.DataFrame(data)
    spirals_x_axis = alt.X("x", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))
    spirals_y_axis = alt.Y("y", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))
    pred_chart = (
        alt.Chart(df, width=350, height=300, title="Predictions from MLP")
        .mark_square(size=50, opacity=1)
        .encode(
            x=spirals_x_axis,
            y=spirals_y_axis,
            color=alt.Color("pred", scale=alt.Scale(scheme="blueorange")),
        )
    )

    return pred_chart

In [91]:
points, labels = make_spirals(dataset_size, noise_std=0.05)

In [92]:
spiral_chart = true_plot(points, labels)
spiral_chart


# Torch

In [93]:
class TorchModel(nn.Module):
    def __init__(self, hidden_size: int, num_layers: int):
        super(TorchModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.stack = nn.Sequential(
            nn.Linear(2, hidden_size),
            nn.ReLU(),
            *(nn.Linear(hidden_size, hidden_size), nn.ReLU()) * (num_layers - 1),
            nn.Linear(hidden_size, 2),
            nn.LogSoftmax(dim=1),
        )

    def forward(self, x):
        return self.stack(x)


## Create torch model

In [94]:
torch_models = [TorchModel(hidden_size, num_layers) for _ in range(swarm_size)]
torch_optimizers = [
    torch.optim.Adam(torch_model.parameters(), lr=lr) for torch_model in torch_models
]

In [95]:
# Set torch to use GPU if available
device = torch.device("cpu")
if GPU and torch.cuda.is_available():
    device = torch.device("cuda")
    for torch_model in torch_models:
        torch_model.to(device)

In [96]:
# Convert to torch tensors
X_train_torch = torch.from_numpy(np.array(points)).float()
y_train_torch = torch.from_numpy(np.array(labels))

In [97]:
# Move to GPU if available
if GPU:
    X_train_torch = X_train_torch.to(device)
    y_train_torch = y_train_torch.to(device)

## Train torch model

In [98]:
start = time()
k = 0
for torch_model, torch_optimizer in zip(torch_models, torch_optimizers):
    k += 1
    torch_model.train()
    for epoch in range(epochs):
        torch_optimizer.zero_grad()
        y_pred = torch_model(X_train_torch)
        loss = torch.nn.functional.cross_entropy(y_pred, y_train_torch)
        loss.backward()
        torch_optimizer.step()

        if epoch % 100 == 0:
            print(f"Model {k} - Epoch {epoch} Loss: {loss.item()}")

print(f"torch time: {time() - start}")

Model 1 - Epoch 0 Loss: 0.6887051463127136
Model 1 - Epoch 100 Loss: 0.5867904424667358
Model 1 - Epoch 200 Loss: 0.4025283455848694
Model 1 - Epoch 300 Loss: 0.3540712296962738
Model 1 - Epoch 400 Loss: 0.33100658655166626
Model 1 - Epoch 500 Loss: 0.2858526110649109
Model 1 - Epoch 600 Loss: 0.1611620932817459
Model 1 - Epoch 700 Loss: 0.11111099272966385
Model 1 - Epoch 800 Loss: 0.09106907248497009
Model 1 - Epoch 900 Loss: 0.061246857047080994
Model 1 - Epoch 1000 Loss: 0.04531754553318024
Model 1 - Epoch 1100 Loss: 0.034055355936288834
Model 1 - Epoch 1200 Loss: 0.024046994745731354
Model 1 - Epoch 1300 Loss: 0.011068102903664112
Model 1 - Epoch 1400 Loss: 0.00511524872854352
Model 1 - Epoch 1500 Loss: 0.0029998261015862226
Model 1 - Epoch 1600 Loss: 0.0019837573636323214
Model 1 - Epoch 1700 Loss: 0.0012865954777225852
Model 1 - Epoch 1800 Loss: 0.0009132098639383912
Model 1 - Epoch 1900 Loss: 0.0007132573518902063
Model 1 - Epoch 2000 Loss: 0.0005803103558719158
Model 1 - Epoch

## Visualize torch predictions

In [99]:
xs = pred_grid()

In [100]:
# Take average of predictions across all models
y = np.zeros((len(torch_models), xs.shape[0], 2))
for idx, model in enumerate(torch_models):
    model.eval()
    with torch.no_grad():
        y_pred = model(torch.from_numpy(xs).float().to(device))
        y[idx] = y_pred.cpu().numpy()

y = np.mean(y, axis=0)

In [101]:
# Plot predictions
pred_chart = prediction_plot(xs, y)
chart = pred_chart + spiral_chart
chart

# Turba

In [102]:
def cross_entropy_turba(params, input, output, apply_fn):
    log_probs = apply_fn({"params": params}, input)
    labels = jax.nn.one_hot(output, log_probs.shape[1])
    loss = -jnp.mean(jnp.sum(labels * log_probs, axis=1))
    return loss, labels

In [103]:
class JaxModel(linen.Module):
    hidden_layers: int = 1
    hidden_dim: int = 32

    @linen.compact
    def __call__(self, x):
        for layer in range(self.hidden_layers):
            x = linen.Dense(self.hidden_dim)(x)
            x = linen.relu(x)
        x = linen.Dense(2)(x)
        x = linen.log_softmax(x)
        return x


In [104]:
# Define optimizer
optimizer = optax.adam(learning_rate=lr)

## Create Turba model

In [105]:
turba_model = JaxModel(hidden_layers=num_layers, hidden_dim=hidden_size)
turba_state = TurbaTrainState.swarm(turba_model, optimizer, swarm_size, points)

In [106]:
# Set Turba to use GPU if available
if GPU and xla_bridge.get_backend().platform != "gpu":
    raise RuntimeError("GPU support not available for Turba.")

## Train Turba model

In [107]:
# Convert to jnp arrays
X_train_turba = jnp.array(
    np.expand_dims(points, axis=0).repeat(swarm_size, axis=0), dtype=jnp.float32
)
y_train_turba = jnp.array(np.expand_dims(labels, axis=0).repeat(swarm_size, axis=0))

In [108]:
start = time()
for epoch in range(epochs):
    turba_state, loss, _ = turba_state.train(X_train_turba, y_train_turba, cross_entropy_turba)

    if epoch % 100 == 0:
        print(f"Epoch {epoch} Loss: {loss.mean()}")

print(f"turba time: {time() - start}")

Epoch 0 Loss: 0.6928533315658569
Epoch 100 Loss: 0.5026736259460449
Epoch 200 Loss: 0.27960872650146484
Epoch 300 Loss: 0.07157839834690094
Epoch 400 Loss: 0.021062644198536873
Epoch 500 Loss: 0.008877117186784744
Epoch 600 Loss: 0.004373231902718544
Epoch 700 Loss: 0.002478718990460038
Epoch 800 Loss: 0.0015482979360967875
Epoch 900 Loss: 0.0010448326356709003
Epoch 1000 Loss: 0.0007497220649383962
Epoch 1100 Loss: 0.000562878733035177
Epoch 1200 Loss: 0.00043735187500715256
Epoch 1300 Loss: 0.00034847145434468985
Epoch 1400 Loss: 0.00028323879814706743
Epoch 1500 Loss: 0.00023399262863676995
Epoch 1600 Loss: 0.00019596815400291234
Epoch 1700 Loss: 0.00016607035649940372
Epoch 1800 Loss: 0.00014117803948465735
Epoch 1900 Loss: 0.00012117482401663437
Epoch 2000 Loss: 0.00010485952225280926
Epoch 2100 Loss: 9.155629231827334e-05
Epoch 2200 Loss: 8.053541387198493e-05
Epoch 2300 Loss: 7.123780233087018e-05
Epoch 2400 Loss: 6.332913471851498e-05
Epoch 2500 Loss: 5.655695713358e-05
Epoch 2

## Visualize turba predictions

In [109]:
xs = pred_grid()
xs = jnp.array(np.expand_dims(xs, axis=0).repeat(swarm_size, axis=0), dtype=jnp.float32)

In [110]:
# Take average of predictions across all models
y = turba_state.predict(xs)
y = np.mean(y, axis=0)

In [111]:
# Convert to numpy
xs = np.array(xs[0])
y = np.array(y)

In [112]:
# Plot predictions
pred_chart = prediction_plot(xs, y)
chart = pred_chart + spiral_chart
chart

# Jax

In [113]:
# Instantiate model functions
classifier_fns = JaxModel(hidden_layers=num_layers, hidden_dim=hidden_size)

In [114]:
# Define optimizer
optimizer = optax.adam(learning_rate=lr)

In [115]:
# Define loss function
def cross_entropy(logprobs, labels):
    one_hot_labels = jax.nn.one_hot(labels, logprobs.shape[1])
    return -jnp.mean(jnp.sum(one_hot_labels * logprobs, axis=-1))

In [116]:
# Create wrapper loss function and gradient
def loss_fn(params, batch):
    logits = classifier_fns.apply({"params": params}, batch[0])
    loss = jnp.mean(cross_entropy(logits, batch[1]))
    return loss


loss_and_grad_fn = jax.value_and_grad(loss_fn)

In [117]:
# Define initialization function
def init_fn(input_shape, seed):
    rng = jr.PRNGKey(seed)
    dummy_input = jnp.ones((1, input_shape))
    params = classifier_fns.init(rng, dummy_input)["params"]
    opt_state = optimizer.init(params)
    return params, opt_state

In [118]:
# Define training function
@jax.jit
def train_step_fn(params, opt_state, batch):
    loss, grad = loss_and_grad_fn(params, batch)
    updates, opt_state = optimizer.update(grad, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

In [119]:
# Define prediction function
@jax.jit
def predict_fn(params, x):
    x = jnp.array(x)
    return classifier_fns.apply({"params": params}, x)

In [120]:
# Create vectorized functions
parallel_init_fn = jax.vmap(init_fn, in_axes=(None, 0))
parallel_train_step_fn = jax.vmap(train_step_fn, in_axes=(0, 0, None))
parallel_predict_fn = jax.vmap(predict_fn, in_axes=(0, None))

## Create Jax Model

In [121]:
# Initialize models
seeds = jnp.linspace(0, swarm_size - 1, swarm_size).astype(int)
model_states, opt_states = parallel_init_fn(2, seeds)

## Train Jax model

In [122]:
# Training loop
start = time()
for i in range(epochs):
    model_states, opt_states, _ = parallel_train_step_fn(
        model_states, opt_states, (points, labels)
    )

print(f"jax time: {time() - start}")

jax time: 8.655934572219849


## Visualize jax predictions

In [123]:
# Create prediction grid
xs = jnp.array(pred_grid())

In [124]:
# Take average of predictions across all models
y = parallel_predict_fn(model_states, xs)
y = jnp.mean(y, axis=0)

In [125]:
# Convert to numpy
xs = np.array(xs)
y = np.array(y)

In [126]:
# Plot predictions
pred_chart = prediction_plot(xs, y)
chart = pred_chart + spiral_chart
chart