In [6]:
import torch
import torch.nn as nn
from torch.utils.data import dataset
from typing import List, Tuple
import numpy as np
# from torch import variable 

In [7]:
device = torch.device("mps")

In [40]:
class ConvBN(nn.Module):
    def __init__(self,in_channels,out_channels, kernel_size, stride, bais=False,padding=0):
        super().__init__()
        layer = []
        layer.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bais,stride=stride,padding=padding))
        layer.append(nn.BatchNorm2d(num_features=out_channels))
        layer.append(nn.ReLU6(inplace=True))
        self.model = nn.Sequential(*layer)#.to(device)
    def forward(self,x):
        return self.model(x)
    
class Inverted_Residual(nn.Module):
    def __init__(self, in_channels, out_channels, expansion_factor=6, kernel_size=3, stride=2 ):
        super().__init__()
        if stride != 1 and stride != 2:
            raise ValueError("Stride should be 1 or 2")
        
        self.inchannels = in_channels
        self.outchannels = out_channels
        
        self.convop = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * expansion_factor, 1, bias=False),
            nn.BatchNorm2d(in_channels * expansion_factor),
            nn.ReLU6(inplace=True),

            nn.Conv2d(in_channels * expansion_factor, in_channels * expansion_factor,
                      kernel_size, stride, padding=1,
                      groups=in_channels * expansion_factor, bias=False),
            nn.BatchNorm2d(in_channels * expansion_factor),
            nn.ReLU6(inplace=True),

            nn.Conv2d(in_channels * expansion_factor, out_channels, 1),
            nn.BatchNorm2d(out_channels)).to(device)

        self.is_residual = True if stride == 1 else False
    def forward(self, x):
        if self.is_residual:
            _x = nn.Conv2d(self.inchannels,self.outchannels,1).to(device)(x)
            return _x + self.convop(x)
        else:
            return self.convop(x)
        
class MobileNetv2(nn.Module):
                        #  t, c, n, s
    structure = np.array([[1,16 ,1,1],
                          [6,24 ,2,2],
                          [6,32 ,3,2],
                          [6,64 ,4,2],
                          [6,96 ,3,1],
                          [6,160,3,2],
                          [6,320,1,1]
                          ]) 
    def __init__(self):
        super().__init__()
        layer=[]
        st = self.structure
        layer.append(ConvBN(3,32,3,2))
        for i in range (0,7):
            inchannels = st[i-1,1] if (i!=0) else 32
            layer.append(
                Inverted_Residual(in_channels=inchannels,
                                  out_channels=st[i,1],
                                  expansion_factor=st[i,0],
                                  stride=st[i,3])
            )
            for j in range (1,st[i,2]):
                layer.append(
                Inverted_Residual(in_channels=st[i,1],
                                  out_channels=st[i,1],
                                  expansion_factor=st[i,0],
                                  stride=1)
                )
        layer.append(ConvBN(320,1280,1,1))
        self.model = nn.Sequential(*layer).to(device)
    def forward(self,x):
        return self.model(x)

class Features(nn.Module):
    def __init__(self, model, layers:list):
        super().__init__()
        self.model = model
        self.layers = layers
        self.feature = {}
        self.hooks = []
        for layer in layers:
            lay = dict(model.named_modules())[layer]
            hook = lay.register_forward_hook(self.get_features(layer))
            self.hooks.append(hook)
    def get_features(self,layer_name):
        def hook (model, input, output):
            self.feature[layer_name] = output
        return hook
    def extract(self,x):
        _ = self.model(x)
        return self.feature
    def remove(self):
        for hook in self.hooks:
            hook.remove()
            
class FPNetwork(nn.Module):
    def __init__(self, 
                 in_channels:dict = {'3':24,
                                     '6':32,
                                     '13':96,
                                     '18':1280}, 
                 out_channels:int = 256):
        super().__init__()
        # features should be in ascending order: ex [3,6,13,18] 3 stands for 3rd layer etc
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest').to(device)
        self.outchannels = out_channels
        self.convs = nn.ModuleDict()
        for level,key in enumerate(list(in_channels.keys())):
            self.convs.add_module(key,nn.Conv2d(in_channels=in_channels[key],out_channels=self.outchannels, kernel_size= 1).to(device))
        
    def forward(self, features:dict):
        self.keys = list(features.keys())
        self.features = features
        self.output = {}
        
        for i in range (len(self.keys)-1,-1,-1):
            feature = self.features[self.keys[i]]
            if i == len(self.keys)-1:
                self.output[self.keys[i]] = self.convs[self.keys[i]](feature)
                continue
            x = self.convs[self.keys[i]](feature)
            self.output[self.keys[i]] = x + self.upsample(self.output[self.keys[i+1]])
            
        return self.output
        

