Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/DIVA-DIA/DeepDIVA-Private
Browse files Browse the repository at this point in the history
…into develop
  • Loading branch information
Renthal committed Jun 5, 2019
2 parents bffb34c + 20f241c commit 8ed3fbb
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 77 deletions.
1 change: 0 additions & 1 deletion models/semantic_segmentation/SegNet.py
Expand Up @@ -31,7 +31,6 @@ def segnet(output_channels=8, resume=None, **kwargs):
class SegNet(nn.Module):
def __init__(self, output_channels, pretrained=False, **kwargs):
super(SegNet, self).__init__()
# TODO: make different functions for different VGG models
vgg = vgg19_bn(pretrained=pretrained, **kwargs)

num_classes = output_channels
Expand Down
92 changes: 16 additions & 76 deletions models/semantic_segmentation/deeplabv3_resnet.py
Expand Up @@ -5,20 +5,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.utils.model_zoo as model_zoo

import os
import logging

model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

from models.image_classification.ResNet import resnet18, resnet34, resnet50, resnet101, resnet152

def make_layer(block, in_channels, channels, num_blocks, stride=1, dilation=1):
strides = [stride] + [1]*(num_blocks - 1) # (stride == 2, num_blocks == 4 --> strides == [2, 1, 1, 1])
Expand Down Expand Up @@ -109,37 +96,18 @@ def __init__(self, num_layers, pretrained=False):
super(ResNet_Bottleneck_OS16, self).__init__()

if num_layers == 50:
resnet = models.resnet50()

if pretrained:
# load pretrained model:
resnet.load_state_dict(torch.load("../pretrained_models/resnet/resnet50-19c8e357.pth"))
# remove fully connected layer, avg pool and layer5:
self.resnet = nn.Sequential(*list(resnet.children())[:-3])

print ("pretrained resnet, 50")
resnet = resnet50(pretrained)

elif num_layers == 101:
resnet = models.resnet101()
if pretrained:
# load pretrained model:
resnet.load_state_dict(torch.load("../pretrained_models/resnet/resnet101-5d3b4d8f.pth"))
# remove fully connected layer, avg pool and layer5:
self.resnet = nn.Sequential(*list(resnet.children())[:-3])

print ("pretrained resnet, 101")
resnet = resnet101(pretrained)

elif num_layers == 152:
resnet = models.resnet152()
if pretrained:
# load pretrained model:
resnet.load_state_dict(torch.load("../pretrained_models/resnet/resnet152-b121ed2d.pth"))
# remove fully connected layer, avg pool and layer5:
self.resnet = nn.Sequential(*list(resnet.children())[:-3])

print ("pretrained resnet, 152")
resnet = resnet152(pretrained)
else:
raise Exception("num_layers must be in {50, 101, 152}!")

# remove fully connected layer, avg pool and layer5:
self.resnet = nn.Sequential(*list(resnet.children())[:-3])
self.layer5 = make_layer(Bottleneck, in_channels=4*256, channels=512, num_blocks=3, stride=1, dilation=2)

def forward(self, x):
Expand All @@ -158,34 +126,19 @@ def __init__(self, num_layers, pretrained=False):
super(ResNet_BasicBlock_OS16, self).__init__()

if num_layers == 18:
resnet = models.resnet18()
resnet = resnet18(pretrained)
num_blocks = 2

if pretrained:
try:
resnet.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
except Exception as exp:
logging.warning(exp)

# remove fully connected layer, avg pool and layer5:
self.resnet = nn.Sequential(*list(resnet.children())[:-3])

elif num_layers == 34:
resnet = models.resnet34()

if pretrained:
try:
resnet.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False)
except Exception as exp:
logging.warning(exp)
resnet = resnet34(pretrained)

# remove fully connected layer, avg pool and layer5:
self.resnet = nn.Sequential(*list(resnet.children())[:-3])
num_blocks = 3

else:
raise Exception("num_layers must be in {18, 34}!")

# remove fully connected layer, avg pool and layer5:
self.resnet = nn.Sequential(*list(resnet.children())[:-3])
self.layer5 = make_layer(BasicBlock, in_channels=256, channels=512, num_blocks=num_blocks, stride=1, dilation=2)

def forward(self, x):
Expand All @@ -198,39 +151,26 @@ def forward(self, x):

return output


class ResNet_BasicBlock_OS8(nn.Module):
def __init__(self, num_layers, pretrained=False):
super(ResNet_BasicBlock_OS8, self).__init__()

if num_layers == 18:
resnet = models.resnet18()
resnet = resnet18(pretrained)
num_blocks_layer_4 = 2
num_blocks_layer_5 = 2

if pretrained:
try:
resnet.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
except Exception as exp:
logging.warning(exp)

# remove fully connected layer, avg pool, layer4 and layer5:
self.resnet = nn.Sequential(*list(resnet.children())[:-4])

elif num_layers == 34:
resnet = models.resnet34()
resnet = resnet34(pretrained)
num_blocks_layer_4 = 6
num_blocks_layer_5 = 3

if pretrained:
try:
resnet.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False)
except Exception as exp:
logging.warning(exp)

else:
raise Exception("num_layers must be in {18, 34}!")

# remove fully connected layer, avg pool, layer4 and layer5:
self.resnet = nn.Sequential(*list(resnet.children())[:-4])

self.layer4 = make_layer(BasicBlock, in_channels=128, channels=256, num_blocks=num_blocks_layer_4, stride=1, dilation=2)
self.layer5 = make_layer(BasicBlock, in_channels=256, channels=512, num_blocks=num_blocks_layer_5, stride=1, dilation=4)

Expand Down

0 comments on commit 8ed3fbb

Please sign in to comment.