Importing all necessary libraries

In [None]:

# Basic data manipulations
import pandas as pd
import numpy as np

# Handling images
from PIL import Image
import matplotlib.pyplot as plt

# Handling paths

import time

# Pytorch essentials
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision.datasets import ImageFolder


# Pytorch essentials for datasets.
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset, DataLoader

# Pytorch way of data augmentation.
import torchvision
from torchvision import datasets, models, transforms, utils
from torchvision.transforms import v2

import cv2
import os
from glob import glob
from tqdm import tqdm
import shutil
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import confusion_matrix , accuracy_score, classification_report
import seaborn as sns

import albumentations as A
from albumentations.pytorch import ToTensorV2
! pip install segmentation-models-pytorch
import segmentation_models_pytorch as smp


ISIC 2016 Dataloader

In [None]:

train_img_df = pd.read_csv('/kaggle/input/isic-2016-dataset/train_ISIC_2016.csv')
train_img_df.head()


In [None]:

val_img_df = pd.read_csv('/kaggle/input/isic-2016-dataset/val_ISIC_2016.csv')
val_img_df.head()


In [None]:

test_img_df = pd.read_csv('/kaggle/input/isic-2016-dataset/test_ISIC_2016.csv')
test_img_df.head()


In [None]:

train_img = train_img_df['Image_Id'].tolist()
train_img_paths = []
train_mask_paths = []
for i in range(len(train_img)):
    train_img_paths.append("/kaggle/input/isic-2016-dataset/ISIC_2016_dataset/ISIC_2016_dataset/ISBI2016_ISIC_Part1_Training_Data/" + train_img[i])
    train_mask_paths.append("/kaggle/input/isic-2016-dataset/ISIC_2016_dataset/ISIC_2016_dataset/ISBI2016_ISIC_Part1_Training_GroundTruth/" + train_img[i][:-4] + "_Segmentation" + ".png")
print(len(train_img_paths))
print(len(train_mask_paths))

train_df = pd.DataFrame({"images":train_img_paths,"masks":train_mask_paths})
train_df.head()


In [None]:

val_img = val_img_df['Image_Id'].tolist()
val_img_paths = []
val_mask_paths = []
for i in range(len(val_img)):
    val_img_paths.append("/kaggle/input/isic-2016-dataset/ISIC_2016_dataset/ISIC_2016_dataset/ISBI2016_ISIC_Part1_Training_Data/" + val_img[i])
    val_mask_paths.append("/kaggle/input/isic-2016-dataset/ISIC_2016_dataset/ISIC_2016_dataset/ISBI2016_ISIC_Part1_Training_GroundTruth/" + val_img[i][:-4] + "_Segmentation" + ".png")
print(len(val_img_paths))
print(len(val_mask_paths))

val_df = pd.DataFrame({"images":val_img_paths,"masks":val_mask_paths})
val_df.head()


In [None]:

test_img = test_img_df['Image_Id'].tolist()
test_img_paths = []
test_mask_paths = []
for i in range(len(test_img)):
    test_img_paths.append("/kaggle/input/isic-2016-dataset/ISIC_2016_dataset/ISIC_2016_dataset/ISBI2016_ISIC_Part1_Test_Data/" + test_img[i])
    test_mask_paths.append("/kaggle/input/isic-2016-dataset/ISIC_2016_dataset/ISIC_2016_dataset/ISBI2016_ISIC_Part1_Test_GroundTruth/" + test_img[i][:-4] + "_Segmentation" + ".png")
print(len(test_img_paths))
print(len(test_mask_paths))

test_df = pd.DataFrame({"images":test_img_paths,"masks":test_mask_paths})
test_df.head()


CPM-17 Dataloader

In [None]:
# read img and mask
train_img_paths = sorted(glob('/kaggle/working/Cervical_Segmentation_dataset/train/images/*.bmp'))
train_mask_paths = sorted(glob('/kaggle/working/Cervical_Segmentation_dataset/train/masks/*.png'))
train_df = pd.DataFrame({"images":train_img_paths,"masks":train_mask_paths})
train_df.head()


In [None]:
# read img and mask
valid_img_paths = sorted(glob('/kaggle/working/Cervical_Segmentation_dataset/val/images/*.bmp'))
valid_mask_paths = sorted(glob('/kaggle/working/Cervical_Segmentation_dataset/val/masks/*.png'))
val_df = pd.DataFrame({"images":valid_img_paths,"masks":valid_mask_paths})
val_df.head()


In [None]:
# read img and mask
test_img_paths = sorted(glob('/kaggle/working/Cervical_Segmentation_dataset/test/images/*.bmp'))
test_mask_paths = sorted(glob('/kaggle/working/Cervical_Segmentation_dataset/test/masks/*.png'))
test_df = pd.DataFrame({"images":test_img_paths,"masks":test_mask_paths})
test_df.head()


MonuSeg Dataloader

In [None]:

# read img and mask
train_img_paths = sorted(glob('/kaggle/input/monuseg/MonuSeg/Training/Images/*.png'))
train_mask_paths = sorted(glob('/kaggle/input/monuseg/MonuSeg/Training/Masks/*.png'))
train_df = pd.DataFrame({"images":train_img_paths,"masks":train_mask_paths})
train_df.head()


In [None]:

# read img and mask
val_img_paths = sorted(glob('/kaggle/input/monuseg/MonuSeg/Val/Images/*.png'))
val_mask_paths = sorted(glob('/kaggle/input/monuseg/MonuSeg/Val/Masks/*.png'))
val_df = pd.DataFrame({"images":val_img_paths,"masks":val_mask_paths})
val_df.head()


In [None]:


