In [2]:
import torch
import math
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())
# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())


True
True


In [4]:
dtype = torch.float
device = torch.device("mps")

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(3000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

# Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d


print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')


99 1130.861328125
199 750.8001098632812
299 499.476806640625
399 333.2823181152344
499 223.38153076171875
599 150.70616149902344
699 102.64713287353516
799 70.86636352539062
899 49.85017776489258
999 35.952354431152344
1099 26.761768341064453
1199 20.68408966064453
1299 16.664897918701172
1399 14.006948471069336
1499 12.249263763427734
1599 11.086902618408203
1699 10.31822395324707
1799 9.809871673583984
1899 9.47369384765625
1999 9.251360893249512
2099 9.10432243347168
2199 9.007089614868164
2299 8.94277572631836
2399 8.90024185180664
2499 8.872112274169922
2599 8.853507995605469
2699 8.841203689575195
2799 8.833064079284668
2899 8.827682495117188
2999 8.824122428894043
Result: y = 0.00018627170356921852 + 0.8541829586029053 x + -3.213549280189909e-05 x^2 + -0.09296654909849167 x^3
