<a href="https://colab.research.google.com/github/JackCaoG/torch-xla-examples/blob/main/inference_basic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
# This must be set before importing `torch_xla`, turn this on only for debugging.
# For more debugging tips, check https://youtu.be/LK3A3vjo-KQ?si=_7tH7pvFIDJH7NWM
os.environ['PT_XLA_DEBUG'] = "1"
os.environ['PT_XLA_DEBUG_FILE'] = "/tmp/pt_xla_debug.txt"

import torch
import torch_xla

import numpy as np
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met


print(torch.__version__)
print(torch_xla.__version__)

2.3.0+cpu
2.3.0+libtpu


In [2]:
device = torch_xla.device()


# dummy inference example, you can replace linear with your model
linear = nn.Linear(10, 20).to(device)
input = torch.randn(1, 10).to(device)

In [3]:
# clear the metrics before executation if you only care about metrics
# for the current execution.
met.clear_all()

def single_step_linear(input):
  # no grad is optional, but it is a good practice for inference
  with torch.no_grad():
    res = linear(input)
  # trigger the execution
  xm.mark_step()
  return res

res = single_step_linear(input)
print(res)

tensor([[-0.9603, -0.2537, -1.2938,  0.2158,  0.4935, -0.2579,  0.2074, -0.3205,
          0.8562, -0.5789,  0.9714, -0.4794, -0.3721,  0.4322,  0.1533, -0.3180,
         -0.1307, -0.4394, -0.5698, -0.8652]], device='xla:0')


In [4]:
# Things to check for an efficent inference (and also training)
# 1. There is one execution per step
#    1.1 no op fallback to CPU
#    1.2 no dynamic shape in both input shape and model code


# For 1, check the metrics `ExecuteTime` and `CompileTime`.
# You can also check the files `/tmp/pt_xla_debug.txt` we specified above and
# make sure there is only one `Execution` output for a single step.
print(met.short_metrics_report())

Metric: CompileTime
  TotalSamples: 1
  Accumulator: 324ms435.056us
  Percentiles: 1%=324ms435.056us; 5%=324ms435.056us; 10%=324ms435.056us; 20%=324ms435.056us; 50%=324ms435.056us; 80%=324ms435.056us; 90%=324ms435.056us; 95%=324ms435.056us; 99%=324ms435.056us
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 007ms849.159us
  Percentiles: 1%=007ms849.159us; 5%=007ms849.159us; 10%=007ms849.159us; 20%=007ms849.159us; 50%=007ms849.159us; 80%=007ms849.159us; 90%=007ms849.159us; 95%=007ms849.159us; 99%=007ms849.159us
Metric: TransferFromDeviceTime
  TotalSamples: 1
  Accumulator: 004ms551.870us
  Percentiles: 1%=004ms551.870us; 5%=004ms551.870us; 10%=004ms551.870us; 20%=004ms551.870us; 50%=004ms551.870us; 80%=004ms551.870us; 90%=004ms551.870us; 95%=004ms551.870us; 99%=004ms551.870us
Counter: MarkStep
  Value: 1



In [5]:
# For 1.1, check there is no `aten::` metrics
met.clear_all()
fallback_tensor = torch.nonzero(input)
xm.mark_step()

# wait for all async ops to finish
xm.wait_device_ops()
# If you see any `aten::` metrics, open a bug to PyTorch/XLA github and
# we will try to lower it.
print(met.short_metrics_report())

Metric: CompileTime
  TotalSamples: 1
  Accumulator: 021ms654.651us
  Percentiles: 1%=021ms654.651us; 5%=021ms654.651us; 10%=021ms654.651us; 20%=021ms654.651us; 50%=021ms654.651us; 80%=021ms654.651us; 90%=021ms654.651us; 95%=021ms654.651us; 99%=021ms654.651us
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 002ms673.154us
  Percentiles: 1%=002ms673.154us; 5%=002ms673.154us; 10%=002ms673.154us; 20%=002ms673.154us; 50%=002ms673.154us; 80%=002ms673.154us; 90%=002ms673.154us; 95%=002ms673.154us; 99%=002ms673.154us
Metric: TransferToDeviceTime
  TotalSamples: 1
  Accumulator: 058.402us
  Percentiles: 1%=058.402us; 5%=058.402us; 10%=058.402us; 20%=058.402us; 50%=058.402us; 80%=058.402us; 90%=058.402us; 95%=058.402us; 99%=058.402us