# read img and mask
test_img_paths = sorted(glob('/kaggle/input/monuseg/MonuSeg/Test/Images/*.png'))
test_mask_paths = sorted(glob('/kaggle/input/monuseg/MonuSeg/Test/Masks/*.png'))
test_df = pd.DataFrame({"images":test_img_paths,"masks":test_mask_paths})
test_df.head()


Cervical_Cancer_Segmentation_Dataset Loader

In [None]:

# read img and mask
train_img_paths = sorted(glob('/kaggle/input/cervical-cancer-segmentation/CervicalCancer/train/images/*.bmp'))
train_mask_paths = sorted(glob('/kaggle/input/cervical-cancer-segmentation/CervicalCancer/train/masks/*.bmp'))
train_df = pd.DataFrame({"images":train_img_paths,"masks":train_mask_paths})
train_df.head()


In [None]:

# read img and mask
val_img_paths = sorted(glob('/kaggle/input/cervical-cancer-segmentation/CervicalCancer/val/images/*.bmp'))
val_mask_paths = sorted(glob('/kaggle/input/cervical-cancer-segmentation/CervicalCancer/val/masks/*.bmp'))
val_df = pd.DataFrame({"images":val_img_paths,"masks":val_mask_paths})
val_df.head()


In [None]:

# read img and mask
test_img_paths = sorted(glob('/kaggle/input/cervical-cancer-segmentation/CervicalCancer/test/images/*.bmp'))
test_mask_paths = sorted(glob('/kaggle/input/cervical-cancer-segmentation/CervicalCancer/test/masks/*.bmp'))
test_df = pd.DataFrame({"images":test_img_paths,"masks":test_mask_paths})
test_df.head()


Showing some images and their corresponding Ground Truth Masks

In [None]:

show_imgs = 4
idx = np.random.choice(len(train_df), show_imgs, replace=False)
fig, axes = plt.subplots(show_imgs*2//4, 4, figsize=(15, 8))
axes = axes.flatten()
for i, ax in enumerate(axes):
    new_i = i//2
    if i % 2 ==0 :
        full_path = train_df.loc[idx[new_i]]['images']
        basename = os.path.basename(full_path) 
    else:
        full_path = train_df.loc[idx[new_i]]['masks']
        basename = os.path.basename(full_path) + ' -mask' 
    ax.imshow(plt.imread(full_path))
    ax.set_title(basename)
    ax.set_axis_off()
    

Image resizing and dataset Loading

In [None]:

train_transforms = A.Compose([
    A.Resize(512, 512),
    A.RandomCrop(height=512, width=512, always_apply=True),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.01, scale_limit=(-0.04,0.04), rotate_limit=(-15,15), p=0.5),
    # A.Normalize(p=1.0),
    # ToTensorV2(),
])

test_transforms = A.Compose([
    A.Resize(512, 512),
    # ToTensorV2(),
])


import torch
import cv2
import numpy as np
import albumentations as A

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, transforms_=None):
        self.df = dataframe
        # We'll use transforms for data augmentation and converting PIL images to torch tensors.
        self.transforms_ = transforms_
        self.pre_normalize = v2.Compose([
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.resize = [512, 512]
        self.class_size = 2

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        img = cv2.cvtColor(cv2.imread(self.df.iloc[index]['images']), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.df.iloc[index]['masks'],cv2.IMREAD_GRAYSCALE)
        mask = np.where(mask<127, 0, 1).astype(np.int16)
        aug = self.transforms_(image=img, mask=mask)
        img, mask = aug['image'], aug['mask']
        img = img/255
        # img = self.pre_normalize(img)
        img = torch.tensor(img, dtype=torch.float).permute(2, 0, 1)
        #target = torch.tensor(mask, dtype=torch.long)
        # Convert target (mask) to tensor and resize it first
        target = torch.tensor(mask, dtype=torch.float)  # Convert to tensor, shape: [1, 512, 512]

        # Resize the target tensor before creating the sample dictionary
        #target_resized = target.view(3, self.resize[0], self.resize[1])
      

        # Now create the sample dictionary with the resized target
        sample = {'x': img, 'y': target}

        #sample = {'x': img, 'y': target}
        return sample


Training Dataloader and Validation Dataloader Calling

In [None]:

device = torch.device("cuda" if torch.cuda.is_available else "cpu")
train_dataset = MyDataset(train_df, train_transforms)
val_dataset = MyDataset(val_df, test_transforms)

BATCH_SIZE = 4
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
print(f'len train: {len(train_df)}')
print(f'len val: {len(val_df)}')


Definition of Customized Convolutional and ReLU layers and Attention Modules

In [None]:

import torch
import torch.nn as nn

try:
    from inplace_abn import InPlaceABN
except ImportError:
    InPlaceABN = None


class Conv2dReLU(nn.Sequential):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        padding=0,
        stride=1,
        use_batchnorm=True,
    ):
        if use_batchnorm == "inplace" and InPlaceABN is None:
            raise RuntimeError(
                "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
                + "To install see: https://github.com/mapillary/inplace_abn"
            )

        conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )
        relu = nn.ReLU(inplace=True)

        if use_batchnorm == "inplace":
            bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
            relu = nn.Identity()

        elif use_batchnorm and use_batchnorm != "inplace":
            bn = nn.BatchNorm2d(out_channels)

        else:
            bn = nn.Identity()

        super(Conv2dReLU, self).__init__(conv, bn, relu)

