In [1]:
# Tutorial of getting information from backwards

## define some functions
some backward has attributes `_saved*`, which save the most important information of calculation.

In [2]:
import torch
import torch.nn as nn
def print_grad_fn(out):
    grad = out.grad_fn
    print(grad)
    print([attr for attr in dir(grad) if attr[:6]=="_saved"])
    print([sub_g for sub_g in grad.next_functions])


## Convolutions

In [3]:
# convlution 1d
print("="*10, "convolution 1d start", "="*10)
conv = nn.Conv1d(3, 4, 3, 1, 1)
out = conv(torch.randn(1, 3, 4))
print_grad_fn(out)
print("="*10, "convolution 1d end", "="*10)
# convolution 2d
print("="*10, "convolution 2d start", "="*10)
conv = nn.Conv2d(3, 4, 3, 1, 1)
out = conv(torch.randn(1, 3, 4, 4))
print_grad_fn(out)
print("="*10, "convolution 2d end", "="*10)
# convolution 3d
print("="*10, "convolution 3d start", "="*10)
conv = nn.Conv3d(3, 4, 3, 1, 1)
out = conv(torch.randn(1, 3, 4, 4, 4))
print_grad_fn(out)
print("="*10, "convolution 3d end", "="*10)
# convolution with groups
conv = nn.Conv2d(3, 3, 3, 1, 1, groups=3)
out = conv(torch.randn(1, 3, 4, 4))
print_grad_fn(out)
print("="*10, "convolution with groups end", "="*10)
# convolution with dilation
print("="*10, "convolution with dilation start", "="*10)
conv = nn.Conv2d(3, 4, 3, 1, 1, dilation=2)
out = conv(torch.randn(1, 3, 4, 4))
print_grad_fn(out)
print("="*10, "convolution with dilation end", "="*10)

<ConvolutionBackward0 object at 0x106d20be0>
['_saved_bias_sym_sizes_opt', '_saved_dilation', '_saved_groups', '_saved_input', '_saved_output_padding', '_saved_padding', '_saved_stride', '_saved_transposed', '_saved_weight']
[(None, 0), (<AccumulateGrad object at 0x106d23df0>, 0), (<AccumulateGrad object at 0x106d238b0>, 0)]
<ConvolutionBackward0 object at 0x106d23d90>
['_saved_bias_sym_sizes_opt', '_saved_dilation', '_saved_groups', '_saved_input', '_saved_output_padding', '_saved_padding', '_saved_stride', '_saved_transposed', '_saved_weight']
[(None, 0), (<AccumulateGrad object at 0x106d23df0>, 0), (<AccumulateGrad object at 0x106d238b0>, 0)]
<ConvolutionBackward0 object at 0x106d20be0>
['_saved_bias_sym_sizes_opt', '_saved_dilation', '_saved_groups', '_saved_input', '_saved_output_padding', '_saved_padding', '_saved_stride', '_saved_transposed', '_saved_weight']
[(None, 0), (<AccumulateGrad object at 0x106d23df0>, 0), (<AccumulateGrad object at 0x106d238b0>, 0)]
<ConvolutionBackwar

## fc

In [4]:
# Linear
print("="*10, "Linear start", "="*10)
fc = nn.Linear(3, 4)
out = fc(torch.randn(1, 3, requires_grad=True))
print_grad_fn(out)
print("="*10, "Linear end", "="*10)
# Linear with input more than 2 dims
print("="*10, "Linear with input more than 2 dims start", "="*10)
out = fc(torch.randn(1, 4, 4, 3, requires_grad=True))
print_grad_fn(out)
print("="*10, "Linear with input more than 2 dims end", "="*10)

<AddmmBackward0 object at 0x106d20be0>
['_saved_alpha', '_saved_beta', '_saved_mat1', '_saved_mat1_sym_sizes', '_saved_mat1_sym_strides', '_saved_mat2', '_saved_mat2_sym_sizes', '_saved_mat2_sym_strides']
[(<AccumulateGrad object at 0x141cc5f90>, 0), (<AccumulateGrad object at 0x141cc5ff0>, 0), (<TBackward0 object at 0x141cc5990>, 0)]
<AddBackward0 object at 0x106d20be0>
['_saved_alpha']
[(<UnsafeViewBackward0 object at 0x141cc5f90>, 0), (<AccumulateGrad object at 0x141cc5ff0>, 0)]


