In [424]:
import torch
import torchvision
import torch.nn as nn
import math
from torchinfo import summary

In [425]:
class backbone(nn.Module):
    def __init__(self):
        super(backbone, self).__init__()
        model = torchvision.models.resnet50(pretrained=True)
        self.features = nn.Sequential(*list(model.children())[:-3])
        for param in self.features.parameters():
            param.requires_grad = False
        self.query_conv=nn.Conv2d(1024,256, 1)
        self.target_conv=nn.Conv2d(1024,256, 3)
        self.flatten = nn.Flatten(start_dim=2,end_dim=3)
    def forward(self, query,target):
        query_features = self.features(query)
        query_features_conv=self.query_conv(query_features)
        query_features_conv_flatten=self.flatten(query_features_conv)
        query_features_conv_flatten=query_features_conv_flatten.permute(0, 2, 1)
        
        target_features = self.features(target)
        target_features_conv=self.target_conv(target_features)
        target_features_conv_flatten=self.flatten(target_features_conv)
        target_features_conv_flatten=target_features_conv_flatten.permute(0, 2, 1)
        return query_features_conv_flatten,target_features_conv_flatten


In [426]:
res50_model = torchvision.models.resnet50(pretrained=True)
res50_ = backbone()

In [427]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)
model = backbone().to(device)
summary(model=res50_,
         input_size=[(1,3,128,128),(1,3,1000,600)])

cpu


Layer (type:depth-idx)                        Output Shape              Param #
backbone                                      [1, 64, 256]              --
├─Sequential: 1-1                             [1, 1024, 8, 8]           --
│    └─Conv2d: 2-1                            [1, 64, 64, 64]           (9,408)
│    └─BatchNorm2d: 2-2                       [1, 64, 64, 64]           (128)
│    └─ReLU: 2-3                              [1, 64, 64, 64]           --
│    └─MaxPool2d: 2-4                         [1, 64, 32, 32]           --
│    └─Sequential: 2-5                        [1, 256, 32, 32]          --
│    │    └─Bottleneck: 3-1                   [1, 256, 32, 32]          (75,008)
│    │    └─Bottleneck: 3-2                   [1, 256, 32, 32]          (70,400)
│    │    └─Bottleneck: 3-3                   [1, 256, 32, 32]          (70,400)
│    └─Sequential: 2-6                        [1, 512, 16, 16]          --
│    │    └─Bottleneck: 3-4                   [1, 512, 16, 16]       

In [428]:
query = torch.rand(2,3,128,128)
target = torch.rand(2,3,1000,600)
query,target= res50_(query,target)
# print(query.shape,target)


In [429]:
# query=query.reshape(64,256)
# target=target.reshape(2196,256)
query.shape

torch.Size([2, 64, 256])

In [430]:
position_embedding_query = nn.Parameter(torch.rand(1,
                                             query.shape[1], 
                                             query.shape[2]),
                                  requires_grad=True) 

In [431]:
position_embedding_target = nn.Parameter(torch.rand(1,
                                             target.shape[1], 
                                             target.shape[2]),
                                  requires_grad=True) 

In [432]:
query_pe=query+position_embedding_query
target_pe=target+position_embedding_target

In [433]:
class MultiheadSelfAttentionBlock(nn.Module):
    """Creates a multi-head self-attention block ("MSA block" for short).
    """
    # 2. Initialize the class with hyperparameters from Table 1
    def __init__(self,
                 embedding_dim:int=256,
                 num_heads:int=8, # 
                 attn_dropout:int=0): # doesn't look like the paper uses any dropout in MSABlocks
        super().__init__()
        
        # 3. Create the Norm layer (LN)
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
        
        # 4. Create the Multi-Head Attention (MSA) layer
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                    num_heads=num_heads,
                                                    dropout=attn_dropout,
                                                    batch_first=True) # does our batch dimension come first?
        
    # 5. Create a forward() method to pass the data throguh the layers
    def forward(self,q,k,v):
#         print (k.shape)
        q = self.layer_norm(q)
        k = self.layer_norm(k)
        v = self.layer_norm(v)
#         print (q.shape)
        attn_output, _ = self.multihead_attn(query=q, # query embeddings 
                                             key=k, # key embeddings
                                             value=v, # value embeddings
                                             need_weights=False) # do we need the weights or just the layer outputs?
        return attn_output

In [434]:
# Create an instance of MSABlock
multihead_self_attention_block = MultiheadSelfAttentionBlock(embedding_dim=256, 
                                                             num_heads=8) 

# Pass patch and position image embedding through MSABlock
patched_image_through_msa_block = multihead_self_attention_block(query_pe,target_pe,target)
print(f"Input shape of MSA block: {query.shape}")
print(f"Output shape MSA block: {patched_image_through_msa_block.shape}")

