In [2]:
import re

In [3]:
import torch
import torch.nn as nn
import torchvision.models as models

In [4]:
from transformers import DistilBertTokenizer, DistilBertModel

In [35]:
import torchinfo

In [7]:
import torchvision.models as models

In [70]:
class ModifiedModel(nn.Module):
    def __init__(self, pretrained_model: nn.Module, custom_module: nn.Module, insert_after_layer: int, 
                 debug: bool = False):
        super(ModifiedModel, self).__init__()
        self.debug = debug
        
        self.pretrained_model = pretrained_model
        
        self.features = nn.Sequential(*list(self.pretrained_model.children())[:insert_after_layer])
        
        self.remaining_layers = nn.Sequential(*list(self.pretrained_model.children())[insert_after_layer:-1])

        self.linear = list(self.pretrained_model.children())[-1]
        
        self.custom_module = custom_module

    def forward(self, x):
        
        x = self.features(x)
        if self.debug:
            print(x.size())
        
        
        x = self.custom_module(x)
        if self.debug:
            print(x.size())
        
        
        x = self.remaining_layers(x)
        if self.debug:
            print(x.size())

        x = torch.squeeze(x)

        x = self.linear(x)
        
        return x
        


pretrained_resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)


custom_module = nn.Conv2d(64, 64, kernel_size=3, padding=1) 

insert_after_layer = 3  

modified_model = ModifiedModel(pretrained_model=pretrained_resnet, 
                               custom_module=custom_module, insert_after_layer=insert_after_layer,
                                debug = True)


input_data = torch.randn(1, 3, 512, 512) 
output = pretrained_resnet(input_data)
print(output.shape)  

torch.Size([1, 1000])


In [58]:
torchinfo.summary(pretrained_resnet, (1, 3, 512, 512), device="cpu")

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 256, 256]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 256, 256]         128
├─ReLU: 1-3                              [1, 64, 256, 256]         --
├─MaxPool2d: 1-4                         [1, 64, 128, 128]         --
├─Sequential: 1-5                        [1, 64, 128, 128]         --
│    └─BasicBlock: 2-1                   [1, 64, 128, 128]         --
│    │    └─Conv2d: 3-1                  [1, 64, 128, 128]         36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 128, 128]         128
│    │    └─ReLU: 3-3                    [1, 64, 128, 128]         --
│    │    └─Conv2d: 3-4                  [1, 64, 128, 128]         36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 128, 128]         128
│    │    └─ReLU: 3-6                    [1, 64, 128, 128]         --
│

In [59]:
torchinfo.summary(modified_model, (1, 3, 512, 512))

torch.Size([1, 64, 256, 256])
torch.Size([1, 64, 256, 256])
torch.Size([1, 512, 1, 1])


Layer (type:depth-idx)                        Output Shape              Param #
ModifiedModel                                 [1000]                    --
├─Sequential: 1-1                             [1, 64, 256, 256]         --
│    └─Conv2d: 2-1                            [1, 64, 256, 256]         9,408
│    └─BatchNorm2d: 2-2                       [1, 64, 256, 256]         128
│    └─ReLU: 2-3                              [1, 64, 256, 256]         --
├─Conv2d: 1-2                                 [1, 64, 256, 256]         36,928
├─Sequential: 1-3                             [1, 512, 1, 1]            --
│    └─MaxPool2d: 2-4                         [1, 64, 128, 128]         --
│    └─Sequential: 2-5                        [1, 64, 128, 128]         --
│    │    └─BasicBlock: 3-1                   [1, 64, 128, 128]         73,984
│    │    └─BasicBlock: 3-2                   [1, 64, 128, 128]         73,984
│    └─Sequential: 2-6                        [1, 128, 64, 64]          --
│   

In [87]:

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
bert_model = DistilBertModel.from_pretrained("distilbert-base-uncased")
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
output = bert_model(**encoded_input)

In [61]:
def insert_module(model, indices, modules):
    indices = indices if isinstance(indices, list) else [indices]
    modules = modules if isinstance(modules, list) else [modules]
    assert len(indices) == len(modules)

    layers_name = [name for name, _ in model.named_modules()][1:]
    for index, module in zip(indices, modules):
        layer_name = re.sub(r'(.)(\d)', r'[\2]', layers_name[index])
        exec("model.{name} = nn.Sequential(model.{name}, module)".format(name = layer_name))

In [64]:
layers_name = [name for name, _ in model.named_modules()][1:]

In [65]:
layers_name

