In [381]:
import torchvision.datasets as datasets
import torch
from torch import nn
import ml_collections
from tqdm import tqdm

# Download dataset
train_data = datasets.Flowers102(root='./flower-102/train', split='train', download=True)
val_data = datasets.Flowers102(root='./flower-102/val', split='val', download=True)
test_data = datasets.Flowers102(root='./flower-102/test', split='test', download=True)

Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz to flower-102\train\flowers-102\102flowers.tgz


100%|██████████████████████████████████████████████████████████████| 344862509/344862509 [00:31<00:00, 10974774.16it/s]


Extracting flower-102\train\flowers-102\102flowers.tgz to flower-102\train\flowers-102
Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/imagelabels.mat to flower-102\train\flowers-102\imagelabels.mat


100%|████████████████████████████████████████████████████████████████████████████| 502/502 [00:00<00:00, 505653.36it/s]


Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/setid.mat to flower-102\train\flowers-102\setid.mat


100%|██████████████████████████████████████████████████████████████████████| 14989/14989 [00:00<00:00, 15025913.64it/s]


Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz to flower-102\val\flowers-102\102flowers.tgz


100%|██████████████████████████████████████████████████████████████| 344862509/344862509 [00:31<00:00, 10800863.62it/s]


Extracting flower-102\val\flowers-102\102flowers.tgz to flower-102\val\flowers-102
Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/imagelabels.mat to flower-102\val\flowers-102\imagelabels.mat


100%|████████████████████████████████████████████████████████████████████████████| 502/502 [00:00<00:00, 506261.27it/s]


Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/setid.mat to flower-102\val\flowers-102\setid.mat


100%|██████████████████████████████████████████████████████████████████████| 14989/14989 [00:00<00:00, 15054698.91it/s]


Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz to flower-102\test\flowers-102\102flowers.tgz


100%|██████████████████████████████████████████████████████████████| 344862509/344862509 [00:29<00:00, 11641807.12it/s]


Extracting flower-102\test\flowers-102\102flowers.tgz to flower-102\test\flowers-102
Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/imagelabels.mat to flower-102\test\flowers-102\imagelabels.mat


100%|████████████████████████████████████████████████████████████████████████████| 502/502 [00:00<00:00, 506261.27it/s]


Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/setid.mat to flower-102\test\flowers-102\setid.mat


100%|████████████████████████████████████████████████████████████████████████████████████| 14989/14989 [00:00<?, ?it/s]


### Model Implementations:

In [369]:
# List of available pretrained resnets from pytorch
class Pretrains():
    resnet_versions = [
        'resnet18',
        'resnet34',
        'resnet50',
        'resnet101',
        'resnet152'
    ]
    vgg_versions = [
        'vgg11',
        'vgg11_bn',
        'vgg13',
        'vgg13_bn',
        'vgg16',
        'vgg16_bn',
        'vgg19',
        'vgg19_bn'
    ]

class PretrainBackbone(nn.Module):
    def __init__(self, config):
        super(ResNetBackbone, self).__init__()
        
        # Load pretrained ResNet/VGG backbone
        if config.pretrain in Pretrains.resnet_versions or config.resnet_version in Pretrains.vgg_versions:
            model = torch.hub.load('pytorch/vision:v0.10.0', config.pretrain, pretrained=True)
        else:
            raise ValueError('Invalid ResNet/VGG Version. Please select from: ' 
                             + ', '.join(Pretrains.resnet_versions + Pretrains.vgg_versions))
        
        # Segments out only the backbone layers as list, unpacks, and load into nn.Sequential
        backbone_layers = list(model.children())[:-1]
        self.backbone = nn.Sequential(*backbone_layers)
        
    def forward(self, x):
        x = self.backbone(x)
        return x

class ActivationFunction(nn.Module):
    def __init__(self, config):
        super(ActivationFunction, self).__init__()
        match config.type:
            case 'LeakyReLU':
                self.activation_func = nn.LeakyReLU(
                    config.negative_slope,
                    inplace = True
                )
            case 'ReLU':
                self.activation_func = nn.ReLU(inplace = True)
            case 'Softmax':
                self.activation_func = nn.Softmax(dim = self.dim)
            case _:
                raise ValueError('Invalid activation function or not implemented')
    def forward(self, x):
        return self.activation_func(x)