class ECA(nn.Module):
    def __init__(self, in_channels, kernel_size=3):
        super(ECA, self).__init__()
        self.in_channels = in_channels
        self.kernel_size = kernel_size

        # Ensure kernel size is odd
        if self.kernel_size % 2 == 0:
            self.kernel_size += 1

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(in_channels, in_channels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, groups=in_channels, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch_size, channels, height, width = x.size()

        # Squeeze operation
        y = self.avg_pool(x).view(batch_size, channels)

        # Channel-wise convolution
        y = y.unsqueeze(2)  # Add a dimension for Conv1d
        y = self.conv(y)
        y = y.view(batch_size, channels, 1)

        # Apply sigmoid activation
        y = self.sigmoid(y)

        # Scale the input
        return x * y.view(batch_size, channels, 1, 1)

class BitPlaneAttention(nn.Module):
    def __init__(self, in_channels, bit_planes=8):
        super(BitPlaneAttention, self).__init__()

        self.in_channels = in_channels
        self.bit_planes = bit_planes

        # A 1x1 convolution to adjust the attention across bit planes
        self.attention_conv = nn.Conv2d(in_channels * bit_planes, in_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def bit_plane_decomposition(self, x):
        # Decompose the feature map into binary bit planes
        bit_planes = []
        for i in range(self.bit_planes):
            bit_mask = 1 << i
            bit_plane = (x.int() & bit_mask) // bit_mask
            bit_planes.append(bit_plane.float())
        return torch.stack(bit_planes, dim=1)  # (B, bit_planes, C, H, W)

    def forward(self, x):
        batch_size, C, H, W = x.size()

        # Decompose into bit planes
        bit_planes = self.bit_plane_decomposition(x)  # (B, bit_planes, C, H, W)
        bit_planes = bit_planes.view(batch_size, -1, H, W)  # (B, bit_planes * C, H, W)

        # Apply attention across bit planes
        attention_map = self.attention_conv(bit_planes)  # (B, C, H, W)
        attention_map = self.sigmoid(attention_map)  # Apply sigmoid to get values in [0, 1]

        # Multiply attention map with original input
        out = x * attention_map

        return out

class ArgMax(nn.Module):
    def __init__(self, dim=None):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        return torch.argmax(x, dim=self.dim)


class Clamp(nn.Module):
    def __init__(self, min=0, max=1):
        super().__init__()
        self.min, self.max = min, max

    def forward(self, x):
        return torch.clamp(x, self.min, self.max)


class Activation(nn.Module):
    def __init__(self, name, **params):
        super().__init__()

        if name is None or name == "identity":
            self.activation = nn.Identity(**params)
        elif name == "sigmoid":
            self.activation = nn.Sigmoid()
        elif name == "softmax2d":
            self.activation = nn.Softmax(dim=1, **params)
        elif name == "softmax":
            self.activation = nn.Softmax(**params)
        elif name == "logsoftmax":
            self.activation = nn.LogSoftmax(**params)
        elif name == "tanh":
            self.activation = nn.Tanh()
        elif name == "argmax":
            self.activation = ArgMax(**params)
        elif name == "argmax2d":
            self.activation = ArgMax(dim=1, **params)
        elif name == "clamp":
            self.activation = Clamp(**params)
        elif callable(name):
            self.activation = name(**params)
        else:
            raise ValueError(
                f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/"
                f"argmax/argmax2d/clamp/None; got {name}"
            )

    def forward(self, x):
        return self.activation(x)


class Attention(nn.Module):
    def __init__(self, name, **params):
        super().__init__()

        if name is None:
            self.attention = nn.Identity(**params)
        elif name == "eca":
            self.attention = ECA(**params)
        elif name == "bp":
            self.attention = BitPlaneAttention(**params)
        else:
            raise ValueError("Attention {} is not implemented".format(name))

    def forward(self, x):
        return self.attention(x)

Definition of Segmentation Head

In [None]:
import torch.nn as nn
#####from .modules import Activation


class SegmentationHead(nn.Sequential):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
    ):
        conv2d = nn.Conv2d(
            in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
        )
        upsampling = (
            nn.UpsamplingBilinear2d(scale_factor=upsampling)
            if upsampling > 1
            else nn.Identity()
        )
        activation = Activation(activation)
        super().__init__(conv2d, upsampling, activation)


class ClassificationHead(nn.Sequential):
    def __init__(
        self, in_channels, classes, pooling="avg", dropout=0.2, activation=None
    ):
        if pooling not in ("max", "avg"):
            raise ValueError(
                "Pooling should be one of ('max', 'avg'), got {}.".format(pooling)
            )
        pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1)
        flatten = nn.Flatten()
        dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
        linear = nn.Linear(in_channels, classes, bias=True)
        activation = Activation(activation)
        super().__init__(pool, flatten, dropout, linear, activation)

Definition of Segmentation Model Abstract Architecture: Encoder, Decoder and Segmentation Head

In [None]:

import torch
from segmentation_models_pytorch.base import initialization as init
#####from . import initialization as init
#####from .hub_mixin import SMPHubMixin
from segmentation_models_pytorch.base.hub_mixin import SMPHubMixin

class SegmentationModel(torch.nn.Module, SMPHubMixin):
    def initialize(self):
        init.initialize_decoder(self.decoder)
        init.initialize_head(self.segmentation_head)
        if self.classification_head is not None:
            init.initialize_head(self.classification_head)

    def check_input_shape(self, x):
        h, w = x.shape[-2:]
        output_stride = self.encoder.output_stride
        if h % output_stride != 0 or w % output_stride != 0:
            new_h = (
                (h // output_stride + 1) * output_stride
                if h % output_stride != 0
                else h
            )
            new_w = (
                (w // output_stride + 1) * output_stride
                if w % output_stride != 0
                else w
            )
            raise RuntimeError(
                f"Wrong input shape height={h}, width={w}. Expected image height and width "
                f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})."
            )

    def forward(self, x):
        """Sequentially pass `x` trough model`s encoder, decoder and heads"""

        self.check_input_shape(x)

        features = self.encoder(x)
        decoder_output = self.decoder(*features)

        masks = self.segmentation_head(decoder_output)

        if self.classification_head is not None:
            labels = self.classification_head(features[-1])
            return masks, labels

        return masks

    @torch.no_grad()
    def predict(self, x):
        """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`

        Args:
            x: 4D torch tensor with shape (batch_size, channels, height, width)

        Return:
            prediction: 4D torch tensor with shape (batch_size, classes, height, width)

        """
        if self.training:
            self.eval()

        x = self.forward(x)

        return x

