In [None]:
# default_exp model

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
# export 
import torch
import torch.nn as nn
import torchvision
import copy
from .utils import get_model

class HANet(nn.Module):
    class FCN_model(nn.Module):
        def __init__(self, n_classes=4):
            super(HANet.HANet_model, self).__init__()
            self.color_trunk = torchvision.models.resnet101(pretrained=True)
            del self.color_trunk.fc, self.color_trunk.avgpool, self.color_trunk.layer4
            self.depth_trunk = copy.deepcopy(self.color_trunk)
            self.conv1 = nn.Conv2d(2048, 512, 1)
            self.conv2 = nn.Conv2d(512, 128, 1)
            self.conv3 = nn.Conv2d(128, n_classes, 1)
        def forward(self, color, depth):
            # Color
            color_feat_1 = self.color_trunk.conv1(color) # 3 -> 64
            color_feat_1 = self.color_trunk.bn1(color_feat_1)
            color_feat_1 = self.color_trunk.relu(color_feat_1)
            color_feat_1 = self.color_trunk.maxpool(color_feat_1) 
            color_feat_2 = self.color_trunk.layer1(color_feat_1) # 64 -> 256
            color_feat_3 = self.color_trunk.layer2(color_feat_2) # 256 -> 512
            color_feat_4 = self.color_trunk.layer3(color_feat_3) # 512 -> 1024
            # Depth
            depth_feat_1 = self.depth_trunk.conv1(depth) # 3 -> 64
            depth_feat_1 = self.depth_trunk.bn1(depth_feat_1)
            depth_feat_1 = self.depth_trunk.relu(depth_feat_1)
            depth_feat_1 = self.depth_trunk.maxpool(depth_feat_1) 
            depth_feat_2 = self.depth_trunk.layer1(depth_feat_1) # 64 -> 256
            depth_feat_3 = self.depth_trunk.layer2(depth_feat_2) # 256 -> 512
            depth_feat_4 = self.depth_trunk.layer3(depth_feat_3) # 512 -> 1024
            # Concatenate
            feat = torch.cat([color_feat_4, depth_feat_4], dim=1) # 2048
            feat_1 = self.conv1(feat)
            feat_2 = self.conv2(feat_1)
            feat_3 = self.conv3(feat_2)
            return nn.Upsample(scale_factor=2, mode="bilinear")(feat_3)
        
    def __init__(self, pretrained=False, n_class=4):
        super(HANet, self).__init__()
        if pretrained == True:
            self.net = self.FCN_model(4)
            model_path = get_model()
            self.net.load_state_dict(torch.load(model_path))
            print('Load pretrained complete')
        else:
            self.net = self.HANet_model(n_classes=n_class)
            
    def forward(self, Color, Depth):
        output = self.net(Color, Depth)
        
        return output
            

# Load Model

In [8]:
model = HANet()

# Show model structure

In [9]:
print(model)

HANet(
  (color_trunk): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
        