<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/Megatron.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.distributed as dist

def init_distributed_backend(backend="nccl"):
    """Initialize distributed backend."""
    dist.init_process_group(backend=backend)
    local_rank = dist.get_rank()
    torch.cuda.set_device(local_rank)
    return local_rank

class ParallelLinear(nn.Module):
    """Parallel Linear Layer split across two GPUs."""
    def __init__(self, input_size, output_size):
        super(ParallelLinear, self).__init__()
        self.input_size = input_size
        self.output_size = output_size

        # Split weights across two GPUs
        self.weight_1 = nn.Parameter(torch.randn(input_size, output_size // 2).cuda(0))
        self.weight_2 = nn.Parameter(torch.randn(input_size, output_size // 2).cuda(1))

        # Shared bias on GPU 0
        self.bias = nn.Parameter(torch.zeros(output_size).cuda(0))

    def forward(self, x):
        # Send input to respective GPUs
        x1 = x.to(0)
        x2 = x.to(1)

        # Parallel linear projections
        out1 = torch.matmul(x1, self.weight_1)
        out2 = torch.matmul(x2, self.weight_2)

        # Gather the results back on GPU 0
        out = torch.cat([out1, out2], dim=-1).to(0)

        # Add bias
        return out + self.bias

class ParallelMultiHeadAttention(nn.Module):
    """Parallel Multi-Head Self-Attention with split heads."""
    def __init__(self, embed_size, num_heads):
        super(ParallelMultiHeadAttention, self).__init__()
        assert embed_size % num_heads == 0, "Embed size must be divisible by num_heads"

        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads
        self.num_heads_per_gpu = num_heads // 2

        # Linear projections for Q, K, V across two GPUs
        self.qkv_proj_1 = nn.Linear(embed_size, 3 * self.num_heads_per_gpu * self.head_dim).cuda(0)
        self.qkv_proj_2 = nn.Linear(embed_size, 3 * self.num_heads_per_gpu * self.head_dim).cuda(1)

        # Output projection on GPU 0
        self.out_proj = nn.Linear(embed_size, embed_size).cuda(0)

    def forward(self, x):
        batch_size, seq_length, embed_size = x.size()

        # Distribute input across GPUs
        x1 = x.to(0)
        x2 = x.to(1)

        # Compute Q, K, V for each GPU
        qkv_1 = self.qkv_proj_1(x1)
        qkv_2 = self.qkv_proj_2(x2)

        # Split Q, K, V on each GPU
        q1, k1, v1 = torch.chunk(qkv_1, 3, dim=-1)
        q2, k2, v2 = torch.chunk(qkv_2, 3, dim=-1)

        # Attention scores and probabilities
        attn_scores_1 = torch.matmul(q1, k1.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_scores_2 = torch.matmul(q2, k2.transpose(-2, -1)) / (self.head_dim ** 0.5)

        attn_probs_1 = torch.softmax(attn_scores_1, dim=-1)
        attn_probs_2 = torch.softmax(attn_scores_2, dim=-1)

        # Apply attention to values
        attn_output_1 = torch.matmul(attn_probs_1, v1)
        attn_output_2 = torch.matmul(attn_probs_2, v2)

        # Gather outputs on GPU 0
        attn_output = torch.cat([attn_output_1.to(0), attn_output_2.to(0)], dim=-1)

        # Output projection
        output = self.out_proj(attn_output)
        return output

class ParallelTransformerLayer(nn.Module):
    """Full Transformer Layer with Parallel Attention and MLP."""
    def __init__(self, embed_size, num_heads, hidden_size):
        super(ParallelTransformerLayer, self).__init__()
        self.attention = ParallelMultiHeadAttention(embed_size, num_heads)
        self.linear1 = ParallelLinear(embed_size, hidden_size)
        self.activation = nn.ReLU()
        self.linear2 = ParallelLinear(hidden_size, embed_size)

    def forward(self, x):
        # Self-attention + Add & Norm
        attn_output = self.attention(x)
        x = x + attn_output

        # Feedforward network + Add & Norm
        ffn_output = self.linear2(self.activation(self.linear1(x)))
        return x + ffn_output

In [4]:

# def main():
local_rank = init_distributed_backend()

# Sample input: (batch_size, seq_length, embed_size)
input_data = torch.randn(8, 16, 512).cuda(local_rank)

# Create parallel transformer layer
transformer_layer = ParallelTransformerLayer(embed_size=512, num_heads=8, hidden_size=2048)

# Forward pass
output = transformer_layer(input_data)

if local_rank == 0:
    print("Output shape:", output.shape)

# if __name__ == "__main__":
#     main()

ValueError: Error initializing torch.distributed using env:// rendezvous: environment variable RANK expected, but not set