In [1]:
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import dataset


class GlobalDescriptor(nn.Module):
    def __init__(self, p=1):
        super().__init__()
        self.p = p

    def forward(self, x):
        assert x.dim() == 4, 'the input tensor of GlobalDescriptor must be the shape of [B, C, H, W]'
        if self.p == 1:
            return x.mean(dim=[-1, -2])
        elif self.p == float('inf'):
            return torch.flatten(F.adaptive_max_pool2d(x, output_size=(1, 1)), start_dim=1)
        else:
            x = F.avg_pool2d(x.clamp(min=1e-6).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p)
            return x.squeeze(-1).squeeze(-1)

    def extra_repr(self):
        return f"p={self.p}"


class L2Norm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        assert x.dim() == 2, 'the input tensor of L2Norm must be the shape of [B, C]'
        return F.normalize(x, p=2, dim=-1)


class CGD(nn.Module):
    def __init__(self, gd_config='MG', feature_dim=1024, num_classes=200):
        super(CGD,self).__init__()

        self.n = len(gd_config)
        k = feature_dim // self.n

        global_descriptors = []
        main_modules = []

        dim_count = feature_dim
        for gd in gd_config:
            if gd == 'S':
                p = 1
            elif gd == 'M':
                p = float('inf')
            else:
                p = 3
            dim_count -= k
            if (dim_count > 0) and (dim_count < k):
                k += dim_count
            global_descriptors.append(GlobalDescriptor(p=p))
            main_modules.append(nn.Sequential(nn.Linear(feature_dim, k, bias=False), L2Norm()))
        self.global_descriptors = nn.ModuleList(global_descriptors)
        self.main_moduels = nn.ModuleList(main_modules)

        self.auxiliary_module = nn.Sequential(
            nn.BatchNorm1d(feature_dim),
            nn.Linear(feature_dim, num_classes, bias=True))

    def forward(self, x):
        gds = []
        for i in range(self.n):
            gd = self.global_descriptors[i](x)
            if i == 0:
                classes = self.auxiliary_module(gd)
            gd = self.main_moduels[i](gd)
            gds.append(gd)
        global_descriptor = F.normalize(torch.cat(gds, dim=-1), dim=-1)
        return global_descriptor, classes


class CGD2(nn.Module):
    def __init__(self, gd_config='SG', feature_dim=1024):
        super(CGD2,self).__init__()

        n = len(gd_config)

        global_descriptors = []
        for gd in gd_config:
            if gd == 'S':
                p = 1
            elif gd == 'M':
                p = float('inf')
            else:
                p = 3
            global_descriptors.append(
                nn.Sequential(
                    GlobalDescriptor(p=p), 
                    L2Norm()
                    )
                )
        self.global_descriptors = nn.ModuleList(global_descriptors)
        self.linear = nn.Linear(feature_dim*n, feature_dim, bias=False)

    def forward(self, x):
        gds = []
        for global_descriptor in self.global_descriptors:
            gd = global_descriptor(x)
            gds.append(gd)
        global_descriptor = torch.cat(gds, dim=-1)
        global_descriptor = F.normalize(self.linear(global_descriptor), dim=-1)
        return global_descriptor


class MultiAtrous(nn.Module):
    def __init__(self, in_channel, out_channel, size, dilation_rates=[3, 6, 9]):
        super().__init__()
        self.dilated_convs = [
            nn.Conv2d(in_channel, int(out_channel/4),
                      kernel_size=3, dilation=rate, padding=rate)
            for rate in dilation_rates
        ]
        self.gap_branch = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channel, int(out_channel/4), kernel_size=1),
            nn.ReLU(),
            nn.Upsample(size=(size, size), mode='bilinear')
        )
        self.dilated_convs.append(self.gap_branch)
        self.dilated_convs = nn.ModuleList(self.dilated_convs)

    def forward(self, x):
        local_feat = []
        for dilated_conv in self.dilated_convs:
            local_feat.append(dilated_conv(x))
        local_feat = torch.cat(local_feat, dim=1)
        return local_feat


