# Pipeline Parallelism

## Simple Pipeline

No local SPMD or FSDP.

In [12]:
import os
os.environ["WORLD_SIZE"] = '8'
os.environ["RANK"] = '0'
os.environ["CPU_NUM_DEVICES"] = os.environ["WORLD_SIZE"]
os.environ["PJRT_DEVICE"] = 'CPU'

import torch_xla
torch_xla.runtime.global_runtime_device_attributes()

[{'name': 'CPU:0'},
 {'name': 'CPU:1'},
 {'name': 'CPU:2'},
 {'name': 'CPU:3'},
 {'name': 'CPU:4'},
 {'name': 'CPU:5'},
 {'name': 'CPU:6'},
 {'name': 'CPU:7'}]

In [50]:
from typing import Optional

import numpy as np
import torch
from torch import nn

class SimpleLinear(nn.Module):
  NUM_CLASSES = 3

  def __init__(self, input_dim):
      super().__init__()
      # Instead of Sequential, define layers separately for easier split points
      self.layer0 = nn.Linear(input_dim, input_dim // 2)
      self.relu = nn.ReLU()
      self.layer1 = nn.Linear(input_dim // 2, 3)
      self.layer2 = nn.Linear(3, self.NUM_CLASSES)

  def forward(self, x):
      x = self.layer0(x)
      x = self.relu(x)
      x = self.layer1(x)
      x = self.layer2(x)
      return x

## Simple Pipeline - No FSDP or TP

In [47]:
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])

class TrainingOptions():
    def __init__(self):
      self.batch_size = 128
      self.num_epochs = 1
      self.lr = 0.1
      self.log_steps = 8
      self.input_dim = 16834
      self.train_dataset_len = 1024 * 8
      self.pipeline_chunks = 2

opts = TrainingOptions()
device = 'cpu'

model = SimpleLinear(opts.input_dim).to(device)

# Define split points for pipeline parallelism
split_spec = {
  "layer0": SplitPoint.END,
  "layer1": SplitPoint.END,
}

# Create a sample input for the pipeline
chunks = opts.pipeline_chunks
batch_size = opts.batch_size
example_input = torch.randn(batch_size, opts.input_dim, device=device)

# Make sure that program is full-graph capturable:
# torch.export.export(model, (example_input,))

# Create the pipeline and respective stage for the rank.
pipe = pipeline(model, mb_args=(example_input,), split_spec=split_spec)
# stage = ...
#schedule = ScheduleGPipe(stage, chunks)

### Inspecting a Pipeline

Note that this pipeline contains all the individual stages of interest.
Given that this was formed from `torch.export` it seems plausible that we can
get each sub-module asa a separately traced StableHLO program to be stitched
together.

In [49]:
print("E2E Pipeline - Contains all submodules and a coordinating function.")
print(pipe)

print("Individual stages are fx.GraphModules")
for idx in range(pipe.num_stages):
  print(f"Stage {idx}:\n", pipe.get_stage_module(idx).print_readable())

E2E Pipeline - Contains all submodules and a coordinating function.
GraphModule(
  (submod_0): GraphModule(
    (layer0): InterpreterModule()
  )
  (submod_1): GraphModule(
    (relu): InterpreterModule()
    (layer1): InterpreterModule()
  )
  (submod_2): GraphModule(
    (layer2): InterpreterModule()
  )
)



def forward(self, x):
    submod_0 = self.submod_0(x);  x = None
    submod_1 = self.submod_1(submod_0);  submod_0 = None
    submod_2 = self.submod_2(submod_1);  submod_1 = None
    return (submod_2,)
    
