In [1]:
import math

def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1, ceil_mode=False):
    """
    Utility function for computing output of convolutions
    takes a tuple of (h,w) and returns a tuple of (h,w)
    """
    
    if type(h_w) is not tuple:
        h_w = (h_w, h_w)
    
    if type(kernel_size) is not tuple:
        kernel_size = (kernel_size, kernel_size)
    
    if type(stride) is not tuple:
        stride = (stride, stride)
    
    if type(pad) is not tuple:
        pad = (pad, pad)

    if ceil_mode:
        h = math.ceil((h_w[0] + (2 * pad[0]) - (dilation * (kernel_size[0] - 1)) - 1)/ stride[0] + 1)
        w = math.ceil((h_w[1] + (2 * pad[1]) - (dilation * (kernel_size[1] - 1)) - 1)/ stride[1] + 1)
    else:  
        h = (h_w[0] + (2 * pad[0]) - (dilation * (kernel_size[0] - 1)) - 1)// stride[0] + 1
        w = (h_w[1] + (2 * pad[1]) - (dilation * (kernel_size[1] - 1)) - 1)// stride[1] + 1
    
    return h, w

def convtransp_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
    """
    Utility function for computing output of transposed convolutions
    takes a tuple of (h,w) and returns a tuple of (h,w)
    """
    
    if type(h_w) is not tuple:
        h_w = (h_w, h_w)
    
    if type(kernel_size) is not tuple:
        kernel_size = (kernel_size, kernel_size)
    
    if type(stride) is not tuple:
        stride = (stride, stride)
    
    if type(pad) is not tuple:
        pad = (pad, pad)
        
    h = (h_w[0] - 1) * stride[0] - 2 * pad[0] + kernel_size[0] + pad[0]
    w = (h_w[1] - 1) * stride[1] - 2 * pad[1] + kernel_size[1] + pad[1]
    
    return h, w

In [4]:
"""
        self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100)
        self.relu1_1 = nn.ReLU(inplace=True)
        self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
        self.relu1_2 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/2
        

        # conv2
        self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
        self.relu2_1 = nn.ReLU(inplace=True)
        self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
        self.relu2_2 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/4

        self.score_fr = nn.Conv2d(128, n_class, 1)
        self.upscore = nn.ConvTranspose2d(n_class, n_class, 64, stride=32,
                                          bias=False)

        # conv3
        self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
        self.relu3_1 = nn.ReLU(inplace=True)
        self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
        self.relu3_2 = nn.ReLU(inplace=True)
        self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
        self.relu3_3 = nn.ReLU(inplace=True)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/8
"""

h, w = conv_output_shape((10000,10000), kernel_size=3, pad =96) #conv1_1
print(h,",", w)
h, w = conv_output_shape((h,w), kernel_size=3, pad =1) #conv1_2
print(h,",", w)
h, w = conv_output_shape((h,w), kernel_size=1, stride=2, ceil_mode=True) #pool1
print(h,",", w)
h, w = conv_output_shape((h,w), kernel_size=3, pad =1) #conv2_1 
print(h,",", w)
h, w = conv_output_shape((h,w), kernel_size=3, pad =1) #conv2_2
print(h,",", w)
h, w = conv_output_shape((h,w), kernel_size=2, stride=2, ceil_mode=True) #pool2
print(h,",", w)
h, w = conv_output_shape((h,w), kernel_size=3, pad =1) #conv3_1 
print(h,",", w)
h, w = conv_output_shape((h,w), kernel_size=3, pad =1) #conv3_2
print(h,",", w)
h, w = conv_output_shape((h,w), kernel_size=3, pad =1) #conv3_3
print(h,",", w)
h, w = conv_output_shape((h,w), kernel_size=2, stride=2, ceil_mode=True) #pool2
print(h,",", w)


# self.score = nn.Conv2d(256, n_class, 1)
# # self.upscore = nn.ConvTranspose2d(n_class, n_class, 64, stride=32, bias=False)
#        self.upscore8 = nn.ConvTranspose2d(n_class, n_class, 16, stride=8, bias=False)
print()
h, w = conv_output_shape((h,w), kernel_size=1) #score
print(h,",", w)
h, w = convtransp_output_shape((h,w), kernel_size=16, stride=8) #upscore8
print(h,",", w)


10190 , 10190
10190 , 10190
5096 , 5096
5096 , 5096
5096 , 5096
2548 , 2548
2548 , 2548
2548 , 2548
2548 , 2548
1274 , 1274

1274 , 1274
10200 , 10200


In [3]:
# 10190 , 10190
# 10190 , 10190
# 5096 , 5096
# 5096 , 5096
# 5096 , 5096
# 2548 , 2548
# 2548 , 2548
# 2548 , 2548
# 2548 , 2548
# 1274 , 1274

# 1274 - 10000/8 = 24
# 24/2 = 12
