In [1]:
%matplotlib inline

In [13]:
import torch

x = torch.ones(5)
y = torch.zeros(3)
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

In [4]:
print("Gradient function for z =", z.grad_fn)
print("Gradient function for loss =", loss.grad_fn)

Gradient function for z = <AddBackward0 object at 0x7f254be66bd0>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x7f254be66b90>


In [5]:
loss.backward() #偏微分および微分
print(w.grad)
print(b.grad)

tensor([[0.0592, 0.0555, 0.3308],
        [0.0592, 0.0555, 0.3308],
        [0.0592, 0.0555, 0.3308],
        [0.0592, 0.0555, 0.3308],
        [0.0592, 0.0555, 0.3308]])
tensor([0.0592, 0.0555, 0.3308])


In [18]:
#訓練済みモデルで推論するケース(勾配計算が不要なケース)すなわち、ネットワークの順伝搬関数のみを使用する場合
z = torch.matmul(x, w)+b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w)+b
print(z.requires_grad)

True
False


In [17]:
#もしくは
z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)

False


In [20]:
#ヤコビ行列　入力が同じ複数の多変数関数に対して、それぞれの変数で偏微分し、それを行列にまとめたもの
inp = torch.eye(5, requires_grad=True)
out = (inp+1).pow(2)

out.backward(torch.ones_like(inp), retain_graph=True)
print("First call\n", inp.grad)
out.backward(torch.ones_like(inp), retain_graph=True)
print("\nSecond call\n", inp.grad)

#PyTorchではbackwardを実行すると、勾配を蓄積（accumulate）する仕様
#そのため適切に勾配を計算するには、gradを事前に0にリセットする必要がある

inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print("\nCall after zeroing gradient\n", inp.grad)

First call
 tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.],
        [2., 2., 2., 2., 4.]])

Second call
 tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.],
        [4., 4., 4., 4., 8.]])

Call after zeroing gradient
 tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.],
        [2., 2., 2., 2., 4.]])