Definition of Decoder Block and Center Block

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F

#####from segmentation_models_pytorch.base import modules as md

class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        use_batchnorm=True,
        attention_type=None,
    ):
        super().__init__()
        self.conv1 = Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention1 = Attention(
            attention_type, in_channels=in_channels + skip_channels
        )
        self.conv2 = Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention2 = Attention(attention_type, in_channels=out_channels)

    def forward(self, x, skip=None):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if skip is not None:
        #    print("X=, Skip= ", x.shape, skip.shape)
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x


class CenterBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, use_batchnorm=True):
        conv1 = Conv2dReLU(
            in_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        conv2 = Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        super().__init__(conv1, conv2)


class UnetDecoder(nn.Module):
    def __init__(
        self,
        encoder_channels,
        decoder_channels,
        n_blocks=5,
        use_batchnorm=True,
        attention_type=None,
        center=False,
    ):
        super().__init__()

        if n_blocks != len(decoder_channels):
            raise ValueError(
                "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
                    n_blocks, len(decoder_channels)
                )
            )

        # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[1:]
        # reverse channels to start from head of encoder
        encoder_channels = encoder_channels[::-1]

        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        in_channels = [head_channels] + list(decoder_channels[:-1])
        skip_channels = list(encoder_channels[1:]) + [0]
        out_channels = decoder_channels

        if center:
            self.center = CenterBlock(
                head_channels, head_channels, use_batchnorm=use_batchnorm
            )
        else:
            self.center = nn.Identity()

        # combine decoder keyword arguments
        kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
        blocks = [
            DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
            for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
        ]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, *features):
        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder

        head = features[0]
        skips = features[1:]

        x = self.center(head)
        for i, decoder_block in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            x = decoder_block(x, skip)

        return x
        

Installing Pytorch Wavelets for using Wavelet Transforms

In [None]:
!rm -rf pytorch_wavelets


In [None]:
!pip install git+https://github.com/fbcotter/pytorch_wavelets


Definition of Wavelet Transform (HAAR)

In [None]:

import torch
import torch.nn.functional as F
from pytorch_wavelets import DWTForward, DWTInverse
def Wavelet_Transform(image):
    # Ensure image and wavelet transform are on the same device
    device = image.device
    xfm = DWTForward(J=1, wave='haar', mode='zero').to(device)

    # Perform wavelet decomposition
    Yl, Yh = xfm(image)

    # Upsample Yl to match the input shape
    Yl_upsampled = F.interpolate(Yl, size=(image.shape[2], image.shape[3]), mode='bilinear', align_corners=False)
    return Yl_upsampled
    

Definition of Encoder (MobileNetV2) with the fusion of ECA and Bit-Plane (BA) attention modules, named, FA.

In [None]:

import torch
import torch.nn as nn
import numpy as np
import functools
import torch.utils.model_zoo as model_zoo
from pretrainedmodels.models.inceptionv4 import InceptionV4, pretrained_settings


class EncoderMixin:
    """Mixin class to add encoder functionality such as:
    - output channels specification for feature tensors
    - patching first convolution for arbitrary input channels
    """
    _output_stride = 32

    @property
    def out_channels(self):
        return self._out_channels[:self._depth + 1]

    @property
    def output_stride(self):
        return min(self._output_stride, 2 ** self._depth)

    def set_in_channels(self, in_channels, pretrained=True):
        if in_channels == 3:
            return
        self._in_channels = in_channels
        if self._out_channels[0] == 3:
            self._out_channels = tuple([in_channels] + list(self._out_channels)[1:])
        patch_first_conv(self, new_in_channels=in_channels, pretrained=pretrained)

    def make_dilated(self, output_stride):
        if output_stride == 16:
            stage_list, dilation_list = [5], [2]
        elif output_stride == 8:
            stage_list, dilation_list = [4, 5], [2, 4]
        else:
            raise ValueError("Output stride must be 16 or 8.")
        self._output_stride = output_stride
        stages = self.get_stages()
        for stage_idx, dilation_rate in zip(stage_list, dilation_list):
            replace_strides_with_dilation(stages[stage_idx], dilation_rate)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
