In [1]:
import torch
import torch.nn as nn
from torchvision.models import resnet152

In [26]:
# call resnet model
model = resnet152(pretrained=False)



In [27]:
def _deconv_block(in_channels, kernel_size, stride, padding):
        return nn.Sequential(
                    nn.ConvTranspose2d(in_channels, in_channels,kernel_size,
                                       stride, padding),
                    nn.BatchNorm2d(in_channels),
                    nn.ReLU(),
                    nn.Conv2d(in_channels, in_channels//2, kernel_size=3, stride=1, padding=1),
                    nn.BatchNorm2d(in_channels//2),
                    nn.ReLU(),
                    nn.Conv2d(in_channels//2, in_channels//2, kernel_size=3, stride=1, padding=1),
                    nn.BatchNorm2d(in_channels//2),
                    nn.ReLU(),
                    nn.Conv2d(in_channels//2, in_channels//2, kernel_size=3, stride=1, padding=1),
                    nn.BatchNorm2d(in_channels//2),
                    nn.ReLU(),
                    nn.Conv2d(in_channels//2, in_channels//4, kernel_size=3, stride=1, padding=1),
                    nn.BatchNorm2d(in_channels//4),
                    nn.ReLU()
                    )

In [28]:
import torch.nn as nn

def flatten_model(model):
    if not isinstance(model, nn.Module):
        raise ValueError("The provided model must be a PyTorch model (nn.Module).")

    flattened = []
    flattened += list(model.children())[:4]

    for i in range(4, 8):
        sequence = list(model.children())[i]
        flattened += list(sequence.children())

    flattened += list(model.children())[-2:]

    resnet_top = nn.Sequential(*flattened[:38])
    resnet_mid = nn.ModuleList(flattened[38:54])
    avg_pool2d = flattened[54]
    deconv = nn.Sequential(
        _deconv_block(in_channels=256, kernel_size=3, stride=2, padding=1),
        _deconv_block(in_channels=64, kernel_size=3, stride=2, padding=[2, 1]),
        _deconv_block(in_channels=16, kernel_size=3, stride=2, padding=[2, 1]),
        _deconv_block(in_channels=4, kernel_size=[3, 4], stride=1, padding=2)
    )

    return resnet_top, resnet_mid, avg_pool2d, deconv



In [29]:
# test flatten_model
model = flatten_model(model)

In [30]:
len(model)

4

In [31]:
model

(Sequential(
   (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (2): ReLU(inplace=True)
   (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
   (4): Bottleneck(
     (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (downsample): Sequential(
       (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (1): BatchNorm2d(

In [33]:
# create AuxConv blocks  ,to use in  
# aux_1024 = [AuxConv(in_channels=1024, c_tag=16, p=p, downsample=True) for _ in range(13)]
# aux_2048 = [AuxConv(in_channels=2048, c_tag=16, p=p) for _ in range(3)]
# self.aux_modules = nn.ModuleList(aux_1024 + aux_2048)

class AuxConv(nn.Module):
    def __init__(self, in_channels, c_tag,stride = 1, p=0, downsample=False):
        super(AuxConv, self).__init__()
        self.downsample = downsample
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, c_tag, kernel_size=(3,1), stride=1),
            # nn.BatchNorm2d(c_tag),     
            nn.ReLU(),
            nn.Dropout(p),
            nn.Conv2d(c_tag, c_tag, kernel_size=(1,3), stride=1),
            # nn.BatchNorm2d(c_tag),        
            nn.ReLU(),
            nn.Dropout(p),
            # nn.Conv2d(c_tag, c_tag, kernel_size=3, stride=1, padding=p),
            # nn.BatchNorm2d(c_tag),
            # nn.ReLU()
        )
        if self.downsample:
            self.downsample = nn.Sequential(
                nn.Conv2d(c_tag, c_tag, kernel_size=1, stride=2),
            )
            # add downsample block to the block
            self.block = nn.Sequential(self.block, self.downsample)
            
    def forward(self, x):
        return self.block(x)

In [35]:
# create AuxConv blocks  ,to use in
# aux_1024 = [AuxConv(in_channels=1024, c_tag=16, p=p, downsample=True) for _ in range(13)]
# aux_2048 = [AuxConv(in_channels=2048, c_tag=16, p=p) for _ in range(3)]
# self.aux_modules = nn.ModuleList(aux_1024 + aux_2048)

aux_1024 = [AuxConv(in_channels=1024, c_tag=16, p=0.5, downsample=True) for _ in range(13)]
aux_2048 = [AuxConv(in_channels=2048, c_tag=16, p=0.5) for _ in range(3)]
aux_modules = nn.ModuleList(aux_1024 + aux_2048)

# add aux_modules to model

model = model + tuple(aux_modules)


In [36]:
model

(Sequential(
   (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (2): ReLU(inplace=True)
   (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
   (4): Bottleneck(
     (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (downsample): Sequential(
       (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (1): BatchNorm2d(

In [37]:
class DEPTH(nn.Module):
    def __init__(self , wts = None,freeze = True , p = 0):
        super(DEPTH, self).__init__()
        resnet = resnet152(pretrained=False)
        if wts:
            resnet.fc = nn.Linear(2048, 800) # output layer
            resnet.load_state_dict(torch.load(wts))

In [38]:
resnet = resnet152(pretrained=False)
resnet.fc.in_features



2048

In [39]:
nn.Linear(2048, 25 * 32)

Linear(in_features=2048, out_features=800, bias=True)