['embeddings',
 'embeddings.word_embeddings',
 'embeddings.position_embeddings',
 'embeddings.LayerNorm',
 'embeddings.dropout',
 'transformer',
 'transformer.layer',
 'transformer.layer.0',
 'transformer.layer.0.attention',
 'transformer.layer.0.attention.dropout',
 'transformer.layer.0.attention.q_lin',
 'transformer.layer.0.attention.k_lin',
 'transformer.layer.0.attention.v_lin',
 'transformer.layer.0.attention.out_lin',
 'transformer.layer.0.sa_layer_norm',
 'transformer.layer.0.ffn',
 'transformer.layer.0.ffn.dropout',
 'transformer.layer.0.ffn.lin1',
 'transformer.layer.0.ffn.lin2',
 'transformer.layer.0.ffn.activation',
 'transformer.layer.0.output_layer_norm',
 'transformer.layer.1',
 'transformer.layer.1.attention',
 'transformer.layer.1.attention.dropout',
 'transformer.layer.1.attention.q_lin',
 'transformer.layer.1.attention.k_lin',
 'transformer.layer.1.attention.v_lin',
 'transformer.layer.1.attention.out_lin',
 'transformer.layer.1.sa_layer_norm',
 'transformer.layer.1.

In [66]:
class sparse_part(nn.Module):

    def __init__(self, in_dim, out_dim):

        super().__init__()

        self.lin = nn.Linear(in_features=in_dim, out_features=out_dim)
        self.laynorm = nn.LayerNorm(out_dim)
        self.act = nn.ReLU6()

    def forward(self, input):

        x = self.lin(input)
        x = self.laynorm(self.act(x))

        return x

In [67]:
class sparse_module(nn.Module):
    
    def __init__(self, layer_size: int, internal_size: int):

        super().__init__()

        self.layer = sparse_part(layer_size, internal_size)

        self.bootleneck = sparse_part(internal_size, internal_size//8)

        self.output = nn.Linear(internal_size//8, layer_size)

    def forward(self, input):

        x = self.layer(input)
        x = self.bottleneck(self.lin2(input))
        x = nn.functional.leaky_relu(self.output(x))

        return x, nn.functional.l1_loss(x)

In [132]:
class ModifiedModel(nn.Module):
    def __init__(self, pretrained_model: nn.Module, custom_module: nn.Module, insert_after_layer: int, 
                 debug: bool = False):
        super(ModifiedModel, self).__init__()
        self.debug = debug
        

        self.pretrained_model = pretrained_model

        self.embedding = list(self.pretrained_model.children())[0]
        

        self.features = nn.Sequential(*list(self.pretrained_model.children())[1:insert_after_layer])
        

        self.remaining_layers = nn.Sequential(*list(self.pretrained_model.children())[insert_after_layer:])
        
        self.custom_module = custom_module

    def forward(self, x):

        print(self.embedding)
        
        enc = self.embedding(x["input_ids"])
        if self.debug:
            print(x, enc.size())
            
        x = self.features({"hidden_states": enc, "attention_mask": x["attention_mask"]})
        if self.debug:
            print(x.size())
        

        x = self.custom_module(x)
        if self.debug:
            print(x.size())

        x = self.remaining_layers(x)
        if self.debug:
            print(x.size())
        
        return x

In [133]:
custom_module = sparse_module(768, 96) 

insert_after_layer = 33

modified_model = ModifiedModel(pretrained_model=bert_model, 
                               custom_module=custom_module, insert_after_layer=insert_after_layer,
                                debug = True)

In [194]:
list(modified_model.pretrained_model.named_modules())[16]

('transformer.layer.0.ffn',
 FFN(
   (dropout): Dropout(p=0.1, inplace=False)
   (lin1): Linear(in_features=768, out_features=3072, bias=True)
   (lin2): Linear(in_features=3072, out_features=768, bias=True)
   (activation): GELUActivation()
 ))

In [134]:
output = modified_model(encoded_input)

Embeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
{'input_ids': tensor([[ 101, 5672, 2033, 2011, 2151, 3793, 2017, 1005, 1040, 2066, 1012,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])} torch.Size([1, 12, 768])


TypeError: 'NoneType' object is not subscriptable

In [36]:
insert_module(model.pretrained, 33, sparse_module(768, 96))

In [29]:
encoded_input["input_ids"].shape

torch.Size([1, 12])

In [30]:
encoded_input

{'input_ids': tensor([[ 101, 5672, 2033, 2011, 2151, 3793, 2017, 1005, 1040, 2066, 1012,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [105]:
list(encoded_input.values())

[tensor([[ 101, 5672, 2033, 2011, 2151, 3793, 2017, 1005, 1040, 2066, 1012,  102]]),
 tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])]

In [77]:
[val.shape for val in  encoded_input.values()]

[torch.Size([1, 12]), torch.Size([1, 12])]

In [81]:
model

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Li