class BitPlaneAttention(nn.Module):
    """Bit Plane Attention Block."""
    def __init__(self, in_channels):
        super(BitPlaneAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels // 4, in_channels, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv1(y)
        y = self.conv2(y)
        y = self.sigmoid(y)
        return x * y.expand_as(x)

class ECA(nn.Module):
    """Efficient Channel Attention Block."""
    def __init__(self, in_channels, gamma=2, b=1):
        super(ECA, self).__init__()
        self.gamma = gamma
        self.b = b
        kernel_size = int(abs((in_channels / self.gamma) + self.b))
        kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
        self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = torch.mean(x, dim=(2, 3), keepdim=True)
        y = y.squeeze(-1).permute(0, 2, 1)  # Convert to (N, C, 1) -> (N, 1, C)
        y = self.conv(y)
        y = y.permute(0, 2, 1).unsqueeze(-1)  # Convert back to (N, C, 1, 1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)

class FusedAttention(nn.Module):
    """Fused Attention Block combining BitPlane and ECA Attention."""
    def __init__(self, in_channels):
        super(FusedAttention, self).__init__()
        self.bitplane_attention = BitPlaneAttention(in_channels)
        self.eca_attention = ECA(in_channels)

    def forward(self, x):
        bitplane_attention = self.bitplane_attention(x)
        eca_attention = self.eca_attention(x)
        # Combine them with element-wise averaging
        return (bitplane_attention  + eca_attention)/2

class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin):
    def __init__(self, out_channels, depth=5, **kwargs):
        super().__init__(**kwargs)
        self._depth = depth
        self._out_channels = out_channels
        self._in_channels = 3
        self.attention_blocks = nn.ModuleList([FusedAttention(c) for c in self._out_channels[1:]])
        
        del self.classifier

    def get_stages(self):
        return [
            nn.Identity(),
            self.features[:2],
            self.features[2:4],
            self.features[4:7],
            self.features[7:14],
            self.features[14:],
        ]

    def forward(self, x):
        features = []
        for i in range(self._depth + 1):
            x = self.get_stages()[i](x)

            if i == 0:  # Apply Wavelet Transform and Fused Attention only at the first stage
                wavelet_texture = Wavelet_Transform(x)  # Apply wavelet transform
                x = torch.mul(x, wavelet_texture) 
                
            if i > 0:  # Add Fused Attention block only after the first stage
                x = self.attention_blocks[i - 1](x)
            features.append(x)
        return features

    def load_state_dict(self, state_dict, **kwargs):
        state_dict.pop("last_linear.bias", None)
        state_dict.pop("last_linear.weight", None)
        super().load_state_dict(state_dict, **kwargs)