class DolgLocalBranch(nn.Module):
    def __init__(self, in_channel, out_channel, hidden_channel=2048, image_size=512):
        super().__init__()
        self.multi_atrous = MultiAtrous(in_channel, hidden_channel, size=int(image_size/8))
        self.conv1x1_1 = nn.Conv2d(hidden_channel, out_channel, kernel_size=1)
        self.conv1x1_2 = nn.Conv2d(
            out_channel, out_channel, kernel_size=1, bias=False)
        self.conv1x1_3 = nn.Conv2d(out_channel, out_channel, kernel_size=1)

        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(out_channel)
        self.softplus = nn.Softplus()

    def forward(self, x):
        local_feat = self.multi_atrous(x)

        local_feat = self.conv1x1_1(local_feat)
        local_feat = self.relu(local_feat)
        local_feat = self.conv1x1_2(local_feat)
        local_feat = self.bn(local_feat)

        attention_map = self.relu(local_feat)
        attention_map = self.conv1x1_3(attention_map)
        attention_map = self.softplus(attention_map)

        local_feat = F.normalize(local_feat, p=2, dim=1)
        local_feat = local_feat * attention_map

        return local_feat


class OrthogonalFusion(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, local_feat, global_feat):
        global_feat_norm = torch.norm(global_feat, p=2, dim=1)
        projection = torch.bmm(global_feat.unsqueeze(1), torch.flatten(
            local_feat, start_dim=2))
        projection = torch.bmm(global_feat.unsqueeze(
            2), projection).view(local_feat.size())
        projection = projection / \
            (global_feat_norm * global_feat_norm).view(-1, 1, 1, 1)
        orthogonal_comp = local_feat - projection
        global_feat = global_feat.unsqueeze(-1).unsqueeze(-1)
        return torch.cat([global_feat.expand(orthogonal_comp.size()), orthogonal_comp], dim=1)

class CGDolgNet4(nn.Module):
    def __init__(self,
                 model_name='resnet50', 
                 pretrained=True, 
                 input_dim=3, 
                 hidden_dim=1024, 
                 output_dim=512, 
                 image_size=224,
                 gd_config='MG'):
        super().__init__()

        if model_name == 'resnet101':
            model_name = 'gluon_resnet101_v1b'

        self.model = timm.create_model(
            model_name,
            pretrained=pretrained,
            features_only=True,
            in_chans=input_dim,
            out_indices=(2, 4)
        )
        self.orthogonal_fusion = OrthogonalFusion()
        self.local_branch = DolgLocalBranch(512, hidden_dim, 2048, image_size)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.cgd = CGD(gd_config)
        self.fc_1 = nn.Linear(2048, hidden_dim)
        self.fc_2 = nn.Linear(int(2*hidden_dim), output_dim)

    def forward(self, x):
        output = self.model(x)

        local_feat = self.local_branch(output[0])  # ,hidden_channel,16,16
        global_feat, classes = self.cgd(output[1])  # ,1024
        global_feat = self.fc_1(global_feat)
        feat = self.orthogonal_fusion(local_feat, global_feat)
        feat = self.gap(feat).squeeze()
        feat = self.fc_2(feat)

        return feat, classes


class CGDolgNet3(nn.Module):
    def __init__(self,
                 model_name='resnet50', 
                 pretrained=True, 
                 input_dim=3, 
                 hidden_dim=1024, 
                 output_dim=512, 
                 image_size=224,
                 gd_config='MG'):
        super().__init__()

        if model_name == 'resnet101':
            model_name = 'gluon_resnet101_v1b'

        self.model = timm.create_model(
            model_name,
            pretrained=pretrained,
            features_only=True,
            in_chans=input_dim,
            out_indices=(2, 3)
        )
        self.orthogonal_fusion = OrthogonalFusion()
        self.local_branch = DolgLocalBranch(512, hidden_dim, 2048, image_size)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.cgd = CGD(gd_config)
        self.fc_1 = nn.Linear(1024, hidden_dim)
        self.fc_2 = nn.Linear(int(2*hidden_dim), output_dim)

    def forward(self, x):
        output = self.model(x)

        local_feat = self.local_branch(output[0])  # ,hidden_channel,16,16
        global_feat, classes = self.cgd(output[1])  # ,1024

        feat = self.orthogonal_fusion(local_feat, global_feat)
        feat = self.gap(feat).squeeze()
        feat = self.fc_2(feat)

        return feat, classes




