In [1]:
from time import time

import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import optax
import torch
import torch.nn as nn
from flax import linen
from jax.lib import xla_bridge
from turbanet import TurbaTrainState
import altair as alt
import pandas as pd


# Inputs

In [2]:
# GENERAL INPUTS
GPU = False

# NETWORK SHAPE INPUTS
hidden_size = 512
num_layers = 2

# TRAINING INPUTS
swarm_size = 10
epochs = 100
lr = 1e-3
dataset_size = 100

In [3]:
# Set numpy/torch/flax seeds
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x1c9987dc130>

# Create Data

In [4]:
# Create random data
def make_spirals(n_samples, noise_std=0.0, rotations=1.0):
    ts = jnp.linspace(0, 1, n_samples)
    rs = ts**0.5
    thetas = rs * rotations * 2 * np.pi
    signs = np.random.randint(0, 2, (n_samples,)) * 2 - 1
    labels = (signs > 0).astype(int)

    xs = rs * signs * jnp.cos(thetas) + np.random.randn(n_samples) * noise_std
    ys = rs * signs * jnp.sin(thetas) + np.random.randn(n_samples) * noise_std
    points = jnp.stack([xs, ys], axis=1)
    return points, labels

In [5]:
def pred_grid():
    grid_size = 50
    width = 1.5
    x0s, x1s = jnp.meshgrid(
        np.linspace(-width, width, grid_size), np.linspace(-width, width, grid_size)
    )
    xs = np.stack([x0s, x1s]).transpose().reshape((-1, 2))

    return xs

In [6]:
def true_plot(xs, y):
    df = pd.DataFrame({"x": xs[:, 0], "y": xs[:, 1], "label": y})

    spirals_x_axis = alt.X("x", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))
    spirals_y_axis = alt.Y("y", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))

    spiral_chart = (
        alt.Chart(df, width=350, height=300)
        .mark_circle(stroke="white", size=80, opacity=1)
        .encode(x=spirals_x_axis, y=spirals_y_axis, color=alt.Color("label:N"))
    )

    return spiral_chart

In [7]:
def prediction_plot(xs, y):
    data = {"x": xs[:, 0], "y": xs[:, 1], "pred": np.exp(y)[:, 1]}
    df = pd.DataFrame(data)
    spirals_x_axis = alt.X("x", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))
    spirals_y_axis = alt.Y("y", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))
    pred_chart = (
        alt.Chart(df, width=350, height=300, title="Predictions from MLP")
        .mark_square(size=50, opacity=1)
        .encode(
            x=spirals_x_axis,
            y=spirals_y_axis,
            color=alt.Color("pred", scale=alt.Scale(scheme="blueorange")),
        )
    )

    return pred_chart

In [8]:
points, labels = make_spirals(dataset_size, noise_std=0.05)

In [9]:
spiral_chart = true_plot(points, labels)
spiral_chart


# Torch

