In [93]:
import torchvision.models as torch_models
from torch import nn
import torch

In [94]:
class Special_Adapter_v1(nn.Module):
    def __init__(self, in_planes:int, mid_planes:int, kernel_size:int, use_alpha=True, conv_group=1):
        super().__init__()
        self.in_planes = in_planes
        self.mid_planes = mid_planes
        self.conv = nn.Conv2d(in_planes, mid_planes, kernel_size=kernel_size, groups=conv_group)
        self.bn1 = nn.BatchNorm2d(mid_planes)
        self.relu = nn.ReLU(inplace=True)
        self.convTransposed = nn.ConvTranspose2d(mid_planes, in_planes, kernel_size=kernel_size, groups=conv_group)
        self.bn2 = nn.BatchNorm2d(in_planes)
        
        self.use_alpha = use_alpha
        if use_alpha:
            self.alpha = nn.Parameter(torch.ones(1)*0.02)
            print('Apply alpha!')
    
    def forward(self, x):
        if isinstance(x, tuple):
            x = x[0]
        
        ### original: conv+bn+ReLU+convT+bn+ReLU ###
        out = self.conv(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.convTransposed(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        if self.use_alpha:
            out = out * self.alpha

        return out

In [95]:
feature_extractor = torch_models.resnet18()

In [96]:
for name, module in feature_extractor.named_children():
    print(f'{name}')    
    # x = module(x)
    # if name in self.layer_names:
    #     adapter_id = name.replace('.', '_') + '_adapters'
    #     adapters = getattr(self, adapter_id)
    #     for adapter in adapters:
    #         x = adapter(x)

conv1
bn1
relu
maxpool
layer1
layer2
layer3
layer4
avgpool
fc


In [97]:
class CNN_Adapter_Net_CIL_V2(nn.Module):
    def __init__(self,fe):
        super(CNN_Adapter_Net_CIL_V2, self).__init__()
        self.feature_extractor = fe
        self.feature_extractor.fc = nn.Identity()
        self.layer_names = ['layer1','layer2','layer3','layer4']
        for layer_id in self.layer_names:
            adapter_id = layer_id.replace('.', '_')+'_adapters'
            self.register_module(adapter_id, nn.ModuleList([]))
        self.task_sizes = [2]
    def forward(self,x):
        
        for name, module in self.feature_extractor.named_children():
            print(f'{name}')    
            if name in self.layer_names:
                adapter_id = name.replace('.', '_') + '_adapters'
                adapters = getattr(self, adapter_id)
                b, c, h, w = x.shape
                if len(adapters) < len(self.task_sizes):
                    print(f'Append new adapters')
                    getattr(self, adapter_id).append(Special_Adapter_v1(c, c, 3).cuda())
                
                x = x + getattr(self, adapter_id)[-1](x)
                
            x = module(x)
        return x

In [98]:
net = CNN_Adapter_Net_CIL_V2(feature_extractor).cuda()

In [99]:
x = torch.rand(1, 3, 224, 224).cuda()
x.shape


torch.Size([1, 3, 224, 224])

In [100]:
res = net(x)

conv1
bn1
relu
maxpool
layer1
Append new adapters
Apply alpha!
layer2
Append new adapters
Apply alpha!
layer3
Append new adapters
Apply alpha!
layer4
Append new adapters
Apply alpha!
avgpool
fc


In [101]:
res.shape

torch.Size([1, 512, 1, 1])

In [102]:
net.layer1_adapters

ModuleList(
  (0): Special_Adapter_v1(
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (convTransposed): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)