class CGDolgNet(nn.Module):
    def __init__(self,
                 model_name='resnet50', 
                 pretrained=True, 
                 input_dim=3, 
                 hidden_dim=1024, 
                 output_dim=512, 
                 image_size=224,
                 gd_config='SG'):
        super().__init__()

        if model_name == 'resnet101':
            model_name = 'gluon_resnet101_v1b'

        self.model = timm.create_model(
            model_name,
            pretrained=pretrained,
            features_only=True,
            in_chans=input_dim,
            out_indices=(2, 3)
        )
        self.orthogonal_fusion = OrthogonalFusion()
        self.local_branch = DolgLocalBranch(512, hidden_dim, 2048, image_size)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.cgd = CGD(gd_config)
        self.fc_1 = nn.Linear(1024, hidden_dim)
        self.fc_2 = nn.Linear(int(2*hidden_dim), output_dim)

    def forward(self, x):
        output = self.model(x)

        local_feat = self.local_branch(output[0])  # ,hidden_channel,16,16
        global_feat = self.fc_1(self.cgd(output[1]).squeeze())  # ,1024

        feat = self.orthogonal_fusion(local_feat, global_feat)
        feat = self.gap(feat).squeeze()
        feat = self.fc_2(feat)

        return feat

class CGDolgNet2(nn.Module):
    def __init__(self,
                 model_name='resnet50', 
                 pretrained=True, 
                 input_dim=3, 
                 hidden_dim=1024, 
                 output_dim=512, 
                 image_size=224,
                 gd_config='SG'):
        super().__init__()

        if model_name == 'resnet101':
            model_name = 'gluon_resnet101_v1b'

        self.model = timm.create_model(
            model_name,
            pretrained=pretrained,
            features_only=True,
            in_chans=input_dim,
            out_indices=(2, 3)
        )
        self.orthogonal_fusion = OrthogonalFusion()
        self.local_branch = DolgLocalBranch(512, hidden_dim, 2048, image_size)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.cgd = CGD2(gd_config)
        self.fc_1 = nn.Linear(1024, hidden_dim)
        self.fc_2 = nn.Linear(int(2*hidden_dim), output_dim)

    def forward(self, x):
        output = self.model(x)

        local_feat = self.local_branch(output[0])  # ,hidden_channel,16,16
        global_feat = self.fc_1(self.cgd(output[1]).squeeze())  # ,1024

        feat = self.orthogonal_fusion(local_feat, global_feat)
        feat = self.gap(feat).squeeze()
        feat = self.fc_2(feat)

        return feat


In [2]:
trn_dataset = dataset.load(
        name = 'cub',
        root = '/root/youngkim/image-retrieval/Proxy-Anchor-CVPR2020/data',
        mode = 'train',
        transform = dataset.utils.make_transform(
            is_train = True, 
            is_inception = False,
        ))

In [3]:
dl_tr = torch.utils.data.DataLoader(
        trn_dataset,
        batch_size = 120,
        shuffle = True,
        num_workers = 4,
        drop_last = True,
        pin_memory = True
    )

In [4]:
pbar = iter(dl_tr)
x, y = pbar.next()

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

BATCH_NORM_EPSILON = 1e-5
BATCH_NORM_DECAY = 0.9  # == pytorch's default value as well


class BatchNormRelu(nn.Sequential):
    def __init__(self, num_channels, relu=True):
        super().__init__(nn.BatchNorm2d(num_channels, eps=BATCH_NORM_EPSILON), nn.ReLU() if relu else nn.Identity())


def conv(in_channels, out_channels, kernel_size=3, stride=1, bias=False):
    return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                     stride=stride, padding=(kernel_size - 1) // 2, bias=bias)