In [10]:
class TorchModel(nn.Module):
    def __init__(self, hidden_size: int, num_layers: int):
        super(TorchModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.stack = nn.Sequential(
            nn.Linear(2, hidden_size),
            nn.ReLU(),
            *(nn.Linear(hidden_size, hidden_size), nn.ReLU()) * (num_layers - 1),
            nn.Linear(hidden_size, 2),
            nn.LogSoftmax(dim=1),
        )

    def forward(self, x):
        return self.stack(x)


## Create torch model

In [11]:
torch_models = [TorchModel(hidden_size, num_layers) for _ in range(swarm_size)]
torch_optimizers = [
    torch.optim.Adam(torch_model.parameters(), lr=lr) for torch_model in torch_models
]

In [12]:
# Set torch to use GPU if available
device = torch.device("cpu")
if GPU and torch.cuda.is_available():
    device = torch.device("cuda")
    for torch_model in torch_models:
        torch_model.to(device)

In [13]:
# Convert to torch tensors
X_train_torch = torch.from_numpy(np.array(points)).float()
y_train_torch = torch.from_numpy(np.array(labels))

In [14]:
# Move to GPU if available
if GPU:
    X_train_torch = X_train_torch.to(device)
    y_train_torch = y_train_torch.to(device)

## Train torch model

In [15]:
start = time()
for torch_model, torch_optimizer in zip(torch_models, torch_optimizers):
    torch_model.train()
    for epoch in range(epochs):
        torch_optimizer.zero_grad()
        y_pred = torch_model(X_train_torch)
        loss = torch.nn.functional.cross_entropy(y_pred, y_train_torch)
        loss.backward()
        torch_optimizer.step()

        if epoch % 100 == 0:
            print(f"Epoch {epoch} loss: {loss.item()}")

print(f"torch time: {time() - start}")


Epoch 0 loss: 0.7015222311019897
Epoch 0 loss: 0.6937174797058105
Epoch 0 loss: 0.6985547542572021
Epoch 0 loss: 0.6963149905204773
Epoch 0 loss: 0.7113945484161377
Epoch 0 loss: 0.6931594014167786
Epoch 0 loss: 0.7011054158210754
Epoch 0 loss: 0.6697982549667358
Epoch 0 loss: 0.6965584754943848
Epoch 0 loss: 0.6896067261695862
Epoch 0 loss: 0.7039291858673096
Epoch 0 loss: 0.6945608258247375
Epoch 0 loss: 0.7157065868377686
Epoch 0 loss: 0.6814790964126587
Epoch 0 loss: 0.6988835334777832
Epoch 0 loss: 0.6978043913841248
Epoch 0 loss: 0.6915125846862793
Epoch 0 loss: 0.6837044358253479
Epoch 0 loss: 0.6896480321884155
Epoch 0 loss: 0.6977105140686035
Epoch 0 loss: 0.6858565807342529
Epoch 0 loss: 0.6873359680175781
Epoch 0 loss: 0.6769067645072937
Epoch 0 loss: 0.6840882301330566
Epoch 0 loss: 0.6821351051330566
Epoch 0 loss: 0.6944190859794617
Epoch 0 loss: 0.686427891254425
Epoch 0 loss: 0.6898047924041748
Epoch 0 loss: 0.6783340573310852
Epoch 0 loss: 0.682384192943573
Epoch 0 loss

## Visualize torch predictions

In [16]:
xs = pred_grid()

In [17]:
# Take average of predictions across all models
y = np.zeros((len(torch_models), xs.shape[0], 2))
for idx, model in enumerate(torch_models):
    model.eval()
    with torch.no_grad():
        y_pred = model(torch.from_numpy(xs).float().to(device))
        y[idx] = y_pred.cpu().numpy()

y = np.mean(y, axis=0)

In [18]:
# Plot predictions
pred_chart = prediction_plot(xs, y)
chart = pred_chart + spiral_chart
chart

# Turba

In [19]:
def cross_entropy_turba(params, input, output, apply_fn):
    log_probs = apply_fn({"params": params}, input)
    labels = jax.nn.one_hot(output, log_probs.shape[1])
    loss = -jnp.mean(jnp.sum(labels * log_probs, axis=1))
    return loss, labels

In [20]:
class JaxModel(linen.Module):
    hidden_layers: int = 1
    hidden_dim: int = 32

    @linen.compact
    def __call__(self, x):
        for layer in range(self.hidden_layers):
            x = linen.Dense(self.hidden_dim)(x)
            x = linen.relu(x)
        x = linen.Dense(2)(x)
        x = linen.log_softmax(x)
        return x


## Create Turba model

In [21]:
turba_model = JaxModel(hidden_layers=num_layers, hidden_dim=hidden_size)
turba_state = TurbaTrainState.swarm(turba_model, swarm_size, 2, learning_rate=lr)

In [22]:
# Set Turba to use GPU if available
if GPU and xla_bridge.get_backend().platform != "gpu":
    raise RuntimeError("GPU support not available for Turba.")

## Train Turba model

In [27]:
# Convert to jnp arrays
X_train_turba = jnp.array(
    np.expand_dims(points, axis=0).repeat(swarm_size, axis=0), dtype=jnp.float32
)
y_train_turba = jnp.array(np.expand_dims(labels, axis=0).repeat(swarm_size, axis=0))

In [31]:
start = time()
for epoch in range(epochs):
    turba_state, loss, _ = turba_state.train(X_train_turba, y_train_turba, cross_entropy_turba)

    if epoch % 100 == 0:
        print(f"Epoch {epoch} loss: {loss}")

print(f"turba time: {time() - start}")

Epoch 0 loss: [0.01989749 0.03048518 0.03464876 0.02263453 0.01999506 0.02648546
 0.02325751 0.01524994 0.02470766 0.01522729 0.02157579 0.01798083
 0.01980194 0.0249024  0.02892171 0.01710601 0.03406153 0.02562087
 0.03471166 0.02443377 0.01598468 0.02878641 0.01890839 0.01943718
 0.02351479 0.02322996 0.03430641 0.0248089  0.01478038 0.01825798
 0.02000794 0.02747967 0.01924886 0.02542999 0.01999106 0.02613681
 0.02028572 0.02586551 0.02908243 0.028736   0.02953729 0.02592747
 0.01461057 0.01988927 0.03377891 0.01806061 0.02562007 0.01588437
 0.02826984 0.02414952 0.02154514 0.02308157 0.04319855 0.02688492
 0.02199867 0.02354035 0.02794691 0.02689847 0.02900563 0.03002342
 0.01861075 0.01483687 0.01572811 0.01587019 0.02848407 0.02595776
 0.01180284 0.03085091 0.02113145 0.03587299 0.01627695 0.02291079
 0.02168578 0.01760775 0.0212584  0.02311016 0.01844833 0.02182013
 0.0185155  0.03029763 0.02586266 0.02827227 0.02188956 0.02197225
 0.02457011 0.01489021 0.01664164 0.02650387 0.0

## Visualize turba predictions

In [124]:
xs = pred_grid()
xs = jnp.array(np.expand_dims(xs, axis=0).repeat(swarm_size, axis=0), dtype=jnp.float32)

In [125]:
# Take average of predictions across all models
y = turba_state.predict(xs)
y = np.mean(y, axis=0)

In [126]:
# Convert to numpy
xs = np.array(xs[0])
y = np.array(y)

In [127]:
# Plot predictions
pred_chart = prediction_plot(xs, y)
chart = pred_chart + spiral_chart
chart

# Jax

In [128]:
# Instantiate model functions
classifier_fns = JaxModel(hidden_layers=num_layers, hidden_dim=hidden_size)

In [129]:
# Define optimizer
optimizer = optax.adam(learning_rate=lr)

In [130]:
# Define loss function
def cross_entropy(logprobs, labels):
    one_hot_labels = jax.nn.one_hot(labels, logprobs.shape[1])
    return -jnp.mean(jnp.sum(one_hot_labels * logprobs, axis=-1))

In [131]:
# Create wrapper loss function and gradient
def loss_fn(params, batch):
    logits = classifier_fns.apply({"params": params}, batch[0])
    loss = jnp.mean(cross_entropy(logits, batch[1]))
    return loss


loss_and_grad_fn = jax.value_and_grad(loss_fn)

In [132]:
# Define initialization function
def init_fn(input_shape, seed):
    rng = jr.PRNGKey(seed)
    dummy_input = jnp.ones((1, input_shape))
    params = classifier_fns.init(rng, dummy_input)["params"]
    opt_state = optimizer.init(params)
    return params, opt_state

In [133]:
# Define training function
@jax.jit
def train_step_fn(params, opt_state, batch):
    loss, grad = loss_and_grad_fn(params, batch)
    updates, opt_state = optimizer.update(grad, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

In [134]:
# Define prediction function
@jax.jit
def predict_fn(params, x):
    x = jnp.array(x)
    return classifier_fns.apply({"params": params}, x)

In [135]:
# Create vectorized functions
parallel_init_fn = jax.vmap(init_fn, in_axes=(None, 0))
parallel_train_step_fn = jax.vmap(train_step_fn, in_axes=(0, 0, None))
parallel_predict_fn = jax.vmap(predict_fn, in_axes=(0, None))

## Create Jax Model

In [136]:
# Initialize models
seeds = jnp.linspace(0, swarm_size - 1, swarm_size).astype(int)
model_states, opt_states = parallel_init_fn(2, seeds)

## Train Jax model

In [138]:
# Training loop
start = time()
for i in range(epochs):
    model_states, opt_states, _ = parallel_train_step_fn(
        model_states, opt_states, (points, labels)
    )

print(f"jax time: {time() - start}")

jax time: 0.4891629219055176


## Visualize jax predictions

In [139]:
# Create prediction grid
xs = jnp.array(pred_grid())

In [141]:
# Take average of predictions across all models
y = parallel_predict_fn(model_states, xs)
y = jnp.mean(y, axis=0)

array([[-40.792522,   0.      ],
       [-39.48198 ,   0.      ],
       [-38.161472,   0.      ],
       ...,
       [  0.      , -38.477222],
       [  0.      , -39.51838 ],
       [  0.      , -40.54807 ]], shape=(2500, 2), dtype=float32)

In [142]:
# Convert to numpy
xs = np.array(xs)
y = np.array(y)

array([[-40.792522,   0.      ],
       [-39.48198 ,   0.      ],
       [-38.161472,   0.      ],
       ...,
       [  0.      , -38.477222],
       [  0.      , -39.51838 ],
       [  0.      , -40.54807 ]], shape=(2500, 2), dtype=float32)

In [143]:
# Plot predictions
pred_chart = prediction_plot(xs, y)
chart = pred_chart + spiral_chart
chart