In [25]:
import torch
import torchvision.models as models

In [26]:
vgg = models.vgg16(pretrained=True)  # This may take a few minutes.

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /home/aishutin/.cache/torch/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [05:10<00:00, 1.78MB/s] 


In [27]:
print(vgg)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [40]:
ConvTranspose2d_DEFAULT_KERNEL_DIM_SIZE = 5
ConvTranspose2d_DEFAULT_MAX_STRIDE = 4
ConvTranspose2d_DEFAULT_MAX_DILATION = 4
ConvTranspose2d_DEFAULT_MAX_INP_PADDING = 4
ConvTranspose2d_DEFAULT_MAX_OUT_PADDING = 4
def calc_parameters(lay):
    total = 0
    for tensor in lay.parameters():
        curr = 1
        for el in tensor.shape:
            curr *= el
        total += curr
    return total

def choose_parameters_in_ConvTranspose2d_space(input_shape, output_shape, kernel_dims=None, symmetric=True):
    batch_size, cin, hin, win = input_shape
    batch_size, cout, hout, wout = output_shape
    
    if kernel_dims is None:
        kernel_dims = range(1, 1 + ConvTranspose2d_DEFAULT_KERNEL_DIM_SIZE)
    
    input_ex = torch.randn((1, cin, hin, win))
    strides = range(1, ConvTranspose2d_DEFAULT_MAX_STRIDE + 1)
    dilations = range(1, 1 + ConvTranspose2d_DEFAULT_MAX_DILATION)
    inp_pads = range(0, ConvTranspose2d_DEFAULT_MAX_INP_PADDING + 1)
    out_pads = range(0, ConvTranspose2d_DEFAULT_MAX_OUT_PADDING + 1)
    
    configurations = []
    
    for stride1 in strides:
        for stride2 in strides:
            if stride1 != stride2 and symmetric:
                continue
            for dil1 in dilations:
                for dil2 in dilations:
                    if dil1 != dil2 and symmetric:
                        continue
                    for kdim1 in kernel_dims:  
                        for kdim2 in kernel_dims:
                            if kdim1 != kdim2 and symmetric:
                                continue
                            for ipad1 in inp_pads:
                                for ipad2 in inp_pads:
                                    if ipad1 != ipad2 and symmetric:
                                        continue
                                    for opad1 in out_pads:
                                        for opad2 in out_pads:
                                            if opad1 != opad2 and symmetric:
                                                continue
                                            params =  { 'in_channels': cin, 
                                                        'out_channels': cout,
                                                        'kernel_size': (kdim1, kdim2),
                                                        'stride': (stride1, stride2),
                                                        'dilation': (dil1, dil2),
                                                        'padding': (ipad1, ipad2),
                                                        'output_padding': (opad1, opad2)}
                                            
                                            try:
                                                lay = torch.nn.ConvTranspose2d(**params)
                                                if lay(input_ex).shape != (1, cout, hout, wout):
                                                    continue
                                            except Exception as exp:
                                                continue
                                            

                                            configurations.append((calc_parameters(lay), params))
    configurations.sort(key=lambda x: x[0])
    return [el[1] for el in configurations]    

In [43]:
def choose_parameters_in_Linear_space(input_shape, output_shape):
    return [{'in_features': input_shape[-1], 'out_features': output_shape[-1]}]

In [42]:
input = torch.randn(1, 3, 32, 32)
conv = torch.nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
output = conv(input)
for el in choose_parameters_in_ConvTranspose2d_space(output.shape, input.shape):
    print(el)

{'in_channels': 64, 'out_channels': 3, 'kernel_size': (1, 1), 'stride': (1, 1), 'dilation': (1, 1), 'padding': (0, 0), 'output_padding': (0, 0)}
{'in_channels': 64, 'out_channels': 3, 'kernel_size': (1, 1), 'stride': (1, 1), 'dilation': (2, 2), 'padding': (0, 0), 'output_padding': (0, 0)}
{'in_channels': 64, 'out_channels': 3, 'kernel_size': (1, 1), 'stride': (1, 1), 'dilation': (3, 3), 'padding': (0, 0), 'output_padding': (0, 0)}
{'in_channels': 64, 'out_channels': 3, 'kernel_size': (1, 1), 'stride': (1, 1), 'dilation': (3, 3), 'padding': (1, 1), 'output_padding': (2, 2)}
{'in_channels': 64, 'out_channels': 3, 'kernel_size': (1, 1), 'stride': (1, 1), 'dilation': (4, 4), 'padding': (0, 0), 'output_padding': (0, 0)}
{'in_channels': 64, 'out_channels': 3, 'kernel_size': (1, 1), 'stride': (1, 1), 'dilation': (4, 4), 'padding': (1, 1), 'output_padding': (2, 2)}
{'in_channels': 64, 'out_channels': 3, 'kernel_size': (2, 2), 'stride': (1, 1), 'dilation': (2, 2), 'padding': (1, 1), 'output_pad