Metric: TransferFromDeviceTime
  TotalSamples: 1
  Accumulator: 569.034us
  Percentiles: 1%=569.034us; 5%=569.034us; 10%=569.034us; 20%=569.034us; 50%=569.034us; 80%=569.034us; 90%=569.034us; 95%=569.034us; 99%=569.034us
Counter: MarkStep
  Value: 1
Counter: at

In [5]:
# For 1.2, run your model x steps and make sure it does not recompile
met.clear_all()

for _ in range(10):
  input = torch.randn(1, 10).to(device)
  res = single_step_linear(input)

# wait for all async ops to finish
xm.wait_device_ops()
print(met.short_metrics_report())

Counter: CachedCompile
  Value: 9
Metric: CompileTime
  TotalSamples: 1
  Accumulator: 026ms647.611us
  Percentiles: 1%=026ms647.611us; 5%=026ms647.611us; 10%=026ms647.611us; 20%=026ms647.611us; 50%=026ms647.611us; 80%=026ms647.611us; 90%=026ms647.611us; 95%=026ms647.611us; 99%=026ms647.611us
Metric: ExecuteTime
  TotalSamples: 10
  Accumulator: 030ms336.955us
  ValueRate: 02s196ms079.568us / second
  Rate: 723.896 / second
  Percentiles: 1%=433.567us; 5%=433.567us; 10%=502.263us; 20%=752.633us; 50%=001ms462.685us; 80%=008ms936.260us; 90%=009ms745.577us; 95%=009ms745.577us; 99%=009ms745.577us
Metric: TransferToDeviceTime
  TotalSamples: 10
  Accumulator: 506.849us
  ValueRate: 011ms641.075us / second
  Rate: 209.946 / second
  Percentiles: 1%=034.891us; 5%=034.891us; 10%=037.006us; 20%=037.275us; 50%=048.435us; 80%=054.783us; 90%=104.311us; 95%=104.311us; 99%=104.311us
Counter: MarkStep
  Value: 10



In [None]:
# If your model is compiling for every execution and the input shape is fixed
# it is likely that your model has some dynamism in it. Try to follow
# https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#common-debugging-environment-variables-combinations
# to dump the IR or HLO and compare it across runs and see where the dynamism coming from.

In [6]:
# lastly use `torch.compile` to speed up your model inference
compiled_linear = torch.compile(linear, backend="openxla")

met.clear_all()

with torch.no_grad():
  # you don't need mark_step if model is compiled
  res = compiled_linear(input)

# wait for all async ops to finish
xm.wait_device_ops()
print(met.short_metrics_report())

Metric: CompileTime
  TotalSamples: 1
  Accumulator: 026ms185.527us
  Percentiles: 1%=026ms185.527us; 5%=026ms185.527us; 10%=026ms185.527us; 20%=026ms185.527us; 50%=026ms185.527us; 80%=026ms185.527us; 90%=026ms185.527us; 95%=026ms185.527us; 99%=026ms185.527us
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 006ms484.284us
  Percentiles: 1%=006ms484.284us; 5%=006ms484.284us; 10%=006ms484.284us; 20%=006ms484.284us; 50%=006ms484.284us; 80%=006ms484.284us; 90%=006ms484.284us; 95%=006ms484.284us; 99%=006ms484.284us
Counter: MarkStep
  Value: 1



In [7]:
# now try to inference a slightly more compilated model
import torchvision

model = torchvision.models.resnet18().to(device)
compiled_model = torch.compile(model, backend="openxla")

# [Batch, Channel, dim, dim]
resnet_input = torch.randn(64, 3, 224, 224).to(device)

In [11]:
met.clear_all()
for _ in range(10):
  res = compiled_model(resnet_input)

# wait for all async ops to finish
xm.wait_device_ops()
print(met.short_metrics_report())

Metric: ExecuteTime
  TotalSamples: 10
  Accumulator: 01s451ms058.179us
  ValueRate: 03s193ms215.897us / second
  Rate: 22.0061 / second
  Percentiles: 1%=053ms988.449us; 5%=053ms988.449us; 10%=101ms992.691us; 20%=149ms625.897us; 50%=150ms026.367us; 80%=197ms511.857us; 90%=201ms579.853us; 95%=201ms579.853us; 99%=201ms579.853us

