In [26]:
import torch

x = torch.ones(5) # input tensor
y = torch.zeros(3) # expected output
w = torch.randn(5, 3, requires_grad=True)  
b = torch.randn(3, requires_grad=True)  
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

## ✔️ 연산그래프

위에서 구성한 신경망에서 w와 b를 업데이트하면서 loss를 줄여나간다

In [27]:
print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")

Gradient function for z = <AddBackward0 object at 0x7fd56a0a38e0>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x7fd56a0a3a30>


#### 변화도(Gradient) 계산하기

In [28]:
loss.backward()
print(w.grad)
print(b.grad)

tensor([[0.0674, 0.1067, 0.3081],
        [0.0674, 0.1067, 0.3081],
        [0.0674, 0.1067, 0.3081],
        [0.0674, 0.1067, 0.3081],
        [0.0674, 0.1067, 0.3081]])
tensor([0.0674, 0.1067, 0.3081])


In [29]:
z = torch.matmul(x,w)+b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x,w)+b
print(z.requires_grad)

True
False


## ✔️ 변화도 추적 멈추기

추적을 멈췄을 때 연산 속도가 상승한다.

In [30]:
z = torch.matmul(x,w)+b
z_det = z.detach()
print(z_det.requires_grad)

False


## ✔️ 선택적으로 읽기: 텐서 변화도와 야코비안 곱


In [31]:
inp = torch.eye(5, requires_grad=True)
out = (inp+1).pow(2)
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"First call\n{inp.grad}")
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")

First call
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.],
        [2., 2., 2., 2., 4.]])

Second call
tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.],
        [4., 4., 4., 4., 8.]])

Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.],
        [2., 2., 2., 2., 4.]])