## Normalization

In [5]:
# batch norm
print("="*10, "batch norm start", "="*10)
bn = nn.BatchNorm2d(3)
out = bn(torch.randn(1, 3, 4, 4))
print_grad_fn(out)
print("="*10, "batch norm end", "="*10)
# # layer norm
print("="*10, "layer norm start", "="*10)
ln = nn.LayerNorm([3, 4, 4])
out = ln(torch.randn(1, 3, 4, 4))
print_grad_fn(out)
print("="*10, "layer norm end", "="*10)
# group norm
print("="*10, "group norm start", "="*10)
gn = nn.GroupNorm(3, 3)
out = gn(torch.randn(1, 3, 4, 4))
print_grad_fn(out)
print("="*10, "group norm end", "="*10)

<NativeBatchNormBackward0 object at 0x106d23700>
['_saved_eps', '_saved_input', '_saved_result1', '_saved_result2', '_saved_running_mean', '_saved_running_var', '_saved_training', '_saved_weight']
[(None, 0), (<AccumulateGrad object at 0x106d238b0>, 0), (<AccumulateGrad object at 0x106d23df0>, 0)]
<NativeLayerNormBackward0 object at 0x106d23550>
['_saved_bias', '_saved_input', '_saved_normalized_shape', '_saved_result1', '_saved_result2', '_saved_weight']
[(None, 0), (<AccumulateGrad object at 0x106d238b0>, 0), (<AccumulateGrad object at 0x106d23df0>, 0)]
<NativeGroupNormBackward0 object at 0x106d23700>
['_saved_C', '_saved_HxW', '_saved_N', '_saved_eps', '_saved_group', '_saved_input', '_saved_result1', '_saved_result2', '_saved_weight']
[(None, 0), (<AccumulateGrad object at 0x106d23550>, 0), (<AccumulateGrad object at 0x106d23df0>, 0)]


## add, sub, mul, div, matmul

In [6]:
a = nn.Parameter(torch.randn(1, 3, 4, 4))
b = nn.Parameter(torch.randn(1, 3, 4, 4))
# add
print("="*10, "add start", "="*10)
out = a + b
print_grad_fn(out)
print("="*10, "add end", "="*10)
# sub
print("="*10, "sub start", "="*10)
out = a - b
print_grad_fn(out)
print("="*10, "sub end", "="*10)
# mul
print("="*10, "mul start", "="*10)
out = a * b
print_grad_fn(out)
print("="*10, "mul end", "="*10)
# div
print("="*10, "div start", "="*10)
out = a / b
print_grad_fn(out)
print("="*10, "div end", "="*10)
# matmul
print("="*10, "matmul start", "="*10)
out = torch.matmul(a, b)
print_grad_fn(out)
print("="*10, "matmul end", "="*10)
# bmmbackward
print("="*10, "bmmbackward start", "="*10)
a = torch.randn(1, 3, 4, 4, requires_grad=True)
b = torch.randn(1, 3, 4, 5, requires_grad=True)
out = a.matmul(b)
print_grad_fn(out)
print("="*10, "bmmbackward end", "="*10)
# mmbackward
print("="*10, "mmbackward start", "="*10)
a = torch.randn(4, 4, requires_grad=True)
b = torch.randn(4, 5, requires_grad=True)
out = a.matmul(b)
print_grad_fn(out)
print("="*10, "mmbackward end", "="*10)


