In [3]:
import torch
import torch.nn as nn

In [21]:
class DecoderBlock(nn.Module):
	def __init__(self, in_channels, out_channels, expansion=3, do_up_sampling=True):
		"""
		Decoder block module.

		Args:
			in_channels (int): Number of input channels.
			out_channels (int): Number of output channels.
			expansion (int, optional): Expansion factor. Default is 3.
		"""
		super(DecoderBlock, self).__init__()
		self.relu = nn.ReLU(inplace=True)

		self.cnn1 = nn.Conv2d(in_channels, in_channels *
							  expansion, kernel_size=1, stride=1)
		self.bnn1 = nn.BatchNorm2d(in_channels*expansion)

		# nearest neighbor x2
		self.do_up_sampling = do_up_sampling
		self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

		# DW conv/ c_in*exp x 5 x 5 x c_in*exp
		self.cnn2 = nn.Conv2d(in_channels*expansion, in_channels *
							  expansion, kernel_size=5, padding=2, stride=1)
		self.bnn2 = nn.BatchNorm2d(in_channels*expansion)

		self.cnn3 = nn.Conv2d(in_channels*expansion,
							  out_channels, kernel_size=1, stride=1)
		self.bnn3 = nn.BatchNorm2d(out_channels)
		
		self.skip_cnn = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)

	def forward(self, x):
		"""
		Forward pass through the decoder block.

		Args:
			x (torch.Tensor): Input tensor.

		Returns:
			torch.Tensor: Output tensor.
		"""
		print("DecoderBlock")
		print("input shape: ", x.shape)
		temp_x = x
		x = self.cnn1(x)
		x = self.bnn1(x)
		x = self.relu(x)
  
		print("cnn1 shape: ", x.shape)
   
		print("upsample shape: ", x.shape)
			
		x = self.cnn2(x)
		x = self.bnn2(x)
		x = self.relu(x)
  
		print("cnn2 shape: ", x.shape)
		
		x = self.cnn3(x)
		x = self.bnn3(x)
		
		("cnn3 shape: ", x.shape)
  
		# adding skip connection
		temp_x = self.skip_cnn(temp_x)

		print("skip shape: ", temp_x.shape)
  
		x = x + temp_x
  
		print("output shape: ", x.shape)
  
		if self.do_up_sampling:
			x = self.upsample(x)
  
		return x

In [22]:
model_x = DecoderBlock(512, 256, 3)

input = torch.randn(1, 512, 16, 16)
output = model_x(input)

assert output.shape == (1, 256, 32, 32)

DecoderBlock
input shape:  torch.Size([1, 512, 16, 16])
cnn1 shape:  torch.Size([1, 1536, 16, 16])
upsample shape:  torch.Size([1, 1536, 16, 16])
cnn2 shape:  torch.Size([1, 1536, 16, 16])
skip shape:  torch.Size([1, 256, 16, 16])
output shape:  torch.Size([1, 256, 16, 16])


In [24]:
model_x = DecoderBlock(32*4, 32, 3, do_up_sampling=False)

input = torch.randn(1, 32*4, 160, 160)
output = model_x(input)

assert output.shape == (1, 32, 160, 160)

DecoderBlock
input shape:  torch.Size([1, 128, 160, 160])
cnn1 shape:  torch.Size([1, 384, 160, 160])
upsample shape:  torch.Size([1, 384, 160, 160])
cnn2 shape:  torch.Size([1, 384, 160, 160])
skip shape:  torch.Size([1, 32, 160, 160])
output shape:  torch.Size([1, 32, 160, 160])


In [25]:
model_y = DecoderBlock(4 * 32, 32, 3, do_up_sampling=False)

input = torch.randn(1, 4 * 32, 160, 160)
output = model_y(input)

assert output.shape == (1, 32, 160, 160)

DecoderBlock
input shape:  torch.Size([1, 128, 160, 160])
cnn1 shape:  torch.Size([1, 384, 160, 160])
upsample shape:  torch.Size([1, 384, 160, 160])


cnn2 shape:  torch.Size([1, 384, 160, 160])
skip shape:  torch.Size([1, 32, 160, 160])
output shape:  torch.Size([1, 32, 160, 160])