# Utility Functions
def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True):
    for module in model.modules():
        if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels:
            break
    weight = module.weight.detach()
    module.in_channels = new_in_channels
    if not pretrained:
        module.weight = nn.parameter.Parameter(
            torch.Tensor(module.out_channels, new_in_channels // module.groups, *module.kernel_size))
        module.reset_parameters()
    elif new_in_channels == 1:
        module.weight = nn.parameter.Parameter(weight.sum(1, keepdim=True))
    else:
        new_weight = torch.Tensor(module.out_channels, new_in_channels // module.groups, *module.kernel_size)
        for i in range(new_in_channels):
            new_weight[:, i] = weight[:, i % default_in_channels]
        new_weight *= (default_in_channels / new_in_channels)
        module.weight = nn.parameter.Parameter(new_weight)


def replace_strides_with_dilation(module, dilation_rate):
    for mod in module.modules():
        if isinstance(mod, nn.Conv2d):
            mod.stride = (1, 1)
            mod.dilation = (dilation_rate, dilation_rate)
            kh, kw = mod.kernel_size
            mod.padding = ((kh // 2) * dilation_rate, (kw // 2) * dilation_rate)
            if hasattr(mod, "static_padding"):
                mod.static_padding = nn.Identity()


# Preprocessing
def preprocess_input(x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs):
    if input_space == "BGR":
        x = x[..., ::-1].copy()
    if input_range is not None and x.max() > 1 and input_range[1] == 1:
        x = x / 255.0
    if mean is not None:
        x -= np.array(mean)
    if std is not None:
        x /= np.array(std)
    return x


# Encoder retrieval functions
encoders = {
    "mobilenet_v2": {
        "encoder": MobileNetV2Encoder,
        "pretrained_settings": {
            "imagenet": {
                "mean": [0.485, 0.456, 0.406],
                "std": [0.229, 0.224, 0.225],
                "url": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
                "input_space": "RGB",
                "input_range": [0, 1],
            }
        },
        "params": {"out_channels": (3, 16, 24, 32, 96, 1280)},
    }
}


def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
    if name not in encoders:
        raise KeyError(f"Unsupported encoder `{name}`, supported: {list(encoders.keys())}")

    Encoder = encoders[name]["encoder"]
    params = encoders[name]["params"]
    params.update(depth=depth)
    encoder = Encoder(**params)

    if weights:
        settings = encoders[name]["pretrained_settings"][weights]
        # Load only base network weights, SE blocks will be randomly initialized
        pretrained_dict = model_zoo.load_url(settings["url"])
        model_dict = encoder.state_dict()

        # Filter out SE block weights, which are not present in pre-trained weights
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and "se_blocks" not in k}
        model_dict.update(pretrained_dict)

        encoder.load_state_dict(model_dict)

    encoder.set_in_channels(in_channels, pretrained=weights is not None)
    if output_stride != 32:
        encoder.make_dilated(output_stride)

    return encoder



def get_preprocessing_params(encoder_name, pretrained="imagenet"):
    settings = encoders[encoder_name]["pretrained_settings"].get(pretrained)
    return {
        "input_space": settings.get("input_space", "RGB"),
        "input_range": list(settings.get("input_range", [0, 1])),
        "mean": list(settings["mean"]),
        "std": list(settings["std"]),
    }


def get_preprocessing_fn(encoder_name, pretrained="imagenet"):
    params = get_preprocessing_params(encoder_name, pretrained)
    return functools.partial(preprocess_input, **params)


Definition of UNet structure

In [None]:

from typing import Optional, Union, List
"""
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
    SegmentationModel,
    SegmentationHead,
    ClassificationHead,
)
"""
#####from .decoder import UnetDecoder

class Unet(SegmentationModel):
    """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder*
    and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial
    resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation*
    for fusing decoder blocks with skip connections.

    Args:
        encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
            to extract features of different spatial resolution
        encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
            two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
            with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
            Default is 5
        encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
            other pretrained weights (see table with available weights for each encoder_name)
        decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
            Length of the list should be the same as **encoder_depth**
        decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
            is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
            Available options are **True, False, "inplace"**
        decoder_attention_type: Attention module used in decoder of the model. Available options are
            **None** and **scse** (https://arxiv.org/abs/1808.08127).
        in_channels: A number of input channels for the model, default is 3 (RGB images)
        classes: A number of classes for output mask (or you can think as a number of channels of output mask)
        activation: An activation function to apply after the final convolution layer.
            Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
                **callable** and **None**.
            Default is **None**
        aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
            on top of encoder if **aux_params** is not **None** (default). Supported params:
                - classes (int): A number of classes
                - pooling (str): One of "max", "avg". Default is "avg"
                - dropout (float): Dropout factor in [0, 1)
                - activation (str): An activation function to apply "sigmoid"/"softmax"
                    (could be **None** to return logits)

    Returns:
        ``torch.nn.Module``: Unet

    .. _Unet:
        https://arxiv.org/abs/1505.04597

   """

    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_use_batchnorm: bool = True,
        decoder_channels: List[int] = (256, 128, 64, 32, 16),
        decoder_attention_type: Optional[str] = None,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, callable]] = None,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )

        self.decoder = UnetDecoder(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
            n_blocks=encoder_depth,
            use_batchnorm=decoder_use_batchnorm,
            center=True if encoder_name.startswith("vgg") else False,
            attention_type=decoder_attention_type,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=3,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params
            )
        else:
            self.classification_head = None

        self.name = "u-{}".format(encoder_name)
        self.initialize()


In [None]:
!pip install torchinfo

In [None]:
!pip install torchsummary

In [None]:
import ssl
ssl._create_default_https_context = ssl.create_default_context


Displaying the architecture of our proposed model

In [None]:
import torch
from torchinfo import summary
#from model import Unet  # Assuming your model is saved in model.py

# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class_size = 1
model = Unet(
    encoder_name="mobilenet_v2",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",
    #decoder_attention_type="Defattn",             # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=class_size                      # model output channels (number of classes in your dataset)
    #activation="softmax"
).to(device)

# Get the summary
input_size = (1, 3, 512, 512)  # Batch size of 1, 3 channels, 512x512 image
summary(model, input_size=input_size, device=device)


Definition of evaluation metrics and train and Validation Dataloader for training purpose

In [None]:

# Dice score implementation
def dice_score(pred, target, smooth=1e-6):
    pred = pred > 0.5  # Threshold predictions
    target = target > 0.5
    intersection = (pred * target).sum().float()
    dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
    return dice


def train(dataloader, model, loss_fn, optimizer, lr_scheduler):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.train()
    epoch_loss = 0
    epoch_iou_score = 0
    epoch_dice_score = 0

    for batch_i, batch in enumerate(dataloader):
        x, y = batch['x'].to(device), batch['y'].to(device)

        optimizer.zero_grad()
        pred = model(x)
        #print("pred=",pred)
        loss = loss_fn(pred, y)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        pred = torch.sigmoid(pred)
        #print("pred_sigmoid",pred)
        pred = pred.squeeze(dim=1)
        
        y = y.round().long()

        # Calculate Dice score
        dice = dice_score(pred, y)
        epoch_dice_score += dice.item()

        # Calculate IoU score
        tp, fp, fn, tn = smp.metrics.get_stats(pred, y, mode='binary', threshold=0.5)
        iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro").item()
        epoch_iou_score += iou

        lr_scheduler.step()

    return epoch_loss / num_batches, epoch_dice_score / num_batches, epoch_iou_score / num_batches

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    epoch_loss = 0
    epoch_iou_score = 0
    epoch_dice_score = 0

    with torch.no_grad():
        for batch_i, batch in enumerate(dataloader):
            x, y = batch['x'].to(device), batch['y'].to(device)

            pred = model(x)
            loss = loss_fn(pred, y)

            epoch_loss += loss.item()

            pred = torch.sigmoid(pred)
            pred = pred.squeeze(dim=1)
            y = y.round().long()

            # Calculate Dice score
            dice = dice_score(pred, y)
            epoch_dice_score += dice.item()

            # Calculate IoU score
            tp, fp, fn, tn = smp.metrics.get_stats(pred, y, mode='binary', threshold=0.5)
            iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro").item()
            epoch_iou_score += iou

    return epoch_loss / num_batches, epoch_dice_score / num_batches, epoch_iou_score / num_batches


Definition of loss function

In [None]:

def loss_fn(y_Pred, y_True, smooth=1e-6):
    #print("Y_Pred, Y_True", y_Pred.shape, y_True.shape)
    loss_fn1 = smp.losses.DiceLoss(mode="binary")
    loss_dice = loss_fn1(y_Pred, y_True)
    
    return loss_dice
    

Model Training

In [None]:

EPOCHS = 100
logs = {
    'train_loss': [], 'val_loss': [],
    'train_iou_score': [], 'val_iou_score': [],
    'train_dice_score': [], 'val_dice_score': []
}


if os.path.exists('checkpoints') == False:
    os.mkdir("checkpoints")

#loss_fn = smp.losses.DiceLoss(mode="binary")
#loss_fn = smp.losses.FocalLoss(mode="binary")

#loss_fn = DiceLoss(mode="binary")
#loss_fn = FocalLoss(mode="binary")

learning_rate = 0.001
optimizer = torch.optim.NAdam(model.parameters(), lr=learning_rate)
step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size = 100, gamma=0.1)

# Earlystopping
patience = 5
counter = 0
#best_loss = np.inf
best_dice_score=-1

