In [2]:
from time import time

import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import optax
import torch
import torch.nn as nn
from flax import linen
from jax.lib import xla_bridge
from turbanet import TurbaTrainState
import altair as alt
import pandas as pd


# Inputs

In [3]:
# GENERAL INPUTS
GPU = False

# NETWORK SHAPE INPUTS
hidden_size = 8
num_layers = 3

# TRAINING INPUTS
swarm_size = 1
epochs = 10000
lr = 1e-3
dataset_size = 100

In [4]:
# Set numpy/torch/flax seeds
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7083100d1bf0>

# Create Data

In [5]:
# Create random data
def make_spirals(n_samples, noise_std=0.0, rotations=1.0):
    ts = jnp.linspace(0, 1, n_samples)
    rs = ts**0.5
    thetas = rs * rotations * 2 * np.pi
    signs = np.random.randint(0, 2, (n_samples,)) * 2 - 1
    labels = (signs > 0).astype(int)

    xs = rs * signs * jnp.cos(thetas) + np.random.randn(n_samples) * noise_std
    ys = rs * signs * jnp.sin(thetas) + np.random.randn(n_samples) * noise_std
    points = jnp.stack([xs, ys], axis=1)
    return points, labels

In [6]:
def pred_grid():
    grid_size = 50
    width = 1.5
    x0s, x1s = jnp.meshgrid(
        np.linspace(-width, width, grid_size), np.linspace(-width, width, grid_size)
    )
    xs = np.stack([x0s, x1s]).transpose().reshape((-1, 2))

    return xs

In [7]:
def true_plot(xs, y):
    df = pd.DataFrame({"x": xs[:, 0], "y": xs[:, 1], "label": y})

    spirals_x_axis = alt.X("x", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))
    spirals_y_axis = alt.Y("y", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))

    spiral_chart = (
        alt.Chart(df, width=350, height=300)
        .mark_circle(stroke="white", size=80, opacity=1)
        .encode(x=spirals_x_axis, y=spirals_y_axis, color=alt.Color("label:N"))
    )

    return spiral_chart

In [8]:
def prediction_plot(xs, y):
    data = {"x": xs[:, 0], "y": xs[:, 1], "pred": np.exp(y)[:, 1]}
    df = pd.DataFrame(data)
    spirals_x_axis = alt.X("x", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))
    spirals_y_axis = alt.Y("y", scale=alt.Scale(domain=[-1.5, 1.5], nice=False))
    pred_chart = (
        alt.Chart(df, width=350, height=300, title="Predictions from MLP")
        .mark_square(size=50, opacity=1)
        .encode(
            x=spirals_x_axis,
            y=spirals_y_axis,
            color=alt.Color("pred", scale=alt.Scale(scheme="blueorange")),
        )
    )

    return pred_chart

In [9]:
points, labels = make_spirals(dataset_size, noise_std=0.05)

In [10]:
spiral_chart = true_plot(points, labels)
spiral_chart


# Torch

