In [147]:
from time import time

import jax
import jax.numpy as jnp
import numpy as np
import torch
import torch.nn as nn
from flax import linen
from jax.lib import xla_bridge
from turbanet import TurbaTrainState
import pandas as pd


# Inputs

In [148]:
# GENERAL INPUTS
GPU = False
repeats = 10

# NETWORK SHAPE INPUTS
hidden_sizes = [16]  # 8 * np.arange(1, 65)
num_layers = [1]

# TRAINING INPUTS
lr = 1e-3
dataset_size = 128
swarm_sizes = 2 * np.arange(1, 65)
epochs = [1024]
batch_sizes = [128]

In [149]:
print(
    f"Combinations: {len(hidden_sizes) * len(num_layers) * len(swarm_sizes) * len(epochs) * len(batch_sizes)}"
)

Combinations: 64


In [150]:
# Set numpy/torch/flax seeds
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x21c5becd8b0>

# Create Data

In [151]:
# Create random data
def make_spirals(n_samples, noise_std=0.0, rotations=1.0):
    ts = jnp.linspace(0, 1, n_samples)
    rs = ts**0.5
    thetas = rs * rotations * 2 * np.pi
    signs = np.random.randint(0, 2, (n_samples,)) * 2 - 1
    labels = (signs > 0).astype(int)

    xs = rs * signs * jnp.cos(thetas) + np.random.randn(n_samples) * noise_std
    ys = rs * signs * jnp.sin(thetas) + np.random.randn(n_samples) * noise_std
    points = jnp.stack([xs, ys], axis=1)
    return points, labels

In [152]:
points, labels = make_spirals(dataset_size, noise_std=0.05)

# Torch