class Conv2dBlock(nn.Module):
    def __init__(self, config):
        super(Conv2dBlock, self).__init__()
        modules = []
        if config.layer_num < 1:
            raise ValueError('Number of layers cannot be less than 1')
        for layer_idx in range(config.layer_num):
            # Conv2d
            modules.append(nn.Conv2d(
                config.in_channels if not layer_idx else config.out_channels,
                config.out_channels,
                kernel_size = 3,
                padding = 1
            ))
            
            # Batch Normalization
            if config.use_batchnorm:
                modules.append(nn.BatchNorm2d(config.out_channels))
                
            # Activation function, skip this step if skip_last_activation is True
            if config.skip_last_activation and layer_idx == config.layer_num - 1:
                break   
            modules.append(ActivationFunction(config.activation_func))
        self.sequential = nn.Sequential(*modules)
        
    def forward(self, x):
        return self.sequential(x)

# Creates a mirrored Conv2dBlock
class RevConv2dBlock(nn.Module):
    def __init__(self, conv2d_block):
        super(RevConv2dBlock, self).__init__()
        
        # Reverses module from conv2d_block
        modules = list(conv2d_block.sequential)
        modules.reverse()
        module_iterator = iter(range(len(modules)))
        for idx in module_iterator:
            if isinstance(modules[idx], torch.nn.modules.batchnorm.BatchNorm2d):
                
                # Switch order of batch and conv2d
                modules[idx], modules[idx + 1] = modules[idx + 1], modules[idx]
                
                # Swap conv2d with convtranspose2d
                modules[idx] = nn.ConvTranspose2d(
                    modules[idx].out_channels,
                    modules[idx].in_channels,
                    kernel_size = modules[idx].kernel_size,
                    stride = modules[idx].stride,
                    padding = modules[idx].padding
                )
                
                modules[idx + 1] = nn.BatchNorm2d(modules[idx].out_channels)
                
                # Skip next index
                next(module_iterator)
            
        if isinstance(modules[0], ActivationFunction):
            activation_func = modules.pop(0)
            modules.append(activation_func)
            
        self.sequential = nn.Sequential(*modules)
        
    def forward(self, x):
        return self.sequential(x)
    
class VGGBackboneBlock(nn.Module):
    def __init__(self, config):
        super(VGGBackboneBlock, self).__init__()
        config.skip_last_activation = False
        
        # Conv2d
        self.conv2d_block = Conv2dBlock(config)
        
        # Maxpool
        self.maxpool = nn.MaxPool2d(
            kernel_size=config.compression_ratio, 
            stride=config.compression_ratio
        )
    
    def forward(self, x):
        out = self.conv2d_block(x)
        out = self.maxpool(out)
        return out
    
    def get_reverse(self):
        return RevVGGBackconeBlock(self)
    
class RevVGGBackconeBlock(nn.Module):
    def __init__(self, vgg_backbone_block):
        super(VGGBackboneBlock).__init__()
        # To be implemented
    def forward(self, x):
        # To be implemented
        return x
    
class ResidualBlock(nn.Module):
    def __init__(self, config):
        super(ResidualBlock, self).__init__()
        
        # Main Conv2d block
        main_block_config = config
        main_block_config.layer_num = config.main_layer_num
        main_block_config.skip_last_activation = True
        self.main_block = Conv2dBlock(main_block_config)
        
        # Shortcut Conv2d block, we leave self.shortcut_block as undefined if shortcut layer depth = 0
        if config.shortcut_layer_num:
            shortcut_block_config = config
            shortcut_block_config.layer_num = config.shortcut_layer_num
            shortcut_block_config.skip_last_activation = True
            self.shortcut_block = Conv2dBlock(shortcut_block_config)
            
        self.activation_func = ActivationFunction(config.activation_func)
        
        # Optional maxpooling layer if compression_ratio is set
        if hasattr(config, 'compression_ratio'):
            self.maxpool = nn.MaxPool2d(
                kernel_size=config.compression_ratio, 
                stride=config.compression_ratio
        )
    
    def forward(self, x):
        out = self.main_block(x)
        if hasattr(self, 'shortcut_block'):
            out += self.shortcut_block(x)
        else:
            out += x
            
        out = self.activation_func(out)
            
        if hasattr(self, 'maxpool'):
            out = self.maxpool(out)
            
        return out

    def get_reverse(self):
        # Get reversed version
        return RevResidualBlock(self)

