In [25]:
import mlx.core as mx

In [26]:

def set_mlx_device(use_gpu=False):
    """
    Determine the device to use for computations based on user preference.

    Parameters:
        use_gpu (bool): Whether to attempt using a GPU (CUDA, MPS, ROCm).

    Returns:
        torch.device: The selected computation device.
    """
    if use_gpu:
        mx.set_default_device(mx.gpu)
    else:
        mx.set_default_device(mx.cpu)

    print(f"Using MLX device: {mx.default_device()}")

In [27]:
# Select the computation device
set_mlx_device(use_gpu=True)

Using MLX device: Device(gpu, 0)


In [30]:
A = mx.random.normal()

In [31]:
A

array(0.253527, dtype=float32)

In [9]:

use_gpu = True
device = get_device(use_gpu=use_gpu)
print(f"Using device: {device}")

# Create a random matrix on the selected device
try:
    A = torch.randn(1000, 500, device=device)  # Generate random tensor
    print("Matrix generated successfully.")

    # Perform QR decomposition
    Q, R = torch.linalg.qr(A)
    print("QR decomposition successful.")
except RuntimeError as e:
    print(f"An error occurred during computation: {e}")

    # Fallback to CPU if GPU computation fails
    if device != torch.device("cpu"):
        print("Retrying on CPU...")
        A = A.to("cpu")  # Move data to CPU
        Q, R = torch.linalg.qr(A)
        print("QR decomposition successful on CPU.")

Metal Performance Shaders (MPS) available. Using Metal backend.
Using device: mps
Matrix generated successfully.
Retrying on CPU...
QR decomposition successful on CPU.


In [None]:

def main(use_gpu=False):
    # Select the computation device
    device = get_device(use_gpu=use_gpu)
    print(f"Using device: {device}")

    # Create a random matrix on the selected device
    try:
        A = torch.randn(1000, 500, device=device)  # Generate random tensor
        print("Matrix generated successfully.")

        # Perform QR decomposition
        Q, R = torch.qr(A)
        print("QR decomposition successful.")
    except RuntimeError as e:
        print(f"An error occurred during computation: {e}")

        # Fallback to CPU if GPU computation fails
        if device != torch.device("cpu"):
            print("Retrying on CPU...")
            A = A.to("cpu")  # Move data to CPU
            Q, R = torch.qr(A)
            print("QR decomposition successful on CPU.")

if __name__ == "__main__":
    # Pass `use_gpu=True` to prioritize GPUs, or keep it False for CPU.
    main(use_gpu=True)