<a href="https://colab.research.google.com/github/FengDSP/colabs/blob/main/Tensor_Parallel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [61]:
import numpy as np
import threading

In [153]:
N = 10
D = 20
H = 40
LR = 0.001
EPOCHS = 20

labels = np.random.randint(low=0, high=2, size=(N))
features = np.random.randn(N, D)

def sigmoid(x):
  return 1 / (1 + np.exp(-x))

def cross_entropy(y_pred, y_true):
  return - y_true * np.log(y_pred) - (1 - y_true) * np.log(1 - y_pred)

TP = 2
# Column parallel -> H -> Row parallel -> logit
assert H % TP == 0

layer_tp = [np.random.randn(D, H // TP) * 0.1 for _ in range(TP)]
bias_tp = [np.zeros(H // TP) for _ in range(TP)]
proj_tp = [np.random.randn(H // TP) for _ in range(TP)]
p_bias_tp = np.zeros(1)

In [158]:
def train():
  layer = np.concatenate(layer_tp, axis=1)
  bias = np.concatenate(bias_tp)
  proj = np.concatenate(proj_tp)
  p_bias = p_bias_tp.copy()

  for e in range(EPOCHS):
    # forward
    x = features
    preact = x @ layer + bias
    activation = np.maximum(0, preact)  # N, H
    logits = activation @ proj + p_bias
    prob = sigmoid(logits)

    # loss
    loss = cross_entropy(prob, labels).mean()
    print(f"Epoch {e} loss={loss}")

    # backward
    d_logit = prob - labels  # N
    d_p_bias = d_logit.sum()
    d_proj = activation.T @ d_logit   # H, 1
    d_activation = d_logit.reshape(-1, 1) @ proj.reshape(1, -1)    # N, H
    d_preact = d_activation * (preact > 0)  # N, H
    d_bias = d_preact.sum(axis=0)  # H
    d_layer = x.T @ d_preact  # D, H

    layer -= LR * d_layer
    bias -= LR * d_bias
    proj -= LR * d_proj
    p_bias -= LR * d_p_bias

train()

Epoch 0 loss=0.6685087322760365
Epoch 1 loss=0.5463913051939092
Epoch 2 loss=0.44949604756574846
Epoch 3 loss=0.3747345190150955
Epoch 4 loss=0.3183985839512868
Epoch 5 loss=0.27618571016018956
Epoch 6 loss=0.24490588669293967
Epoch 7 loss=0.21992188152893805
Epoch 8 loss=0.19964563485403025
Epoch 9 loss=0.18278670993041032
Epoch 10 loss=0.1688279026633367
Epoch 11 loss=0.15672002634955073
Epoch 12 loss=0.14628174447523565
Epoch 13 loss=0.13721570448804848
Epoch 14 loss=0.12927649340361408
Epoch 15 loss=0.12218339901263808
Epoch 16 loss=0.1158466689141086
Epoch 17 loss=0.11021568026407172
Epoch 18 loss=0.10524478165190343
Epoch 19 loss=0.10067106574464216


In [162]:
barrier = threading.Barrier(TP)
to_reduce = []
lock = threading.Lock()

def all_reduce(x, op=None):
  assert op == 'sum'
  with lock:
    if not to_reduce:
      to_reduce.append(x)
    else:
      to_reduce[0] += x
  barrier.wait()
  result = to_reduce[0]
  if barrier.wait() == 0:
    to_reduce.clear()
  return result

In [159]:
def train_worker(i):
  layer = layer_tp[i].copy()
  bias = bias_tp[i].copy()
  proj = proj_tp[i].copy()
  p_bias = p_bias_tp.copy()
  for e in range(EPOCHS):
    # forward
    x = features
    preact = x @ layer + bias
    activation = np.maximum(0, preact)  # N, H
    logits = activation @ proj
    if i == 0:
      logits += p_bias
    logits = all_reduce(logits, op='sum')
    if i == 0:
      prob = sigmoid(logits)
      # loss
      loss = cross_entropy(prob, labels).mean()
      print(f"Epoch {e} loss={loss}")

      # backward
      d_logit = prob - labels  # N
      # print(f'd_logit={d_logit}')
      d_p_bias = d_logit.sum()
    else:
      d_logit = np.zeros_like(logits)
    d_logit = all_reduce(d_logit, op='sum')
    # print(f'd_logit={d_logit} after all readuce')
    d_proj = activation.T @ d_logit   # H, 1
    d_activation = d_logit.reshape(-1, 1) @ proj.reshape(1, -1)    # N, H
    d_preact = d_activation * (preact > 0)  # N, H
    d_bias = d_preact.sum(axis=0)  # H
    d_layer = x.T @ d_preact  # D, H

    layer -= LR * d_layer
    bias -= LR * d_bias
    proj -= LR * d_proj
    if i == 0:
      p_bias -= LR * d_p_bias

In [165]:
threads = []
for i in range(TP):
  t = threading.Thread(target=train_worker, args=(i, ))
  t.start()
  threads.append(t)

for t in threads:
  t.join()

Epoch 0 loss=4.912591801666268
Epoch 1 loss=2.7430450759191065
Epoch 2 loss=1.5102569861675312
Epoch 3 loss=1.278795563554971
Epoch 4 loss=1.6803453768683796
Epoch 5 loss=2.905304303601011
Epoch 6 loss=3.309683561604799
Epoch 7 loss=3.8384999270550253
Epoch 8 loss=3.643791742226925
Epoch 9 loss=3.2039173561849914
Epoch 10 loss=3.270571429564888
Epoch 11 loss=3.654928254525329
Epoch 12 loss=4.448823261521026
Epoch 13 loss=5.068537052084682
Epoch 14 loss=4.5439236683669915
Epoch 15 loss=4.039457339368953
Epoch 16 loss=3.4477388872852033
Epoch 17 loss=2.9010270219900334
Epoch 18 loss=2.633577601868524
Epoch 19 loss=2.9158517711979997
