# **DEPENDENCIES**

In [None]:
# a = []
# while(1):
#   a.append("1")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

BatchNorm2d = nn.BatchNorm2d
bn_mom = 0.1
algc = False

In [None]:

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=False):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               padding=1, bias=False)
        self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
        self.downsample = downsample
        self.stride = stride
        self.no_relu = no_relu

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        if self.no_relu:
            return out
        else:
            return self.relu(out)

class Bottleneck(nn.Module):
    expansion = 2

    def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=True):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = BatchNorm2d(planes * self.expansion, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.no_relu = no_relu

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        if self.no_relu:
            return out
        else:
            return self.relu(out)

class segmenthead(nn.Module):

    def __init__(self, inplanes, interplanes, outplanes, scale_factor=None):
        super(segmenthead, self).__init__()
        self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom)
        self.conv1 = nn.Conv2d(inplanes, interplanes, kernel_size=3, padding=1, bias=False)
        self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(interplanes, outplanes, kernel_size=1, padding=0, bias=True)
        self.scale_factor = scale_factor

    def forward(self, x):

        x = self.conv1(self.relu(self.bn1(x)))
        out = self.conv2(self.relu(self.bn2(x)))

        if self.scale_factor is not None:
            height = x.shape[-2] * self.scale_factor
            width = x.shape[-1] * self.scale_factor
            out = F.interpolate(out,
                        size=[height, width],
                        mode='bilinear', align_corners=algc)

        return out

