In [None]:
%pip install scikit-learn-extra



In [None]:
# Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn_extra.cluster import KMedoids
from scipy import ndimage
import tensorflow as tf
import os
import os.path
import random
import shutil
import hashlib
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
import torch.utils.data as data
from torch.utils.model_zoo import tqdm
import tarfile
import zipfile
from torch import Tensor
from torchvision import datasets
from torchvision.datasets import SBDataset
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image
from tqdm.auto import tqdm
from collections import OrderedDict
from typing import Any, List, Tuple, Optional, Callable, Dict, TypeVar, Iterable
from vision import *
from utils import *

random.seed(0)

**Steps to do:**

*   Preprocessing Pipeline
  *   Part of Dataset init()
     *   Ignore images with dimension lower than filter_size x filter_size
     *   For images with multiple labels, pick one and let the others be part of the background
  *   Part of Transform or Function
     *   Take a filter_size x filter_size crop by resampling crops until the complete object instance is inside the crop
     *   Upscale image to 480x480
     *   Use k-mediods to generate random number of foreground & background clicks
     *   Generate interaction maps
*   Model
  *   DenseNet-121 Encoder
  *   Decoder from BRS paper
  *   Add ASPP module between Encoder & Decoder
  *   Add Semantic Supervision block to Encoder during pretraining
  *   Add LIP for pooling

In [None]:
# Function to convert PIL images to Tensors. We can pass this as a transform to the Dataset
def PIL_to_tensor(img, target):
  return transforms.ToTensor()(img), transforms.ToTensor()(target)

In [None]:
def download_extract(url: str, root: str, filename: str, md5: str) -> None:
    download_url(url, root, filename, md5)
    with tarfile.open(os.path.join(root, filename), "r") as tar:
        tar.extractall(path=root)

# Remember to upload vision.py and utils.py

In [None]:
class CustomSBDataset(VisionDataset):
    """`Semantic Boundaries Dataset <http://home.bharathh.info/pubs/codes/SBD/download.html>`_

    Args:
        root (string): Root directory of the Semantic Boundaries Dataset
        image_set (string, optional): Select the image_set to use, ``train``, ``val`` or ``train_noval``.
            Image set ``train_noval`` excludes VOC 2012 val images.
        mode (string, optional): Select target type. Possible values 'boundaries' or 'segmentation'.
            In case of 'boundaries', the target is an array of shape `[num_classes, H, W]`,
            where `num_classes=20`.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version. Input sample is PIL image and target is a numpy array
            if `mode='boundaries'` or PIL image if `mode='segmentation'`.
    """

    url = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"
    md5 = "82b4d87ceb2ed10f6038a1cba92111cb"
    filename = "benchmark.tgz"

    voc_train_url = "http://home.bharathh.info/pubs/codes/SBD/train_noval.txt"
    voc_split_filename = "train_noval.txt"
    voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722"

    def __init__(
            self,
            root: str,
            image_set: str = "train",
            mode: str = "segmentation",
            download: bool = False,
            img_filter_size = 200,
            transforms: Optional[Callable] = None,
    ) -> None:

        try:
            from scipy.io import loadmat
            self._loadmat = loadmat
        except ImportError:
            raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: "
                               "pip install scipy")

        super(CustomSBDataset, self).__init__(root, transforms)
        self.image_set = verify_str_arg(image_set, "image_set",
                                        ("train", "val", "train_noval"))
        self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries"))
        self.num_classes = 20
        self.img_filter_size = img_filter_size

        sbd_root = self.root
        image_dir = os.path.join(sbd_root, 'img')
        mask_dir = os.path.join(sbd_root, 'inst')

        if download:
            self.download_dataset(sbd_root)

        if not os.path.isdir(sbd_root):
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')
            
        self.load_filenames(sbd_root, image_dir, mask_dir, image_set)

        self._get_target = self._get_segmentation_target \
            if self.mode == "segmentation" else self._get_boundaries_target

        #self.filter_images()

    def _get_segmentation_target(self, filepath: str) -> Image.Image:
        mat = self._loadmat(filepath)
        return Image.fromarray(mat['GTinst'][0]['Segmentation'][0])

    def _get_boundaries_target(self, filepath: str) -> np.ndarray:
        mat = self._loadmat(filepath)
        return np.concatenate([np.expand_dims(mat['GTinst'][0]['Boundaries'][0][i][0].toarray(), axis=0)
                               for i in range(self.num_classes)], axis=0)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img = Image.open(self.images[index]).convert('RGB')
        target = self._get_target(self.masks[index])

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        img = transforms.ToTensor()(img)
        target = (transforms.ToTensor()(target)*255)
        target = self.naive_ignore_multiple_object_instances(target)
        img, target = self.center_crop(img, target)
        img = transforms.Resize((480, 480))(img)
        target = transforms.Resize((480, 480))(target)

        return img, target

    def __len__(self) -> int:
        return len(self.images)

    def extra_repr(self) -> str:
        lines = ["Image set: {image_set}", "Mode: {mode}"]
        return '\n'.join(lines).format(**self.__dict__)

    def download_dataset(self, sbd_root) -> None:
        download_extract(self.url, self.root, self.filename, self.md5)
        extracted_ds_root = os.path.join(self.root, "benchmark_RELEASE", "dataset")

        for f in ["cls", "img", "inst", "train.txt", "val.txt"]:
            old_path = os.path.join(extracted_ds_root, f)
            shutil.move(old_path, sbd_root)
        download_url(self.voc_train_url, sbd_root, self.voc_split_filename,
                      self.voc_split_md5)
        
    def load_filenames(self, sbd_root, image_dir, mask_dir, image_set) -> None:
        split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt')

        with open(os.path.join(split_f), "r") as fh:
            file_names = [x.strip() for x in fh.readlines()]

        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
        assert (len(self.images) == len(self.masks))
        
    # Remove images with both height and width lower than img_filter_size
    def filter_images(self):
        img_to_remove = set()
        for i in range(len(self.images)):
            img = Image.open(self.images[i])
            width, height = img.size
            if width < self.img_filter_size or height < self.img_filter_size:
                img_to_remove.add(i)

        self.images = [img for index, img in enumerate(self.images) if index not in img_to_remove]
        self.masks = [mask for index, mask in enumerate(self.masks) if index not in img_to_remove]

    # For targets that have multiple object instances, pick label 1 as foreground and let the others be part of the background
    def naive_ignore_multiple_object_instances(self, target):
        cond = torch.eq(target, torch.ones_like(target))
        target = torch.where(cond, target, torch.zeros_like(target))
        return target

    # Apply CenterCrop to image to make it a square based on its smaller dimension
    def center_crop(self, image, target):
      min_len = min(image.shape[1], image.shape[2])
      image = transforms.CenterCrop(min_len)(image)
      target = transforms.CenterCrop(min_len)(target)
      return image, target

    # Apply Random Crop on Image & Target while keeping most of the object instance inside the crop
    def random_crop(self, image, target):
      left = -1
      right = -1
      top = -1
      bottom = -1

      # Find the extreme points of the object instance in the target
      res = torch.nonzero(target[0])
      values, indices = torch.min(res, 0)
      top = values[0].item()
      left = values[1].item()
      values, indices = torch.max(res, 0)
      bottom = values[0].item()
      right = values[1].item()

      # Calculate range to sample top left crop point from
      if right - left >= self.img_filter_size:
        x_min = (right - left)//2
        x_max = (right - left)//2
      else:
        x_min = min(0, right - self.img_filter_size)
        x_max = left

      if bottom - top >= self.img_filter_size:
        y_min = (bottom - top)//2
        y_max = (bottom - top)//2
      else:
        y_min = min(0, bottom - self.img_filter_size)
        y_max = top

      # Sample top left crop point
      x = random.randint(x_min, x_max)
      y = random.randint(y_min, y_max)

      # Apply same random crop to both image and target
      image = transforms.functional.crop(image, y, x, self.img_filter_size, self.img_filter_size)
      target = transforms.functional.crop(target, y, x, self.img_filter_size, self.img_filter_size)

      return image, target

