In [1]:
import torch

In [2]:
import numpy as np

In [3]:
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

In [4]:
print(device)

cpu


In [40]:
x = torch.tensor(1, requires_grad=True, dtype=torch.float64)
print(x)

y = 2*x + 1

y.backward()
print(x.grad)

tensor(1., dtype=torch.float64, requires_grad=True)
tensor(2., dtype=torch.float64)


In [43]:
x = torch.ones(4, requires_grad=True)

In [45]:
# wrong
for epoch in range(3):
  y = (2*x).mean()

  y.backward()
  print(x.grad)

tensor([0.5000, 0.5000, 0.5000, 0.5000])
tensor([1., 1., 1., 1.])
tensor([1.5000, 1.5000, 1.5000, 1.5000])


In [46]:
# right
x.grad.zero_()

for epoch in range(3):
  y = (2*x).mean()

  y.backward()
  print(x.grad)

  x.grad.zero_()

tensor([0.5000, 0.5000, 0.5000, 0.5000])
tensor([0.5000, 0.5000, 0.5000, 0.5000])
tensor([0.5000, 0.5000, 0.5000, 0.5000])


In [27]:
x.requires_grad_(False)

x_no_grad = x.detach()

with torch.no_grad():
  pass

In [67]:
x = torch.tensor(1.0)
y = torch.tensor(2.0)

w = torch.tensor(1.0, requires_grad=True)

for epoch in range(100):

  # forward pass
  y_hat = w * x
  loss = (y_hat - y)**2

  # backward pass
  loss.backward()

  # SGD
  with torch.no_grad():
    w -= 0.1 * w.grad

  w.grad.zero_()

  print(f'epoch:{epoch}, loss:{loss}, y_hat:{y_hat}, dl/dw:{w.grad}')

print("weight:",w.item())

epoch:0, loss:1.0, y_hat:1.0, dl/dw:0.0
epoch:1, loss:0.6399999260902405, y_hat:1.2000000476837158, dl/dw:0.0
epoch:2, loss:0.40959998965263367, y_hat:1.3600000143051147, dl/dw:0.0
epoch:3, loss:0.26214396953582764, y_hat:1.4880000352859497, dl/dw:0.0
epoch:4, loss:0.16777217388153076, y_hat:1.590399980545044, dl/dw:0.0
epoch:5, loss:0.1073741763830185, y_hat:1.672320008277893, dl/dw:0.0
epoch:6, loss:0.06871946156024933, y_hat:1.7378560304641724, dl/dw:0.0
epoch:7, loss:0.0439804345369339, y_hat:1.7902848720550537, dl/dw:0.0
epoch:8, loss:0.028147462755441666, y_hat:1.8322279453277588, dl/dw:0.0
epoch:9, loss:0.018014369532465935, y_hat:1.865782380104065, dl/dw:0.0
epoch:10, loss:0.01152919139713049, y_hat:1.8926259279251099, dl/dw:0.0
epoch:11, loss:0.007378678303211927, y_hat:1.9141007661819458, dl/dw:0.0
epoch:12, loss:0.004722354002296925, y_hat:1.9312806129455566, dl/dw:0.0
epoch:13, loss:0.0030223065987229347, y_hat:1.9450244903564453, dl/dw:0.0
epoch:14, loss:0.0019342721207067