class RevResidualBlock(nn.Module):
    def __init__(self, residual_block):
        super(RevResidualBlock, self).__init__()
        self.main_block = RevConv2dBlock(residual_block.main_block)
        
        if hasattr(residual_block, 'shortcut_block'):
            self.shortcut_block = RevConv2dBlock(residual_block.shortcut_block)
            
        if hasattr(residual_block, 'maxpool'):
            self.upsample = nn.Upsample(scale_factor=residual_block.maxpool.stride)
            
    def forward(self, x):
        if hasattr(self, 'upsample'):
            x = self.upsample(x)
        else:
            x = x
            
        out = self.main_block(x)
        
        if hasattr(self, 'shortcut_block'):
            out += self.shortcut_block(x)
        else:
            out += x
        return out
    
class Encoder(nn.Module):
    def __init__(self, config):
        super(Encoder, self).__init__()
        modules = []
        match config.type:
            case 'residual_blocks':
                for idx, block_feature in enumerate(config.features):
                    if not idx:
                        in_channels = config.in_channels
                        out_channels = block_feature
                    else:
                        in_channels = config.features[idx - 1]
                        out_channels = block_feature

                    block_config = ml_collections.ConfigDict({
                        'main_layer_num': config.main_layer_num,
                        'shortcut_layer_num': config.shortcut_layer_num,
                        'in_channels': in_channels,
                        'out_channels': out_channels,
                        'use_batchnorm': config.use_batchnorm,
                        'activation_func': config.activation_func,
                    })
                    if hasattr(config, 'compression_ratio'):
                        block_config.compression_ratio = config.compression_ratio

                    modules.append(ResidualBlock(block_config))
            case 'vgg_backbone_blocks':
                # To be implemented
                raise NotImplementedError('To be implemented')
                
        self.sequential = nn.Sequential(*modules)
    def forward(self, x):
        return self.sequential(x)
        
class Decoder(nn.Module):
    def __init__(self, arg):
        super(Decoder, self).__init__()
        # Initialize by mirroring encoder
        if isinstance(arg, Encoder):
            encoder = arg
            modules = list(encoder.sequential)
            modules.reverse()
            
            for idx in range(len(modules)):
                modules[idx] = modules[idx].get_reverse()
            self.sequential = nn.Sequential(*modules)
        # Initialize by config (not implemented since we are using mirrored encoder/decoder)
        else:
            raise NotImplementedError('This decoder class is only implemented to be initialized by mirroring an encoder class')
    def forward(self, x):
        return self.sequential(x)
    
class AutoEncoder(nn.Module):
    def __init__(self, config):
        super(AutoEncoder, self).__init__()
        # Encoder
        encoder_config = config.encoder_config
        encoder_config.in_channels = config.in_channels
        self.encoder = Encoder(encoder_config)
        
        # Check for bottleneck input size by passing dummy input to encoder
        dummy_input = torch.randn(1, config.in_channels, config.in_dimension[0], config.in_dimension[1])
        out = self.encoder.forward(dummy_input)
        out_dimension = list(out.size())
        in_bottleneck = out_dimension[1] * out_dimension[2] * out_dimension[3]
        
        # Bottleneck
        self.bottleneck = nn.Linear(in_bottleneck, config.bottleneck_width)
        
        # Decoder
        self.decoder = Decoder(self.encoder)
        
    def forward(self, x):
        out = self.encoder(x)
        encoder_out_shape = out.size() 
        flatten = out.view(out.size(0), -1)
        out = self.bottleneck(flatten)
        reshaped = out.view(out.size()[0], out.size()[1], 1, 1)
        out = self.decoder(reshaped)
        
        return out

### Test:

In [370]:
config_dict = {
    'in_dimension': (224, 224),
    'in_channels': 3,
    'encoder_config': {
        'type': 'residual_blocks',
        'compression_ratio': 2,
        'features': [64, 128, 256, 512, 512, 512],
        'main_layer_num': 3,
        'shortcut_layer_num': 1,
        'use_batchnorm': True,
        'activation_func': {
            'type': 'LeakyReLU',
            'negative_slope': 0.1
        },
    },
    'decoder_config': {
        'mirror_encoder': True
    },
    'bottleneck_width': 512
}
test_config = ml_collections.ConfigDict(config_dict)

autoencoder = AutoEncoder(test_config)
autoencoder.forward(torch.randn(1, 3, 224, 224))

