<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/Gpipe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
# !pip install fairscale

In [None]:
import torch
import torch.nn as nn
from fairscale.nn import Pipe
import torch.multiprocessing as mp
import torch.distributed as dist

class TransformerBlock(nn.Module):
    """A basic transformer block."""
    def __init__(self, embed_size, num_heads, hidden_size):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_size, num_heads)
        self.linear1 = nn.Linear(embed_size, hidden_size)
        self.activation = nn.ReLU()
        self.linear2 = nn.Linear(hidden_size, embed_size)

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        x = x + attn_output  # Residual connection
        ffn_output = self.activation(self.linear1(x))
        x = self.linear2(ffn_output) + x  # Residual connection
        return x

def build_pipeline_model():
    """Create a pipeline model with two stages."""
    embed_size = 512
    num_heads = 8
    hidden_size = 2048

    # Define two sequential transformer blocks
    layer1 = TransformerBlock(embed_size, num_heads, hidden_size)
    layer2 = TransformerBlock(embed_size, num_heads, hidden_size)

    # Combine into a sequential model
    model = nn.Sequential(layer1, layer2)

    # Use Pipe for pipeline parallelism with 2 chunks
    chunks = 2
    pipeline_model = Pipe(model, chunks=chunks, style=Pipe.MultiProcess, devices=[0, 1])

    return pipeline_model

def run_training(rank, world_size):
    """Initialize distributed environment and run training."""
    dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)

    # Build and distribute the model
    model = build_pipeline_model()

    # Sample input: (seq_length, batch_size, embed_size)
    input_data = torch.randn(16, 8, 512).cuda(rank)

    # Run a forward pass
    output = model(input_data)

    if rank == 0:
        print("Output shape:", output.shape)

def main():
    world_size = 2  # Simulate two GPUs
    mp.spawn(run_training, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


W1014 01:36:49.230000 136992724959232 torch/multiprocessing/spawn.py:146] Terminating process 1235 via signal SIGTERM
