In [7]:
import torch
from torch import nn
from collections import OrderedDict
import torchvision.models as models

In [4]:
class BatchNorm1dNoBias(nn.BatchNorm1d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.bias.requires_grad = False

In [8]:
class EncodeProject(nn.Module):
    def __init__(self, hparams):
        super().__init__()

        if hparams['arch'] == 'ResNet50':
            cifar_head = (hparams['data'] == 'cifar')
            self.convnet = ResNet50(cifar_head=cifar_head, hparams=hparams)
            self.encoder_dim = 2048
        elif hparams['arch'] == 'resnet18':
            self.convnet = ResNet18(cifar_head=(hparams.data == 'cifar'))
            self.encoder_dim = 512
        else:
            raise NotImplementedError

        num_params = sum(p.numel() for p in self.convnet.parameters() if p.requires_grad)

        print(f'======> Encoder: output dim {self.encoder_dim} | {num_params/1e6:.3f}M parameters')

        self.proj_dim = 128
        projection_layers = [
            ('fc1', nn.Linear(self.encoder_dim, self.encoder_dim, bias=False)),
            ('bn1', nn.BatchNorm1d(self.encoder_dim)),
            ('relu1', nn.ReLU()),
            ('fc2', nn.Linear(self.encoder_dim, 128, bias=False)),
            ('bn2', BatchNorm1dNoBias(128)),
        ]

        self.projection = nn.Sequential(OrderedDict(projection_layers))

    def forward(self, x, out='z'):
        h = self.convnet(x)
        if out == 'h':
            return h
        return self.projection(h)

In [31]:
class EncodePP(nn.Module):
    def __init__(self):
        super().__init__()
        self.convnet = ResNet50(cifar_head=False, hparams=None)
        self.encoder_dim = 2048
        num_params = sum(p.numel() for p in self.convnet.parameters() if p.requires_grad)
        print(f'======> Encoder: output dim {self.encoder_dim} | {num_params/1e6:.3f}M parameters')

        self.proj_dim = 128
        projection_layers = [
            ('fc1', nn.Linear(self.encoder_dim, self.encoder_dim, bias=False)),
            ('bn1', nn.BatchNorm1d(self.encoder_dim)),
            ('relu1', nn.ReLU()),
            ('fc2', nn.Linear(self.encoder_dim, 128, bias=False)),
            ('bn2', BatchNorm1dNoBias(128)),
        ]

        self.projection = nn.Sequential(OrderedDict(projection_layers))

    def forward(self, x, out='z'):
        h = self.convnet(x)
        if out == 'h':
            return h
        return self.projection(h)

In [19]:
class Flatten(nn.Module):
    def __init__(self, dim=-1):
        super(Flatten, self).__init__()
        self.dim = dim

    def forward(self, feat):
        return torch.flatten(feat, start_dim=self.dim)


class ResNetEncoder(models.resnet.ResNet):
    """Wrapper for TorchVison ResNet Model
    This was needed to remove the final FC Layer from the ResNet Model"""
    def __init__(self, block, layers, cifar_head=False, hparams=None):
        super().__init__(block, layers)
        self.cifar_head = cifar_head
        if cifar_head:
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn1 = self._norm_layer(64)
            self.relu = nn.ReLU(inplace=True)
        self.hparams = hparams

        print('** Using avgpool **')

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        if not self.cifar_head:
            x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        return x

class ResNet18(ResNetEncoder):
    def __init__(self, cifar_head=True):
        super().__init__(models.resnet.BasicBlock, [2, 2, 2, 2], cifar_head=cifar_head)


class ResNet50(ResNetEncoder):
    def __init__(self, cifar_head=True, hparams=None):
        super().__init__(models.resnet.Bottleneck, [3, 4, 6, 3], cifar_head=cifar_head, hparams=hparams)

In [22]:
convnet = ResNet50(cifar_head=False, hparams=None)

** Using avgpool **


In [37]:
from torchsummary import summary
summary(convnet, (3,330,330), verbose=0)

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 64, 165, 165]        9,408
├─BatchNorm2d: 1-2                       [-1, 64, 165, 165]        128
├─ReLU: 1-3                              [-1, 64, 165, 165]        --
├─MaxPool2d: 1-4                         [-1, 64, 83, 83]          --
├─Sequential: 1-5                        [-1, 256, 83, 83]         --
|    └─Bottleneck: 2-1                   [-1, 256, 83, 83]         --
|    |    └─Conv2d: 3-1                  [-1, 64, 83, 83]          4,096
|    |    └─BatchNorm2d: 3-2             [-1, 64, 83, 83]          128
|    |    └─ReLU: 3-3                    [-1, 64, 83, 83]          --
|    |    └─Conv2d: 3-4                  [-1, 64, 83, 83]          36,864
|    |    └─BatchNorm2d: 3-5             [-1, 64, 83, 83]          128
|    |    └─ReLU: 3-6                    [-1, 64, 83, 83]          --
|    |    └─Conv2d: 3-7                  [-1, 256, 83, 83]         16,38

In [30]:
for name, module in convnet.named_children():
    print(module)

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU(inplace=True)
MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, t

In [32]:
withhead = EncodePP()

** Using avgpool **


In [33]:
summary(withhead, (3,330,330), verbose=0)

Layer (type:depth-idx)                   Output Shape              Param #
├─ResNet50: 1-1                          [-1, 2048]                --
|    └─Conv2d: 2-1                       [-1, 64, 165, 165]        9,408
|    └─BatchNorm2d: 2-2                  [-1, 64, 165, 165]        128
|    └─ReLU: 2-3                         [-1, 64, 165, 165]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 83, 83]          --
|    └─Sequential: 2-5                   [-1, 256, 83, 83]         --
|    |    └─Bottleneck: 3-1              [-1, 256, 83, 83]         75,008
|    |    └─Bottleneck: 3-2              [-1, 256, 83, 83]         70,400
|    |    └─Bottleneck: 3-3              [-1, 256, 83, 83]         70,400
|    └─Sequential: 2-6                   [-1, 512, 42, 42]         --
|    |    └─Bottleneck: 3-4              [-1, 512, 42, 42]         379,392
|    |    └─Bottleneck: 3-5              [-1, 512, 42, 42]         280,064
|    |    └─Bottleneck: 3-6              [-1, 512, 42, 42] 

In [35]:
for name, module in withhead.named_children():
    print(module)

ResNet50(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1

In [39]:
model_params = {
      18: {
          'block': 'ResidualBlock',
          'layers': [2, 2, 2, 2]
      },
      34: {
          'block': 'ResidualBlock',
          'layers': [3, 4, 6, 3]
      },
      50: {
          'block': 'BottleneckBlock',
          'layers': [3, 4, 6, 3]
      },
      101: {
          'block': 'BottleneckBlock',
          'layers': [3, 4, 23, 3]
      },
      152: {
          'block': 'BottleneckBlock',
          'layers': [3, 8, 36, 3]
      },
      200: {
          'block': 'BottleneckBlock',
          'layers': [3, 24, 36, 3]
      }
  }

In [40]:
model_params[18]

{'block': 'ResidualBlock', 'layers': [2, 2, 2, 2]}

In [41]:
params = model_params[50]

In [42]:
params['block']

'BottleneckBlock'