Input shape of MSA block: torch.Size([2, 64, 256])
Output shape MSA block: torch.Size([2, 64, 256])


In [435]:
summary(model=multihead_self_attention_block,
         input_size=[(64, 256),(2196 ,256 ),(2196,256)])

Layer (type:depth-idx)                   Output Shape              Param #
MultiheadSelfAttentionBlock              [64, 256]                 --
├─LayerNorm: 1-1                         [64, 256]                 512
├─LayerNorm: 1-2                         [2196, 256]               (recursive)
├─LayerNorm: 1-3                         [2196, 256]               (recursive)
├─MultiheadAttention: 1-4                [64, 256]                 263,168
Total params: 263,680
Trainable params: 263,680
Non-trainable params: 0
Total mult-adds (M): 2.28
Input size (MB): 4.56
Forward/backward pass size (MB): 0.13
Params size (MB): 0.00
Estimated Total Size (MB): 4.70

In [436]:
class MLPBlock(nn.Module):
    """Creates a layer normalized multilayer perceptron block ("MLP block" for short)."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 embedding_dim:int=256, # Hidden Size D from Table 1 for ViT-Base
                 mlp_size:int=1024, # MLP size from Table 1 for ViT-Base
                 dropout:int=0.1): # Dropout from Table 3 for ViT-Base
        super().__init__()
        
        # 3. Create the Norm layer (LN)
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
        
        # 4. Create the Multilayer perceptron (MLP) layer(s)
        self.mlp = nn.Sequential(
            nn.Linear(in_features=embedding_dim,
                      out_features=mlp_size),
            nn.ReLU(), # "The MLP contains two layers with a GELU non-linearity (section 3.1)."
            nn.Dropout(p=dropout),
            nn.Linear(in_features=mlp_size, # needs to take same in_features as out_features of layer above
                      out_features=embedding_dim), # take back to embedding_dim
            nn.Dropout(p=dropout) # "Dropout, when used, is applied after every dense layer.."
        )
    
    # 5. Create a forward() method to pass the data throguh the layers
    def forward(self, x):
        x = self.layer_norm(x)
        x = self.mlp(x)
        return x

In [437]:
mlp_block = MLPBlock(embedding_dim=256, # from Table 1 
                     mlp_size=1024, # from Table 1
                     dropout=0.1) # from Table 3

# Pass output of MSABlock through MLPBlock
patched_image_through_mlp_block = mlp_block(patched_image_through_msa_block)
print(f"Input shape of MLP block: {patched_image_through_mlp_block.shape}")
print(f"Output shape MLP block: {patched_image_through_mlp_block.shape}")

Input shape of MLP block: torch.Size([2, 64, 256])
Output shape MLP block: torch.Size([2, 64, 256])


In [438]:
summary(model=mlp_block,
         input_size=[(64, 256)])

Layer (type:depth-idx)                   Output Shape              Param #
MLPBlock                                 [64, 256]                 --
├─LayerNorm: 1-1                         [64, 256]                 512
├─Sequential: 1-2                        [64, 256]                 --
│    └─Linear: 2-1                       [64, 1024]                263,168
│    └─ReLU: 2-2                         [64, 1024]                --
│    └─Dropout: 2-3                      [64, 1024]                --
│    └─Linear: 2-4                       [64, 256]                 262,400
│    └─Dropout: 2-5                      [64, 256]                 --
Total params: 526,080
Trainable params: 526,080
Non-trainable params: 0
Total mult-adds (M): 33.67
Input size (MB): 0.07
Forward/backward pass size (MB): 0.79
Params size (MB): 2.10
Estimated Total Size (MB): 2.96

In [439]:
# 1. Create a class that inherits from nn.Module
class TransformerEncoderBlock(nn.Module):
    """Creates a Transformer Encoder block."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 embedding_dim:int=256, # Hidden size D from Table 1 for ViT-Base
                 num_heads:int=8, # Heads from Table 1 for ViT-Base
                 mlp_size:int=1024, # MLP size from Table 1 for ViT-Base
                 mlp_dropout:int=0.1, # Amount of dropout for dense layers from Table 3 for ViT-Base
                 attn_dropout:int=0): # Amount of dropout for attention layers
        super().__init__()

        # 3. Create MSA block (equation 2)
        self.msa_block = MultiheadSelfAttentionBlock(embedding_dim=embedding_dim,
                                                     num_heads=num_heads,
                                                     attn_dropout=attn_dropout)
        
        # 4. Create MLP block (equation 3)
        self.mlp_block =  MLPBlock(embedding_dim=embedding_dim,
                                   mlp_size=mlp_size,
                                   dropout=mlp_dropout)
        
    # 5. Create a forward() method  
    def forward(self, q,k,v):
        
        # 6. Create residual connection for MSA block (add the input to the output)
        q =  self.msa_block(q,k,v) + q 
        
        # 7. Create residual connection for MLP block (add the input to the output)
        q = self.mlp_block(q) + q
        
        return q