In [11]:
class TorchModel(nn.Module):
    def __init__(self, hidden_size: int, num_layers: int):
        super(TorchModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.stack = nn.Sequential(
            nn.Linear(2, hidden_size),
            nn.ReLU(),
            *(nn.Linear(hidden_size, hidden_size), nn.ReLU()) * (num_layers - 1),
            nn.Linear(hidden_size, 2),
            nn.LogSoftmax(dim=1),
        )

    def forward(self, x):
        return self.stack(x)


## Create torch model

In [12]:
torch_models = [TorchModel(hidden_size, num_layers) for _ in range(swarm_size)]
torch_optimizers = [
    torch.optim.Adam(torch_model.parameters(), lr=lr) for torch_model in torch_models
]

In [13]:
# Set torch to use GPU if available
device = torch.device("cpu")
if GPU and torch.cuda.is_available():
    device = torch.device("cuda")
    for torch_model in torch_models:
        torch_model.to(device)

In [14]:
# Convert to torch tensors
X_train_torch = torch.from_numpy(np.array(points)).float()
y_train_torch = torch.from_numpy(np.array(labels))

In [15]:
# Move to GPU if available
if GPU:
    X_train_torch = X_train_torch.to(device)
    y_train_torch = y_train_torch.to(device)

## Train torch model

In [16]:
start = time()
for torch_model, torch_optimizer in zip(torch_models, torch_optimizers):
    torch_model.train()
    for epoch in range(epochs):
        torch_optimizer.zero_grad()
        y_pred = torch_model(X_train_torch)
        loss = torch.nn.functional.cross_entropy(y_pred, y_train_torch)
        loss.backward()
        torch_optimizer.step()

        if epoch % 100 == 0:
            print(f"Epoch {epoch} loss: {loss.item()}")

print(f"torch time: {time() - start}")


Epoch 0 loss: 0.6828737854957581
Epoch 100 loss: 0.6152295470237732
Epoch 200 loss: 0.481781929731369
Epoch 300 loss: 0.43467628955841064
Epoch 400 loss: 0.40698665380477905
Epoch 500 loss: 0.35786542296409607
Epoch 600 loss: 0.28868746757507324
Epoch 700 loss: 0.21677570044994354
Epoch 800 loss: 0.13285386562347412
Epoch 900 loss: 0.09474566578865051
Epoch 1000 loss: 0.08112376183271408
Epoch 1100 loss: 0.07077959924936295
Epoch 1200 loss: 0.06130019202828407
Epoch 1300 loss: 0.05346008390188217
Epoch 1400 loss: 0.04405168071389198
Epoch 1500 loss: 0.03755630925297737
Epoch 1600 loss: 0.03350276127457619
Epoch 1700 loss: 0.029528360813856125
Epoch 1800 loss: 0.025065308436751366
Epoch 1900 loss: 0.021160736680030823
Epoch 2000 loss: 0.01766040548682213
Epoch 2100 loss: 0.011051957495510578
Epoch 2200 loss: 0.00791778415441513
Epoch 2300 loss: 0.005935355555266142
Epoch 2400 loss: 0.004643777385354042
Epoch 2500 loss: 0.0037033907137811184
Epoch 2600 loss: 0.0030174434650689363
Epoch 2

## Visualize torch predictions

In [17]:
xs = pred_grid()

In [18]:
# Take average of predictions across all models
y = np.zeros((len(torch_models), xs.shape[0], 2))
for idx, model in enumerate(torch_models):
    model.eval()
    with torch.no_grad():
        y_pred = model(torch.from_numpy(xs).float().to(device))
        y[idx] = y_pred.cpu().numpy()

y = np.mean(y, axis=0)

In [19]:
# Plot predictions
pred_chart = prediction_plot(xs, y)
chart = pred_chart + spiral_chart
chart

# Turba

In [20]:
def cross_entropy_turba(params, batch, apply_fn):
    log_probs = apply_fn({"params": params}, batch["input"])
    labels = jax.nn.one_hot(batch["output"], log_probs.shape[1])
    loss = -jnp.mean(jnp.sum(labels * log_probs, axis=-1))
    return loss, labels

In [21]:
class JaxModel(linen.Module):
    hidden_layers: int = 1
    hidden_dim: int = 32

    @linen.compact
    def __call__(self, x):
        for layer in range(self.hidden_layers):
            x = linen.Dense(self.hidden_dim)(x)
            x = linen.relu(x)
        x = linen.Dense(2)(x)
        x = linen.log_softmax(x)
        return x


## Create Turba model

In [22]:
turba_model = JaxModel(hidden_size, num_layers)
turba_state = TurbaTrainState.swarm(turba_model, swarm_size, 2, learning_rate=lr)

In [23]:
# Set Turba to use GPU if available
if GPU and xla_bridge.get_backend().platform != "gpu":
    raise RuntimeError("GPU support not available for Turba.")

## Train Turba model

In [46]:
from turbanet.loss import softmax_cross_entropy

# Convert to jnp arrays
X_train_turba = jnp.array(
    np.expand_dims(points, axis=0).repeat(swarm_size, axis=0), dtype=jnp.float32
)
y_train_turba = jnp.array(np.expand_dims(labels, axis=0).repeat(swarm_size, axis=0))

start = time()
for epoch in range(epochs):
    turba_state, _, loss = turba_state.train(X_train_turba, y_train_turba, softmax_cross_entropy)

    if epoch % 100 == 0:
        print(f"Epoch {epoch} loss: {loss}")

print(f"turba time: {time() - start}")

Epoch 0 loss: [0.68592983]
Epoch 100 loss: [0.6859298]


KeyboardInterrupt: 

## Visualize turba predictions

In [31]:
xs = pred_grid()
xs = jnp.array(np.expand_dims(xs, axis=0).repeat(swarm_size, axis=0), dtype=jnp.float32)
xs.shape

(1, 2500, 2)

In [34]:
# Take average of predictions across all models
y = turba_state.predict(xs)
y = np.mean(y, axis=0)

In [41]:
xs.shape

(1, 2500, 2)

In [42]:
xs = np.array(xs[0])
y = np.array(y)

In [43]:
# Plot predictions
pred_chart = prediction_plot(xs, y)
chart = pred_chart + spiral_chart
chart

# Jax

In [None]:
classifier_fns = JaxModel()

In [None]:
optimizer = optax.adam(learning_rate=lr)

In [None]:
def cross_entropy(logprobs, labels):
    one_hot_labels = jax.nn.one_hot(labels, logprobs.shape[1])
    return -jnp.mean(jnp.sum(one_hot_labels * logprobs, axis=-1))


In [None]:
def loss_fn(params, batch):
    logits = classifier_fns.apply({"params": params}, batch[0])
    loss = jnp.mean(cross_entropy(logits, batch[1]))
    return loss, logits


loss_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

In [None]:
def init_fn(input_shape, seed):
    rng = jr.PRNGKey(seed)
    dummy_input = jnp.ones((1, input_shape))
    params = classifier_fns.init(rng, dummy_input)["params"]
    opt_state = optimizer.init(params)
    return params, opt_state

In [None]:
@jax.jit
def train_step_fn(params, opt_state, batch):
    (loss, grad), output = loss_and_grad_fn(params, batch)
    updates, opt_state = optimizer.update(grad, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss, output

In [None]:
parallel_init_fn = jax.vmap(init_fn, in_axes=(None, 0))
parallel_train_step_fn = jax.vmap(train_step_fn, in_axes=(0, 0, None))

## Create Jax Model

In [None]:
seeds = jnp.linspace(0, swarm_size - 1, swarm_size).astype(int)
model_states, opt_states = parallel_init_fn(2, seeds)

## Train Jax model

In [None]:
start = time()
for i in range(epochs):
    model_states, opt_states, _ = parallel_train_step_fn(
        model_states, opt_states, (points, labels)
    )

print(f"jax time: {time() - start}")