<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/Gpipe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.pipeline.sync import Pipe
from torch.distributed import rpc

class TransformerBlock(nn.Module):
    """A basic transformer block."""
    def __init__(self, embed_size, num_heads, hidden_size):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_size, num_heads)
        self.linear1 = nn.Linear(embed_size, hidden_size)
        self.activation = nn.ReLU()
        self.linear2 = nn.Linear(hidden_size, embed_size)

    def forward(self, x):
        # Self-attention
        attn_output, _ = self.attention(x, x, x)
        x = x + attn_output  # Residual connection

        # Feedforward network
        ffn_output = self.activation(self.linear1(x))
        x = self.linear2(ffn_output) + x  # Residual connection
        return x

def build_pipeline_model():
    """Create a pipeline model with two stages across two GPUs."""
    embed_size = 512
    num_heads = 8
    hidden_size = 2048

    # Define two stages of the model, each on a different GPU
    layer1 = TransformerBlock(embed_size, num_heads, hidden_size).cuda(0)
    layer2 = TransformerBlock(embed_size, num_heads, hidden_size).cuda(1)

    # Combine the two layers into a pipeline model
    model = nn.Sequential(layer1, layer2)

    # Use Pipe to wrap the model for pipeline parallelism
    chunks = 8  # Number of microbatches
    pipeline_model = Pipe(model, chunks=chunks)

    return pipeline_model

def main():
    """Main function to initialize distributed processing and run training."""
    # Initialize the RPC framework
    dist.init_process_group(backend='nccl')
    rpc.init_rpc(name="worker", rank=0, world_size=1)

    # Create the pipeline model
    model = build_pipeline_model()

    # Sample input: (sequence_length, batch_size, embed_size)
    input_data = torch.randn(16, 8, 512).cuda(0)  # Start input on GPU 0

    # Run forward pass
    output = model(input_data)
    print("Output shape:", output.shape)

    # Shutdown RPC
    rpc.shutdown()

if __name__ == "__main__":
    main()