# To see more debug info, please use `graph_module.print_readable()`
Individual stages are fx.GraphModules
class GraphModule(torch.nn.Module):
    def forward(self, x):
        x: "f32[128, 16834]"; 
    
        x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        # No stacktrace found for following nodes
        layer0: "f32[128, 8417]" = self.layer0(x);  x = None
        return layer0
        
    class layer0(torch.nn.Module):
        def forward(self, x: "f

# [WIP] Training Loop with Pipelining

Note that pipelining relies on `torch.export` and requires a fully functional
model.

In [None]:
import os
import sys
from typing import Optional

import numpy as np
import torch
from torch import nn
import torch.optim as optim

import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend
import torch_xla.distributed.parallel_loader as pl

import torch.distributed as dist
from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline, PipelineStage
from torch.distributed.tensor import init_device_mesh, Shard, distribute_tensor
import torch_xla.distributed.spmd as xs

def train(training_options):
  print(training_options)
  # Torchrun is needed for Pipeline Parallelism by default. Generally, we
  # don't need it for SPMD, and we could rely on `process_index` and
  # `addressable_runtime_device_count` from PjRT runtime. However, it would
  # be needed if we have multiple SPMD worlds within the same physical machine.
  # Hence, we retain the requirement, and can relax it later on.
  rank = xr.process_index() # int(os.environ["RANK"]) 
  print("Addressable:", xr.addressable_runtime_device_count())
  chunks = training_options.pipeline_chunks
 
  # Use XLA device
  device = xm.xla_device()

  print(f"Rank {rank} using device {device}")
  num_devices = xr.global_runtime_device_count()

  # -----
  # (Preferred) Leverage the DTensor/DeviceMesh variants for a more seamless
  # user interface with submeshes.
  # -----
  # global_mesh = init_device_mesh("xla", (chunks, num_devices, 1),
  #                                mesh_dim_names=("pp", "data", "model"))
  # local_mesh = global_mesh["data", "model"]

  # -----
  # Alternatively:
  # -----
  num_local_devices = xr.addressable_runtime_device_count() // chunks
 
  # Global submesh
  global_mesh_shape = (chunks, num_local_devices, 1)
  global_mesh = xs.Mesh(np.arange(num_devices), global_mesh_shape, ("pp", "data", "model"))
 
  # Local submesh
  device_id_start = rank * num_local_devices
  local_device_ids = np.arange(device_id_start, device_id_start + num_local_devices)
  local_mesh_shape = global_mesh_shape[1:]
  local_mesh = xs.Mesh(local_device_ids, local_mesh_shape, ("data", "model"))
  # -----

  # Initialize process group
  dist.init_process_group(
      backend="xla",
      init_method="xla://",
      rank=rank,
      world_size=world_size
  )

  torch.manual_seed(42)
  model = SimpleLinear(training_options.input_dim).to(device)

  # Shard the model weights as needed:
  # parallelize_model(model, local_mesh)

  # Define split points for pipeline parallelism
  split_spec = {
    "layer0": SplitPoint.END,
    "layer1": SplitPoint.END,
  }

  # Create a sample input for the pipeline
  batch_size = training_options.batch_size
  example_input = torch.randn(batch_size, training_options.input_dim, device=device)

  # Create the pipeline and respective stage for the rank.
  pipe = pipeline(model, chunks,  mb_args=(example_input,), split_spec=split_spec)
  stage = PipelineStage(pipe, rank, device)
  schedule = ScheduleGPipe(stage, chunks)

  # Training loop
  losses = []
  loss_fn = nn.CrossEntropyLoss()
  optimizer = optim.SGD(model.parameters(), lr=training_options.lr)

  for epoch in range(training_options.num_epochs):
    for step, (data, target) in enumerate(data_generator()):
      if rank == 0:
        xs.mark_sharding(data, local_mesh, ('data', 'model'))
        # or distribute_tensor(data, local_mesh, [Shard(0), Shard(1)])
        schedule.step(data)
        optimizer.zero_grad()
      else:
        output = schedule.step()
        # Only the last rank computes loss and does backward
        if rank == world_size - 1:
          loss = loss_fn(output, target)
          losses.append(loss.clone().detach())
          loss.backward()
          optimizer.step()
          if step % training_options.log_steps == 0:
            print(f"Epoch {epoch} step {step} loss {loss}")
      xm.mark_step()

In [11]:
xr.use_spmd()

class TrainingOptions():
    def __init__(self):
      self.batch_size = 128
      self.num_epochs = 1
      self.lr = 0.1
      self.log_steps = 8
      self.input_dim = 16834
      self.train_dataset_len = 1024 * 8
      self.pipeline_chunks = 2

print('Start training loop...')
train(TrainingOptions())
dist.destroy_process_group()


Start training loop...
<__main__.TrainingOptions object at 0x7fcaec0f8a50>
Addressable: 8
Rank 0 using device xla:0




AssertionError: Number of device IDs (4) must match the global number of devices (8)