<AddBackward0 object at 0x106de9930>
['_saved_alpha']
[(<AccumulateGrad object at 0x106dea110>, 0), (<AccumulateGrad object at 0x106dea9e0>, 0)]
<SubBackward0 object at 0x106de9930>
['_saved_alpha']
[(<AccumulateGrad object at 0x106dea1d0>, 0), (<AccumulateGrad object at 0x106de9180>, 0)]
<MulBackward0 object at 0x106d238b0>
['_saved_other', '_saved_self']
[(<AccumulateGrad object at 0x106de9930>, 0), (<AccumulateGrad object at 0x106dea4d0>, 0)]
<DivBackward0 object at 0x106d238b0>
['_saved_other', '_saved_self']
[(<AccumulateGrad object at 0x106de9930>, 0), (<AccumulateGrad object at 0x106dea4d0>, 0)]
<UnsafeViewBackward0 object at 0x106d238b0>
['_saved_self_sym_sizes']
[(<BmmBackward0 object at 0x106de9930>, 0)]
<UnsafeViewBackward0 object at 0x106d238b0>
['_saved_self_sym_sizes']
[(<BmmBackward0 object at 0x106de9930>, 0)]
<MmBackward0 object at 0x106d238b0>
['_saved_mat2', '_saved_mat2_sym_sizes', '_saved_mat2_sym_strides', '_saved_self', '_saved_self_sym_sizes', '_saved_self_sym_s

## concat, split

In [7]:
a = nn.Parameter(torch.randn(1, 4, 4, 4), requires_grad=True)
b = nn.Parameter(torch.randn(1, 6, 4, 4), requires_grad=True)
# concat
print("="*10, "concat start", "="*10)
out = torch.cat([a, b], dim=1)
print_grad_fn(out)
print("="*10, "concat end", "="*10)
# split
print("="*10, "split start", "="*10)
# out = torch.split(a, 2, dim=1)
out = torch.chunk(a, 2, dim=1)
out = out[0]
print_grad_fn(out)
print("="*10, "split end", "="*10)

<CatBackward0 object at 0x106d23d90>
['_saved_dim']
[(<AccumulateGrad object at 0x106d238b0>, 0), (<AccumulateGrad object at 0x106de9780>, 0)]
<SplitBackward0 object at 0x106d23760>
['_saved_dim', '_saved_self_sym_sizes', '_saved_split_size']
[(<AccumulateGrad object at 0x106d23550>, 0)]


## flatten, reshape, view, unsafeview, clone

In [8]:
a = torch.randn(1, 3, 4, 4, requires_grad=True)
# flatten
print("="*10, "flatten start", "="*10)
flat = nn.Flatten()
out = flat(a)
print_grad_fn(out)
print("="*10, "flatten end", "="*10)
# reshape
print("="*10, "reshape start", "="*10)
out = a.reshape(1, 3, 16)
print_grad_fn(out)
print("="*10, "reshape end", "="*10)
# view
print("="*10, "view start", "="*10)
out = a.view(1, 3, 16)
print_grad_fn(out)
print("="*10, "view end", "="*10)
# unsafeview
print("="*10, "unsafeview start", "="*10)
out = a.view(1, -1, 16)
print_grad_fn(out)
print("="*10, "unsafeview end", "="*10)
# clone 
print("="*10, "clone start", "="*10)
out = a.clone()
print_grad_fn(out)
print("="*10, "clone end", "="*10)

<ReshapeAliasBackward0 object at 0x106de9bd0>
['_saved_self_sym_sizes']
[(<AccumulateGrad object at 0x106de9720>, 0)]
<ReshapeAliasBackward0 object at 0x106de9bd0>
['_saved_self_sym_sizes']
[(<AccumulateGrad object at 0x106de9720>, 0)]
<ViewBackward0 object at 0x106de9bd0>
['_saved_self_sym_sizes']
[(<AccumulateGrad object at 0x106de9720>, 0)]
<ViewBackward0 object at 0x106de9bd0>
['_saved_self_sym_sizes']
[(<AccumulateGrad object at 0x106de9720>, 0)]
<CloneBackward0 object at 0x1419575b0>
[]
[(<AccumulateGrad object at 0x106de9bd0>, 0)]


## permute, expansion, squeeze, unsqueeze, transpose, einops

In [10]:
a = torch.randn(1, 2, 3, 4, requires_grad=True)
# permute
print("="*10, "permute start", "="*10)
out = a.permute(0, 2, 1, 3)
print_grad_fn(out)
print("="*10, "permute end", "="*10)
# squeeze
print("="*10, "squeeze start", "="*10)
out = a.squeeze()
print_grad_fn(out)
print("="*10, "squeeze end", "="*10)
# unsqueeze
print("="*10, "unsqueeze start", "="*10)
out = a.unsqueeze(0)
print_grad_fn(out)
print("="*10, "unsqueeze end", "="*10)
# transpose
print("="*10, "transpose start", "="*10)
out = a.transpose(0, 1)
print_grad_fn(out)
print("="*10, "transpose end", "="*10)
# einops b c h w -> b c (h w)
print("="*10, "einops start", "="*10)
from einops import rearrange, reduce, repeat
a = torch.randn(1, 3, 4, 4, requires_grad=True)
out = rearrange(a, 'b c h w -> b c (h w)')
print_grad_fn(out)
print("="*10, "einops end", "="*10)
# slice backward single
print("="*10, "slice backward single start", "="*10)
a = torch.randn(1, 3, 10, 32, requires_grad=True)
out = a[:, :16, :, :]
print_grad_fn(out)
print("="*10, "slice backward single end", "="*10)
# expansion
a = torch.randn(1, 3, 1, 1, requires_grad=True)
print("="*10, "expansion start", "="*10)
out = a.expand(1, 3, 4, 4)
print_grad_fn(out)
print("="*10, "expansion end", "="*10)

<PermuteBackward0 object at 0x141ddff10>
['_saved_dims']
[(<AccumulateGrad object at 0x141ddfe80>, 0)]
<SqueezeBackward0 object at 0x141ddff10>
['_saved_self_sym_sizes']
[(<AccumulateGrad object at 0x141ddffd0>, 0)]
<UnsqueezeBackward0 object at 0x141ddff10>
['_saved_dim']
[(<AccumulateGrad object at 0x141ddfeb0>, 0)]
<TransposeBackward0 object at 0x106d23d90>
['_saved_dim0', '_saved_dim1']
[(<AccumulateGrad object at 0x141ddff10>, 0)]
<ReshapeAliasBackward0 object at 0x106d23d90>
['_saved_self_sym_sizes']
[(<PermuteBackward0 object at 0x141dc31f0>, 0)]
<SliceBackward0 object at 0x141dc6bc0>
['_saved_dim', '_saved_end', '_saved_self_sym_sizes', '_saved_start', '_saved_step']
[(<SliceBackward0 object at 0x141dc31f0>, 0)]
<ExpandBackward0 object at 0x106d23d90>
['_saved_self_sym_sizes']
[(<AccumulateGrad object at 0x141dc31f0>, 0)]


## activations

In [None]:
a = torch.randn(1, 3, 4, 4, requires_grad=True)
# relu
print("="*10, "relu start", "="*10)
out = torch.relu(a)
print_grad_fn(out)
print("="*10, "relu end", "="*10)
# silu
print("="*10, "silu start", "="*10)
out = torch.nn.functional.silu(a)
print_grad_fn(out)
print("="*10, "silu end", "="*10)
# gelu
print("="*10, "gelu start", "="*10)
out = torch.nn.functional.gelu(a)
print_grad_fn(out)
print("="*10, "gelu end", "="*10)
# hardswish
print("="*10, "hardswish start", "="*10)
out = torch.nn.functional.hardswish(a)
print_grad_fn(out)
print("="*10, "hardswish end", "="*10)
# sigmoid
print("="*10, "sigmoid start", "="*10)
out = torch.sigmoid(a)
print_grad_fn(out)
print("="*10, "sigmoid end", "="*10)
# tanh
print("="*10, "tanh start", "="*10)
out = torch.tanh(a)
print_grad_fn(out)
print("="*10, "tanh end", "="*10)
# softmax
print("="*10, "softmax start", "="*10)
out = torch.softmax(a, dim=1)
print_grad_fn(out)
print("="*10, "softmax end", "="*10)
# logsoftmax
print("="*10, "logsoftmax start", "="*10)
out = torch.log_softmax(a, dim=1)
print_grad_fn(out)
print("="*10, "logsoftmax end", "="*10)

<ReluBackward0 object at 0x10b51a050>
['_saved_result']
[(<AccumulateGrad object at 0x168a0fbe0>, 0)]
torch.Size([1, 3, 4, 4]) torch.Size([1, 3, 4, 4])
<SiluBackward0 object at 0x10b51a050>
['_saved_self']
[(<AccumulateGrad object at 0x109cc7f40>, 0)]
torch.Size([1, 3, 4, 4]) torch.Size([1, 3, 4, 4])
<GeluBackward0 object at 0x10b51a050>
['_saved_approximate', '_saved_self']
[(<AccumulateGrad object at 0x168a0fbe0>, 0)]
torch.Size([1, 3, 4, 4]) torch.Size([1, 3, 4, 4])
none
<HardswishBackward0 object at 0x10b70b340>
['_saved_self']
[(<AccumulateGrad object at 0x168a0fbe0>, 0)]
torch.Size([1, 3, 4, 4]) torch.Size([1, 3, 4, 4])
<SigmoidBackward0 object at 0x10b70b340>
['_saved_result']
[(<AccumulateGrad object at 0x168a0fbe0>, 0)]
torch.Size([1, 3, 4, 4]) torch.Size([1, 3, 4, 4])
<TanhBackward0 object at 0x10b70b340>
['_saved_result']
[(<AccumulateGrad object at 0x168a0fbe0>, 0)]
torch.Size([1, 3, 4, 4]) torch.Size([1, 3, 4, 4])
<SoftmaxBackward0 object at 0x10b70b340>
['_saved_dim', '_s

## poolings

In [None]:
a = torch.randn(1, 3, 4, 4, requires_grad=True)
# AdaptiveAvgPool
print("="*10, "AdaptiveAvgPool start", "="*10)
out = torch.nn.AdaptiveAvgPool2d((1, 1))(a)
print_grad_fn(out)
print("="*10, "AdaptiveAvgPool end", "="*10)
# MaxPool
print("="*10, "MaxPool start", "="*10)
out = torch.nn.MaxPool2d((2, 2))(a)
print_grad_fn(out)
print("="*10, "MaxPool end", "="*10)
# AvgPool
print("="*10, "AvgPool start", "="*10)
out = torch.nn.AvgPool2d((2, 2))(a)
print_grad_fn(out)
print("="*10, "AvgPool end", "="*10)

torch.Size([1, 3, 1, 1])
<MeanBackward1 object at 0x11ba1b9d0>
['_saved_dim', '_saved_keepdim', '_saved_self', '_saved_self_sym_sizes']
[(<AccumulateGrad object at 0x11ba1b700>, 0)]
(1, 3, 4, 4)
True
torch.Size([1, 3, 2, 2])
<MaxPool2DWithIndicesBackward0 object at 0x11ba1b9d0>
['_saved_ceil_mode', '_saved_dilation', '_saved_kernel_size', '_saved_padding', '_saved_result1', '_saved_self', '_saved_stride']
[(<AccumulateGrad object at 0x11ba1ba60>, 0)]
torch.Size([1, 3, 4, 4])
torch.Size([1, 3, 2, 2])
<AvgPool2DBackward0 object at 0x11ba1b700>
['_saved_ceil_mode', '_saved_count_include_pad', '_saved_divisor_override', '_saved_kernel_size', '_saved_padding', '_saved_self', '_saved_stride']
[(<AccumulateGrad object at 0x11ba1ba60>, 0)]
torch.Size([1, 3, 4, 4])


## upsample

In [None]:
a = torch.randn(1, 3, 32, 32, requires_grad=True)
# upsample bilinear
print("="*10, "upsample bilinear start", "="*10)
out = torch.nn.functional.interpolate(a, scale_factor=2, mode='bilinear', align_corners=True)
print_grad_fn(out)
print("="*10, "upsample bilinear end", "="*10)
# upsample nearest
print("="*10, "upsample nearest start", "="*10)
out = torch.nn.functional.interpolate(a, scale_factor=2, mode='nearest')
print_grad_fn(out)
print("="*10, "upsample nearest end", "="*10)
# upsample bicubic
print("="*10, "upsample bicubic start", "="*10)
out = torch.nn.functional.interpolate(a, scale_factor=2, mode='bicubic', align_corners=True)
print_grad_fn(out)
print("="*10, "upsample bicubic end", "="*10)

torch.Size([1, 3, 64, 64])
<UpsampleBilinear2DBackward0 object at 0x280bdabf0>
['_saved_align_corners', '_saved_output_size', '_saved_scales_h', '_saved_scales_w', '_saved_self_sym_sizes']
[(<AccumulateGrad object at 0x281007460>, 0)]
torch.Size([1, 3, 64, 64])
<UpsampleNearest2DBackward0 object at 0x280bdabf0>
['_saved_output_size', '_saved_scales_h', '_saved_scales_w', '_saved_self_sym_sizes']
[(<AccumulateGrad object at 0x11ba18ca0>, 0)]
torch.Size([1, 3, 64, 64])
<UpsampleBicubic2DBackward0 object at 0x11ba18be0>
['_saved_align_corners', '_saved_output_size', '_saved_scales_h', '_saved_scales_w', '_saved_self_sym_sizes']
[(<AccumulateGrad object at 0x11ba1b760>, 0)]
