# Tensor Parallelism
References:
- Demystifying Tensor Parallelism (https://robotchinwag.com/posts/demystifying-tensor-parallelism/#pairwise-sharding)

Purpose: expand beyond 1 GPU model training and support training larger models

Approach: leverage 3D parallelism, specifically tensor parallelism here, to split layer weights across devices

Result: enables increased memory capacity and higher throughput, but there's a large communication cost for all-gather and all-reduce operations

Definitions:
- pairwise sharding = each device (or rank) is assigned a *pair* of weight shards, i.e. Pair = $(A_i, B_i)$ and operations with them can happen on the device (no collective operation needed until the end). Usually, $A_i$ is column sharded, and $B_i$ is row-sharded and then an all-reduce operation is needed. In the backward, because of the transposes, it's flipped sharding that ends in an all-reduce so the backward gradients are also pairwise sharding
- forward, row-shard = need all-reduce (sum) at the end to combine partial outputs
- backward w.r.t. X, col-shard = need all-gather at the end to combine partial derivatives

Notes:
- V1.1 - it looks like we use all-reduce whenever we need the full result as input to something like an activation function ReLU. Otherwise, we use an all-gather?
- V1.2 - actually because ReLU is a pointwise operation, we can avoid an extra collective operation of AllReduce after the first matmul and instead focus on doing a col-wise weight sharding s.t. we just do a row-wise weight sharding in the second mat mul followed by a AllReduce. This allows us to avoid 1 collective operation which is important to reduce networking bottlenecks.
There is a pattern of column wise sharding -> rowwise sharding = "pairwise sharding" to reduce the need for an All-Gather. This can be applied to the Attention layer in a Transformer model as well (Q: how?)


Gradients of tensor parallel layer
- the backwards of a pairwise sharded layer is also pairwise sharded, because we operate in reverse order and the sharding is flipped because of transposes
- a forward col-wise sharded matmul (which uses all-gather) = row-wise sharded in backprop to use all-reduce to sum partial gradients
- Q: Why transpose A in backprop computation in matmuls? 
    - A: Y=XA, X'=Y'@A^T. This is because each row j in A tells us how much X dimension j contributes to each output dimension in Y. Therefore, when we transpose, A^T, row j because col j, and given the gradients for each dimension in Y, we can compute how much a change in dimension j in X will be, based on the dimension gradients in Y and the each contribution listed in A^T col j.

Action items:
- COMPLETE - Q: we know that to compute gradients w.r.t X is flipped sharding because while the forward Y = XA, X' = Y'@A^T, but how about for the actual weights gradients?
Well actually, no collective operations are needed to udpate the weights - each rank/node just updates its own shard weights locally with its local gradient.

- TODO - Q: How to improve the pairwise sharding implementation so it's a bit more comprehensive?
Primarily completed it, it should be pretty good conceptually. Only next step is just really actually using torch.distributed.all_reduce instead of a built in self.reduce().

- TODO - Q: How can we apply pairwise sharding in the attention layer?


In [None]:
# Pairwise sharding - col-wise weight sharding -> row-wise weight sharding -> all-reduce
# Assumes 2-way tensor parallel

import torch
import torch.nn as nn
from typing import List
torch.manual_seed(0)

class FeedForward(nn.Module):

    def __init__(self, D, expansion=4):
        super().__init__()
        hidden = expansion * D
        self.fc1_W_DH = torch.randn(D, hidden)
        self.act = nn.ReLU()
        self.fc2_W_HD = torch.randn(hidden, D)

    def all_reduce(self, parts: List[torch.Tensor]):
        # assume we have some all reduce collective operation
        # simulate sum across tensor-parallel ranks
        # in real code: out = Z2_BND_local; torch.distributed.all_reduce(out, op=SUM) # synchronous/blocking collective - waits for all ranks to reach that call and the reduction is finished
        return parts[0] + parts[1]

    def all_gather(self, parts: List[torch.Tensor]):
        # assume we have some all gather collective operation
        return torch.cat([parts[0], parts[1]], dim=1)

    def forward(self, X_BND):
        Z1_BNH = X_BND @ self.fc1_W_DH # X = XA
        H_BNH = self.act(Z1_BNH) # X = ReLU(X)
        Z2_BND = H_BNH @ self.fc2_W_HD # Y = XB
        return Z2_BND
    
    def forward_pairwise_sharded(self, X_BND):
        """
        weight col-wise sharding -> ReLU -> weight row-wise sharding -> all-reduce
        which is more optimal than:
        weight row-wise sharding -> all-reduce -> ReLU -> weight col-wise sharding -> all-gather
        """
        # tracking shapes
        B, N, D = X_BND.shape
        _, H = self.fc1_W_DH.shape
        K = D // 2
        L = H // 2

        # weight col-wise sharding
        fc1_W_DL_0, fc1_W_DL_1 = self.fc1_W_DH[:, :L], self.fc1_W_DH[:, L:]

        Z1_BNL_0 = X_BND @ fc1_W_DL_0
        Z1_BNL_1 = X_BND @ fc1_W_DL_1

        # ReLU
        H_BNL_0 = self.act(Z1_BNL_0)
        H_BNL_1 = self.act(Z1_BNL_1)

        # weight row-wise sharding
        fc2_W_LD_0, fc2_W_LD_1 = self.fc2_W_HD[:L, :], self.fc2_W_HD[L:, :]
        Z2_BND_0 = H_BNL_0 @ fc2_W_LD_0
        Z2_BND_1 = H_BNL_1 @ fc2_W_LD_1

        # all-reduce
        Z2_BND = self.all_reduce([Z2_BND_0, Z2_BND_1])

        return Z2_BND
        

B, N, D = 32, 128, 768
ff = FeedForward(D=D)
X_BND = torch.randn(B, N, D) # (B,N,D)
Y_BND = ff(X_BND)
Y_BND_pairwise = ff.forward_pairwise_sharded(X_BND)
assert Y_BND.shape == (B, N, D)
assert Y_BND.shape == Y_BND_pairwise.shape
assert torch.allclose(Y_BND, Y_BND_pairwise, rtol=1e-2, atol=1e-4) # FP32 has slight rounding differences so we increase the tolerance here

In [9]:
# Pairwise sharding - col-wise weight sharding -> row-wise weight sharding -> all-reduce
# Assumes 2-way tensor parallel
# Adding a backward function now, without biases for simplicity

import torch
import torch.nn as nn
from typing import List
torch.manual_seed(0)

class FeedForward(nn.Module):

    def __init__(self, D, expansion=4):
        super().__init__()
        hidden = expansion * D
        self.fc1_W_DH = torch.randn(D, hidden)
        self.act = nn.ReLU() # using ReLU for easier derivation
        self.fc2_W_HD = torch.randn(hidden, D)

        self.cache = None
        self.grads = {}

    def all_reduce(self, parts: List[torch.Tensor]):
        # assume we have some all reduce collective operation
        # simulate sum across tensor-parallel ranks
        # in real code: out = Z2_BND_local; torch.distributed.all_reduce(out, op=SUM) # synchronous/blocking collective - waits for all ranks to reach that call and the reduction is finished
        return sum(parts)

    def all_gather(self, parts: List[torch.Tensor]):
        # assume we have some all gather collective operation
        return torch.cat([parts[0], parts[1]], dim=1)

    def forward(self, X_BND):
        Z1_BNH = X_BND @ self.fc1_W_DH # X = XA
        H_BNH = self.act(Z1_BNH) # X = ReLU(X)
        Z2_BND = H_BNH @ self.fc2_W_HD # Y = XB

        self.cache = (X_BND, Z1_BNH, H_BNH)
        return Z2_BND
    
    def backward(self, dY_BND):
        """
        Assume we have a dY_BND = \delta_{loss} / \delta_{Y_BND} which is computed from the derivative of the differentiable loss function
        """
        X_BND, Z1_BNH, H1_BNH = self.cache

        dW2_HD = torch.einsum("bnd,bnh->hd", dY_BND, H1_BNH)
        dH1_BNH = torch.einsum("bnd,hd->bnh", dY_BND, self.fc2_W_HD)

        # derivative of ReLU
        dZ1_BNH = dH1_BNH * (Z1_BNH > 0).to(dH1_BNH.dtype)

        dW1_DH = torch.einsum("bnh,bnd->dh", dZ1_BNH, X_BND)
        dX_BND = torch.einsum("bnh,dh->bnd", dZ1_BNH, self.fc1_W_DH)

        self.grads = {
            "dW1_DH": dW1_DH,
            "dW2_HD": dW2_HD,
        }
        
        return dX_BND

    def forward_pairwise_sharded(self, X_BND):
        """
        weight col-wise sharding -> ReLU -> weight row-wise sharding -> all-reduce
        which is more optimal than:
        weight row-wise sharding -> all-reduce -> ReLU -> weight col-wise sharding -> all-gather
        """
        # tracking shapes
        B, N, D = X_BND.shape
        _, H = self.fc1_W_DH.shape
        assert H % 2 == 0, "H must be divisble by 2"
        K = D // 2
        L = H // 2

        # weight col-wise sharding
        fc1_W_DL_0, fc1_W_DL_1 = self.fc1_W_DH[:, :L], self.fc1_W_DH[:, L:]

        Z1_BNL_0 = X_BND @ fc1_W_DL_0
        Z1_BNL_1 = X_BND @ fc1_W_DL_1

        # ReLU
        H_BNL_0 = self.act(Z1_BNL_0)
        H_BNL_1 = self.act(Z1_BNL_1)

        # weight row-wise sharding
        fc2_W_LD_0, fc2_W_LD_1 = self.fc2_W_HD[:L, :], self.fc2_W_HD[L:, :]
        Z2_BND_0 = H_BNL_0 @ fc2_W_LD_0
        Z2_BND_1 = H_BNL_1 @ fc2_W_LD_1

        # all-reduce
        Z2_BND = self.all_reduce([Z2_BND_0, Z2_BND_1])

        self.cache = X_BND, Z1_BNL_0, Z1_BNL_1, H_BNL_0, H_BNL_1

        return Z2_BND
    
    def backward_pairwise_sharded(self, dY_BND):
        """
        Assume we have a dY_BND = \delta_{loss} / \delta_{Y_BND} which is computed from the derivative of the differentiable loss function

        Backward gradients:
        dL/dX = dL/dY (B_0^T @ ReLU'(F) @ A_0^T + B_1^T @ ReLU'(G) @ A_1^T)
        - final operation is an all-reduce because we're summing gradients

        where the pairwise sharded forward pass is:

        X_BND -> F_BNL = X_0_BND@A_0_DL -> M_BNL = ReLU(F_BNL) -> K_BND = M_BNL@B_0_LD -> Y_BND = K_BND + H_BND
             |-> G_BNL = X_1_BND@A_1_DL -> N_BNL = ReLU(G_BNL) -> H_BND = N_BNL@B_1_LD ->|

        Where dim_L = H // 2, A = fc1_W_DH and B = fc2_W_HD
        A_DH is split col-wise into A_0_DL, A_1_DL
        B_HD is split row-wise into B_0_LD, B_1_LD
        """
        X_BND, F_BNL, G_BNL, M_BNL, N_BNL = self.cache

        # define dimension shapes
        B, N, D = dY_BND.shape
        _, H = self.fc1_W_DH.shape
        assert H % 2 == 0, "H must be divisble by 2"
        L = H // 2

        # define the split weight shards
        A_DH = self.fc1_W_DH
        B_HD = self.fc2_W_HD

        A_0_DL, A_1_DL = A_DH[:, :L], A_DH[:, L:]
        B_0_LD, B_1_LD = B_HD[:L, :], B_HD[L:, :]

        # compute gradients for second linear pass per shard
        dB_0_LD = torch.einsum("bnd,bnl->ld", dY_BND, M_BNL)
        dM_BNL = torch.einsum("bnd,ld->bnl", dY_BND, B_0_LD)

        dB_1_LD = torch.einsum("bnd,bnl->ld", dY_BND, N_BNL)
        dN_BNL = torch.einsum("bnd,ld->bnl", dY_BND, B_1_LD)

        # compute gradients for ReLU
        dF_BNL = dM_BNL * (F_BNL > 0).to(dM_BNL.dtype)
        dG_BNL = dN_BNL * (G_BNL > 0).to(dN_BNL.dtype)

        # compute gradients for first linear pass per shard
        dA_0_DL = torch.einsum("bnl,bnd->dl", dF_BNL, X_BND)
        dX_0_BND = torch.einsum("bnl,dl->bnd", dF_BNL, A_0_DL)

        dA_1_DL = torch.einsum("bnl,bnd->dl", dG_BNL, X_BND)
        dX_1_BND = torch.einsum("bnl,dl->bnd", dG_BNL, A_1_DL)

        # all reduce dX_BND
        dX_BND = self.all_reduce([dX_0_BND, dX_1_BND])

        self.grads = {
            # gradients for shards on device 0
            "dB_0_LD": dB_0_LD,
            "dA_0_DL": dA_0_DL,

            # gradients for shards on device 1
            "dB_1_LD": dB_1_LD,
            "dA_1_DL": dA_1_DL,
        }

        return dX_BND
        

B, N, D = 32, 128, 768
ff = FeedForward(D=D)

# test forward pass + forward pairwise sharded are equal
X_BND = torch.randn(B, N, D) # (B,N,D)
Y_BND = ff(X_BND)
Y_BND_pairwise = ff.forward_pairwise_sharded(X_BND)
assert Y_BND.shape == (B, N, D)
assert Y_BND.shape == Y_BND_pairwise.shape
assert torch.allclose(Y_BND, Y_BND_pairwise, rtol=1e-2, atol=1e-4) # FP32 has slight rounding differences so we increase the tolerance here

# test backward pass and backward pairwise sharded are equal
dY_BND = torch.randn(B, N, D) # random values to just check compiling and shapes for now
Y_BND = ff(X_BND)
dX_BND = ff.backward(dY_BND)

Y_BND_pairwise = ff.forward_pairwise_sharded(X_BND)
dX_BND_pairwise = ff.backward_pairwise_sharded(dY_BND)

assert dX_BND.shape == (B, N, D)
assert dX_BND_pairwise.shape == (B, N, D)
assert torch.allclose(dX_BND, dX_BND_pairwise, rtol=1e-2, atol=1e-4)