In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim
import os

In [2]:
class regr_loss(nn.Module):
    
    def __init__(self, device):
        super(regr_loss, self).__init__()
        
        self.device=device

    def forward(self, inp, target):

        cls=target[0, :, 0]
        regr=target[0, :, 1:]
        regr_keep=(cls == 1).nonzero(as_tuple=False)[:, 0]
        regr_true=regr[regr_keep]
        regr_pred=inp[0][regr_keep]
        diff=torch.abs(regr_true-regr_pred)
        
        less_one=(diff<1.0).float()
        
        loss = less_one*0.5*(diff**2) + torch.abs(1-less_one)*(diff-0.5)
        
        loss = torch.sum(loss, 1)
        loss = torch.mean(loss) if loss.numel()>0 else torch.tensor(0.0)

        return loss.to(self.device)

In [3]:
class cls_loss(nn.Module):
    
    def __init__(self,device):
        super(cls_loss, self).__init__()
      
        self.device = device

    def forward(self, inp, target):
        
        y_true = target[0][0]
        cls_keep = (y_true != -1).nonzero(as_tuple=False)[:, 0]
        cls_true = y_true[cls_keep].long()
        cls_pred = inp[0][cls_keep]
        
        loss = F.nll_loss(F.log_softmax(cls_pred, dim=-1), cls_true)  # negative-log-likelihood loss
        loss = torch.mean(loss) if loss.numel() > 0 else torch.tensor(0.0) 
        
        return loss.to(self.device)

In [4]:
class conv_block(nn.Module):
    
    def __init__(self, inp, out, kernel, stride=1, padding=0, activation=True, bn=False, bias=True):
        super().__init__()

        self.conv=nn.Conv2d(inp, out, kernel, stride=stride, padding=padding, bias=bias)
        self.norm=nn.BatchNorm2d(out, eps=1e-5, momentum=0.01, affine=True)
        self.relu=nn.ReLU(inplace=True)
        
        self.rel=activation
        self.bn=bn

    def forward(self, x):
        
        x = self.conv(x)
        
        if self.bn==True:
            x=self.norm(x)
        
        if self.rel==True:
            x=self.relu(x)
        
        return x

In [5]:
class CTPN(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        base_model=models.vgg16(pretrained=True)    
        layers=list(base_model.features)[:-1]
        self.base_layers=nn.Sequential(*layers)  # block5_conv3 output
        
        self.conv1=conv_block(512, 512, 3, 1, 1) #region-proposal network layer
        self.recurrent=nn.LSTM(512,128, bidirectional=True, batch_first=True)
        self.conv2=conv_block(256, 512, 1, 1)
        
        self.cls=conv_block(512, 10*2, 1, 1, activation=False)
        self.regr=conv_block(512, 10*2, 1, 1)
        self.refine=conv_block(512, 10, 1, 1)

    def forward(self, x):
        
        x=self.base_layers(x)
        
        # rpn
        x=self.conv1(x)

        x1=x.permute(0,2,3,1).contiguous()
        b=x1.size() 
        x1=x1.view(b[0]*b[1], b[2], b[3])

        x2,_=self.recurrent(x1)

        xsz=x.size()
        x3=x2.view(xsz[0], xsz[2], xsz[3], 256)

        x3=x3.permute(0,3,1,2).contiguous()
        x3=self.conv2(x3)
        x=x3

        cls=self.cls(x)
        regr=self.regr(x)
        refine=self.refine(x)

        cls=cls.permute(0,2,3,1).contiguous()
        regr=regr.permute(0,2,3,1).contiguous()
        refine=refine.permute(0,2,3,1).contiguous()

        cls=cls.view(cls.size(0), cls.size(1)*cls.size(2)*10, 2)
        regr=regr.view(regr.size(0), regr.size(1)*regr.size(2)*10, 2)
        refine=refine.view(refine.size(0), refine.size(1)*refine.size(2)*10, 1)

        return cls, regr, refine