In [6]:
# Verifying MPS device availability in PyTorch
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")


tensor([1.], device='mps:0')


In [5]:
import torch
import math


device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

    # Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d


print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

Using mps device
99 831.2708129882812
199 558.5472412109375
299 376.4941711425781
399 254.89291381835938
499 173.62200927734375
599 119.27133178710938
699 82.90028381347656
799 58.544334411621094
899 42.22264099121094
999 31.276939392089844
1099 23.9307861328125
1199 18.996496200561523
1299 15.67951774597168
1399 13.447775840759277
1499 11.944893836975098
1599 10.931893348693848
1699 10.248442649841309
1799 9.786897659301758
1899 9.474883079528809
1999 9.263748168945312
Result: y = -0.014117185957729816 + 0.8408061861991882 x + 0.0024354492779821157 x^2 + -0.09106381982564926 x^3
