In [2]:
import torch
import torch.nn as nn

In [13]:
# Define model of double convolution

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        
        self.conv = nn.Sequential(        
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))
        
    def forward(self, x):
        return self.conv(x)

class Funil(nn.Module):
    def __init__(self, features):
        super(Funil, self).__init__()
        
        self.ups = nn.ModuleList()

        for f in range(len(features)-1):
            self.ups.append(nn.ConvTranspose2d(features[f], features[f+1], kernel_size=2, stride=2))
            self.ups.append(DoubleConv(features[f+1], features[f+1]))

        self.final_conv = nn.Conv2d(features[-1], features[0], kernel_size=1)
        
    def forward(self, x):
        
        for i in range(len(self.ups)):
            x = self.ups[i](x)
        
        return self.final_conv(x)

In [14]:
model = Funil([3, 16, 64])

In [15]:
x = torch.rand(10, 3, 64, 64)
y = model(x)

y.shape

torch.Size([10, 3, 256, 256])