class SelectiveKernel(nn.Module):
    def __init__(self, in_channels, out_channels, stride, sk_ratio, min_dim=32):
        super().__init__()
        assert sk_ratio > 0.0
        self.main_conv = nn.Sequential(conv(in_channels, 2 * out_channels, stride=stride),
                                       BatchNormRelu(2 * out_channels))
        mid_dim = max(int(out_channels * sk_ratio), min_dim)
        self.mixing_conv = nn.Sequential(conv(out_channels, mid_dim, kernel_size=1), BatchNormRelu(mid_dim),
                                         conv(mid_dim, 2 * out_channels, kernel_size=1))

    def forward(self, x):
        x = self.main_conv(x)
        x = torch.stack(torch.chunk(x, 2, dim=1), dim=0)  # 2, B, C, H, W
        g = x.sum(dim=0).mean(dim=[2, 3], keepdim=True)
        m = self.mixing_conv(g)
        m = torch.stack(torch.chunk(m, 2, dim=1), dim=0)  # 2, B, C, 1, 1
        return (x * F.softmax(m, dim=0)).sum(dim=0)


class Projection(nn.Module):
    def __init__(self, in_channels, out_channels, stride, sk_ratio=0):
        super().__init__()
        if sk_ratio > 0:
            self.shortcut = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)),
                                          # kernel_size = 2 => padding = 1
                                          nn.AvgPool2d(kernel_size=2, stride=stride, padding=0),
                                          conv(in_channels, out_channels, kernel_size=1))
        else:
            self.shortcut = conv(in_channels, out_channels, kernel_size=1, stride=stride)
        self.bn = BatchNormRelu(out_channels, relu=False)

    def forward(self, x):
        return self.bn(self.shortcut(x))


class BottleneckBlock(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride, sk_ratio=0, use_projection=False):
        super().__init__()
        if use_projection:
            self.projection = Projection(in_channels, out_channels * 4, stride, sk_ratio)
        else:
            self.projection = nn.Identity()
        ops = [conv(in_channels, out_channels, kernel_size=1), BatchNormRelu(out_channels)]
        if sk_ratio > 0:
            ops.append(SelectiveKernel(out_channels, out_channels, stride, sk_ratio))
        else:
            ops.append(conv(out_channels, out_channels, stride=stride))
            ops.append(BatchNormRelu(out_channels))
        ops.append(conv(out_channels, out_channels * 4, kernel_size=1))
        ops.append(BatchNormRelu(out_channels * 4, relu=False))
        self.net = nn.Sequential(*ops)

    def forward(self, x):
        shortcut = self.projection(x)
        return F.relu(shortcut + self.net(x))


class Blocks(nn.Module):
    def __init__(self, num_blocks, in_channels, out_channels, stride, sk_ratio=0):
        super().__init__()
        self.blocks = nn.ModuleList([BottleneckBlock(in_channels, out_channels, stride, sk_ratio, True)])
        self.channels_out = out_channels * BottleneckBlock.expansion
        for _ in range(num_blocks - 1):
            self.blocks.append(BottleneckBlock(self.channels_out, out_channels, 1, sk_ratio))

    def forward(self, x):
        for b in self.blocks:
            x = b(x)
        return x


class Stem(nn.Sequential):
    def __init__(self, sk_ratio, width_multiplier):
        ops = []
        channels = 64 * width_multiplier // 2
        if sk_ratio > 0:
            ops.append(conv(3, channels, stride=2))
            ops.append(BatchNormRelu(channels))
            ops.append(conv(channels, channels))
            ops.append(BatchNormRelu(channels))
            ops.append(conv(channels, channels * 2))
        else:
            ops.append(conv(3, channels * 2, kernel_size=7, stride=2))
        ops.append(BatchNormRelu(channels * 2))
        ops.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        super().__init__(*ops)


