This is to verify that whether conv3d degrades to conv1d when `kernel_width` and `kernel_height` equals 1.

If so, we will be only using conv3d in our code since it offers more flexibility that we can choose whether or not to take into account the time and pitch information, simply by setting the `kernel_width` and `kernel_height` to 1 or more than 1, for a larger reception field.

In [36]:
import torch
import torch.nn as nn

batch_size = 16
in_channels = 2
width = 128
height = 128
track_num = 5
input_tensor = torch.randn([batch_size, track_num, in_channels, width, height])

In [37]:
conv3d = nn.Conv3d(in_channels, in_channels*track_num, (track_num,1,1))
conv1d = nn.Conv1d(in_channels, in_channels*track_num, track_num)

weight_3d = conv3d.weight
weight_1d = conv1d.weight
bias_3d = conv3d.bias
bias_1d = conv1d.bias

print(weight_3d.shape, weight_1d.shape)
print(bias_3d.shape, bias_1d.shape)

weight_init = torch.randn(weight_3d.shape)
bias_init = torch.randn(bias_3d.shape)

conv3d.weight.data = weight_init
conv3d.bias.data = bias_init
conv1d.weight.data = weight_init.reshape(weight_1d.shape)
conv1d.bias.data = bias_init.reshape(bias_1d.shape)

torch.Size([10, 2, 5, 1, 1]) torch.Size([10, 2, 5])
torch.Size([10]) torch.Size([10])


In [38]:
input_1d = input_tensor.permute(3,4,0,2,1).reshape(-1,in_channels,track_num) # [width*height*batch_size, in_channels, track_num]
output_1d = conv1d(input_1d)

output_1d = output_1d.reshape(width, height, batch_size, in_channels*track_num)
output_1d = output_1d.permute(2,3,0,1) # [batch_size, inchannels*track_num, width, height]
print(output_1d.shape)

torch.Size([16, 10, 128, 128])


In [42]:
input_3d = input_tensor.permute(0,2,1,3,4)
output_3d = conv3d(input_3d)
output_3d = output_3d.reshape(batch_size, in_channels*track_num, width, height)
print(output_3d.shape)

torch.Size([16, 10, 128, 128])


In [43]:
assert torch.allclose(output_1d, output_3d)