model.to(device)
for epoch in tqdm(range(EPOCHS)):
    train_loss, train_dice_score, train_iou_score = train(train_loader, model, loss_fn, optimizer, step_lr_scheduler)
    val_loss, val_dice_score, val_iou_score = test(val_loader, model, loss_fn)

    logs['train_loss'].append(train_loss)
    logs['val_loss'].append(val_loss)
    logs['train_dice_score'].append(train_dice_score)
    logs['val_dice_score'].append(val_dice_score)
    logs['train_iou_score'].append(train_iou_score)
    logs['val_iou_score'].append(val_iou_score)

    print(f'EPOCH: {str(epoch+1).zfill(3)} | '
          f'train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f} | '
          f'train_dice_score: {train_dice_score:.3f}, val_dice_score: {val_dice_score:.3f} | '
          f'train_iou_score: {train_iou_score:.3f}, val_iou_score: {val_iou_score:.3f} | '
          f'lr: {optimizer.param_groups[0]["lr"]}')

    # Save model
    torch.save(model.state_dict(), "checkpoints/last.pth")

    if val_dice_score > best_dice_score:
        counter = 0
        best_dice_score = val_dice_score
        torch.save(model.state_dict(), "checkpoints/best.pth")
    else:
        counter += 1

    # Early stopping
    if counter >= patience:
        print("Early stopping!")
        break


Displaying the loss curve and Dice Score curve

In [None]:

plt.figure(figsize=(15,5))
plt.subplot(1,2,1)
plt.plot(logs['train_loss'],label='Train_Loss')
plt.plot(logs['val_loss'],label='Validation_Loss')
plt.title('Train_Loss & Validation_Loss',fontsize=20)
plt.legend()

#plt.subplot(1,2,2)
#plt.plot(logs['train_iou_score'],label='Train_Iou_Score')
#plt.plot(logs['val_iou_score'],label='Validation_Iou_Score')
#plt.title('Train_Iou_score & Validation_Iou_score',fontsize=20)
#plt.legend()

plt.subplot(1,2,2)
plt.plot(logs['train_dice_score'],label='Train_Dice_Score')
plt.plot(logs['val_dice_score'],label='Validation_Dice_Score')
plt.title('Train_Dice_score & Validation_Dice_score',fontsize=20)
plt.legend()


In [None]:

test_transforms = A.Compose([
    A.Resize(512, 512),
    # ToTensorV2(),
])


