### .detach() 

- 当我们再训练网络的时候可能希望保持一部分的网络参数不变，只对其中一部分的参数进行调整
- 或者值训练部分分支网络，并不让其梯度对主网络的梯度造成影响，这时候我们就需要使用detach()函数来切断一些分支的反向传播


1. tensor.detach() creates a tensor that shares storage with tensor that does not require grad. tensor.clone()creates a copy of tensor that imitates the original tensor's requires_grad field.
You should use detach() when attempting to remove a tensor from a computation graph, and clone as a way to copy the tensor while still keeping the copy as a part of the computation graph it came from.

2. tensor.data returns a new tensor that shares storage with tensor. However, it always has requires_grad=False (even if the original tensor had requires_grad=True


http://www.bnikolic.co.uk/blog/pytorch-detach.html#:~:targetText=The%20detach()%20method%20constructs,visualised%20using%20the%20torchviz%20package.

https://www.cnblogs.com/wanghui-garcia/p/10677071.html

In [5]:
import torch

a = torch.tensor([1,2,3.], requires_grad=True)
print(a)
print(a.grad)
print('\n')

out = a.sigmoid()
print(out)
print('\n')

print(out.sum().backward())
print(a.grad)

tensor([1., 2., 3.], requires_grad=True)
None


tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)


None
tensor([0.1966, 0.1050, 0.0452])


In [34]:
import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)

#这时候没有对c进行更改，所以并不会影响backward()
out.sum().backward()
print(a.grad)

None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0.1966, 0.1050, 0.0452])


In [18]:
import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)

#使用新生成的Variable进行反向传播
#可见c,out之间的区别是c是没有梯度的，out是有梯度的
#如果这里使用的是c进行sum()操作并进行backward()，则会报错
c.sum().backward()
print(a.grad)

None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [43]:
import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)
c.zero_() #使用in place函数对其进行修改


#会发现c的修改同时会影响out的值
print(c)
print(out)


#这时候对c进行更改，所以会影响backward()，这时候就不能进行backward()，会报错
out.sum().backward()
print(a.grad)

None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
None


### .data

如果上面的操作使用的是.data，效果会不同：

这里的不同在于.data的修改不会被autograd追踪，这样当进行backward()时它不会报错，回得到一个错误的backward值

In [23]:
import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

c = out.data
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时也会影响out的值
print(c)
print(out) 

#这里的不同在于.data的修改不会被autograd追踪，这样当进行backward()时它不会报错，回得到一个错误的backward值
out.sum().backward()
print(a.grad)

None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor(2.5644, grad_fn=<SumBackward0>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
tensor([0., 0., 0.])


In [24]:
import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid().sum() #但是如果sum写在这里，而不是写在backward()前，得到的结果是正确的
print(out)


c = out.data
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时也会影响out的值
print(c)
print(out)

#没有写在这里
out.backward()
print(a.grad)

None
tensor(2.5644, grad_fn=<SumBackward0>)
tensor(2.5644)
tensor(0.)
tensor(0., grad_fn=<SumBackward0>)
tensor([0.1966, 0.1050, 0.0452])


### Summary

1. .detach() and .data change(change means in_place change, like .zero()) will change out values, but .data change will not influence out.sum().backward().

2. If out.sum() in front of these changes(change means in_place change, like .zero()), out.sum().backward can be executed even if .detach() and .data change.

In [71]:
import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
out = a.sigmoid()
print(out)

# c1 = out.detach()
# c1.zero_()
# # c1 = torch.tensor([1, 2, 3.])
# print(out)

c2 = out.data
c2.zero_()
# c2 = torch.tensor([1, 2, 3.])
print(out)

out.sum().backward()
print(a.grad)

tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
tensor([0., 0., 0.])


In [74]:
import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
out = a.sigmoid().sum()
print(out)

c1 = out.detach()
c1.zero_()
# c1 = torch.tensor([1, 2, 3.])
print(out)

# c2 = out.data
# c2.zero_()
# # c2 = torch.tensor([1, 2, 3.])
# print(out)

out.backward()
print(a.grad)

tensor(2.5644, grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
tensor([0.1966, 0.1050, 0.0452])
