In [None]:
import importlib
import os

# "COMPOSITE": single device per GPU card.
# "FLAT" : single device per GPU stack; two devices per GPU card.
os.environ["ZE_FLAT_DEVICE_HIERARCHY"] = "FLAT"

import lightning_xpu
import lightning
import torch

In [None]:
# Define device types to be considered.
device_types = ["cpu", "cuda", "mps", "xpu", "fictional_device"]

# Print information about available device types.
print("Devices seen by torch:")
for device_type in sorted(device_types):
    # Determine number of devices of each type.
    try:
        device_module = importlib.import_module(f"torch.{device_type}")
    except ModuleNotFoundError:
        device_module = None
    n_device = getattr(device_module, "device_count", lambda: 0)()
    devices = [f"{device_type}:{idx}" for idx in range(n_device)]
    print(f"    {device_type}: {devices}")

In [None]:
from lightning.pytorch.accelerators import AcceleratorRegistry
print("Devices seen by lightning:")
for device_type in sorted(AcceleratorRegistry.available_accelerators()):
    try:
        device = AcceleratorRegistry.get(device_type)
    except ModuleNotFoundError:
        device = None
    devices = (device.get_parallel_devices(device.auto_device_count())
               if getattr(device, "is_available", lambda: False)() else [])
    print(f"    {device_type}: {devices}")   

In [None]:
import importlib
import random
import time
# Define device types to be considered.
device_types = ["cpu", "cuda", "mps", "xpu", "fictional_device"]
# Number of times to attempt matrix multiplication.
n_attempt = 3
# Print information about available device types.
for device_type in device_types:
    # Determine number of devices of each type.
    try:
        device_module = importlib.import_module(f"torch.{device_type}")
    except ModuleNotFoundError:
        device_module = None
    if hasattr(device_module, "is_available") and device_module.is_available():
        n_device = device_module.device_count()
    else:
        n_device = 0
    print(f"\nDevice type: {device_type}")
    print(f"Number of devices: {n_device}")
    # Test matrix-multiplication time for all devices of current type,
    # considering devices in random order.
    indices = list(range(n_device))
    random.shuffle(indices)
    i_dim = 0
    while n_device:
        dim = 2**i_dim
        i_dim += 1
        i_attempt = 0
        print()
        while i_attempt < n_attempt:
            i_attempt += 1
            for i_device in indices:
                device_name = f"{device_type}:{i_device}"
                if dim > 1024 and "cpu" == device_type:
                    n_device = 0
                    i_attempt = n_attempt + 1
                    break
                t0 = time.time()
                try:
                    x=torch.randn((dim, dim), device=torch.device(device_name))
                    y=torch.randn((dim, dim), device=torch.device(device_name))
                    z=torch.matmul(x,y)
                except RuntimeError:
                    n_device =0
                t1 = time.time()
                if n_device:
                    print(f"{device_name}: order = {dim}; "
                            f"attempt ={i_attempt : 3d}; "
                            f"time ={(t1 - t0) * 1.e6 : 8.1f} microseconds")
                else:
                    print(f"{device_type}: order = {dim}; out of memory")
                    i_attempt = n_attempt + 1
                    break