class DAPPM(nn.Module):
    def __init__(self, inplanes, branch_planes, outplanes, BatchNorm=nn.BatchNorm2d):
        super(DAPPM, self).__init__()
        bn_mom = 0.1
        self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale0 = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.process1 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process2 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process3 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process4 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.compression = nn.Sequential(
                                    BatchNorm(branch_planes * 5, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
                                    )
        self.shortcut = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
                                    )

    def forward(self, x):
        width = x.shape[-1]
        height = x.shape[-2]
        x_list = []

        x_list.append(self.scale0(x))
        x_list.append(self.process1((F.interpolate(self.scale1(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[0])))
        x_list.append((self.process2((F.interpolate(self.scale2(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[1]))))
        x_list.append(self.process3((F.interpolate(self.scale3(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[2])))
        x_list.append(self.process4((F.interpolate(self.scale4(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[3])))

        out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x)
        return out

class PAPPM(nn.Module):
    def __init__(self, inplanes, branch_planes, outplanes, BatchNorm=nn.BatchNorm2d):
        super(PAPPM, self).__init__()
        bn_mom = 0.1
        self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )

        self.scale0 = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )

        self.scale_process = nn.Sequential(
                                    BatchNorm(branch_planes*4, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes*4, branch_planes*4, kernel_size=3, padding=1, groups=4, bias=False),
                                    )


        self.compression = nn.Sequential(
                                    BatchNorm(branch_planes * 5, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
                                    )

        self.shortcut = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
                                    )


    def forward(self, x):
        width = x.shape[-1]
        height = x.shape[-2]
        scale_list = []

        x_ = self.scale0(x)
        scale_list.append(F.interpolate(self.scale1(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale2(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale3(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale4(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)

        scale_out = self.scale_process(torch.cat(scale_list, 1))

        out = self.compression(torch.cat([x_,scale_out], 1)) + self.shortcut(x)
        return out


class PagFM(nn.Module):
    def __init__(self, in_channels, mid_channels, after_relu=False, with_channel=False, BatchNorm=nn.BatchNorm2d):
        super(PagFM, self).__init__()
        self.with_channel = with_channel
        self.after_relu = after_relu
        self.f_x = nn.Sequential(
                                nn.Conv2d(in_channels, mid_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(mid_channels)
                                )
        self.f_y = nn.Sequential(
                                nn.Conv2d(in_channels, mid_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(mid_channels)
                                )
        if with_channel:
            self.up = nn.Sequential(
                                    nn.Conv2d(mid_channels, in_channels,
                                              kernel_size=1, bias=False),
                                    BatchNorm(in_channels)
                                   )
        if after_relu:
            self.relu = nn.ReLU(inplace=True)

    def forward(self, x, y):
        input_size = x.size()
        if self.after_relu:
            y = self.relu(y)
            x = self.relu(x)

        y_q = self.f_y(y)
        y_q = F.interpolate(y_q, size=[input_size[2], input_size[3]],
                            mode='bilinear', align_corners=False)
        x_k = self.f_x(x)

        if self.with_channel:
            sim_map = torch.sigmoid(self.up(x_k * y_q))
        else:
            sim_map = torch.sigmoid(torch.sum(x_k * y_q, dim=1).unsqueeze(1))

        y = F.interpolate(y, size=[input_size[2], input_size[3]],
                            mode='bilinear', align_corners=False)
        x = (1-sim_map)*x + sim_map*y

        return x

class Light_Bag(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(Light_Bag, self).__init__()
        self.conv_p = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )
        self.conv_i = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )

    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)

        p_add = self.conv_p((1-edge_att)*i + p)
        i_add = self.conv_i(i + edge_att*p)

        return p_add + i_add


class DDFMv2(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(DDFMv2, self).__init__()
        self.conv_p = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )
        self.conv_i = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )

    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)

        p_add = self.conv_p((1-edge_att)*i + p)
        i_add = self.conv_i(i + edge_att*p)

        return p_add + i_add

class Bag(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(Bag, self).__init__()

        self.conv = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=3, padding=1, bias=False)
                                )


    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)
        return self.conv(edge_att*p + (1-edge_att)*i)



In [None]:


# if __name__ == '__main__':


#     x = torch.rand(4, 64, 32, 64).cuda()
#     y = torch.rand(4, 64, 32, 64).cuda()
#     z = torch.rand(4, 64, 32, 64).cuda()
#     net = PagFM(64, 16, with_channel=True).cuda()

#     out = net(x,y)

# **PIDNET**

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import time
# from .model_utils import BasicBlock, Bottleneck, segmenthead, DAPPM, PAPPM, PagFM, Bag, Light_Bag
import logging

BatchNorm2d = nn.BatchNorm2d
bn_mom = 0.1
algc = False



class PIDNet(nn.Module):

    def __init__(self, m=2, n=3, num_classes=151, planes=64, ppm_planes=96, head_planes=128, augment=True):
        super(PIDNet, self).__init__()
        self.augment = augment

        # I Branch
        self.conv1 =  nn.Sequential(
                          nn.Conv2d(3,planes,kernel_size=3, stride=2, padding=1),
                          BatchNorm2d(planes, momentum=bn_mom),
                          nn.ReLU(inplace=True),
                          nn.Conv2d(planes,planes,kernel_size=3, stride=2, padding=1),
                          BatchNorm2d(planes, momentum=bn_mom),
                          nn.ReLU(inplace=True),
                      )

        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(BasicBlock, planes, planes, m)
        self.layer2 = self._make_layer(BasicBlock, planes, planes * 2, m, stride=2)
        self.layer3 = self._make_layer(BasicBlock, planes * 2, planes * 4, n, stride=2)
        self.layer4 = self._make_layer(BasicBlock, planes * 4, planes * 8, n, stride=2)
        self.layer5 =  self._make_layer(Bottleneck, planes * 8, planes * 8, 2, stride=2)

        # P Branch
        self.compression3 = nn.Sequential(
                                          nn.Conv2d(planes * 4, planes * 2, kernel_size=1, bias=False),
                                          BatchNorm2d(planes * 2, momentum=bn_mom),
                                          )

        self.compression4 = nn.Sequential(
                                          nn.Conv2d(planes * 8, planes * 2, kernel_size=1, bias=False),
                                          BatchNorm2d(planes * 2, momentum=bn_mom),
                                          )
        self.pag3 = PagFM(planes * 2, planes)
        self.pag4 = PagFM(planes * 2, planes)

        self.layer3_ = self._make_layer(BasicBlock, planes * 2, planes * 2, m)
        self.layer4_ = self._make_layer(BasicBlock, planes * 2, planes * 2, m)
        self.layer5_ = self._make_layer(Bottleneck, planes * 2, planes * 2, 1)

        # D Branch
        if m == 2:
            self.layer3_d = self._make_single_layer(BasicBlock, planes * 2, planes)
            self.layer4_d = self._make_layer(Bottleneck, planes, planes, 1)
            self.diff3 = nn.Sequential(
                                        nn.Conv2d(planes * 4, planes, kernel_size=3, padding=1, bias=False),
                                        BatchNorm2d(planes, momentum=bn_mom),
                                        )
            self.diff4 = nn.Sequential(
                                     nn.Conv2d(planes * 8, planes * 2, kernel_size=3, padding=1, bias=False),
                                     BatchNorm2d(planes * 2, momentum=bn_mom),
                                     )
            self.spp = PAPPM(planes * 16, ppm_planes, planes * 4)
            self.dfm = Light_Bag(planes * 4, planes * 4)
        else:
            self.layer3_d = self._make_single_layer(BasicBlock, planes * 2, planes * 2)
            self.layer4_d = self._make_single_layer(BasicBlock, planes * 2, planes * 2)
            self.diff3 = nn.Sequential(
                                        nn.Conv2d(planes * 4, planes * 2, kernel_size=3, padding=1, bias=False),
                                        BatchNorm2d(planes * 2, momentum=bn_mom),
                                        )
            self.diff4 = nn.Sequential(
                                     nn.Conv2d(planes * 8, planes * 2, kernel_size=3, padding=1, bias=False),
                                     BatchNorm2d(planes * 2, momentum=bn_mom),
                                     )
            self.spp = DAPPM(planes * 16, ppm_planes, planes * 4)
            self.dfm = Bag(planes * 4, planes * 4)

        self.layer5_d = self._make_layer(Bottleneck, planes * 2, planes * 2, 1)

        # Prediction Head
        if self.augment:
            self.seghead_p = segmenthead(planes * 2, head_planes, num_classes)
            self.seghead_d = segmenthead(planes * 2, planes, 1)

        self.final_layer = segmenthead(planes * 4, head_planes, num_classes)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


    def _make_layer(self, block, inplanes, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
            )

        layers = []
        layers.append(block(inplanes, planes, stride, downsample))
        inplanes = planes * block.expansion
        for i in range(1, blocks):
            if i == (blocks-1):
                layers.append(block(inplanes, planes, stride=1, no_relu=True))
            else:
                layers.append(block(inplanes, planes, stride=1, no_relu=False))

        return nn.Sequential(*layers)

    def _make_single_layer(self, block, inplanes, planes, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
            )

        layer = block(inplanes, planes, stride, downsample, no_relu=True)

        return layer

    def forward(self, x):

        width_output = x.shape[-1] // 8
        height_output = x.shape[-2] // 8

        x = self.conv1(x)
        x = self.layer1(x)
        x = self.relu(self.layer2(self.relu(x)))
        x_ = self.layer3_(x)
        x_d = self.layer3_d(x)

        x = self.relu(self.layer3(x))
        x_ = self.pag3(x_, self.compression3(x))
        x_d = x_d + F.interpolate(
                        self.diff3(x),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)
        if self.augment:
            temp_p = x_

        x = self.relu(self.layer4(x))
        x_ = self.layer4_(self.relu(x_))
        x_d = self.layer4_d(self.relu(x_d))

        x_ = self.pag4(x_, self.compression4(x))
        x_d = x_d + F.interpolate(
                        self.diff4(x),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)
        if self.augment:
            temp_d = x_d

        x_ = self.layer5_(self.relu(x_))
        x_d = self.layer5_d(self.relu(x_d))
        x = F.interpolate(
                        self.spp(self.layer5(x)),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)

        x_ = self.final_layer(self.dfm(x_, x, x_d))

        if self.augment:
            x_extra_p = self.seghead_p(temp_p)
            x_extra_d = self.seghead_d(temp_d)
            return [x_extra_p, x_, x_extra_d]
        else:
            return x_

def get_seg_model(cfg, imgnet_pretrained):

    if 's' in cfg.MODEL.NAME:
        model = PIDNet(m=2, n=3, num_classes=cfg.DATASET.NUM_CLASSES, planes=32, ppm_planes=96, head_planes=128, augment=True)
    elif 'm' in cfg.MODEL.NAME:
        model = PIDNet(m=2, n=3, num_classes=cfg.DATASET.NUM_CLASSES, planes=64, ppm_planes=96, head_planes=128, augment=True)
    else:
        model = PIDNet(m=3, n=4, num_classes=cfg.DATASET.NUM_CLASSES, planes=64, ppm_planes=112, head_planes=256, augment=True)

    if imgnet_pretrained:
        pretrained_state = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu')['state_dict']
        model_dict = model.state_dict()
        pretrained_state = {k: v for k, v in pretrained_state.items() if (k in model_dict and v.shape == model_dict[k].shape)}
        model_dict.update(pretrained_state)
        msg = 'Loaded {} parameters!'.format(len(pretrained_state))
        logging.info('Attention!!!')
        logging.info(msg)
        logging.info('Over!!!')
        model.load_state_dict(model_dict, strict = False)
    else:
        pretrained_dict = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu')
        if 'state_dict' in pretrained_dict:
            pretrained_dict = pretrained_dict['state_dict']
        model_dict = model.state_dict()
        pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if (k[6:] in model_dict and v.shape == model_dict[k[6:]].shape)}
        msg = 'Loaded {} parameters!'.format(len(pretrained_dict))
        logging.info('Attention!!!')
        logging.info(msg)
        logging.info('Over!!!')
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict, strict = False)

    return model

def get_pred_model(name, num_classes):

    if 's' in name:
        model = PIDNet(m=2, n=3, num_classes=num_classes, planes=32, ppm_planes=96, head_planes=128, augment=False)
    elif 'm' in name:
        model = PIDNet(m=2, n=3, num_classes=num_classes, planes=64, ppm_planes=96, head_planes=128, augment=False)
    else:
        model = PIDNet(m=3, n=4, num_classes=num_classes, planes=64, ppm_planes=112, head_planes=256, augment=False)

    return model


In [None]:
# Define the model configuration
m = 2
n = 3
num_classes = 151
planes = 64
ppm_planes = 96
head_planes = 128
augment = False

# Create an instance of the model
model = PIDNet(m=m, n=n, num_classes=num_classes, planes=planes, ppm_planes=ppm_planes, head_planes=head_planes, augment=augment)

# Print the model architecture
print(model)


PIDNet(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1,

# **DATASET**

In [None]:
from google.colab import drive
drive.mount('data')

Mounted at data


In [None]:
import os
import os
import requests
import hashlib
from tqdm import tqdm
import zipfile

In [None]:
# some helper functions to download the dataset
# this code comes mainly from gluoncv.utils
def check_sha1(filename, sha1_hash):
    """Check whether the sha1 hash of the file content matches the expected hash.
    Parameters
    ----------
    filename : str
        Path to the file.
    sha1_hash : str
        Expected sha1 hash in hexadecimal digits.
    Returns
    -------
    bool
        Whether the file content matches the expected hash.
    """
    sha1 = hashlib.sha1()
    with open(filename, 'rb') as f:
        while True:
            data = f.read(1048576)
            if not data:
                break
            sha1.update(data)

    sha1_file = sha1.hexdigest()
    l = min(len(sha1_file), len(sha1_hash))
    return sha1.hexdigest()[0:l] == sha1_hash[0:l]

def download(url, path=None, overwrite=False, sha1_hash=None):
    """Download an given URL
    Parameters
    ----------
    url : str
        URL to download
    path : str, optional
        Destination path to store downloaded file. By default stores to the
        current directory with same name as in url.
    overwrite : bool, optional
        Whether to overwrite destination file if already exists.
    sha1_hash : str, optional
        Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
        but doesn't match.
    Returns
    -------
    str
        The file path of the downloaded file.
    """
    if path is None:
        fname = url.split('/')[-1]
    else:
        path = os.path.expanduser(path)
        if os.path.isdir(path):
            fname = os.path.join(path, url.split('/')[-1])
        else:
            fname = path

    if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
        dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
        if not os.path.exists(dirname):
            os.makedirs(dirname)

        print('Downloading %s from %s...'%(fname, url))
        r = requests.get(url, stream=True)
        if r.status_code != 200:
            raise RuntimeError("Failed downloading url %s"%url)
        total_length = r.headers.get('content-length')
        with open(fname, 'wb') as f:
            if total_length is None: # no content length header
                for chunk in r.iter_content(chunk_size=1024):
                    if chunk: # filter out keep-alive new chunks
                        f.write(chunk)
            else:
                total_length = int(total_length)
                for chunk in tqdm(r.iter_content(chunk_size=1024),
                                  total=int(total_length / 1024. + 0.5),
                                  unit='KB', unit_scale=False, dynamic_ncols=True):
                    f.write(chunk)

        if sha1_hash and not check_sha1(fname, sha1_hash):
            raise UserWarning('File {} is downloaded but the content hash does not match. ' \
                              'The repo may be outdated or download may be incomplete. ' \
                              'If the "repo_url" is overridden, consider switching to ' \
                              'the default repo.'.format(fname))

    return fname

def download_ade(path, overwrite=False):

    """Download ADE20K
    Parameters
    ----------
    path : str
      Location of the downloaded files.
    overwrite : bool, optional
      Whether to overwrite destination file if already exists.
    """
    if not os.path.exists(path):
        os.mkdir(path)
    _AUG_DOWNLOAD_URLS = [
      ('http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip', '219e1696abb36c8ba3a3afe7fb2f4b4606a897c7'),
      ('http://data.csail.mit.edu/places/ADEchallenge/release_test.zip', 'e05747892219d10e9243933371a497e905a4860c'),]
    download_dir = os.path.join(path, 'downloads')
    if not os.path.exists(download_dir):
        os.mkdir(download_dir)
    for url, checksum in _AUG_DOWNLOAD_URLS:
        filename = download(url, path=download_dir, overwrite=overwrite, sha1_hash=checksum)
        # extract
        with zipfile.ZipFile(filename,"r") as zip_ref:
            zip_ref.extractall(path=path)

In [None]:
root = "/content/"
dataset_path = root + "ADEChallengeData2016/images/"
training_data = "training/"
val_data = "validation/"

In [None]:
# Define the paths to your training and validation data
training_data_path = os.path.join(dataset_path, training_data)
val_data_path = os.path.join(dataset_path, val_data)

In [None]:
download_ade(root, overwrite=False)

Downloading /content/downloads/ADEChallengeData2016.zip from http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip...


944710KB [00:16, 57045.89KB/s]                            


Downloading /content/downloads/release_test.zip from http://data.csail.mit.edu/places/ADEchallenge/release_test.zip...


100%|██████████| 206856/206856 [00:03<00:00, 53901.51KB/s]


In [None]:
# Image size that we are going to use
IMG_SIZE = 128
# Our images are RGB (3 channels)
N_CHANNELS = 3
# Scene Parsing has 150 classes + `not labeled`
N_CLASSES = 151

In [None]:
from glob import glob
TRAINSET_SIZE = len(glob(dataset_path + training_data + "*.jpg"))
print(f"The Training Dataset contains {TRAINSET_SIZE} images.")

VALSET_SIZE = len(glob(dataset_path + val_data + "*.jpg"))
print(f"The Validation Dataset contains {VALSET_SIZE} images.")

The Training Dataset contains 20210 images.
The Validation Dataset contains 2000 images.


In [None]:
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset


# # Define a custom PyTorch Dataset
# class CustomDataset(Dataset):
#     def __init__(self, data_path, transform=None, target_size=(512, 512)):
#         self.data_path = data_path
#         self.transform = transform
#         self.target_size = target_size
#         self.file_list = [file for file in os.listdir(data_path) if file.endswith(".jpg")]

#     def __len__(self):
#         return len(self.file_list)

#     def __getitem__(self, idx):
#         img_path = os.path.join(self.data_path, self.file_list[idx])
#         image = Image.open(img_path)
#         image = np.array(image)

#         mask_path = img_path.replace("images", "annotations").replace("jpg", "png")
#         mask = Image.open(mask_path)
#         mask = np.array(mask)
#         mask = np.where(mask == 255, 0, mask)

#         if self.transform:
#             image = self.transform(image)
#             mask = self.transform(mask)

#         return {"image": image, "segmentation_mask": mask}


# # Create dataset instances
# train_dataset = CustomDataset(training_data_path)  # Add transform parameter if needed
# val_dataset = CustomDataset(val_data_path)  # Add transform parameter if needed

# # Now you have train_dataset and val_dataset containing the image and segmentation mask pairs
# # You can access them using train_dataset[idx] or val_dataset[idx]

# # Example usage
# for data in train_dataset:
#     image = data['image']
#     mask = data['segmentation_mask']
#     # Use the image and mask as needed in your training loop


In [None]:
# import matplotlib.pyplot as plt

# # Define a function to visualize a subset of data
# def visualize_dataset(dataset, num_samples=5):
#     fig, axes = plt.subplots(2, num_samples, figsize=(15, 5))
#     fig.suptitle("Visualizing Subset of Data", fontsize=16)

#     for i in range(num_samples):
#         sample = dataset[i]
#         image = sample['image']
#         mask = sample['segmentation_mask']

#         axes[0, i].imshow(image)
#         axes[0, i].set_title(f"Sample {i + 1}\nImage")
#         axes[0, i].axis('off')

#         axes[1, i].imshow(mask, cmap='jet', vmin=0, vmax=150)  # Assuming there are 150 classes
#         axes[1, i].set_title(f"Sample {i + 1}\nMask")
#         axes[1, i].axis('off')

#     plt.tight_layout()
#     plt.show()

# # Visualize a random subset of the training dataset
# visualize_dataset(train_dataset, num_samples=5)

# # Visualize a random subset of the validation dataset
# visualize_dataset(val_dataset, num_samples=5)


In [None]:
num_epochs = 10
learning_rate = 0.0001
batch_size = 16


In [None]:
# import torchvision.transforms as transforms

# # Define data transformations
# transform = transforms.Compose([
#     transforms.Resize((256, 256)),  # Resize to a specific size
#     transforms.RandomHorizontalFlip(),  # Randomly flip horizontally
#     transforms.ToTensor(),  # Convert PIL image to PyTorch tensor
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image
# ])

# # Create dataset instances with the defined transformations
# train_dataset = CustomDataset(training_data_path, transform=transform)
# val_dataset = CustomDataset(val_data_path, transform=transform)


In [None]:
# import torchvision.transforms as transforms

# class CustomDataset(Dataset):
#     def __init__(self, data_path, transform=None, target_size=(512, 512)):
#         self.data_path = data_path
#         self.transform = transform
#         self.target_size = target_size
#         self.file_list = [file for file in os.listdir(data_path) if file.endswith(".jpg")]

#     def __len__(self):
#         return len(self.file_list)

#     def __getitem__(self, idx):
#         img_path = os.path.join(self.data_path, self.file_list[idx])
#         image = Image.open(img_path).convert("RGB")  # Convert to RGB format
#         mask_path = img_path.replace("images", "annotations").replace("jpg", "png")
#         mask = Image.open(mask_path)
#         mask = np.array(mask)
#         mask = np.where(mask == 255, 0, mask)

#         if self.transform:
#             image = self.transform(image)
#             mask = self.transform(mask)

#         return {"image": image, "segmentation_mask": mask}

# # Define data transformations
# data_transform = transforms.Compose([
#     transforms.Resize((512, 512)),  # Resize to the desired target size
#     transforms.ToTensor(),
#     # Add more transformations as needed
# ])

# # Create dataset instances with transformations
# train_dataset = CustomDataset(training_data_path, transform=data_transform)
# val_dataset = CustomDataset(val_data_path, transform=data_transform)


In [None]:
import torchvision.transforms as transforms

class CustomDataset(Dataset):
    def __init__(self, data_path, transform=None, target_size=(512, 512)):
        self.data_path = data_path
        self.transform = transform
        self.target_size = target_size
        self.file_list = [file for file in os.listdir(data_path) if file.endswith(".jpg")]

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path = os.path.join(self.data_path, self.file_list[idx])
        image = Image.open(img_path).convert("RGB")  # Convert to RGB format
        mask_path = img_path.replace("images", "annotations").replace("jpg", "png")
        mask = Image.open(mask_path)
        mask = np.array(mask)
        mask = np.where(mask == 255, 0, mask)

        if self.transform:
            image = self.transform(image)
            mask = self.mask_transform(mask)

        return {"image": image, "segmentation_mask": mask}

    # def mask_transform(self, mask):
    # #     return transforms.ToTensor()(mask).float()
    # def mask_transform(self, mask):
    #   mask_pil = Image.fromarray(mask)  # Convert numpy mask to PIL Image
    #   mask_transform = transforms.Compose([
    #       transforms.Resize(self.target_size),  # Resize masks to the same target size
    #       transforms.ToTensor(),
    #   ])
      # return mask_transform(mask_pil).float()

    def mask_transform(self, mask):
            mask_pil = Image.fromarray(mask)  # Convert numpy mask to PIL Image
            mask_transform = transforms.Compose([
                transforms.Resize(self.target_size),  # Resize masks to the same target size
                transforms.ToTensor(),
            ])
            return mask_transform(mask_pil).squeeze(0)

# Define data transformations
data_transform = transforms.Compose([
    transforms.Resize((512, 512)),  # Resize to the desired target size
    transforms.ToTensor(),
    # Add more transformations as needed
])

# Create dataset instances with transformations
train_dataset = CustomDataset(training_data_path, transform=data_transform)
val_dataset = CustomDataset(val_data_path, transform=data_transform)


***full set, not subset***

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Define your custom dataset and dataloaders
# train_dataset = CustomDataset(train_files, transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# val_dataset = CustomDataset(val_files, transform=val_transform)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Define your model
# model = YourCustomModel()

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

**subset**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
from torch.utils.data import DataLoader, Subset
# Choose indices for the subsets
train_subset_indices = list(range(2000))   # Choose the first 2000 samples
val_subset_indices = list(range(200))     # Choose the first 2000 samples for validation as well

# Create subsets of the datasets using the chosen indices
train_subset = Subset(train_dataset, train_subset_indices)
val_subset = Subset(val_dataset, val_subset_indices)

# Define dataloaders for the subsets
batch_size = 16
train_dataloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

In [None]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    for batch in train_dataloader:
        inputs, targets = batch['image'], batch['segmentation_mask']
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)

        # Resize the target masks to match the size of the model's outputs
        targets_resized = torch.nn.functional.interpolate(targets.unsqueeze(1), size=outputs.size()[2:], mode='nearest')

        # Convert target masks to 3D tensors
        targets_3d = targets_resized.squeeze(1)

        # Calculate the loss
        # loss = criterion(outputs, targets_3d)
        loss = criterion(outputs, targets_3d.long())
        print(loss)

        # Backpropagation and optimization
        loss.backward()
        optimizer.step()
    checkpoint_path = f'checkpoint_epoch_{epoch}.pth'
    torch.save(model.state_dict(), checkpoint_path)
    print(f'Saved checkpoint: {checkpoint_path}')

    # Validation loop
    model.eval()
    with torch.no_grad():
        val_loss = 0.0
        for batch in val_dataloader:
            inputs, targets = batch['image'], batch['segmentation_mask']
            outputs = model(inputs)

            # Resize the target masks for validation as well
            targets_resized = torch.nn.functional.interpolate(targets.unsqueeze(1), size=outputs.size()[2:], mode='nearest')

            # Convert target masks to 3D tensors
            targets_3d = targets_resized.squeeze(1)

            val_loss += criterion(outputs, targets_3d.long())

    print(f"Epoch [{epoch+1}/{num_epochs}] - Training Loss: {loss:.4f} - Validation Loss: {val_loss:.4f}")


tensor(6.0739, grad_fn=<NllLoss2DBackward0>)
tensor(5.9750, grad_fn=<NllLoss2DBackward0>)
tensor(5.8979, grad_fn=<NllLoss2DBackward0>)
tensor(5.8282, grad_fn=<NllLoss2DBackward0>)
tensor(5.7387, grad_fn=<NllLoss2DBackward0>)
tensor(5.6837, grad_fn=<NllLoss2DBackward0>)
tensor(5.5420, grad_fn=<NllLoss2DBackward0>)
tensor(5.4265, grad_fn=<NllLoss2DBackward0>)
tensor(5.2609, grad_fn=<NllLoss2DBackward0>)
tensor(5.0673, grad_fn=<NllLoss2DBackward0>)
tensor(4.9814, grad_fn=<NllLoss2DBackward0>)
tensor(4.9101, grad_fn=<NllLoss2DBackward0>)
tensor(4.7099, grad_fn=<NllLoss2DBackward0>)
tensor(4.5705, grad_fn=<NllLoss2DBackward0>)
tensor(4.5783, grad_fn=<NllLoss2DBackward0>)
tensor(4.5384, grad_fn=<NllLoss2DBackward0>)
tensor(4.3340, grad_fn=<NllLoss2DBackward0>)
tensor(4.3416, grad_fn=<NllLoss2DBackward0>)
tensor(4.5431, grad_fn=<NllLoss2DBackward0>)
tensor(4.1228, grad_fn=<NllLoss2DBackward0>)
tensor(4.0376, grad_fn=<NllLoss2DBackward0>)
tensor(4.0645, grad_fn=<NllLoss2DBackward0>)
tensor(3.9

In [None]:
# # Save checkpoint after each epoch
# checkpoint_path = f'checkpoint_epoch_{epoch}.pth'
# Save checkpoint after each epoch
checkpoint_path = f'checkpoint_epoch_{epoch}.pth'
torch.save(model.state_dict(), checkpoint_path)
print(f'Saved checkpoint: {checkpoint_path}')




In [None]:

# Load the trained model checkpoint
loaded_checkpoint_path = 'checkpoint_epoch_0.pth'  # Replace with the actual checkpoint file path
model.load_state_dict(torch.load(loaded_checkpoint_path))
model.eval()  # Set the model to evaluation mode

In [None]:
import torchvision.transforms as transforms
from PIL import Image

# Load and preprocess the test image
test_image_path = '/content/data/MyDrive/CNN_assignment/image_people.jpg'
input_image = Image.open(test_image_path)

# Apply preprocessing transforms
preprocess = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(input_image).unsqueeze(0)  # Add batch dimension


In [None]:
with torch.no_grad():
    predictions = model(input_tensor)

# Assuming predictions is a tensor with shape (batch_size, num_classes, height, width)
# You can extract the predicted class labels for each pixel
predicted_labels = predictions.argmax(dim=1).squeeze().cpu().numpy()


In [None]:
import matplotlib.pyplot as plt

# Plot the original image
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.imshow(input_image)
plt.title('Original Image')

# Plot the predicted segmentation mask
plt.subplot(1, 3, 2)
plt.imshow(predicted_labels, cmap='jet', vmin=0, vmax=num_classes-1)
plt.title('Predicted Segmentation')

# You can also visualize the ground truth mask if available
# ground_truth_path = 'path_to_ground_truth_mask.png'
# ground_truth_mask = Image.open(ground_truth_path)
# plt.subplot(1, 3, 3)
# plt.imshow(ground_truth_mask, cmap='jet', vmin=0, vmax=num_classes-1)
# plt.title('Ground Truth')

plt.show()
