<a href="https://colab.research.google.com/github/JackCaoG/torch-xla-examples/blob/main/spmd_data_parallel/data_parallel_spmd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install torch~=2.4.0 torch_xla[tpu]~=2.4.0 -f https://storage.googleapis.com/libtpu-releases/index.html

Looking in links: https://storage.googleapis.com/libtpu-releases/index.html


In [3]:
from torch_xla import runtime as xr
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.spmd as xs

import numpy as np

import time
import itertools

import torch
import torch_xla
import torchvision
import torch.optim as optim
import torch.nn as nn

  warn(


In [4]:
print(torch.__version__)
print(torch_xla.__version__)

2.4.0+cu121
2.4.0


In [5]:
xr.use_spmd()
torch_xla.experimental.eager_mode(True)

In [6]:
# copied from https://github.com/pytorch/xla/blob/master/examples/train_resnet_base.py
class TrainResNetBase():

  def __init__(self):
    self.img_dim = 224
    self.batch_size = 128
    self.num_steps = 200
    self.num_epochs = 1
    self.train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
    # For the purpose of this example, we are going to use fake data.
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(self.batch_size, 3, self.img_dim, self.img_dim),
              torch.zeros(self.batch_size, dtype=torch.int64)),
        sample_count=self.train_dataset_len // self.batch_size //
        xr.world_size())

    self.device = torch_xla.device()
    # wrap the device loader with MpDeviceLoader
    self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
    self.model = torchvision.models.resnet50().to(self.device)
    self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
    self.loss_fn = nn.CrossEntropyLoss()
    self.compiled_step_fn = torch_xla.experimental.compile(self.step_fn)

  def _train_update(self, step, loss, tracker, epoch):
    print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}')

  def run_optimizer(self):
    self.optimizer.step()

  def step_fn(self, data, target):
    self.optimizer.zero_grad()
    output = self.model(data)
    loss = self.loss_fn(output, target)
    loss.backward()
    self.run_optimizer()
    return loss

  def train_loop_fn(self, loader, epoch):
    tracker = xm.RateTracker()
    self.model.train()
    loader = itertools.islice(loader, self.num_steps)
    for step, (data, target) in enumerate(loader):
      loss = self.compiled_step_fn(data, target)
      tracker.add(self.batch_size)
      if step % 10 == 0:
        xm.add_step_closure(
            self._train_update, args=(step, loss, tracker, epoch))

  def start_training(self):

    for epoch in range(1, self.num_epochs + 1):
      xm.master_print('Epoch {} train begin {}'.format(
          epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
      self.train_loop_fn(self.train_device_loader, epoch)
      xm.master_print('Epoch {} train end {}'.format(
          epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
    xm.wait_device_ops()

In [7]:
# copied from https://github.com/pytorch/xla/blob/master/examples/data_parallel/train_resnet_spmd_data_parallel.py
class TrainResNetXLASpmdDDP(TrainResNetBase):

  def __init__(self):
    super().__init__()
    # Shard along batch dimension only
    num_devices = xr.global_runtime_device_count()
    device_ids = np.arange(num_devices)
    mesh_shape = (num_devices,)
    mesh = xs.Mesh(device_ids, mesh_shape, ('data',))
    # scale the batch size with num_devices since there will be only one
    # process that handles all runtime devices.
    self.batch_size *= num_devices

    train_loader = xu.SampleGenerator(
        data=(torch.zeros(self.batch_size, 3, self.img_dim, self.img_dim),
              torch.zeros(self.batch_size, dtype=torch.int64)),
        sample_count=self.train_dataset_len // self.batch_size)
    self.train_device_loader = pl.MpDeviceLoader(
        train_loader,
        self.device,
        # Shard the input's batch dimension along the `data` axis, no sharding along other dimensions
        input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))

In [7]:
# if you want to profile, uncomment this code block, check my video at https://youtu.be/40jYVhQHGEA

'''
import torch_xla.debug.profiler as xp

profile_port = 9012
profile_logdir = "/tmp/profile/"
duration_ms = 30000
server = xp.start_server(profile_port)
# Ideally you want to start the profile tracing after the initial compilation, for example
# at step 5.
xp.trace_detached(f'localhost:{profile_port}', profile_logdir, duration_ms=duration_ms)
'''

'\nimport torch_xla.debug.profiler as xp\n\nprofile_port = 9012\nprofile_logdir = "/tmp/profile/"\nduration_ms = 30000\nserver = xp.start_server(profile_port)\n# Ideally you want to start the profile tracing after the initial compilation, for example\n# at step 5.\nxp.trace_detached(f\'localhost:{profile_port}\', profile_logdir, duration_ms=duration_ms)\n'

In [8]:
spmd_ddp = TrainResNetXLASpmdDDP()
spmd_ddp.start_training()

Epoch 1 train begin  1:46AM UTC on Aug 22, 2024
epoch: 1, step: 0, loss: 6.922364234924316, rate: 65.64515751167285
epoch: 1, step: 10, loss: 6.912381649017334, rate: 367.0752644456346
epoch: 1, step: 20, loss: 6.902382850646973, rate: 1988.1437959961952
epoch: 1, step: 30, loss: 6.892395496368408, rate: 2664.7267835674106
epoch: 1, step: 40, loss: 6.88240909576416, rate: 2903.893582644753
epoch: 1, step: 50, loss: 6.872413158416748, rate: 3031.3210411878727
epoch: 1, step: 60, loss: 6.862427234649658, rate: 3044.8946917197063
epoch: 1, step: 70, loss: 6.852447986602783, rate: 3088.5627040900636
epoch: 1, step: 80, loss: 6.842465877532959, rate: 3071.746773442035
epoch: 1, step: 90, loss: 6.832467079162598, rate: 3089.9122107526928
epoch: 1, step: 100, loss: 6.822479248046875, rate: 3086.942309530932
epoch: 1, step: 110, loss: 6.812491416931152, rate: 3103.156164329642
epoch: 1, step: 120, loss: 6.802496910095215, rate: 3093.308634856223
epoch: 1, step: 130, loss: 6.792511463165283, ra