## A Distributed Neural Network Simulation / How many raspberry pi's would it take to run gpt-4?

Distributed Data and Pipeline parallelism for large neural networks are both relatively straightforward concepts. Layer parallelism sounds much more complicated, but is it really so bad? Do you need 256 gpus to understand how this is implemented? Not at all!
  
Let's see what it looks like to execute a large neural network layer across many machines, and observe the scaling in terms of internal memory usage and bandwidth, and as well as total external network bandwidth. Then let's use these to speculate some interesting scenarios.

Todo:
- The simulation code is used as a refernce to justify the actual calculations, which are quite simple. These should be extracted because they're useful!
- Make a simulation for training, not just inference
- Refactor into a virtual machine class to better organize code

In [1]:
from jax import grad, value_and_grad, jit, random, nn
import jax.numpy as jnp
import matplotlib.pyplot as plt

Initialize the parameters and input for our NN

In [2]:
seed = 30
key = random.PRNGKey(seed)
embedding_dim = 512
nonlinear_dim = 2048
batch_size = 6

expand_mat = random.normal(key, (embedding_dim, nonlinear_dim)) / nonlinear_dim ** 0.5
key, subkey = random.split(key)
reduce_mat = random.normal(subkey, (nonlinear_dim, embedding_dim)) / embedding_dim ** 0.5
key, subkey = random.split(key)
input_vec = random.normal(subkey, (batch_size, embedding_dim,)) / embedding_dim ** 0.5



In [3]:
print(expand_mat.shape)
print(reduce_mat.shape)
print(input_vec.shape)

(512, 2048)
(2048, 512)
(6, 512)


Run the layer on a single machine, for reference

In [4]:
def layer_single_machine(x):
  return nn.relu(x @ expand_mat) @ reduce_mat

Test for stability by iterating the layer many times (hopefully avoids explosions or vanshing)

In [5]:
def iterate_layer(layer_func, x, iterations):
  for i in range(iterations):
    x = layer_func(x)
  return x

In [6]:
iterate_layer(layer_single_machine, input_vec, 1).sum()

Array(-0.12661535, dtype=float32)

In [7]:
iterate_layer(layer_single_machine, input_vec, 10).sum()

Array(-0.09062261, dtype=float32)

Implement the layer distribued across N machines. Most of the parameters and computation in a transformer are just these simple MLPs, so they make a fairly good approximation of the whole model.

In [8]:
# each machine must use less memory than a single one
def layer_n_machines(x, n):
  # create virtual machines
  reduce_piece_size = embedding_dim // n
  expand_piece_size = nonlinear_dim // n
  # load/distribute parameters and input across machines
  cross_machine_bandwidth = x.size * n
  machines = [{
      "input_vec": x,
      "expand_piece": expand_mat[:, i*expand_piece_size:(i+1)*expand_piece_size],
      "reduce_piece": reduce_mat[:, i*reduce_piece_size:(i+1)*reduce_piece_size]
      } for i in range(n)]

  data_type_bytes = 2 # lets assume fp16
  total_params = expand_mat.size + reduce_mat.size
  machine_params = machines[0]['expand_piece'].size + machines[0]['expand_piece'].size

  # this loop would run in paralell but we'll simulate them serially
  for m in machines:
    m["activation_piece"] = nn.relu(m["input_vec"] @ m["expand_piece"])
  internal_machine_bandwidth = (machines[0]["expand_piece"].size + machines[0]["input_vec"].size) * n

  # collect first matmul results (this requires a small 'all to all' communication)
  machines[0]["full_activation"] = jnp.hstack([m["activation_piece"] for m in machines])
  # no data is actually moved here but pretend that it is
  for m in machines:
    m["full_activation"] = machines[0]["full_activation"]
  cross_machine_bandwidth += machines[0]["full_activation"].size * n

  # project back to embedding for next layer
  for m in machines:
    m["output_piece"] = m["full_activation"] @ m["reduce_piece"]
  internal_machine_bandwidth += (machines[0]["reduce_piece"].size + machines[0]["full_activation"].size) * n

  machines[0]["full_output"] = jnp.hstack([m["output_piece"] for m in machines])
  cross_machine_bandwidth += machines[0]["full_output"].size

  assumed_context_size = 1024

  info = {
      "batch_size": x.shape[0],
      "embed_dim": x.shape[1],
      "activation_dim": reduce_mat.shape[0],
      "virtual_machines": n,
      "total_params": total_params,
      "machine_params": machine_params,
      "fraction": machine_params / total_params,
      "internal_bandwidth_params": internal_machine_bandwidth,
      "cross_machine_bandwidth_params": cross_machine_bandwidth,
      "internal_bandwidth": f"{internal_machine_bandwidth * data_type_bytes / (1000*1000)} mb",
      "cross_machine_bandwidth": f"{cross_machine_bandwidth * data_type_bytes / (1000*1000)} mb",
      # WIP
      "kv_cache_memory": x.shape[1] * x.shape[0] * 2 * assumed_context_size,
      "attention_cross_machine_bandwidth": x.shape[1] * x.shape[0] * 2 * assumed_context_size
      #"machine_memory": machines
  }

  return machines[0]["full_output"], info
  # return machines[0]["full_output"], machines

In [9]:
result, info = layer_n_machines(input_vec, 4)
result.shape

(6, 512)

In [10]:
info

