In [1]:
import re

In [2]:
import torch
import torch.nn as nn
import torchvision.models as models

In [3]:
from transformers import DistilBertTokenizer, DistilBertModel

In [4]:
import torchinfo

In [8]:

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
bert_model = DistilBertModel.from_pretrained("distilbert-base-uncased")
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
output = bert_model(**encoded_input)

In [9]:
def insert_module(model, indices, modules):
    indices = indices if isinstance(indices, list) else [indices]
    modules = modules if isinstance(modules, list) else [modules]
    assert len(indices) == len(modules)

    layers_name = [name for name, _ in model.named_modules()][1:]
    for index, module in zip(indices, modules):
        layer_name = re.sub(r'(.)(\d)', r'[\2]', layers_name[index])
        exec("model.{name} = nn.Sequential(model.{name}, module)".format(name = layer_name))

In [10]:
class sparse_part(nn.Module):

    def __init__(self, in_dim, out_dim):

        super().__init__()

        self.lin = nn.Linear(in_features=in_dim, out_features=out_dim)
        self.laynorm = nn.LayerNorm(out_dim)
        self.act = nn.ReLU6()

    def forward(self, input):

        x = self.lin(input)
        x = self.laynorm(self.act(x))

        return x

In [195]:
class sparse_module(nn.Module):
    
    def __init__(self, layer_size: int, internal_size: int):

        super().__init__()

        self.layer = sparse_part(layer_size, internal_size)

        self.bottleneck = sparse_part(internal_size, internal_size//8)

        self.output = nn.Linear(internal_size//8, layer_size)

    def forward(self, input):

        x = self.layer(input)
        x = self.bottleneck(x)
        x = nn.functional.leaky_relu(self.output(x))

        return x, torch.linalg.vector_norm(x, ord = 1, dim = 2)

In [206]:
class ModifiedModel(nn.Module):
    def __init__(self, pretrained_model: nn.Module, custom_module: nn.Module, insert_after_layer: int, 
                 debug: bool = False):
        """
        Parameters:
            pretrained_model: torch.nn.Module any pretrained model from HF (it may not work for any actually. Moreover, it may work for 
            the only specified here for now, and may require more suffisticated additional tricks to make it work for another
            custom_module: torch.nn.Module The reqired module to add
            insert_after_layer: int The number of block to insert after
            debug: bool Show outputs of some layers during forward pass
        """
        
        super(ModifiedModel, self).__init__()
        self.debug = debug
        

        self.pretrained_model = pretrained_model

        self.embedding = list(self.pretrained_model.children())[0]

        self.arr = nn.ModuleList([])

        pattern = re.compile("transformer.layer.\d+$")
        
        for i in bert_model.named_modules():
            if re.match(pattern, i[0]):
                self.arr.append(i[1])
        
        self.custom_module = custom_module
        self.insert_place = insert_after_layer

    def forward(self, x: dict[torch.tensor]) -> torch.tensor:
        """
        Parameters:
            x: dict("input_ids": torch.tensor, "attention_mask": torch.tensor)
        
        Returns:
            The output of encoder layers, without any linear classification layer at the end for now.
        """
        
        enc = self.embedding(x["input_ids"])
        if self.debug:
            print(enc)

        for module in self.arr[:self.insert_place]:
            enc = module(**{"x": enc, "attn_mask": x["attention_mask"]})[0]
        
        enc, l1_norm = self.custom_module(enc)
        if self.debug:
            print(enc.size(), l1_norm, l1_norm.size())

        for module in self.arr[self.insert_place:]:
            enc = module(**{"x": enc, "attn_mask": x["attention_mask"]})[0]

        
        return enc

In [201]:
custom_module = sparse_module(768, 96) 

insert_after_layer = 2

modified_model = ModifiedModel(pretrained_model=bert_model, 
                               custom_module=custom_module, insert_after_layer=insert_after_layer,
                                debug = False)

In [202]:
layers_name = [name for name, _ in modified_model.named_modules()][1:]

In [203]:
list(modified_model.arr)

[TransformerBlock(
   (attention): MultiHeadSelfAttention(
     (dropout): Dropout(p=0.1, inplace=False)
     (q_lin): Linear(in_features=768, out_features=768, bias=True)
     (k_lin): Linear(in_features=768, out_features=768, bias=True)
     (v_lin): Linear(in_features=768, out_features=768, bias=True)
     (out_lin): Linear(in_features=768, out_features=768, bias=True)
   )
   (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
   (ffn): FFN(
     (dropout): Dropout(p=0.1, inplace=False)
     (lin1): Linear(in_features=768, out_features=3072, bias=True)
     (lin2): Linear(in_features=3072, out_features=768, bias=True)
     (activation): GELUActivation()
   )
   (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
 ),
 TransformerBlock(
   (attention): MultiHeadSelfAttention(
     (dropout): Dropout(p=0.1, inplace=False)
     (q_lin): Linear(in_features=768, out_features=768, bias=True)
     (k_lin): Linear(in_features=768, out_features=768, 

In [204]:
output = modified_model(encoded_input)

In [205]:
output

tensor([[[-0.4618, -0.2675,  0.2585,  ..., -0.2905, -0.2198,  0.1461],
         [-0.4584, -0.2628,  0.2569,  ..., -0.2902, -0.2140,  0.1538],
         [-0.4588, -0.2500,  0.2531,  ..., -0.2852, -0.2159,  0.1659],
         ...,
         [-0.4604, -0.2616,  0.2504,  ..., -0.2881, -0.2142,  0.1562],
         [-0.4467, -0.2545,  0.2575,  ..., -0.2914, -0.2174,  0.1486],
         [-0.4562, -0.2612,  0.2558,  ..., -0.2924, -0.2115,  0.1512]]],
       grad_fn=<NativeLayerNormBackward0>)

### another possible approach for basic architectures

In [5]:
import torchvision.models as models

In [70]:
class ModifiedModel(nn.Module):
    def __init__(self, pretrained_model: nn.Module, custom_module: nn.Module, insert_after_layer: int, 
                 debug: bool = False):
        super(ModifiedModel, self).__init__()
        self.debug = debug
        
        self.pretrained_model = pretrained_model
        
        self.features = nn.Sequential(*list(self.pretrained_model.children())[:insert_after_layer])
        
        self.remaining_layers = nn.Sequential(*list(self.pretrained_model.children())[insert_after_layer:-1])

        self.linear = list(self.pretrained_model.children())[-1]
        
        self.custom_module = custom_module

    def forward(self, x):
        
        x = self.features(x)
        if self.debug:
            print(x.size())
        
        
        x = self.custom_module(x)
        if self.debug:
            print(x.size())
        
        
        x = self.remaining_layers(x)
        if self.debug:
            print(x.size())

        x = torch.squeeze(x)

        x = self.linear(x)
        
        return x
        


pretrained_resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)


custom_module = nn.Conv2d(64, 64, kernel_size=3, padding=1) 

insert_after_layer = 3  

modified_model = ModifiedModel(pretrained_model=pretrained_resnet, 
                               custom_module=custom_module, insert_after_layer=insert_after_layer,
                                debug = True)


input_data = torch.randn(1, 3, 512, 512) 
output = pretrained_resnet(input_data)
print(output.shape)  

torch.Size([1, 1000])


In [58]:
torchinfo.summary(pretrained_resnet, (1, 3, 512, 512), device="cpu")

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 256, 256]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 256, 256]         128
├─ReLU: 1-3                              [1, 64, 256, 256]         --
├─MaxPool2d: 1-4                         [1, 64, 128, 128]         --
├─Sequential: 1-5                        [1, 64, 128, 128]         --
│    └─BasicBlock: 2-1                   [1, 64, 128, 128]         --
│    │    └─Conv2d: 3-1                  [1, 64, 128, 128]         36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 128, 128]         128
│    │    └─ReLU: 3-3                    [1, 64, 128, 128]         --
│    │    └─Conv2d: 3-4                  [1, 64, 128, 128]         36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 128, 128]         128
│    │    └─ReLU: 3-6                    [1, 64, 128, 128]         --
│

In [59]:
torchinfo.summary(modified_model, (1, 3, 512, 512))

torch.Size([1, 64, 256, 256])
torch.Size([1, 64, 256, 256])
torch.Size([1, 512, 1, 1])


Layer (type:depth-idx)                        Output Shape              Param #
ModifiedModel                                 [1000]                    --
├─Sequential: 1-1                             [1, 64, 256, 256]         --
│    └─Conv2d: 2-1                            [1, 64, 256, 256]         9,408
│    └─BatchNorm2d: 2-2                       [1, 64, 256, 256]         128
│    └─ReLU: 2-3                              [1, 64, 256, 256]         --
├─Conv2d: 1-2                                 [1, 64, 256, 256]         36,928
├─Sequential: 1-3                             [1, 512, 1, 1]            --
│    └─MaxPool2d: 2-4                         [1, 64, 128, 128]         --
│    └─Sequential: 2-5                        [1, 64, 128, 128]         --
│    │    └─BasicBlock: 3-1                   [1, 64, 128, 128]         73,984
│    │    └─BasicBlock: 3-2                   [1, 64, 128, 128]         73,984
│    └─Sequential: 2-6                        [1, 128, 64, 64]          --
│   