In [1]:
import torch

device = torch.device('mps' if torch.has_mps else 'cpu')
device

device(type='mps')

In [2]:
# this ensures that the current torch version for MacOS is at least 12.3+
print(torch.backends.mps.is_available())

True


In [3]:
# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())

True


In [9]:
# test case
import math

# set gpu for computation device
dtype = torch.float
device = torch.device("mps")

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

# Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d


print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

99 1462.271240234375
199 998.8912963867188
299 684.0104370117188
399 469.80023193359375
499 323.911376953125
599 224.44036865234375
699 156.5406494140625
799 110.13844299316406
899 78.3910140991211
999 56.64533996582031
1099 41.73309326171875
1199 31.49541473388672
1299 24.458988189697266
1399 19.617340087890625
1499 16.282190322875977
1599 13.982295036315918
1699 12.394553184509277
1799 11.29733657836914
1899 10.538280487060547
1999 10.012651443481445
Result: y = 0.03131430596113205 + 0.8393440246582031 x + -0.005402237642556429 x^2 + -0.09085585922002792 x^3
