In [1]:
import torch

print(torch.__version__)

2.1.0


In [2]:
# Is MPS even available? macOS 12.3+
print(torch.backends.mps.is_available())

# Was the current version of PyTorch built with MPS activated?
print(torch.backends.mps.is_built())

True
True


In [3]:
import math

dtype = torch.float
device = torch.device("mps")

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

    # Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d
 

print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

99 9651.31640625
199 6384.984375
299 4225.1044921875
399 2796.868896484375
499 1852.439453125
599 1227.92822265625
699 814.9655151367188
799 541.8905029296875
899 361.3175354003906
999 241.91204833984375
1099 162.95391845703125
1199 110.74215698242188
1299 76.21650695800781
1399 53.385963439941406
1499 38.28899383544922
1599 28.305944442749023
1699 21.704483032226562
1799 17.339126586914062
1899 14.452454566955566
1999 12.54362678527832
Result: y = -0.001412279438227415 + 0.7974143028259277 x + 0.00024364236742258072 x^2 + -0.08489169925451279 x^3