class classificationhead(nn.Module):
    def __init__(self, channels, num_anchors, num_of_classes):
        super().__init__()
        self.channels = channels
        self.anchors = num_anchors
        self.num_of_classes = num_of_classes
        self.sigmoid = nn.Sigmoid().to(device)

        self.model = nn.Sequential(
            ConvBN(in_channels=channels,out_channels=channels, kernel_size= 3,stride=1,padding=1).to(device),
            ConvBN(in_channels=channels,out_channels=channels, kernel_size= 3,stride=1,padding=1).to(device),
            ConvBN(in_channels=channels,out_channels=channels, kernel_size= 3,stride=1,padding=1).to(device),
            ConvBN(in_channels=channels,out_channels=channels, kernel_size= 3,stride=1,padding=1).to(device),
            ConvBN(in_channels= self.channels, out_channels= self.anchors*self.num_of_classes, kernel_size= 3,stride=1,padding=1)
        ).to(device)
        
    def forward(self,x):
        x = self.model(x)
        _x = x.view(x.shape[0],self.anchors,self.num_of_classes, x.shape[2], x.shape[3])
        f_x = _x.permute(0, 1, 3, 4, 2).contiguous()
        x = self.sigmoid(f_x.reshape((x.shape[0],f_x.shape[1]*f_x.shape[2]*f_x.shape[3],self.num_of_classes))) 
        
        return x
    
class bboxhead(nn.Module):
    def __init__(self, channels, num_anchors):
        super().__init__()
        self.channels = channels
        self.anchors = num_anchors
    
        self.model = nn.Sequential(
            ConvBN(in_channels=channels,out_channels=channels, kernel_size= 3,stride=1,padding=1).to(device),
            ConvBN(in_channels=channels,out_channels=channels, kernel_size= 3,stride=1,padding=1).to(device),
            ConvBN(in_channels=channels,out_channels=channels, kernel_size= 3,stride=1,padding=1).to(device),
            ConvBN(in_channels=channels,out_channels=channels, kernel_size= 3,stride=1,padding=1).to(device),
            ConvBN(in_channels= self.channels, out_channels= self.anchors*4, kernel_size= 3,stride=1,padding=1)
        ).to(device)
        
    def forward(self,x):
        x = self.model(x)
        _x = x.view(x.shape[0], self.anchors, 4, x.shape[2], x.shape[3])
        f_x = _x.permute(0, 1, 3, 4, 2).contiguous()
        x= f_x.reshape((x.shape[0],f_x.shape[1]*f_x.shape[2]*f_x.shape[3],4))
        return x

In [9]:
modelv2 = MobileNetv2().to(device)
# modelv2.eval()

In [10]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
model = model.features.to(device)

Using cache found in /Users/vipulagarwal/.cache/torch/hub/pytorch_vision_v0.10.0


In [11]:
dict(model.named_modules())

{'': Sequential(
   (0): ConvBNActivation(
     (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
     (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): ReLU6(inplace=True)
   )
   (1): InvertedResidual(
     (conv): Sequential(
       (0): ConvBNActivation(
         (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
         (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (2): ReLU6(inplace=True)
       )
       (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     )
   )
   (2): InvertedResidual(
     (conv): Sequential(
       (0): ConvBNActivation(
         (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
         (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=T

In [41]:
p=torch.rand(1,3,640,640,device=device)

extractor = Features(model,['6','3', '13','18'])
features = extractor.extract(p)
topdown = FPNetwork(out_channels=256)
newfeatures = topdown(features)
classifier = classificationhead(channels=256, num_anchors= 12, num_of_classes= 1)
bboxregression = bboxhead(channels= 256 , num_anchors= 12)
# output = {}
# for key in list(newfeatures.keys()):
#     temp = {}
#     temp["bbox"] = bboxregression(newfeatures[key])
#     temp["cls"] = classifier(newfeatures[key])
#     output[key] = temp

In [42]:
dict(classifier.named_modules())

{'': classificationhead(
   (sigmoid): Sigmoid()
   (conv1): ConvBN(
     (model): Sequential(
       (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (2): ReLU6(inplace=True)
     )
   )
   (conv2): ConvBN(
     (model): Sequential(
       (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (2): ReLU6(inplace=True)
     )
   )
   (conv3): ConvBN(
     (model): Sequential(
       (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (2): ReLU6(inplace=True)
     )
   )
   (conv4): ConvBN(
     (model): Sequential(
       (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), 

In [28]:
newfeatures['18'].shape

torch.Size([1, 256, 20, 20])

In [346]:
a= torch.rand((3,4))
a = a.unsqueeze(1).expand(3,5,4)


In [17]:
for i in range (4-1,-1,-1):
    print (i)

3
2
1
0
