From 338aac5e67efbf12d502457b5732c320c0be28e0 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Thu, 30 Jan 2025 09:38:53 +0100 Subject: [PATCH 1/2] Detect TPU using PCI devices --- src/Devices.jl | 38 ++++++++++++++++++++++++++++++++++++++ src/XLA.jl | 2 +- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/Devices.jl b/src/Devices.jl index 5eaa998eff..ea2443386c 100644 --- a/src/Devices.jl +++ b/src/Devices.jl @@ -23,3 +23,41 @@ function addressable_devices(client::XLA.Client=XLA.default_backend[]) ndevices = XLA.ClientNumAddressableDevices(client) return [XLA.ClientGetAddressableDevice(client, i - 1) for i in 1:ndevices] end + +const _GOOGLE_PCI_VENDOR_ID = "0x1ae0" +const _TPU_PCI_DEVICE_IDS = ( + # TPU v2, v3 + "0x0027", + # No public name (plc) + "0x0056", + # TPU v4 + "0x005e", + # TPU v5p + "0x0062", + # TPU v5e + "0x0063", + # TPU v6e + "0x006f", +) + +function has_tpu() + Sys.islinux() || return false + + devices_dir = "/sys/bus/pci/devices/" + isdir(devices_dir) || return false + + try + for path in readdir(devices_dir; join=true, sort=false) + if trim(read(joinpath(path, "vendor"), String)) == _GOOGLE_PCI_VENDOR_ID && + trim(read(joinpath(path, "device"), String)) in _TPU_PCI_DEVICE_IDS + return true + end + end + catch ex + @warn "failed to query PCI device information" maxlog = 1 exception = ( + ex, catch_backtrace() + ) + end + + return false +end diff --git a/src/XLA.jl b/src/XLA.jl index 224a1deeb8..0dba6b3c75 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -148,7 +148,7 @@ function __init__() end @static if !Sys.isapple() - if isfile("/usr/lib/libtpu.so") + if Reactant.has_tpu() dataset_dir = @get_scratch!("libtpu") if !isfile(dataset_dir * "/libtpu.so") Downloads.download( From dd39fdf124ba17e69abc446399ba0253d027a0eb Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Thu, 30 Jan 2025 11:47:23 +0100 Subject: [PATCH 2/2] add source --- src/Devices.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Devices.jl b/src/Devices.jl index ea2443386c..7438f7d0ac 100644 --- a/src/Devices.jl +++ b/src/Devices.jl @@ -24,6 +24,7 @@ function addressable_devices(client::XLA.Client=XLA.default_backend[]) return [XLA.ClientGetAddressableDevice(client, i - 1) for i in 1:ndevices] end +# https://github.com/jax-ml/jax/blob/152099ee0ef31119f16f4c2dac50d84fcb1575ef/jax/_src/hardware_utils.py#L19-L55 const _GOOGLE_PCI_VENDOR_ID = "0x1ae0" const _TPU_PCI_DEVICE_IDS = ( # TPU v2, v3