tensor([[[[ 0.5516, -0.2068, -0.4680,  ..., -0.8531, -0.4450,  0.1232],
          [ 0.8301,  0.7162,  0.5984,  ..., -1.5186, -0.2426, -0.2329],
          [ 0.8712,  1.0302,  1.0262,  ..., -0.5684,  0.7226,  0.5514],
          ...,
          [ 1.0588,  0.4694,  0.0622,  ...,  1.3015,  1.8702,  1.1376],
          [ 1.1240,  0.0624, -0.1162,  ...,  1.1003,  1.2729,  1.0246],
          [ 1.2998,  0.0777,  0.1123,  ...,  0.6273,  0.9946,  0.2985]],

         [[ 0.0466,  0.8038,  0.8825,  ...,  0.8605,  0.5045,  0.8763],
          [ 0.4250, -0.4446, -0.6750,  ...,  0.5111, -0.1032,  0.2369],
          [-0.1207, -0.9937, -0.6350,  ...,  0.9737, -0.2861,  0.0639],
          ...,
          [ 1.1805,  1.4365,  1.4348,  ..., -0.0631, -0.1165,  0.6880],
          [ 0.2076,  0.7495,  0.7422,  ...,  0.1660, -0.2923,  0.7185],
          [-0.1095,  0.4830,  0.3474,  ..., -0.5860, -0.7834,  0.4724]],

         [[-0.7530, -0.1011, -0.5831,  ..., -1.1542, -1.1337, -0.6581],
          [ 0.0807,  0.0728, -

In [383]:
train_data

Dataset Flowers102
    Number of datapoints: 1020
    Root location: ./flower-102/train
    split=train

### Hyperparameters

In [441]:
import copy

# All hyperparameters to be tuned for the autoencoder network
configs_dict = {
    'in_dimension': [(224, 224)],
    'in_channels': [3],
    'encoder_config': {
        'type': ['residual_blocks'],
        'compression_ratio': [2],
        'features': [
            [64, 128, 256, 512, 512, 512],
            [64, 128, 256, 512, 512],
            [64, 128, 256],
            [32, 64, 128, 256, 512, 512, 512],
            [32, 64, 128, 256, 512, 512],
            [32, 64, 128, 256, 512],
            [16, 32, 64, 128, 256, 512, 512],
            [16, 32, 64, 128, 256, 512],
            [16, 32, 64, 128, 256],
        ],
        'main_layer_num': [2, 1],
        'shortcut_layer_num': [1, 0],
        'use_batchnorm': [True, False],
        'activation_func':[
            {
                'type': 'LeakyReLU',
                'negative_slope': 0.1
            },
            {
                'type': 'LeakyReLU',
                'negative_slope': 0.2
            },
            {
                'type': 'ReLU',
            }
        ] 
            
    },
    'decoder_config': {
        'mirror_encoder': [True]
    },
    'bottleneck_width': [256, 512, 1024, 2048]
}

# Parse ConfigDict with hyperparameters to be tuned and output list of all ConfigDicts to be tested
def generate_configs(configs):
    config_list = [ml_collections.ConfigDict()]
    for key in configs:
        new_config_list = []
        
        current_key_configs = []
        if isinstance(configs[key], list):
            current_key_configs = configs[key]
        elif isinstance(configs[key], ml_collections.config_dict.config_dict.ConfigDict):
            current_key_configs = generate_configs(configs[key])
        else:
            raise TypeError(configs[key] + ' is neither a list nor an ml_collections.ConfigDict object')
#         print(key, current_key_configs)
        for key_config in current_key_configs:
            for prev_config in config_list:
                prev_config_copy = copy.deepcopy(prev_config)
                prev_config_copy[key] = key_config
                new_config_list.append(prev_config_copy)
        config_list = new_config_list
    return config_list

config_list = generate_configs(configs)

In [447]:
print(len(config_list))
config_list[0]

864


bottleneck_width: 256
decoder_config:
  mirror_encoder: true
encoder_config:
  activation_func:
    negative_slope: 0.1
    type: LeakyReLU
  compression_ratio: 2
  features:
  - 64
  - 128
  - 256
  - 512
  - 512
  - 512
  in_channels: 3
  main_layer_num: 2
  shortcut_layer_num: 1
  type: residual_blocks
  use_batchnorm: true
in_channels: 3
in_dimension: !!python/tuple
- 224
- 224

### WIP Stuffs