In [12]:
import torch
import math
import time

In [3]:
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())
# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())

True
True


In [11]:
dtype = torch.float
device = torch.device("mps")

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

start = time.time()
learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

# Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d

end = time.time()
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
print(f'Elapsed time: {end-start} seconds')


99 2658.3876953125
199 1882.1866455078125
299 1333.450927734375
399 945.4953002929688
499 671.1956787109375
599 477.2445983886719
699 340.09814453125
799 243.11497497558594
899 174.53004455566406
999 126.02570343017578
1099 91.72138977050781
1199 67.45895385742188
1299 50.29827880859375
1399 38.16021728515625
1499 29.57447052001953
1599 23.501258850097656
1699 19.205175399780273
1799 16.166133880615234
1899 14.016264915466309
1999 12.495391845703125
Result: y = -0.06407849490642548 + 0.8603017926216125 x + 0.01105460710823536 x^2 + -0.09383691847324371 x^3
Elapsed time: 2.494673013687134 seconds
