In [8]:
import torch
from torch import nn
import torch.nn.functional as F
import traceback

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import traceback

In [3]:
#conv block
class Conv_Block(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size = 3, bn = True):
        super(Conv_Block, self).__init__()
        #TODO: Add negation and scale/shift part of C.ReLu part.
        layers = []
        conv2d = nn.Conv2d(in_channel, out_channel, kernel_size)
        #if batch normalization is needed
        if bn:
            layers +=[conv2d, nn.BatchNorm2d(out_channel)]
        else:
            layers += [conv2d, nn.ReLU(inplace=True)]

        self.layer = nn.Sequential(*layers)

    def forward(self, x):
        x= self.layer(x)
        return x



In [35]:
from torch.utils.model_zoo import load_url as load_state_dict_from_url

model_urls = {
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
}



#VGG16 block
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']

class VGG(nn.Module):
    def __init__(self, in_channels, cfg, batch_norm = True, init_weights=True):
        super(VGG, self).__init__()
        self.features = self.make_layers(in_channels, cfg, batch_norm)

        if init_weights:
            self._initialize_weights()
    

    def make_layers(self, in_channels, cfg, batch_norm):

        layers = []

        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v

        return nn.Sequential(*layers)


    def _initialize_weights(self):

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):

        out = []
        for m in self.features:
            x = m(x)
            if isinstance(m, nn.MaxPool2d):
                out.append(x)

        return out[1:] #only last 4 layers are used 


def vgg16(in_channels, pretrained, **kwargs):
    '''
    Function to load vgg16 features model
    '''
    if pretrained:
        kwargs['init_weights'] = False
    
    model = VGG(in_channels, cfg, **kwargs)

    #download model from url
    if pretrained:
        if kwargs['batch_norm']:
            url = model_urls['vgg16_bn']
        else:
            url = model_urls['vgg16']

        state_dict = load_state_dict_from_url(url, progress = True)
        state_dict = remove_classifier_keys(state_dict)
        model.load_state_dict(state_dict)

    return model


def remove_classifier_keys(state_dict):

    new_state_dict = state_dict.copy()
    for key in state_dict.keys():
        if key.find("classifier") != -1:
            try:
                new_state_dict.pop(key)
            except:
                traceback.print_exc()

    return new_state_dict

In [36]:
model = vgg16(3, True, batch_norm = True)

In [115]:
#feature merging branch
class Merger(nn.Module):
    def __init__(self):
        super(Merger, self).__init__()

        self.conv1 = nn.Conv2d(1024, 128, 1)
        self.bn1 = nn.BatchNorm2d(128)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(128, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU()
        
        self.conv3 = nn.Conv2d(384, 64, 1)
        self.bn3 = nn.BatchNorm2d(64)
        self.relu3 = nn.ReLU()
        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.relu4 = nn.ReLU()
        
        self.conv5 = nn.Conv2d(192, 32, 1)
        self.bn5 = nn.BatchNorm2d(32)
        self.relu5 = nn.ReLU()
        self.conv6 = nn.Conv2d(32, 32, 3, padding=1)
        self.bn6 = nn.BatchNorm2d(32)
        self.relu6 = nn.ReLU()
        
        self.conv7 = nn.Conv2d(32, 32, 3, padding=1)
        self.bn7 = nn.BatchNorm2d(32)
        self.relu7 = nn.ReLU()
        
        #init weight by He init
        self._initialize_weights()


    def forward(self, x):

        y = F.interpolate(x[3], scale_factor=2, mode="bilinear", align_corners=True)
        y = torch.cat([y, x[2]], 1)
        y = self.relu1(self.bn1(self.conv1(y)))
        y = self.relu2(self.bn2(self.conv2(y)))
        
        y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
        y = torch.cat((y, x[1]), 1)
        y = self.relu3(self.bn3(self.conv3(y)))
        y = self.relu4(self.bn4(self.conv4(y)))
        
        y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
        y = torch.cat((y, x[0]), 1)
        y = self.relu5(self.bn5(self.conv5(y)))
        y = self.relu6(self.bn6(self.conv6(y)))
        
        y = self.relu7(self.bn7(self.conv7(y)))
        
        return y
    
    def _initialize_weights(self):

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


In [116]:
class Output(nn.Module):
    def __init__(self, scope=512):
        super(Output, self).__init__()
        self.conv1 = nn.Conv2d(32, 1, 1)
        self.sigmoid1 = nn.Sigmoid()
        self.conv2 = nn.Conv2d(32, 4, 1)
        self.sigmoid2 = nn.Sigmoid()
        self.conv3 = nn.Conv2d(32, 1, 1)
        self.sigmoid3 = nn.Sigmoid()
        self.scope = scope
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        score = self.sigmoid1(self.conv1(x))
        loc   = self.sigmoid2(self.conv2(x)) * self.scope
        angle = (self.sigmoid3(self.conv3(x)) - 0.5) * math.pi
        return score, loc, angle

In [117]:

class EAST(nn.Module):
    def __init__(self, in_channels ,pretrained = True, **kwargs):
        super(EAST, self).__init__()

        self.feature_extractor = vgg16(in_channels, pretrained, **kwargs)
        self.merger = Merger()
        self.output = Output()

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.merger(x)
        x = self.output(x)
        return x

In [44]:
test_input = torch.randn(1, 3, 512, 512)
out = model(test_input)
[print(o.shape) for o in out]

torch.Size([1, 128, 128, 128])
torch.Size([1, 256, 64, 64])
torch.Size([1, 512, 32, 32])
torch.Size([1, 512, 16, 16])


[None, None, None, None]

In [118]:
east = EAST(3, batch_norm = True)

In [119]:
out = east(test_input)

In [120]:
[print(o.shape) for o in out]

torch.Size([1, 1, 128, 128])
torch.Size([1, 4, 128, 128])
torch.Size([1, 1, 128, 128])


[None, None, None]

In [95]:
p[2].shape

torch.Size([1, 1, 128, 128])

In [46]:
p = torch.cat([out[2], x], 1)

In [47]:
p.shape

torch.Size([1, 1024, 32, 32])