In [None]:
# Download the training set
sbd = CustomSBDataset(root=".", image_set="train", mode="segmentation", download=True)

Downloading https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz to ./benchmark.tgz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Downloading http://home.bharathh.info/pubs/codes/SBD/train_noval.txt to ./train_noval.txt


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

In [None]:
# Download the validation Dataset
sbd_val = CustomSBDataset(root=".", image_set="val", mode="segmentation", download=True)

Using downloaded and verified file: ./benchmark.tgz


Error: ignored

In [None]:
# Check img and target shape
img, target = sbd.__getitem__(0)
print(img.shape, target.shape)

In [None]:
# Code to display image and segmentation mask. Set images to be num of image, segmentation masks to display
images = 10
fig,axes = plt.subplots(nrows = images, ncols = 2, figsize=(50,50))

for i in range(images):
    img, target = sbd.__getitem__(i)
    axes[i,0].imshow(img)
    axes[i,1].imshow(target)

plt.show()

In [None]:
# Sample user clicks using target
def gen_clicks(target):
  num_pos = random.randint(1, 10)
  num_neg = random.randint(0, 10)

  pos_clicks = gen_pos_clicks(target, num_pos)
  neg_clicks = gen_neg_clicks(target, num_neg)

  return pos_clicks, neg_clicks

# Generate n +ve clicks by randomly sampling points in the foreground
def gen_pos_clicks(target, n):
  pos_clicks = torch.zeros((n, 2))
  dstep = 5
  dmargin = 5

  distances = ndimage.distance_transform_edt(target[0].numpy())
  distances = np.where(distances < dmargin, 0, distances)
  points = torch.nonzero(torch.from_numpy(distances))
  print(len(points))

  for i in range(n):
    resample = 1
    while resample > 0:
      index = random.randint(0, len(points))
      point = points[index]
      if len(pos_clicks) == 0:
        pos_clicks[i][0] = point[0]
        pos_clicks[i][1] = point[1]
        resample = 0
      else:
        min_dist = dstep
        for pos in pos_clicks:
          dist = torch.dist(point.type(torch.FloatTensor), pos.type(torch.FloatTensor), 2)
          if dist < min_dist:
            min_dist = dist
            break
        if min_dist >= dstep:
          pos_clicks[i][0] = point[0]
          pos_clicks[i][1] = point[1]
          resample = 0
  #kmediods = KMedoids(n)
  #kmediods.fit(points.numpy())
  #pos_clicks = torch.tensor(kmediods.cluster_centers_)
  return pos_clicks

# Generate n -ve clicks by randomly sampling points in the background
def gen_neg_clicks(target, n):
  neg_clicks = torch.zeros((n, 2))
  dstep = 5
  dmargin_min = 5
  dmargin_max = 20

  target_inv = torch.logical_not(target[0])
  distances = ndimage.distance_transform_edt(target_inv.numpy())
  distances = np.where(distances < dmargin_min, 0, distances)
  distances = np.where(distances > dmargin_max, 0, distances)
  points = torch.nonzero(torch.from_numpy(distances))
  print(len(points))

  for i in range(n):
    resample = 1
    while resample > 0:
      index = random.randint(0, len(points))
      point = points[index]
      if len(neg_clicks) == 0:
        neg_clicks[i][0] = point[0]
        neg_clicks[i][1] = point[1]
        resample = 0
      else:
        min_dist = dstep
        for neg in neg_clicks:
          dist = torch.dist(point.type(torch.FloatTensor), neg.type(torch.FloatTensor), 2)
          if dist < min_dist:
            min_dist = dist
            break
        if min_dist >= dstep:
          neg_clicks[i][0] = point[0]
          neg_clicks[i][1] = point[1]
          resample = 0

  return neg_clicks

