In [1]:
import torch
from torch import nn, optim
from torch.nn import functional as F
import torchvision
from torchvision import models
from torchsummary import summary

In [8]:
class Decoder(nn.Module):
    def __init__(self, c_in, scale):
        super(Decoder, self).__init__()

        assert scale in [1, 2, 4, 8]

        if scale >= 1:
            self.conv1 = Conv2dBN(c_in, c_in, 3, padding=1)
        if scale >= 4:
            self.conv2 = Conv2dBN(c_in, c_in, 3, padding=1)
        if scale >= 8:
            self.conv3 = Conv2dBN(c_in, c_in, 3, padding=1)

        self.scale = scale

    def forward(self, x):
        if self.scale >= 1:
            x = self.conv1(x)
            if self.scale == 1:
                return x

        if self.scale >= 2:
            x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)

        if self.scale >= 4:
            x = self.conv2(x)
            x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)

        if self.scale >= 8:
            x = self.conv3(x)
            x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)

        return x

In [9]:
class FSModule(nn.Module):
    def __init__(self, cv, cu):
        super(FSModule, self).__init__()

        self.conv1 = Conv2dBN(cv, cu, 1)
        self.conv2 = Conv2dBN(cv, cu, 1)

    def forward(self, v, u):
        x = self.conv1(v)
        r = torch.mul(x, u)
        k = self.conv2(v)
        z = k / (1 + torch.exp(-r))
        return z

In [10]:
class Conv2dBN(nn.Module):
    def __init__(self, c_in, c_out, filter_size, stride=1, padding=0, **kwargs):
        super(Conv2dBN, self).__init__()
        self.conv = nn.Conv2d(c_in, c_out, filter_size, stride=stride, padding=padding, **kwargs)
        self.bn = nn.BatchNorm2d(c_out)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

In [17]:
def calc_miou(pred_batch, label_batch, num_classes, ignore_index):
    '''
    :param pred_batch: [b,h,w]
    :param label_batch: [b,h,w]
    :param num_classes: scalar
    :param ignore_index: scalar
    :return:
    '''
    miou_sum, miou_count = 0, 0
    for batch_idx in range(label_batch.shape[0]):
        pred, label = pred_batch[batch_idx].flatten(0), label_batch[batch_idx].flatten(0)

        mask = label != ignore_index
        pred, label = pred[mask], label[mask]

        pred_one_hot = F.one_hot(pred, num_classes)
        label_one_hot = F.one_hot(label, num_classes)

        intersection = torch.sum(pred_one_hot * label_one_hot)
        union = torch.sum(pred_one_hot) + torch.sum(label_one_hot) - intersection + 1e-6

        miou_sum += intersection / union
        miou_count += 1
    return miou_sum / miou_count

In [201]:
class FarSegNet(nn.Module):
    def __init__(self, num_classes = 1, num_feature = 256, pretrained = False, ignore_index = 500):
        super(FarSegNet, self).__init__()
        
        self.num_classes = num_classes
        self.num_feature = num_feature
        self.ignore_index = ignore_index # ignore losses of this class
        self.EPS = 1e-5
        self.current_step = 0
        self.annealing_step = 2000
        self.focal_factor = 4
        self.focal_z = 1.0
        
        self.backbone = torchvision.models.resnet50(pretrained=True)
        self.backbone_layer_c2 = nn.Sequential(*list(self.backbone.children())[:5])
        self.backbone_layer_c3 = list(self.backbone.children())[5]
        self.backbone_layer_c4 = list(self.backbone.children())[6]
        self.backbone_layer_c5 = list(self.backbone.children())[7]
        
        self.conv_c6 = nn.Conv2d(2048, num_feature, 1)
        self.conv_c5 = nn.Conv2d(2048, num_feature, 1)
        self.conv_c4 = nn.Conv2d(1024, num_feature, 1)
        self.conv_c3 = nn.Conv2d(512, num_feature, 1)
        self.conv_c2 = nn.Conv2d(256, num_feature, 1)

        self.fs5 = FSModule(num_feature, num_feature)
        self.fs4 = FSModule(num_feature, num_feature)
        self.fs3 = FSModule(num_feature, num_feature)
        self.fs2 = FSModule(num_feature, num_feature)

        self.up5 = Decoder(num_feature, 8)
        self.up4 = Decoder(num_feature, 4)
        self.up3 = Decoder(num_feature, 2)
        self.up2 = Decoder(num_feature, 1)

        self.classify = nn.Conv2d(num_feature, num_classes, 3, padding=1)
    
    def forward(self, x, label = None):
        c2 = self.backbone_layer_c2(x)
        c3 = self.backbone_layer_c3(c2)
        c4 = self.backbone_layer_c4(c3)
        c5 = self.backbone_layer_c5(c4)
        c6 = F.adaptive_avg_pool2d(c5, (1, 1))
        u = self.conv_c6(c6)
        
        p5 = self.conv_c5(c5)
        p4 = (self.conv_c4(c4) + F.interpolate(p5, scale_factor = 2)) / 2.
        p3 = (self.conv_c3(c3) + F.interpolate(p4, scale_factor = 2)) / 2.
        p2 = (self.conv_c2(c2) + F.interpolate(p3, scale_factor = 2)) / 2.
        
        z5 = self.fs5(p5, u)
        z4 = self.fs4(p4, u)
        z3 = self.fs3(p3, u)
        z2 = self.fs2(p2, u)

        o5 = self.up5(z5)
        o4 = self.up4(z4)
        o3 = self.up3(z3)
        o2 = self.up2(z2)
        
        x = (o5 + o4 + o3 + o2) / 4.
        x = F.interpolate(x, scale_factor=4, mode="bilinear", align_corners=True)
        logit = self.classify(x)
        
        print("c2\t", c2.shape)
        print("c3\t", c3.shape)
        print("c4\t", c4.shape)
        print("c5\t", c5.shape)
        print("c6\t", c6.shape)
        print("u\t", u.shape)
        print('---------------------------------------------')
        print("p5\t", p5.shape)
        print("p4\t", p4.shape)
        print("p3\t", p3.shape)
        print("p2\t", p2.shape)
        print('---------------------------------------------')
        print("z5\t", z5.shape)
        print("z4\t", z4.shape)
        print("z3\t", z3.shape)
        print("z2\t", z2.shape)
        print('---------------------------------------------')
        print("o5\t", o5.shape)
        print("o4\t", o4.shape)
        print("o3\t", o3.shape)
        print("o2\t", o2.shape)
        print('---------------------------------------------')
        print("x\t", x.shape)
        print("logit\t", logit.shape)
        
        if self.training:
            return self._get_loss(logit, label), self._get_miou(logit, label)
