<a href="https://colab.research.google.com/github/FengDSP/colabs/blob/main/Tensor_Parallel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import threading

In [114]:
N = 40
D = 20
H = 60
LR = 0.01
EPOCHS = 10

labels = np.random.randint(low=0, high=2, size=(N))
features = np.random.randn(N, D)

def sigmoid(x):
  return 1 / (1 + np.exp(-x))

def cross_entropy(y_pred, y_true):
  return - y_true * np.log(y_pred) - (1 - y_true) * np.log(1 - y_pred)

TP = 2
# Column parallel -> H -> Row parallel -> logit
assert H % TP == 0

layer_tp = [np.random.randn(D, H // TP) * 0.1 for _ in range(TP)]
bias_tp = [np.zeros(H // TP) for _ in range(TP)]
proj_tp = [np.random.randn(H // TP) for _ in range(TP)]
p_bias_tp = np.zeros(1)

In [118]:
def train():
  layer = np.concatenate(layer_tp, axis=1)
  bias = np.concatenate(bias_tp)
  proj = np.concatenate(proj_tp)
  p_bias = p_bias_tp.copy()

  for e in range(EPOCHS):
    # forward
    x = features
    preact = x @ layer + bias
    activation = np.maximum(0, preact)  # N, H
    logits = activation @ proj + p_bias
    prob = sigmoid(logits)

    # loss
    loss = cross_entropy(prob, labels).mean()
    print(f"Epoch {e} loss={loss}")

    # backward
    d_logit = prob - labels  # N
    d_p_bias = d_logit.sum()
    d_proj = activation.T @ d_logit   # H, 1
    d_activation = d_logit.reshape(-1, 1) @ proj.reshape(1, -1)    # N, H
    d_preact = d_activation * (preact > 0)  # N, H
    d_bias = d_preact.sum(axis=0)  # H
    d_layer = x.T @ d_preact  # D, H

    layer -= LR * d_layer
    bias -= LR * d_bias
    proj -= LR * d_proj
    p_bias -= LR * d_p_bias

train()

Epoch 0 loss=0.9487139530194348
Epoch 1 loss=0.44234263503102617
Epoch 2 loss=0.1752792993882155
Epoch 3 loss=0.09966124766672273
Epoch 4 loss=0.07375039541089778
Epoch 5 loss=0.06169647865232828
Epoch 6 loss=0.053369447126086336
Epoch 7 loss=0.047316068306977585
Epoch 8 loss=0.042476203622106634
Epoch 9 loss=0.03858467879545534


In [119]:
class Comm(object):
  def __init__(self):
    self.value = None
    self.reduce_barrier = threading.Barrier(TP)
    self.reset_barrier = threading.Barrier(TP)
    self.return_barrier = threading.Barrier(TP)
    self.lock = threading.Lock()

  def all_reduce(self, x, op=None):
    assert op == 'sum', f'Not supporting {op} yet'
    assert x is not None
    with self.lock:
      if self.value is None:
        self.value = x
      else:
        self.value += x
    self.reduce_barrier.wait()
    result = self.value
    if self.reset_barrier.wait() == 0:
      self.value = None
    self.return_barrier.wait()
    return result

In [120]:
def train_worker(i, comm):
  layer = layer_tp[i].copy()
  bias = bias_tp[i].copy()
  proj = proj_tp[i].copy()
  p_bias = p_bias_tp.copy()
  for e in range(EPOCHS):
    # forward
    x = features
    preact = x @ layer + bias
    activation = np.maximum(0, preact)  # N, H
    logits = activation @ proj
    if i == 0:
      logits += p_bias
    logits = comm.all_reduce(logits, op='sum')
    if i == 0:
      prob = sigmoid(logits)
      # loss
      loss = cross_entropy(prob, labels).mean()
      print(f"Epoch {e} loss={loss}")

      # backward
      d_logit = prob - labels  # N
      d_p_bias = d_logit.sum()
    else:
      d_logit = np.zeros_like(logits)
    d_logit = comm.all_reduce(d_logit, op='sum')
    d_proj = activation.T @ d_logit   # H, 1
    d_activation = d_logit.reshape(-1, 1) @ proj.reshape(1, -1)    # N, H
    d_preact = d_activation * (preact > 0)  # N, H
    d_bias = d_preact.sum(axis=0)  # H
    d_layer = x.T @ d_preact  # D, H

    layer -= LR * d_layer
    bias -= LR * d_bias
    proj -= LR * d_proj
    if i == 0:
      p_bias -= LR * d_p_bias

In [121]:
threads = []
comm = Comm()
for i in range(TP):
  t = threading.Thread(target=train_worker, args=(i, comm))
  t.start()
  threads.append(t)

for t in threads:
  t.join()

Epoch 0 loss=0.9487139530194348
Epoch 1 loss=0.44234263503102617
Epoch 2 loss=0.17527929938821551
Epoch 3 loss=0.09966124766672273
Epoch 4 loss=0.0737503954108978
Epoch 5 loss=0.06169647865232828
Epoch 6 loss=0.05336944712608632
Epoch 7 loss=0.0473160683069776
Epoch 8 loss=0.042476203622106634
Epoch 9 loss=0.03858467879545534