class ResNet(nn.Module):
    def __init__(self, layers, width_multiplier, sk_ratio):
        super().__init__()
        ops = [Stem(sk_ratio, width_multiplier)]
        channels_in = 64 * width_multiplier
        ops.append(Blocks(layers[0], channels_in, 64 * width_multiplier, 1, sk_ratio))
        channels_in = ops[-1].channels_out
        ops.append(Blocks(layers[1], channels_in, 128 * width_multiplier, 2, sk_ratio))
        channels_in = ops[-1].channels_out
        ops.append(Blocks(layers[2], channels_in, 256 * width_multiplier, 2, sk_ratio))
        channels_in = ops[-1].channels_out
        ops.append(Blocks(layers[3], channels_in, 512 * width_multiplier, 2, sk_ratio))
        channels_in = ops[-1].channels_out
        self.channels_out = channels_in
        self.net = nn.Sequential(*ops)

    def forward(self, x):
        x_2 = self.net[:3](x)
        x_3 = self.net[3](x_2)
        x_4 = self.net[4:](x_3)
        h = x_4.mean(dim=[2, 3])

        return x_2, x_3, h


class ContrastiveHead(nn.Module):
    def __init__(self, channels_in, out_dim=128, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            if i != num_layers - 1:
                dim, relu = channels_in, True
            else:
                dim, relu = out_dim, False
            self.layers.append(nn.Linear(channels_in, dim, bias=False))
            bn = nn.BatchNorm1d(dim, eps=BATCH_NORM_EPSILON, affine=True)
            if i == num_layers - 1:
                nn.init.zeros_(bn.bias)
            self.layers.append(bn)
            if relu:
                self.layers.append(nn.ReLU())

    def forward(self, x):
        for b in self.layers:
            x = b(x)
        return x


def get_resnet(depth=50, width_multiplier=1, sk_ratio=0):  # sk_ratio=0.0625 is recommended
    layers = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}[depth]
    resnet = ResNet(layers, width_multiplier, sk_ratio)
    return resnet, ContrastiveHead(resnet.channels_out)


def name_to_params(checkpoint):
    sk_ratio = 0.0625 if '_sk1' in checkpoint else 0
    if 'r50_' in checkpoint:
        depth = 50
    elif 'r101_' in checkpoint:
        depth = 101
    elif 'r152_' in checkpoint:
        depth = 152
    else:
        raise NotImplementedError

    if '_1x_' in checkpoint:
        width = 1
    elif '_2x_' in checkpoint:
        width = 2
    elif '_3x_' in checkpoint:
        width = 3
    else:
        raise NotImplementedError

    return depth, width, sk_ratio

In [6]:
checkpoint = '/root/youngkim/image-retrieval/SimCLRv2-Pytorch/r50_1x_sk1.pth'
name_to_params(checkpoint)

(50, 1, 0.0625)

In [7]:
r50, _ =  get_resnet(*name_to_params(checkpoint))
r50

ResNet(
  (net): Sequential(
    (0): Stem(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNormRelu(
        (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
      )
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): BatchNormRelu(
        (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
      )
      (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (5): BatchNormRelu(
        (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
      )
      (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (1): Blocks(
      (blocks): ModuleList(
        (0): BottleneckBlock(
          (projection): Projection(
            (shortcut): Sequential(
          

In [8]:
r50.cuda()
x_2, x_3, h = r50(x.squeeze().cuda())

In [9]:
x_2.shape, x_3.shape, h.shape

(torch.Size([120, 512, 28, 28]),
 torch.Size([120, 1024, 14, 14]),
 torch.Size([120, 2048]))

In [21]:
model = timm.create_model(
            'resnet50',
            pretrained=True,
            features_only=True,
            in_chans=3,
            out_indices=(2, 3)
        )

In [22]:
model.cuda()
r_timm = model(x.squeeze().cuda())

In [23]:
r_timm[0].shape, r_timm[1].shape

(torch.Size([120, 512, 28, 28]), torch.Size([120, 1024, 14, 14]))

In [27]:
model_ = timm.create_model(
            'resnet50',
            pretrained=True,
            features_only=True,
            in_chans=3,
        )
model_.layer4

Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act2): ReLU(inplace=True)
    (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act3): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): Bottleneck(
    (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, mome