#         else:
#             score_map = torch.softmax(logit, dim=1)
#             score_map = score_map.permute(0, 2, 3, 1)
#             pred = torch.argmax(score_map, dim=3)
#             pred = torch.unsqueeze(pred, dim=3)
#             return pred, score_map
        
    def _get_loss(self, logit, label):
        
        logit = logit.permute(0, 2, 3, 1).flatten()
        label = label.flatten()
        mask = label != self.ignore_index
        logit, label = logit[mask], label[mask]
        loss = nn.BCEWithLogitsLoss()(logit, label)

        probs = torch.logit(logit)
        p = ((1-label) + (-1)**(1+label)*probs).squeeze()

        z = torch.pow(1.0 - p, self.focal_factor)
        z = self.focal_z * z

        if self.current_step < self.annealing_step:
            z = z + (1 - z) * (1 - self.current_step / self.annealing_step)
        self.current_step += 1

        loss = z * loss
        avg_loss = torch.mean(loss) / (torch.mean(mask.type(torch.float32)) + self.EPS)
        return avg_loss

    def _get_miou(self, logit, label):
        pred = torch.argmax(logit, dim=1).squeeze_(dim=1)
        return calc_miou(pred, label, self.num_classes, self.ignore_index)

In [202]:
net = FarSegNet().cuda()

In [203]:
x = torch.rand((1, 3, 256 ,256)).cuda()
label = torch.ones((1, 1, 256 ,256)).cuda()
               
net(x, label)

c2	 torch.Size([1, 256, 64, 64])
c3	 torch.Size([1, 512, 32, 32])
c4	 torch.Size([1, 1024, 16, 16])
c5	 torch.Size([1, 2048, 8, 8])
c6	 torch.Size([1, 2048, 1, 1])
u	 torch.Size([1, 256, 1, 1])
---------------------------------------------
p5	 torch.Size([1, 256, 8, 8])
p4	 torch.Size([1, 256, 16, 16])
p3	 torch.Size([1, 256, 32, 32])
p2	 torch.Size([1, 256, 64, 64])
---------------------------------------------
z5	 torch.Size([1, 256, 8, 8])
z4	 torch.Size([1, 256, 16, 16])
z3	 torch.Size([1, 256, 32, 32])
z2	 torch.Size([1, 256, 64, 64])
---------------------------------------------
o5	 torch.Size([1, 256, 64, 64])
o4	 torch.Size([1, 256, 64, 64])
o3	 torch.Size([1, 256, 64, 64])
o2	 torch.Size([1, 256, 64, 64])
---------------------------------------------
x	 torch.Size([1, 256, 256, 256])
logit	 torch.Size([1, 1, 256, 256])


RuntimeError: one_hot is only applicable to index tensor.

In [168]:
a = True

In [170]:
float(a)

1.0

In [40]:
summary(net, (3, 256, 256))

c2	 torch.Size([2, 256, 64, 64])
c3	 torch.Size([2, 512, 32, 32])
c4	 torch.Size([2, 1024, 16, 16])
c5	 torch.Size([2, 2048, 8, 8])
c6	 torch.Size([2, 2048, 1, 1])
u	 torch.Size([2, 256, 1, 1])
---------------------------------------------
p5	 torch.Size([2, 256, 8, 8])
p4	 torch.Size([2, 256, 16, 16])
p3	 torch.Size([2, 256, 32, 32])
p2	 torch.Size([2, 256, 64, 64])
---------------------------------------------
z5	 torch.Size([2, 256, 8, 8])
z4	 torch.Size([2, 256, 16, 16])
z3	 torch.Size([2, 256, 32, 32])
z2	 torch.Size([2, 256, 64, 64])
---------------------------------------------
o5	 torch.Size([2, 256, 64, 64])
o4	 torch.Size([2, 256, 64, 64])
o3	 torch.Size([2, 256, 64, 64])
o2	 torch.Size([2, 256, 64, 64])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           9,408
            Conv2d-2         [-1, 64, 128, 128]           9,408
       BatchNorm2d-