In [440]:
## Create an instance of TransformerEncoderB`lock
transformer_encoder_block = TransformerEncoderBlock()


# # # Print an input and output summary of our Transformer Encoder (uncomment for full output)
summary(model=transformer_encoder_block,
          input_size=[(64, 256),(2196 ,256 ),(2196,256)])
transformer_encoder_block(query_pe,target_pe,target).shape

torch.Size([2, 64, 256])

In [441]:
# 1. Create a ViT class that inherits from nn.Module
class one_shot(nn.Module):
    """Creates a Vision Transformer architecture with ViT-Base hyperparameters by default."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 query_size:int=128,
                 target_width:int=1000,
                 target_height:int=600,
                 in_channels:int=3, # Number of channels in input image# Patch size
                 num_transformer_layers:int=4, # Layers from Table 1 for ViT-Base
                 embedding_dim:int=256, # Hidden size D from Table 1 for ViT-Base
                 mlp_size:int=1024, # MLP size from Table 1 for ViT-Base
                 num_heads:int=8, # Heads from Table 1 for ViT-Base
                 attn_dropout:int=0, # Dropout for attention projection
                 mlp_dropout:int=0.1, # Dropout for dense/MLP layers 
                 embedding_dropout:int=0.1): # Default for ImageNet but can customize this
        super().__init__() # don't forget the super().__init__()!
        
        self.backbone=backbone()
        # 6. Create learnable position embedding
        self.position_embedding_query = nn.Parameter(data=torch.randn(1,  math.ceil(query_size/16)*math.ceil(query_size/16), embedding_dim),
                                                requires_grad=True)
        self.position_embedding_target = nn.Parameter(data=torch.randn(1, math.ceil(target_width/16-2)*math.ceil(target_height/16-2), embedding_dim),
                                                requires_grad=True)
#         self.position_embedding_query = nn.Parameter(data=torch.randn(1,  64, embedding_dim),
#                                                requires_grad=True)
#         self.position_embedding_target = nn.Parameter(data=torch.randn(1, 2196, embedding_dim),
#                                                requires_grad=True)
        # 7. Create embedding dropout value
        self.embedding_dropout = nn.Dropout(p=embedding_dropout)
        
        # 9. Create Transformer Encoder blocks (we can stack Transformer Encoder blocks using nn.Sequential()) 
        # Note: The "*" means "all"
        self.transformer_encoder = TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                                            num_heads=num_heads,
                                                                            mlp_size=mlp_size,
                                                                            mlp_dropout=mlp_dropout)
    # 11. Create a forward() method
    def forward(self, query,target):
        
        # 12. Get batch size
        batch_size = query.shape[0]
        
        
        # 14. Create patch embedding (equation 1)
        query,target= self.backbone(query,target)
        

        # 16. Add position embedding to patch embedding (equation 1) 
        query_pe = self.position_embedding_query + query
        target_pe = self.position_embedding_target + target

        # 18. Pass patch, position and class embedding through transformer encoder layers (equations 2 & 3)
        q=self.transformer_encoder(query_pe,target_pe,target)
        t=self.transformer_encoder(target_pe,query_pe,query)
        q_pe=self.position_embedding_query + q
        t_pe=self.position_embedding_target + t
        for i in range(3):
            q=self.transformer_encoder(q_pe,t_pe,t)
            t=self.transformer_encoder(t_pe,q_pe,q)    
            q_pe=self.position_embedding_query + q
            t_pe=self.position_embedding_target + t

        return q,t  

In [444]:
model=one_shot()
summary(model=model,
         input_size=[(1,3,128,128),(1,3,1000,600)])

Layer (type:depth-idx)                             Output Shape              Param #
one_shot                                           [1, 64, 256]              578,560
├─backbone: 1-1                                    [1, 64, 256]              --
│    └─Sequential: 2-1                             [1, 1024, 8, 8]           --
│    │    └─Conv2d: 3-1                            [1, 64, 64, 64]           (9,408)
│    │    └─BatchNorm2d: 3-2                       [1, 64, 64, 64]           (128)
│    │    └─ReLU: 3-3                              [1, 64, 64, 64]           --
│    │    └─MaxPool2d: 3-4                         [1, 64, 32, 32]           --
│    │    └─Sequential: 3-5                        [1, 256, 32, 32]          (215,808)
│    │    └─Sequential: 3-6                        [1, 512, 16, 16]          (1,219,584)
│    │    └─Sequential: 3-7                        [1, 1024, 8, 8]           (7,098,368)
│    └─Conv2d: 2-2                                 [1, 256, 8, 8]            