Skip to content

Latest commit

 

History

History
70 lines (56 loc) · 2.33 KB

README.md

File metadata and controls

70 lines (56 loc) · 2.33 KB

Torch-MIGraphX

Torch-MIGraphX integrates AMD's graph inference engine with the PyTorch ecosystem. It provides a mgx_module object that may be invoked in the same manner as any other torch module, but utilizes the MIGraphX inference engine internally.

This library currently supports two paths for lowering:

  1. FX Tracing: Uses tracing API provided by the torch.fx library.
  2. Dynamo Backend: Importing torch_migraphx automatically registers the "migraphx" backend that can be used with the torch.compile API.

Getting Started

Docker

The simplest and recommended way to get started is using the provided Dockerfile. Build using:

./build_image.sh

Start container using:

sudo docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined torch_migraphx

The default Dockerfile builds on the nightly pytorch container and installs the latest source version of MIGraphX and torch_migraphx. For more builds refer to the docker directory.

Install From Source

Install Pre-reqs:

Build and install from source

git clone https://github.com/ROCmSoftwarePlatform/torch_migraphx.git
cd ./torch_migraphx/py
export TORCH_CMAKE_PATH=$(python -c "import torch; print(torch.utils.cmake_prefix_path)")
pip install .

Example Usage

# FX Tracing
torch_migraphx.fx.lower_to_mgx(torch_model, sample_inputs)

# Dynamo Backend
torch.compile(torch_model, backend="migraphx")

Lower resnet50 using FX Tracing

import torch
import torchvision
import torch_migraphx

resnet = torchvision.models.resnet50()
sample_input = torch.randn(2, 3, 64, 64)
resnet_mgx = torch_migraphx.fx.lower_to_mgx(resnet, [sample_input])
result = resnet_mgx(sample_input)

Lower densenet using torch.compile

import torch
import torchvision
import torch_migraphx

densenet = torchvision.models.densenet161().cuda()
sample_input = torch.randn(2, 3, 512, 512).cuda()
densenet_mgx = torch.compile(densenet, backend="migraphx")
result = densenet_mgx(sample_input.cuda())

For more examples please refer to the examples directory.