In [None]:
# Convert clicks to interaction maps
def convert_clicks(pos_clicks, neg_clicks):
  pos_map = torch.ones((480, 480))
  neg_map = torch.ones((480, 480))

  for click in pos_clicks:
    pos_map[int(click[0])][int(click[1])] = 0
  for click in neg_clicks:
    neg_map[int(click[0])][int(click[1])] = 0

  pos_map = ndimage.distance_transform_edt(pos_map.numpy())
  neg_map = ndimage.distance_transform_edt(neg_map.numpy())
  pos_map = np.where(pos_map > 255, 255, pos_map)
  neg_map = np.where(neg_map > 255, 255, neg_map)
  pos_map = torch.from_numpy(pos_map).unsqueeze(0)
  neg_map = torch.from_numpy(neg_map).unsqueeze(0)

  return pos_map, neg_map

In [None]:
# Code to test click generation by ploting clicks and segmentation mask
img, target = sbd.__getitem__(6)
pos_clicks = gen_pos_clicks(target, 5)
print(pos_clicks)
neg_clicks = gen_neg_clicks(target, 3)
print(neg_clicks)
plt.imshow(transforms.ToPILImage()(target))
plt.scatter(pos_clicks[:, 1], pos_clicks[:, 0])
plt.scatter(neg_clicks[:, 1], neg_clicks[:, 0])

In [None]:
# Code to visualize interaction maps
pos_map, neg_map = convert_clicks(pos_clicks, neg_clicks)
fig,axes = plt.subplots(nrows = 1, ncols = 2)
axes[0].imshow(pos_map[0].numpy(), cmap='gray', interpolation='bicubic')
axes[1].imshow(neg_map[0].numpy(), cmap='gray', interpolation='bicubic')
plt.show()

In [None]:
"""
DenseNet Encoder Network
"""

class _DenseLayer(nn.Module):
    def __init__(
        self,
        num_input_features: int,
        growth_rate: int,
        bn_size: int,
        drop_rate: float,
        memory_efficient: bool = False
    ) -> None:
        super(_DenseLayer, self).__init__()
        self.norm1: nn.BatchNorm2d
        self.add_module('norm1', nn.BatchNorm2d(num_input_features))
        self.relu1: nn.ReLU
        self.add_module('relu1', nn.ReLU(inplace=True))
        self.conv1: nn.Conv2d
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
                                           growth_rate, kernel_size=1, stride=1,
                                           bias=False))
        self.norm2: nn.BatchNorm2d
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate))
        self.relu2: nn.ReLU
        self.add_module('relu2', nn.ReLU(inplace=True))
        self.conv2: nn.Conv2d
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1,
                                           bias=False))
        self.drop_rate = float(drop_rate)
        self.memory_efficient = memory_efficient

    def bn_function(self, inputs: List[Tensor]) -> Tensor:
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
        return bottleneck_output

    # todo: rewrite when torchscript supports any
    def any_requires_grad(self, input: List[Tensor]) -> bool:
        for tensor in input:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused  # noqa: T484
    def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
        def closure(*inputs):
            return self.bn_function(inputs)

        return cp.checkpoint(closure, *input)

    @torch.jit._overload_method  # noqa: F811
    def forward(self, input: List[Tensor]) -> Tensor:
        pass

    @torch.jit._overload_method  # noqa: F811
    def forward(self, input: Tensor) -> Tensor:
        pass

    # torchscript does not yet support *args, so we overload method
    # allowing it to take either a List[Tensor] or single Tensor
    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
        if isinstance(input, Tensor):
            prev_features = [input]
        else:
            prev_features = input

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("Memory Efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate,
                                     training=self.training)
        return new_features


class _DenseBlock(nn.ModuleDict):
    _version = 2

    def __init__(
        self,
        num_layers: int,
        num_input_features: int,
        bn_size: int,
        growth_rate: int,
        drop_rate: float,
        memory_efficient: bool = False
    ) -> None:
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
            self.add_module('denselayer%d' % (i + 1), layer)

    def forward(self, init_features: Tensor) -> Tensor:
        features = [init_features]
        for name, layer in self.items():
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)


class _Transition(nn.Sequential):
    def __init__(self, num_input_features: int, num_output_features: int) -> None:
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    """Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
    """

    def __init__(
        self,
        growth_rate: int = 32,
        block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
        num_init_features: int = 64,
        bn_size: int = 4,
        drop_rate: float = 0,
        num_classes: int = 1000,
        memory_efficient: bool = False
    ) -> None:

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
                                padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient
            )
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=num_features // 2)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x: Tensor) -> Tensor:
        # First Convolution
        x = self.features.conv0(x)
        x = self.features.norm0(x)
        x = self.features.relu0(x)
        x = self.features.pool0(x)

        # Dense Blocks
        dense1_out = self.features.denseblock1(x)
        x = self.features.transition1(dense1_out)
        dense2_out = self.features.denseblock2(x)
        x = self.features.transition2(dense2_out)
        dense3_out = self.features.denseblock3(x)
        x = self.features.transition3(dense3_out)
        dense4_out = self.features.denseblock4(x)
        
        out = self.features.norm5(dense4_out)
        out = F.relu(out, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.classifier(out)

        dense_out = [dense1_out, dense2_out, dense3_out, dense4_out]
        return out, dense_out


def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None:
    # '.'s are no longer allowed in module names, but previous _DenseLayer
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')

    state_dict = load_state_dict_from_url(model_url, progress=progress)
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)


def _densenet(
    arch: str,
    growth_rate: int,
    block_config: Tuple[int, int, int, int],
    num_init_features: int,
    pretrained: bool,
    progress: bool,
    **kwargs: Any
) -> DenseNet:
    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
    if pretrained:
        _load_state_dict(model, model_urls[arch], progress)
    return model


def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
    """Densenet-121 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
    """
    return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress,
                     **kwargs)

In [None]:
"""
Decoder Network
"""