In [None]:

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe,transforms_=None):
        self.df = dataframe
        self.transforms_ = transforms_
        self.pre_normalize = v2.Compose([
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.resize = [512, 512]
        self.class_size = 2

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        img = cv2.cvtColor(cv2.imread(self.df.iloc[index]['images']), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.df.iloc[index]['masks'],cv2.IMREAD_GRAYSCALE)
        aug = self.transforms_(image=img, mask=mask)
        img, mask = aug['image'], aug['mask']
        img_view = np.copy(img)
        img = img/255
        # img = self.pre_normalize(img)
        img = torch.tensor(img, dtype=torch.float).permute(2, 0, 1)
        mask_view = np.copy(mask)
        mask = np.where(mask<127, 0, 1).astype(np.int16)
        target = torch.tensor(mask, dtype=torch.long)
        sample = {'x': img, 'y': target, 'img_view':img_view, 'mask_view':mask_view}
        return sample

test_dataset = TestDataset(test_df, test_transforms)
test_loader = DataLoader(test_dataset, batch_size=4)


Evaluation of our proposed model on the Test Dataset

In [None]:

model.load_state_dict(torch.load('checkpoints/best.pth'))
model.to(device)

def get_metrics(model, dataloader, threshold):
    IoU_score, precision, f1_score, recall, acc, dice_score = 0, 0, 0, 0, 0, 0
    batches = 0
    model.eval()
    with torch.no_grad():
        for batch_i, batch in enumerate(dataloader):
            x, y = batch['x'].to(device), batch['y'].to(device)  # move data to GPU
            pred = model(x)
            pred = pred.squeeze(dim=1)
            pred = torch.sigmoid(pred)
            y = y.round().long()

            # Calculate stats
            tp, fp, fn, tn = smp.metrics.get_stats(pred, y, mode='binary', threshold=threshold)

            # Calculate various metrics
            batch_iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro").item()
            batch_acc = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro").item()
            batch_f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro").item()
            batch_recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro").item()
            batch_precision = smp.metrics.precision(tp, fp, fn, tn, reduction="micro").item()

            # Manually calculate Dice score
            dice = (2 * tp.sum()) / (2 * tp.sum() + fp.sum() + fn.sum())
            batch_dice_score = dice.item()

            # Aggregate the results
            IoU_score += batch_iou_score
            acc += batch_acc
            f1_score += batch_f1_score
            recall += batch_recall
            precision += batch_precision
            dice_score += batch_dice_score
            batches += 1

    # Compute average metrics over all batches
    IoU_score = round(IoU_score / batches, 3)
    precision = round(precision / batches, 3)
    f1_score = round(f1_score / batches, 3)
    recall = round(recall / batches, 3)
    acc = round(acc / batches, 3)
    dice_score = round(dice_score / batches, 3)

    sample = {
        'iou': IoU_score,
        'pre': precision,
        'fi': f1_score,
        're': recall,# Example Decoder class where Fused Attention is used in decoding stages
        'acc': acc,
        'dice': dice_score
    }
    return sample

# Evaluate the model for different thresholds
threshold_list = [0.3, 0.4, 0.5, 0.6, 0.7]
for threshold in threshold_list:
    sample = get_metrics(model, test_loader, threshold)
    print(f"Threshold: {threshold:.2f} \
    IoU Score: {sample['iou']:.3f} \
    Precision: {sample['pre']:.3f} \
    F1 Score: {sample['fi']:.3f} \
    Recall: {sample['re']:.3f} \
    Accuracy: {sample['acc']:.3f} \
    Dice Score: {sample['dice']:.3f}")


Displaying some input images and their corresponding predicted masks

In [None]:
model.load_state_dict(torch.load("checkpoints/best.pth"))
model.to(device)

show_imgs = 8

random_list = np.random.choice(len(test_dataset), show_imgs, replace=False)

for i in range(show_imgs):
    idx = random_list[i]
    sample = test_dataset[idx]
    pred = model(sample['x'].to('cuda', dtype=torch.float32).unsqueeze(0))
    pred = torch.sigmoid(pred).squeeze(0).squeeze(0)
    pred = pred.data.cpu().numpy()
    pred = np.where(pred<0.5, 0, 1).astype(np.int16)
    pred_img = Image.fromarray(np.uint8(pred), 'L')

    img_view = sample['img_view']
    img_view = Image.fromarray(img_view, 'RGB')

    mask_view = sample['mask_view']
    mask_view = Image.fromarray(mask_view, 'L')

    f, axarr = plt.subplots(1, 3)
    axarr[0].imshow(img_view)
    axarr[0].set_title('Input')
    axarr[0].axis('off')
    axarr[1].imshow(pred_img)
    axarr[1].set_title('pred')
    axarr[1].axis('off')
    axarr[2].imshow(mask_view)
    axarr[2].set_title('gt')
    axarr[2].axis('off')
    plt.show()
    

Visualization for input image and its predicted mask and ground truth mask

In [None]:
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms

# Load the model and set it to evaluation mode
model.load_state_dict(torch.load("checkpoints/best.pth"))
model.to(device)
model.eval()

# Define the path to the test image and ground truth mask
image_path = "/kaggle/input/isic-2016-dataset/ISIC_2016_dataset/ISIC_2016_dataset/ISBI2016_ISIC_Part1_Test_Data/ISIC_0010073.jpg"
mask_path = "/kaggle/input/isic-2016-dataset/ISIC_2016_dataset/ISIC_2016_dataset/ISBI2016_ISIC_Part1_Test_GroundTruth/ISIC_0010073_Segmentation.png"

# Define preprocessing pipeline
transform = transforms.Compose([
    transforms.Resize((512, 512)),  # Ensure input size matches model's requirements
    transforms.ToTensor(),         # Convert PIL Image to PyTorch tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

# Preprocess the test image
image_test = Image.open(image_path).convert("RGB")
image_tensor = transform(image_test).unsqueeze(0).to(device)  # Add batch dimension and move to GPU

# Predict the segmentation mask
with torch.no_grad():
    pred = model(image_tensor)  # Pass the image through the model
    pred = torch.sigmoid(pred).squeeze(0).squeeze(0).cpu().numpy()  # Apply sigmoid, remove batch dimension, and move to CPU
    pred_binary = np.where(pred < 0.5, 0, 1).astype(np.uint8)  # Binarize the prediction

# Load and preprocess the ground truth mask
mask_view = Image.open(mask_path).resize((512, 512))  # Resize to match predicted mask size
mask_view = np.array(mask_view)  # Convert to numpy array

# Convert prediction to image format for visualization
pred_img = Image.fromarray(pred_binary * 255, 'L')  # Scale binary mask to [0, 255] for display

# Visualization
f, axarr = plt.subplots(1, 3, figsize=(15, 5))
axarr[0].imshow(image_test)
axarr[0].set_title('Input Image')
axarr[0].axis('off')

axarr[1].imshow(pred_img, cmap='gray')
axarr[1].set_title('Predicted Mask')
axarr[1].axis('off')

axarr[2].imshow(mask_view, cmap='gray')
axarr[2].set_title('Ground Truth Mask')
axarr[2].axis('off')

plt.tight_layout()
plt.show()


Checking for heatmap

In [None]:

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms

# Step 1: Load the model (ensure it's defined as per your architecture)
#model = UnetDecoder(encoder_channels=[64, 128, 256, 512, 1024], decoder_channels=[512, 256, 128, 64, 32])
model.load_state_dict(torch.load('checkpoints/best.pth'))
model.eval()

# Step 2: Load and preprocess the input image
def load_image(image_path):
    # Load the image
    image = Image.open(image_path).convert("RGB")  # Ensure it's RGB

    # Define the transformations
    preprocess = transforms.Compose([
        transforms.Resize((512, 512)),  # Resize to match model input
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize
    ])

    # Apply transformations
    image_tensor = preprocess(image)
    image_tensor = image_tensor.unsqueeze(0)  # Add batch dimension
    return image_tensor

# Load your input image (replace with your actual image path)
input_image_path = '/kaggle/input/monuseg/MonuSeg/Test/Images/TCGA-GL-6846-01A-01-BS1.png'  # Specify the path to your image                                   
input_tensor = load_image(input_image_path)
input_tensor = input_tensor.to('cuda')


# Step 3: Set up the hook and perform Grad-CAM
def get_activation(layer):
    def hook(model, input, output):
        activation.append(output)
    return hook

# Choose the target layer for Grad-CAM
target_layer = model.decoder.blocks[4].conv2[2]  # Modify this to your architecture

# Register the hook
activation = []
hook = target_layer.register_forward_hook(get_activation(target_layer))

# Forward pass through the model
with torch.no_grad():
    output = model(input_tensor)

# Unregister the hook
hook.remove()
# Now you can access the activation
activation_map = activation[0]  # Get the activation from the hook

# Assuming the output shape is [N, C, H, W]
# For visualization, we need to take the average across the channels
activation_map = activation_map.mean(dim=1, keepdim=True)  # Shape: [N, 1, H, W]
activation_map = F.relu(activation_map)  # ReLU activation

# Normalize the activation map for visualization
activation_map = (activation_map - activation_map.min()) / (activation_map.max() - activation_map.min())
activation_map = activation_map.squeeze().cpu().numpy()  # Convert to numpy for plotting

# Plot the activation map
plt.imshow(activation_map, cmap='inferno')  # or any other colormap
plt.axis('off')  # Turn off axis
#plt.colorbar()  # Show colorbar for reference
#plt.title('Grad-CAM Activation Map')
plt.show()
