# Tensor Parallelism
References:
- Demystifying Tensor Parallelism (https://robotchinwag.com/posts/demystifying-tensor-parallelism/#pairwise-sharding)

Purpose: expand beyond 1 GPU model training and support training larger models

Approach: leverage 3D parallelism, specifically tensor parallelism here, to split layer weights across devices

Result: enables increased memory capacity and higher throughput, but there's a large communication cost for all-gather and all-reduce operations

Notes:
- Stopped here: https://robotchinwag.com/posts/demystifying-tensor-parallelism/#example-feed-forward-transformer-layer

In [None]:
# Attempt to implement column-wise sharding on this forward pass

import torch
import torch.nn as nn
from typing import List

class FeedForward(nn.Module):

    def __init__(self, D, expansion=4):
        super().__init__()
        hidden = expansion * D
        self.fc1_W_DH = torch.randn(D, hidden) 
        self.act = nn.GELU()
        self.fc2_W_HD = torch.randn(hidden, D)

    def all_reduce(self, X: List[torch.Tensor]):
        pass # assume we have some all reduce collective operation

    def all_gather(self, X: List[torch.Tensor]):
        pass # assume we have some all gather collective operation

    def forward(self, X_BND):
        Z1_BNH = X_BND @ self.fc1_W_DH
        H_BNH = self.act(Z1_BNH)
        Z2_BND = H_BNH @ self.fc2_W_HD
        return Z2_BND

    def forward_sharded(self, X_BND):
        # first operation we shard the first matmul row-wise, requiring a all_reduce after
        K = D // 2 # just to track shapes
        X_BNK_0, X_BNK_1 = X_BND[:, :, X_BND.shape[2] // 2:], X_BND[:, :, :X_BND.shape[2] // 2]
        fc1_W_KH_0, fc1_W_KH_1 = self.fc1_W_DH[:, :self.fc1_W_DH.shape[1] // 2], self.fc1_W_DH[:, self.fc1_W_DH.shape[1] // 2:]

        Z1_BNH_0 = X_BNK_0 @ fc1_W_KH_0
        Z1_BNH_1 = X_BNK_1 @ fc1_W_KH_1

        Z1_BNH = self.all_reduce([Z1_BNH_0, Z1_BNH_1])
        H_BNH = self.act(Z1_BNH) # requires the all reduce / element-wise operation requires full values

        # TODO: practice sharding column-wise
        Z2_BND = H_BNH @ self.fc2_W_HD
        return Z2_BND

B, N, D = 32, 128, 768
ff = FeedForward(D=D)
X_BND = torch.randn(B, N, D) # (B,N,D)
Y_BND = ff(X_BND)
assert Y_BND.shape == (B, N, D)