In [153]:
class TorchModel(nn.Module):
    def __init__(self, hidden_size: int, num_layers: int):
        super(TorchModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.stack = nn.Sequential(
            nn.Linear(2, hidden_size),
            nn.ReLU(),
            *(nn.Linear(hidden_size, hidden_size), nn.ReLU()) * (num_layers - 1),
            nn.Linear(hidden_size, 2),
            nn.LogSoftmax(dim=1),
        )

    def forward(self, x):
        return self.stack(x)


In [None]:
def torch_train(torch_models, torch_optimizers, epochs, X, y):
    for torch_model, torch_optimizer in zip(torch_models, torch_optimizers):
        torch_model.train()
        for _ in range(epochs):
            for batch_input, batch_label in zip(X, y):
                torch_optimizer.zero_grad()
                y_pred = torch_model(batch_input)
                loss = torch.nn.functional.cross_entropy(y_pred, batch_label)
                loss.backward()
                torch_optimizer.step()

In [154]:
hidden = []
layers = []
swarm = []
epoch = []
batch = []
times = []
losses = []
for hidden_size in hidden_sizes:
    for num_layer in num_layers:
        for swarm_size in swarm_sizes:
            for epoch_num in epochs:
                for batch_size in batch_sizes:
                    # Create models
                    torch_models = [TorchModel(hidden_size, num_layer) for _ in range(swarm_size)]

                    # Create optimizers
                    torch_optimizers = [
                        torch.optim.Adam(torch_model.parameters(), lr=lr)
                        for torch_model in torch_models
                    ]

                    # Set torch to use GPU if available
                    device = torch.device("cpu")
                    if GPU and torch.cuda.is_available():
                        device = torch.device("cuda")
                        for torch_model in torch_models:
                            torch_model.to(device)

                    # Prepare data
                    X_train_torch = torch.from_numpy(
                        np.array(points.reshape(-1, batch_size, 2))
                    ).float()
                    y_train_torch = torch.from_numpy(np.array(labels.reshape(-1, batch_size)))

                    # Move to GPU if available
                    if GPU:
                        X_train_torch = X_train_torch.to(device)
                        y_train_torch = y_train_torch.to(device)

                    # Train
                    train_times = np.zeros(repeats)
                    for r in range(repeats):
                        start = time()
                        torch_train(
                            torch_models, torch_optimizers, epoch_num, X_train_torch, y_train_torch
                        )
                        train_times[r] = time() - start

                    train_time = train_times.mean()

                    # Print results
                    print(
                        f"Hidden Nodes: {hidden_size}, "
                        f"Layers: {num_layer}, "
                        f"Swarm: {swarm_size}, "
                        f"Epochs: {epoch_num}, "
                        f"Batch: {batch_size}, "
                        f"Time: {train_time}, "
                    )

                    # Save results
                    hidden.append(hidden_size)
                    layers.append(num_layer)
                    swarm.append(swarm_size)
                    epoch.append(epoch_num)
                    batch.append(batch_size)
                    times.append(train_time)

# Output results as dataframe
torch_data = pd.DataFrame(
    {
        "Hidden": hidden,
        "Layers": layers,
        "Swarm": swarm,
        "Epoch": epoch,
        "Batch": batch,
        "Time": times,
    }
)

In [None]:
# Save timing data
torch_data.to_csv("../../data/output/timing/torch_swarm_size_data.csv", index=False)


# Turba

In [155]:
def cross_entropy_turba(params, input, output, apply_fn):
    log_probs = apply_fn({"params": params}, input)
    labels = jax.nn.one_hot(output, log_probs.shape[1])
    loss = -jnp.mean(jnp.sum(labels * log_probs, axis=1))
    return loss, labels

In [156]:
class JaxModel(linen.Module):
    hidden_layers: int = 1
    hidden_dim: int = 32

    @linen.compact
    def __call__(self, x):
        for layer in range(self.hidden_layers):
            x = linen.Dense(self.hidden_dim)(x)
            x = linen.relu(x)
        x = linen.Dense(2)(x)
        x = linen.log_softmax(x)
        return x


In [157]:
def turba_train(turba_state, epochs, X, y):
    for _ in range(epochs):
        for batch_input, batch_label in zip(X, y):
            turba_state, _, _ = turba_state.train(batch_input, batch_label, cross_entropy_turba)


In [158]:
hidden = []
layers = []
swarm = []
epoch = []
batch = []
times = []

for hidden_size in hidden_sizes:
    for num_layer in num_layers:
        for swarm_size in swarm_sizes:
            for epoch_num in epochs:
                for batch_size in batch_sizes:
                    # Create models
                    turba_model = JaxModel(hidden_layers=num_layer, hidden_dim=hidden_size)
                    turba_state = TurbaTrainState.swarm(
                        turba_model, swarm_size, 2, learning_rate=lr
                    )

                    # Set Turba to use GPU if available
                    if GPU and xla_bridge.get_backend().platform != "gpu":
                        raise RuntimeError("GPU support not available for Turba.")

                    # Prepare data
                    X_train_turba = jnp.array(
                        np.expand_dims(points.reshape(-1, batch_size, 2), axis=1).repeat(
                            swarm_size, axis=1
                        ),
                        dtype=jnp.float32,
                    )
                    y_train_turba = jnp.array(
                        np.expand_dims(labels.reshape(-1, batch_size), axis=1).repeat(
                            swarm_size, axis=1
                        )
                    )

                    # Train
                    train_times = np.zeros(repeats)
                    for r in range(repeats):
                        start = time()
                        turba_train(turba_state, epoch_num, X_train_turba, y_train_turba)
                        train_times[r] = time() - start

                    train_time = train_times.mean()

                    # Print results
                    print(
                        f"Hidden Nodes: {hidden_size}, "
                        f"Layers: {num_layer}, "
                        f"Swarm: {swarm_size}, "
                        f"Epochs: {epoch_num}, "
                        f"Batch: {batch_size}, "
                        f"Time: {train_time}, "
                    )

                    # Save results
                    hidden.append(hidden_size)
                    layers.append(num_layer)
                    swarm.append(swarm_size)
                    epoch.append(epoch_num)
                    batch.append(batch_size)
                    times.append(train_time)

# Output results as dataframe
turba_data = pd.DataFrame(
    {
        "Hidden": hidden,
        "Layers": layers,
        "Swarm": swarm,
        "Epoch": epoch,
        "Batch": batch,
        "Time": times,
    }
)

Hidden Nodes: 16, Layers: 1, Swarm: 2, Epochs: 1024, Batch: 128, Time: 0.6248784065246582, 
Hidden Nodes: 16, Layers: 1, Swarm: 4, Epochs: 1024, Batch: 128, Time: 0.6310785293579102, 
Hidden Nodes: 16, Layers: 1, Swarm: 6, Epochs: 1024, Batch: 128, Time: 0.6246379137039184, 
Hidden Nodes: 16, Layers: 1, Swarm: 8, Epochs: 1024, Batch: 128, Time: 0.6039792537689209, 
Hidden Nodes: 16, Layers: 1, Swarm: 10, Epochs: 1024, Batch: 128, Time: 0.6036675930023193, 
Hidden Nodes: 16, Layers: 1, Swarm: 12, Epochs: 1024, Batch: 128, Time: 0.6016450881958008, 
Hidden Nodes: 16, Layers: 1, Swarm: 14, Epochs: 1024, Batch: 128, Time: 0.6072637557983398, 
Hidden Nodes: 16, Layers: 1, Swarm: 16, Epochs: 1024, Batch: 128, Time: 0.6071071147918701, 
Hidden Nodes: 16, Layers: 1, Swarm: 18, Epochs: 1024, Batch: 128, Time: 0.6074105978012085, 
Hidden Nodes: 16, Layers: 1, Swarm: 20, Epochs: 1024, Batch: 128, Time: 0.607907748222351, 
Hidden Nodes: 16, Layers: 1, Swarm: 22, Epochs: 1024, Batch: 128, Time: 0.6

In [160]:
# Save timing data
turba_data.to_csv("../../data/output/timing/turba_swarm_size_data.csv", index=False)