_weights_dict = dict()

class Decoder(nn.Module):

  def __init__(self):
    super(Decoder, self).__init__()

    self.conv_SE_1_32_1 = self.__conv(2, name='conv_SE_1/32_1', in_channels=1024, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
    self.conv_SE_1_32_2 = self.__conv(2, name='conv_SE_1/32_2', in_channels=64, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
    self.conv_1_32_1d = self.__conv(2, name='conv_1/32_1d', in_channels=1024, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1, bias=False)
    self.prelu_1_32_1d = nn.PReLU(num_parameters=512)
    self.bn_1_32_1d = self.__batch_normalization(2, 'bn_1/32_1d', num_features=512, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_1_32_2d = self.__conv(2, name='conv_1/32_2d', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1, bias=False)
    self.prelu_1_32_2d = nn.PReLU(num_parameters=512)
    self.bn_1_32_2d = self.__batch_normalization(2, 'bn_1/32_2d', num_features=512, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_1_32_3d = self.__conv(2, name='conv_1/32_3d', in_channels=512, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.prelu_1_32_3d = nn.PReLU(num_parameters=256)
    self.bn_1_32_3d = self.__batch_normalization(2, 'bn_1/32_3d', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
    self.deconv_1_16d = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2, bias=False)
    self.prelu_1_16d = nn.PReLU(num_parameters=256)
    self.bn_1_16d = self.__batch_normalization(2, 'bn_1/16d', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_SE_1_16_1 = self.__conv(2, name='conv_SE_1/16_1', in_channels=1024, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.conv_SE_1_16_2 = self.__conv(2, name='conv_SE_1/16_2', in_channels=64, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.conv_SE_1_16 = self.__conv(2, name='conv_SE_1/16', in_channels=1024, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.prelu_SE_1_16 = nn.PReLU(num_parameters=256)
    self.bn_SE_1_16 = self.__batch_normalization(2, 'bn_SE_1/16d', num_features=256, eps=9.999999747378752e-06, momentum=0.0)

    self.conv_1_16_1d = self.__conv(2, name='conv_1/16_1d', in_channels=512, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1, bias=False)
    self.prelu_1_16_1d = nn.PReLU(num_parameters=256)
    self.bn_1_16_1d = self.__batch_normalization(2, 'bn_1/16_1d', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_1_16_2d = self.__conv(2, name='conv_1/16_2d', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1, bias=False)
    self.prelu_1_16_2d = nn.PReLU(num_parameters=256)
    self.bn_1_16_2d = self.__batch_normalization(2, 'bn_1/16_2d', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_1_16_3d = self.__conv(2, name='conv_1/16_3d', in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.prelu_1_16_3d = nn.PReLU(num_parameters=128)
    self.bn_1_16_3d = self.__batch_normalization(2, 'bn_1/16_3d', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
    self.deconv_1_8d = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=2, stride=2, bias=False)
    self.prelu_1_8d = nn.PReLU(num_parameters=128)
    self.bn_1_8d = self.__batch_normalization(2, 'bn_1/8d', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_SE_1_8_1 = self.__conv(2, name='conv_SE_1/8_1', in_channels=512, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.conv_SE_1_8_2 = self.__conv(2, name='conv_SE_1/8_1', in_channels=32, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.conv_SE_1_8 = self.__conv(2, name='conv_SE_1/8', in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.prelu_SE_1_8 = nn.PReLU(num_parameters=128)
    self.bn_SE_1_8 = self.__batch_normalization(2, 'bn_SE_1/8d', num_features=128, eps=9.999999747378752e-06, momentum=0.0)

    self.conv_1_8_1d = self.__conv(2, name='conv_1/8_1d', in_channels=256, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1, bias=False)
    self.prelu_1_8_1d = nn.PReLU(num_parameters=128)
    self.bn_1_8_1d = self.__batch_normalization(2, 'bn_1/8_1d', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_1_8_2d = self.__conv(2, name='conv_1/8_2d', in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1, bias=False)
    self.prelu_1_8_2d = nn.PReLU(num_parameters=128)
    self.bn_1_8_2d = self.__batch_normalization(2, 'bn_1/8_2d', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_1_8_3d = self.__conv(2, name='conv_1/8_3d', in_channels=128, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.prelu_1_8_3d = nn.PReLU(num_parameters=64)
    self.bn_1_8_3d = self.__batch_normalization(2, 'bn_1/8_3d', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
    self.deconv_1_4d = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2, bias=False)
    self.prelu_1_4d = nn.PReLU(num_parameters=64)
    self.bn_1_4d = self.__batch_normalization(2, 'bn_1/4d', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_SE_1_4_1 = self.__conv(2, name='conv_SE_1/4_1', in_channels=256, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.conv_SE_1_4_2 = self.__conv(2, name='conv_SE_1/4_2', in_channels=16, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.conv_SE_1_4 = self.__conv(2, name='conv_SE_1/4', in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.prelu_SE_1_4 = nn.PReLU(num_parameters=64)
    self.bn_SE_1_4 = self.__batch_normalization(2, 'bn_1/4d', num_features=64, eps=9.999999747378752e-06, momentum=0.0)

    self.conv_1_4_1d = self.__conv(2, name='conv_1/4_1d', in_channels=128, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1, bias=False)
    self.prelu_1_4_1d = nn.PReLU(num_parameters=64)
    self.bn_1_4_1d = self.__batch_normalization(2, 'bn_1/4_1d', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_1_4_2d = self.__conv(2, name='conv_1/4_2d', in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1, bias=False)
    self.prelu_1_4_2d = nn.PReLU(num_parameters=64)
    self.bn_1_4_2d = self.__batch_normalization(2, 'bn_1/4_2d', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_1_4_3d = self.__conv(2, name='conv_1/4_3d', in_channels=64, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.prelu_1_4_3d = nn.PReLU(num_parameters=32)
    self.bn_1_4_3d = self.__batch_normalization(2, 'bn_1/4_3d', num_features=32, eps=9.999999747378752e-06, momentum=0.0)

    self.pred_1_4 = self.__conv(2, name='pred_1/4', in_channels=32, out_channels=1, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
    self.pred_step_1 = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=4, stride=4, bias=False)

    self.conv_atrous1_1 = self.__conv(2, name='conv_atrous1_1', in_channels=6, out_channels=32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1, bias=False)
    self.prelu_atrous1_1 = nn.PReLU(num_parameters=32)
    self.bn_atrous1_1 = self.__batch_normalization(2, 'bn_atrous1_1', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_atrous1_2 = self.__conv(2, name='conv_atrous1_2', in_channels=32, out_channels=32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1, bias=False)
    self.prelu_atrous1_2 = nn.PReLU(num_parameters=32)
    self.bn_atrous1_2 = self.__batch_normalization(2, 'bn_atrous1_2', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_atrous1_3 = self.__conv(2, name='conv_atrous1_3', in_channels=32, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.prelu_atrous1_3 = nn.PReLU(num_parameters=16)
    self.bn_atrous1_3 = self.__batch_normalization(2, 'bn_atrous1_3', num_features=16, eps=9.999999747378752e-06, momentum=0.0)

    self.conv_atrous2_1 = self.__conv(2, name='conv_atrous2_1', in_channels=6, out_channels=32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), groups=1, bias=False)
    self.prelu_atrous2_1 = nn.PReLU(num_parameters=32)
    self.bn_atrous2_1 = self.__batch_normalization(2, 'bn_atrous2_1', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_atrous2_2 = self.__conv(2, name='conv_atrous2_2', in_channels=32, out_channels=32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1, bias=False)
    self.prelu_atrous2_2 = nn.PReLU(num_parameters=32)
    self.bn_atrous2_2 = self.__batch_normalization(2, 'bn_atrous2_2', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_atrous2_3 = self.__conv(2, name='conv_atrous2_3', in_channels=32, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.prelu_atrous2_3 = nn.PReLU(num_parameters=16)
    self.bn_atrous2_3 = self.__batch_normalization(2, 'bn_atrous2_3', num_features=16, eps=9.999999747378752e-06, momentum=0.0)

    self.conv_atrous3_1 = self.__conv(2, name='conv_atrous3_1', in_channels=6, out_channels=32, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), groups=1, bias=False)
    self.prelu_atrous3_1 = nn.PReLU(num_parameters=32)
    self.bn_atrous3_1 = self.__batch_normalization(2, 'bn_atrous3_1', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_atrous3_2 = self.__conv(2, name='conv_atrous3_2', in_channels=32, out_channels=32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1, bias=False)
    self.prelu_atrous3_2 = nn.PReLU(num_parameters=32)
    self.bn_atrous3_2 = self.__batch_normalization(2, 'bn_atrous3_2', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
    self.conv_atrous3_3 = self.__conv(2, name='conv_atrous3_3', in_channels=32, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
    self.prelu_atrous3_3 = nn.PReLU(num_parameters=16)
    self.bn_atrous3_3 = self.__batch_normalization(2, 'bn_atrous3_3', num_features=16, eps=9.999999747378752e-06, momentum=0.0)
    
    self.conv_s2_down = self.__conv(2, name='conv_s2_down', in_channels=48, out_channels=3, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
    self.conv_s2_up = self.__conv(2, name='conv_s2_up', in_channels=3, out_channels=48, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)
    self.conv_p1_1 = self.__conv(2, name='conv_p1_1', in_channels=48, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1, bias=False)
    self.prelu_p1_1 = nn.PReLU(num_parameters=16)
    self.bn_p1_1 = self.__batch_normalization(2, 'bn_p1_1', num_features=16, eps=9.999999747378752e-06, momentum=0.0)
    self.pred_step_2 = self.__conv(2, name='pred_step_2', in_channels=16, out_channels=1, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)

  def forward(self, dense_out):
    concat_input = dense_out[4]
    concat_5_16 = dense_out[3]
    concat_4_24 = dense_out[2]
    concat_3_12 = dense_out[1]
    concat_2_6 = dense_out[0]

    ################ BLOCK 1 ################      
    pool_SE_1_32    = F.avg_pool2d(concat_5_16, kernel_size=(15, 15), stride=(1, 1), padding=(0,), ceil_mode=False, count_include_pad=False)
    conv_SE_1_32_1  = self.conv_SE_1_32_1(pool_SE_1_32)
    relu_SE_1_32_1  = F.relu(conv_SE_1_32_1)
    conv_SE_1_32_2  = self.conv_SE_1_32_2(relu_SE_1_32_1)
    sigm_SE_1_32    = torch.sigmoid(conv_SE_1_32_2)
    reshape_SE_1_32 = torch.reshape(input = sigm_SE_1_32, shape = (1,1024,1,1))
    scale_SE_1_32   = concat_5_16 * reshape_SE_1_32

    conv_1_32_1d    = self.conv_1_32_1d(scale_SE_1_32)
    #prelu_1_32_1d   = F.prelu(conv_1_32_1d, torch.from_numpy(_weights_dict['prelu_1/32_1d']['weights']))
    prelu_1_32_1d   = self.prelu_1_32_1d(conv_1_32_1d)
    bn_1_32_1d      = self.bn_1_32_1d(prelu_1_32_1d)

    conv_1_32_2d    = self.conv_1_32_2d(bn_1_32_1d)
    #prelu_1_32_2d   = F.prelu(conv_1_32_2d, torch.from_numpy(_weights_dict['prelu_1/32_2d']['weights']))
    prelu_1_32_2d   = self.prelu_1_32_2d(conv_1_32_2d)
    bn_1_32_2d      = self.bn_1_32_2d(prelu_1_32_2d)

    conv_1_32_3d    = self.conv_1_32_3d(bn_1_32_2d)
    #prelu_1_32_3d   = F.prelu(conv_1_32_3d, torch.from_numpy(_weights_dict['prelu_1/32_3d']['weights']))
    prelu_1_32_3d   = self.prelu_1_32_3d(conv_1_32_3d)
    bn_1_32_3d      = self.bn_1_32_3d(prelu_1_32_3d)

    deconv_1_16d	  = self.deconv_1_16d(bn_1_32_3d)
    #prelu_1_16d     = F.prelu(deconv_1_16d, torch.from_numpy(_weights_dict['prelu_1/16d']['weights']))
    prelu_1_16d     = self.prelu_1_16d(deconv_1_16d)
    bn_1_16d        = self.bn_1_16d(prelu_1_16d)

    pool_SE_1_16    = F.avg_pool2d(concat_4_24, kernel_size=(30, 30), stride=(1, 1), padding=(0,), ceil_mode=False, count_include_pad=False)
    conv_SE_1_16_1  = self.conv_SE_1_16_1(pool_SE_1_16)
    relu_SE_1_16_1  = F.relu(conv_SE_1_16_1)
    conv_SE_1_16_2  = self.conv_SE_1_16_2(relu_SE_1_16_1)
    sigm_SE_1_16    = torch.sigmoid(conv_SE_1_16_2)
    reshape_SE_1_16 = torch.reshape(input = sigm_SE_1_16, shape = (1,1024,1,1))
    scale_SE_1_16   = concat_4_24 * reshape_SE_1_16
    conv_SE_1_16    = self.conv_SE_1_16(scale_SE_1_16)
    #prelu_SE_1_16    = F.prelu(conv_SE_1_8, torch.from_numpy(_weights_dict['prelu_SE_1/16']['weights']))
    prelu_SE_1_16   = self.prelu_SE_1_16(conv_SE_1_16)
    bn_SE_1_16      = self.bn_SE_1_16(prelu_SE_1_16)
    #########################################

    ################ BLOCK 2 ################
    concat_1_16d    = torch.cat((bn_1_16d, bn_SE_1_16,), 1)

    conv_1_16_1d    = self.conv_1_16_1d(concat_1_16d)
    #prelu_1_16_1d   = F.prelu(conv_1_16_1d, torch.from_numpy(_weights_dict['prelu_1/16_1d']['weights']))
    prelu_1_16_1d   = self.prelu_1_16_1d(conv_1_16_1d)
    bn_1_16_1d      = self.bn_1_16_1d(prelu_1_16_1d)

    conv_1_16_2d    = self.conv_1_16_2d(bn_1_16_1d)
    #prelu_1_16_2d   = F.prelu(conv_1_16_2d, torch.from_numpy(_weights_dict['prelu_1/16_2d']['weights']))
    prelu_1_16_2d   = self.prelu_1_16_2d(conv_1_16_2d)
    bn_1_16_2d      = self.bn_1_16_2d(prelu_1_16_2d)

    conv_1_16_3d    = self.conv_1_16_3d(bn_1_16_2d)
    #prelu_1_16_3d   = F.prelu(conv_1_16_3d, torch.from_numpy(_weights_dict['prelu_1/16_3d']['weights']))
    prelu_1_16_3d   = self.prelu_1_16_3d(conv_1_16_3d)
    bn_1_16_3d      = self.bn_1_16_3d(prelu_1_16_3d)

    deconv_1_8d		= self.deconv_1_8d(bn_1_16_3d)
    #prelu_1_8d      = F.prelu(deconv_1_8d, torch.from_numpy(_weights_dict['prelu_1/8d']['weights']))
    prelu_1_8d      = self.prelu_1_8d(deconv_1_8d)
    bn_1_8d         = self.bn_1_8d(prelu_1_8d)

    pool_SE_1_8     = F.avg_pool2d(concat_3_12, kernel_size=(60, 60), stride=(1, 1), padding=(0,), ceil_mode=False, count_include_pad=False)
    conv_SE_1_8_1   = self.conv_SE_1_8_1(pool_SE_1_8)        
    relu_SE_1_8_1   = F.relu(conv_SE_1_8_1)
    conv_SE_1_8_2   = self.conv_SE_1_8_2(relu_SE_1_8_1)
    sigm_SE_1_8     = torch.sigmoid(conv_SE_1_8_2)
    reshape_SE_1_8  = torch.reshape(input = sigm_SE_1_8, shape = (1,512,1,1))
    scale_SE_1_8    = concat_3_12 * reshape_SE_1_8
    conv_SE_1_8     = self.conv_SE_1_8(scale_SE_1_8)
    #prelu_SE_1_8    = F.prelu(conv_SE_1_8, torch.from_numpy(_weights_dict['prelu_SE_1/8']['weights']))
    prelu_SE_1_8    = self.prelu_SE_1_8(conv_SE_1_8)
    bn_SE_1_8       = self.bn_SE_1_8(prelu_SE_1_8)
    #########################################

    ################ BLOCK 3 ################																
    concat_1_8d     = torch.cat((bn_1_8d, bn_SE_1_8,), 1)

    conv_1_8_1d     = self.conv_1_8_1d(concat_1_8d)
    #prelu_1_8_1d    = F.prelu(conv_1_8_1d, torch.from_numpy(_weights_dict['prelu_1/8_1d']['weights']))
    prelu_1_8_1d    = self.prelu_1_8_1d(conv_1_8_1d)
    bn_1_8_1d       = self.bn_1_8_1d(prelu_1_8_1d)

    conv_1_8_2d     = self.conv_1_8_2d(bn_1_8_1d)
    #prelu_1_8_2d    = F.prelu(conv_1_8_2d, torch.from_numpy(_weights_dict['prelu_1/8_2d']['weights']))
    prelu_1_8_2d    = self.prelu_1_8_2d(conv_1_8_2d)
    bn_1_8_2d       = self.bn_1_8_2d(prelu_1_8_2d)

    conv_1_8_3d     = self.conv_1_8_3d(bn_1_8_2d)
    #prelu_1_8_3d    = F.prelu(conv_1_8_3d, torch.from_numpy(_weights_dict['prelu_1/8_3d']['weights']))
    prelu_1_8_3d    = self.prelu_1_8_3d(conv_1_8_3d)
    bn_1_8_3d       = self.bn_1_8_3d(prelu_1_8_3d)

    deconv_1_4d		= self.deconv_1_4d(bn_1_8_3d)
    #prelu_1_4d      = F.prelu(deconv_1_4d, torch.from_numpy(_weights_dict['prelu_1/4d']['weights']))
    prelu_1_4d      = self.prelu_1_4d(deconv_1_4d)
    bn_1_4d         = self.bn_1_4d(prelu_1_4d)

    pool_SE_1_4     = F.avg_pool2d(concat_2_6, kernel_size=(120, 120), stride=(1, 1), padding=(0,), ceil_mode=False, count_include_pad=False)
    conv_SE_1_4_1   = self.conv_SE_1_4_1(pool_SE_1_4)
    relu_SE_1_4_1   = F.relu(conv_SE_1_4_1)
    conv_SE_1_4_2   = self.conv_SE_1_4_2(relu_SE_1_4_1)
    sigm_SE_1_4     = torch.sigmoid(conv_SE_1_4_2)
    reshape_SE_1_4  = torch.reshape(input = sigm_SE_1_4, shape = (1,256,1,1))
    scale_SE_1_4    = concat_2_6 * reshape_SE_1_4
    conv_SE_1_4     = self.conv_SE_1_4(scale_SE_1_4)
    #prelu_SE_1_4    = F.prelu(conv_SE_1_4, torch.from_numpy(_weights_dict['prelu_SE_1/4']['weights']))
    prelu_SE_1_4    = self.prelu_SE_1_4(conv_SE_1_4)
    bn_SE_1_4       = self.bn_SE_1_4(prelu_SE_1_4)
    #########################################

    ################ BLOCK 4 ################														        
    concat_1_4d     = torch.cat((bn_1_4d, bn_SE_1_4,), 1)

    conv_1_4_1d     = self.conv_1_4_1d(concat_1_4d)
    #prelu_1_4_1d    = F.prelu(conv_1_4_1d, torch.from_numpy(_weights_dict['prelu_1/4_1d']['weights']))
    prelu_1_4_1d    = self.prelu_1_4_1d(conv_1_4_1d)
    bn_1_4_1d       = self.bn_1_4_1d(prelu_1_4_1d)

    conv_1_4_2d     = self.conv_1_4_2d(bn_1_4_1d)
    #prelu_1_4_2d    = F.prelu(conv_1_4_2d, torch.from_numpy(_weights_dict['prelu_1/4_2d']['weights']))
    prelu_1_4_2d    = self.prelu_1_4_2d(conv_1_4_2d)
    bn_1_4_2d       = self.bn_1_4_2d(prelu_1_4_2d)

    conv_1_4_3d     = self.conv_1_4_3d(bn_1_4_2d)
    #prelu_1_4_3d    = F.prelu(conv_1_4_3d, torch.from_numpy(_weights_dict['prelu_1/4_3d']['weights']))
    prelu_1_4_3d    = self.prelu_1_4_3d(conv_1_4_3d)
    bn_1_4_3d       = self.bn_1_4_3d(prelu_1_4_3d)
    #########################################

    ################ PREDICTION AT 1/4 ################
    pred_1_4        = self.pred_1_4(bn_1_4_3d)

    ################ UNSAMPLE THE PREDICTION FROM 1/4 TO 1/1 ################
    pred_step_1 	= self.pred_step_1(pred_1_4)
    sigp_step_1     = torch.sigmoid(pred_step_1)


    ################ SECONDARY NETWORK (FINE DECODER) STARTS HERE ################
    concat_step_1   = torch.cat((concat_input, sigp_step_1,), 1)

    ################ ATROUS POOLING BLOCK 1 ################
    conv_atrous1_1  = self.conv_atrous1_1(concat_step_1)
    #prelu_atrous1_1 = F.prelu(conv_atrous1_1, torch.from_numpy(_weights_dict['prelu_atrous1_1']['weights']))
    prelu_atrous1_1 = self.prelu_atrous1_1(conv_atrous1_1)
    bn_atrous1_1    = self.bn_atrous1_1(prelu_atrous1_1)

    conv_atrous1_2  = self.conv_atrous1_2(bn_atrous1_1)
    #prelu_atrous1_2 = F.prelu(conv_atrous1_2, torch.from_numpy(_weights_dict['prelu_atrous1_2']['weights']))
    prelu_atrous1_2 = self.prelu_atrous1_2(conv_atrous1_2)
    bn_atrous1_2    = self.bn_atrous1_2(prelu_atrous1_2)

    conv_atrous1_3  = self.conv_atrous1_3(bn_atrous1_2)
    #prelu_atrous1_3 = F.prelu(conv_atrous1_3, torch.from_numpy(_weights_dict['prelu_atrous1_3']['weights']))
    prelu_atrous1_3 = self.prelu_atrous1_3(conv_atrous1_3)
    bn_atrous1_3    = self.bn_atrous1_3(prelu_atrous1_3)
    ########################################################

    ################ ATROUS POOLING BLOCK 2 ################
    conv_atrous2_1  = self.conv_atrous2_1(concat_step_1)
    #prelu_atrous2_1 = F.prelu(conv_atrous2_1, torch.from_numpy(_weights_dict['prelu_atrous2_1']['weights']))
    prelu_atrous2_1 = self.prelu_atrous2_1(conv_atrous2_1)
    bn_atrous2_1    = self.bn_atrous2_1(prelu_atrous2_1)

    conv_atrous2_2  = self.conv_atrous2_2(bn_atrous2_1)
    #prelu_atrous2_2 = F.prelu(conv_atrous2_2, torch.from_numpy(_weights_dict['prelu_atrous2_2']['weights']))
    prelu_atrous2_2 = self.prelu_atrous2_2(conv_atrous2_2)
    bn_atrous2_2    = self.bn_atrous2_2(prelu_atrous2_2)

    conv_atrous2_3  = self.conv_atrous2_3(bn_atrous2_2)
    #prelu_atrous2_3 = F.prelu(conv_atrous2_3, torch.from_numpy(_weights_dict['prelu_atrous2_3']['weights']))
    prelu_atrous2_3 = self.prelu_atrous2_3(conv_atrous2_3)
    bn_atrous2_3    = self.bn_atrous2_3(prelu_atrous2_3)
    ########################################################

    ################ ATROUS POOLING BLOCK 3 ################
    conv_atrous3_1  = self.conv_atrous3_1(concat_step_1)
    #prelu_atrous3_1 = F.prelu(conv_atrous3_1, torch.from_numpy(_weights_dict['prelu_atrous3_1']['weights']))
    prelu_atrous3_1 = self.prelu_atrous3_1(conv_atrous3_1)
    bn_atrous3_1    = self.bn_atrous3_1(prelu_atrous3_1)

    conv_atrous3_2  = self.conv_atrous3_2(bn_atrous3_1)
    #prelu_atrous3_2 = F.prelu(conv_atrous3_2, torch.from_numpy(_weights_dict['prelu_atrous3_2']['weights']))
    prelu_atrous3_2 = self.prelu_atrous3_2(conv_atrous3_2)
    bn_atrous3_2    = self.bn_atrous3_2(prelu_atrous3_2)

    conv_atrous3_3  = self.conv_atrous3_3(bn_atrous3_2)
    #prelu_atrous3_3 = F.prelu(conv_atrous3_3, torch.from_numpy(_weights_dict['prelu_atrous3_3']['weights']))
    prelu_atrous3_3 = self.prelu_atrous3_3(conv_atrous3_3)
    bn_atrous3_3    = self.bn_atrous3_3(prelu_atrous3_3)
    ########################################################

    ################ CONCAT + SQUEEZ & EXCITATION ################
    concat_step_2   = torch.cat((bn_atrous1_3, bn_atrous2_3, bn_atrous3_3,), 1)
    gpool_s2        = F.avg_pool2d(concat_step_2, kernel_size=(480, 480), stride=(1, 1), padding=(0,), ceil_mode=False, count_include_pad=False)
    conv_s2_down    = self.conv_s2_down(gpool_s2)
    relu_s2_down    = F.relu(conv_s2_down)
    conv_s2_up      = self.conv_s2_up(relu_s2_down)
    sig_s2_up       = torch.sigmoid(conv_s2_up)
    resh_s2         = torch.reshape(input = sig_s2_up, shape = (1,48,1,1))
    scale_s2        = concat_step_2 * resh_s2
    ##############################################################

    ################ PREDICTION ################
    conv_p1_1       = self.conv_p1_1(scale_s2)
    #prelu_p1_1      = F.prelu(conv_p1_1, torch.from_numpy(_weights_dict['prelu_p1_1']['weights']))
    prelu_p1_1      = self.prelu_p1_1(conv_p1_1)
    bn_p1_1         = self.bn_p1_1(prelu_p1_1)
    pred_step_2     = self.pred_step_2(bn_p1_1)
    ############################################

    ################ PREDICTION ################
    sig_pred        = torch.sigmoid(pred_step_2)
    ############################################
    return sig_pred

    ################ SECONDARY NETWORK (FINE DECORDER) ENDS HERE ################

  @staticmethod
  def __batch_normalization(dim, name, **kwargs):
    if   dim == 0 or dim == 1:  layer = nn.BatchNorm1d(**kwargs)
    elif dim == 2:  layer = nn.BatchNorm2d(**kwargs)
    elif dim == 3:  layer = nn.BatchNorm3d(**kwargs)
    else:           raise NotImplementedError()
    """
    if 'scale' in _weights_dict[name]:
        layer.state_dict()['weight'].copy_(torch.from_numpy(_weights_dict[name]['scale']))
    else:
        layer.weight.data.fill_(1)

    if 'bias' in _weights_dict[name]:
        layer.state_dict()['bias'].copy_(torch.from_numpy(_weights_dict[name]['bias']))
    else:
        layer.bias.data.fill_(0)

    layer.state_dict()['running_mean'].copy_(torch.from_numpy(_weights_dict[name]['mean']))
    layer.state_dict()['running_var'].copy_(torch.from_numpy(_weights_dict[name]['var']))
    """
    return layer

  @staticmethod
  def __conv(dim, name, **kwargs):
    if   dim == 1:  layer = nn.Conv1d(**kwargs)
    elif dim == 2:  layer = nn.Conv2d(**kwargs)
    elif dim == 3:  layer = nn.Conv3d(**kwargs)
    else:           raise NotImplementedError()

    #layer.state_dict()['weight'].copy_(torch.from_numpy(_weights_dict[name]['weights']))
    #if 'bias' in _weights_dict[name]:
        #layer.state_dict()['bias'].copy_(torch.from_numpy(_weights_dict[name]['bias']))
    return layer

In [None]:
# Initialise encoder object
encoder = densenet121().double()

In [None]:
# Initialise decoder
decoder = Decoder().double()

In [None]:
# Test Encoder
img, target = sbd.__getitem__(4)
pos_clicks, neg_clicks = gen_clicks(target)
pos_map, neg_map = convert_clicks(pos_clicks, neg_clicks)
input = torch.cat((img, pos_map, neg_map), 0)
input = input.unsqueeze(0).double()
encoder_out, dense_out = encoder(input)
dense_out.append(input)

In [None]:
# Test Decoder
decoder_out = decoder(dense_out)
print(decoder_out.shape)

In [None]:
# Visualize Decoder output
plt.imshow(transforms.ToPILImage()(decoder_out[0]))