{'batch_size': 6,
 'embed_dim': 512,
 'activation_dim': 2048,
 'virtual_machines': 4,
 'total_params': 2097152,
 'machine_params': 524288,
 'fraction': 0.25,
 'internal_bandwidth_params': 2158592,
 'cross_machine_bandwidth_params': 64512,
 'internal_bandwidth': '4.317184 mb',
 'cross_machine_bandwidth': '0.129024 mb',
 'kv_cache_memory': 6291456,
 'attention_cross_machine_bandwidth': 6291456}

In [11]:
result, info = layer_n_machines(random.normal(key, (64, embedding_dim,)) / embedding_dim ** 0.5, 4)
info

{'batch_size': 64,
 'embed_dim': 512,
 'activation_dim': 2048,
 'virtual_machines': 4,
 'total_params': 2097152,
 'machine_params': 524288,
 'fraction': 0.25,
 'internal_bandwidth_params': 2752512,
 'cross_machine_bandwidth_params': 688128,
 'internal_bandwidth': '5.505024 mb',
 'cross_machine_bandwidth': '1.376256 mb',
 'kv_cache_memory': 67108864,
 'attention_cross_machine_bandwidth': 67108864}

Increasing batch size comes almost for free!

In [12]:
result, info = layer_n_machines(input_vec, 128)
info

{'batch_size': 6,
 'embed_dim': 512,
 'activation_dim': 2048,
 'virtual_machines': 128,
 'total_params': 2097152,
 'machine_params': 16384,
 'fraction': 0.0078125,
 'internal_bandwidth_params': 4063232,
 'cross_machine_bandwidth_params': 1969152,
 'internal_bandwidth': '8.126464 mb',
 'cross_machine_bandwidth': '3.938304 mb',
 'kv_cache_memory': 6291456,
 'attention_cross_machine_bandwidth': 6291456}

Internal memory usage and bandwidth scale beautifully!
In this regime cross machine networking becomes critical!
  


In [13]:
def layer_four_machines(x):
  return layer_n_machines(x, 4)[0]

In [14]:
def layer_eight_machines(x):
  return layer_n_machines(x, 8)[0]

In [15]:
def layer_lots_o_machines(x):
  return layer_n_machines(x, 256)[0]

In [16]:
iterate_layer(layer_single_machine, input_vec, 10).sum()

Array(-0.09062261, dtype=float32)

In [17]:
iterate_layer(layer_four_machines, input_vec, 10).sum()

Array(-0.09062257, dtype=float32)

In [18]:
iterate_layer(layer_eight_machines, input_vec, 10).sum()

Array(-0.09062257, dtype=float32)

In [19]:
iterate_layer(layer_lots_o_machines, input_vec, 10).sum()

Array(-0.0906228, dtype=float32)

The results of our simulations match the single machine version, yay!

Let's make some more assumptions to roughly approximate gpt-4. Let's assume 1700 billion fp16 parameters, split across 8 gpt-3ish experts, each with 192 layers with embedding dim 12288 and an activation dim 4x that. We're just going for a very rough ballpark.

In [20]:
gpt4_params = 12288 * (12288 * 4) * 2 * 192 * 8
f"{gpt4_params / 1000000000000} trillion params"

'1.855425871872 trillion params'

Close enough :)

Technically if you attached a large storage device, a single raspberry pi could run gpt-4 but it would be quite slow (though this is another separate interesting question). Context size and kv cache would cause attention to be a significant factor for memory and bandwidth when using large batch sizes, but for batch size of 1 they are negligable.  Raspberry pi can have 4gb of ram. So we'll need:

In [21]:
n_pis = (gpt4_params / 1000000000) * 2 // 4
n_pis

927.0

Let's round up to 1024

In [22]:
n_pies = 1024

In [23]:
embedding_dim = 12288
nonlinear_dim = 12288 * 4
batch_size = 1

expand_mat = random.normal(key, (embedding_dim, nonlinear_dim)) / nonlinear_dim ** 0.5
key, subkey = random.split(key)
reduce_mat = random.normal(subkey, (nonlinear_dim, embedding_dim)) / embedding_dim ** 0.5
key, subkey = random.split(key)
input_vec = random.normal(subkey, (batch_size, embedding_dim,)) / embedding_dim ** 0.5

In [24]:
result, info = layer_n_machines(input_vec, n_pies)

In [25]:
info

{'batch_size': 1,
 'embed_dim': 12288,
 'activation_dim': 49152,
 'virtual_machines': 1024,
 'total_params': 1207959552,
 'machine_params': 1179648,
 'fraction': 0.0009765625,
 'internal_bandwidth_params': 1270874112,
 'cross_machine_bandwidth_params': 62926848,
 'internal_bandwidth': '2541.748224 mb',
 'cross_machine_bandwidth': '125.853696 mb',
 'kv_cache_memory': 25165824,
 'attention_cross_machine_bandwidth': 25165824}

We can multiply these figures by 192 to account for all layers in a forward pass, and again by 8 for each of the eight experts

In [26]:
f"{info['machine_params'] * 192 * 8 / 1_000_000_000:.3f}B params on each pi"

'1.812B params on each pi'

For bandwidth calculations, we'll assume only 2 experts are activated at a time. But we can't know these in advance so the experts that are acitvate will be using their underlying resources at full capacity. This is WIP

In [27]:
f"{info['internal_bandwidth_params'] * 192 * 8 * 2 / 1000000000:.3f}B internal memory use on each pi"

'3904.125B internal memory use on each pi'