**Library Imports**

In [None]:
#installations
!pip install torchinfo
!pip install torchsummary

# Basic data manipulations
import pandas as pd
import numpy as np

# Handling images
from PIL import Image
import matplotlib.pyplot as plt

# Handling paths

import time

# Pytorch essentials
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision.datasets import ImageFolder


# Pytorch essentials for datasets.
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset, DataLoader

# Pytorch way of data augmentation.
import torchvision
from torchvision import datasets, models, transforms, utils
from torchvision.transforms import v2

import cv2
import os
from glob import glob
from tqdm import tqdm
import shutil
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import confusion_matrix , accuracy_score, classification_report
import seaborn as sns


import albumentations as A
from albumentations.pytorch import ToTensorV2
! pip install segmentation-models-pytorch
import segmentation_models_pytorch as smp


**Dataset Definations**

In [None]:
# read img and mask
train_img_paths = sorted(glob('train_image_path'))
train_mask_paths = sorted(glob('train_mask_path'))
train_df = pd.DataFrame({"images":train_img_paths,"masks":train_mask_paths})
train_df.head()

In [None]:
# read img and mask
val_img_paths = sorted(glob('valid_image_path'))
val_mask_paths = sorted(glob('valid_mask_path'))
val_df = pd.DataFrame({"images":val_img_paths,"masks":val_mask_paths})
val_df.head()

In [None]:
# read img and mask
test_img_paths = sorted(glob('test_image_path'))
test_mask_paths = sorted(glob('test_mask_path'))
test_df = pd.DataFrame({"images":test_img_paths,"masks":test_mask_paths})
test_df.head()

In [None]:
show_imgs = 4
idx = np.random.choice(len(train_df), show_imgs, replace=False)
fig, axes = plt.subplots(show_imgs*2//4, 4, figsize=(15, 8))
axes = axes.flatten()
for i, ax in enumerate(axes):
    new_i = i//2
    if i % 2 ==0 :
        full_path = train_df.loc[idx[new_i]]['images']
        basename = os.path.basename(full_path)
    else:
        full_path = train_df.loc[idx[new_i]]['masks']
        basename = os.path.basename(full_path) + ' -mask'
    ax.imshow(plt.imread(full_path))
    ax.set_title(basename)
    ax.set_axis_off()

**Data Preprocessing**

In [None]:
train_transforms = A.Compose([
    A.Resize(512, 512),
    A.RandomCrop(height=512, width=512, always_apply=True),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.01, scale_limit=(-0.04,0.04), rotate_limit=(-15,15), p=0.5),
    # A.Normalize(p=1.0),
    # ToTensorV2(),
])

val_transforms = A.Compose([
    A.Resize(512, 512),
    # ToTensorV2(),
])

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, transforms_=None):
        self.df = dataframe
        # We'll use transforms for data augmentation and converting PIL images to torch tensors.
        self.transforms_ = transforms_
        self.pre_normalize = v2.Compose([
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.resize = [512, 512]
        self.class_size = 2

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        img = cv2.cvtColor(cv2.imread(self.df.iloc[index]['images']), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.df.iloc[index]['masks'],cv2.IMREAD_GRAYSCALE)
        mask = np.where(mask<127, 0, 1).astype(np.int16)
        aug = self.transforms_(image=img, mask=mask)
        img, mask = aug['image'], aug['mask']
        img = img/255
        # img = self.pre_normalize(img)
        img = torch.tensor(img, dtype=torch.float).permute(2, 0, 1)
        #target = torch.tensor(mask, dtype=torch.long)
        # Convert target (mask) to tensor and resize it first
        target = torch.tensor(mask, dtype=torch.float)  # Convert to tensor, shape: [1, 512, 512]

        # Resize the target tensor before creating the sample dictionary
        #target_resized = target.view(3, self.resize[0], self.resize[1])

        # Now create the sample dictionary with the resized target
        sample = {'x': img, 'y': target}

        #sample = {'x': img, 'y': target}
        return sample


In [None]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
train_dataset = MyDataset(train_df, train_transforms)
val_dataset = MyDataset(val_df, val_transforms)

BATCH_SIZE = 4
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
print(f'len train: {len(train_df)}')
print(f'len val: {len(val_df)}')
print(f'len test: {len(train_df)}')

**Attention Modules definition and Activation Function calling**

In [None]:

import torch
import torch.nn as nn

try:
    from inplace_abn import InPlaceABN
except ImportError:
    InPlaceABN = None


class Conv2dReLU(nn.Sequential):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        padding=0,
        stride=1,
        use_batchnorm=True,
    ):
        if use_batchnorm == "inplace" and InPlaceABN is None:
            raise RuntimeError(
                "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
                + "To install see: https://github.com/mapillary/inplace_abn"
            )

        conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )
        relu = nn.ReLU(inplace=True)

        if use_batchnorm == "inplace":
            bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
            relu = nn.Identity()

        elif use_batchnorm and use_batchnorm != "inplace":
            bn = nn.BatchNorm2d(out_channels)

        else:
            bn = nn.Identity()

        super(Conv2dReLU, self).__init__(conv, bn, relu)

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False)
        self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False)

    def forward(self, x):
        batch_size, channels, h, w = x.size()

        # Compute average and max across the spatial dimensions separately
        avg_out = torch.mean(x, dim=[2, 3])
        max_out, _ = torch.max(torch.max(x, dim=2)[0], dim=2)  # First max over height, then over width

        # Apply the fully connected layers to the attention mechanism
        out = F.relu(self.fc1(avg_out) + self.fc1(max_out))
        out = torch.sigmoid(self.fc2(out)).view(batch_size, channels, 1, 1)
        return x * out

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        out = torch.sigmoid(self.conv1(x_cat))
        return x * out

class GatedAxialAttentionBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        groups=8
        kernel_size=7
        stride=1
        bias=True
        # Ensure that in_channels is divisible by groups
        if in_channels % groups != 0:
            raise ValueError(f"in_channels ({in_channels}) must be divisible by groups ({groups})")

        # Calculate out_channels for channel attention based on reduction
        reduced_channels = in_channels // reduction

        super().__init__()

        # Channel-wise attention (cSE)
        self.cSE = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, reduced_channels, 1),  # reduce channels for channel attention
            nn.ReLU(inplace=True),
            nn.Conv2d(reduced_channels, in_channels, 1),  # restore channels back to in_channels
            nn.Sigmoid(),
        )

        # Spatial-wise attention (sSE) with variable kernel size
        self.sSE = nn.Sequential(
            nn.Conv2d(in_channels, 1, kernel_size=kernel_size, padding=kernel_size // 2),  # kernel_size is now adjustable
            nn.Sigmoid()
        )

        # Axial attention components with same input/output channels
        self.q_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, groups=groups)
        self.k_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, groups=groups)
        self.v_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, groups=groups)

        # Gating mechanism
        self.gate = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.size()

        # Channel-wise attention (cSE)
        channel_att = self.cSE(x)  # Shape: [B, C, 1, 1]
        x_c = x * channel_att  # Apply channel-wise attention

        # Spatial-wise attention (sSE)
        spatial_att = self.sSE(x)  # Shape: [B, 1, H, W]
        x_s = x * spatial_att  # Apply spatial-wise attention

        # Axial attention mechanism
        q = self.q_conv(x)  # Query [B, C, H, W]
        k = self.k_conv(x)  # Key [B, C, H, W]
        v = self.v_conv(x)  # Value [B, C, H, W]

        # Reshape and compute attention map
        q_flat = q.view(B, -1, H * W)  # Flatten spatial dimensions
        k_flat = k.view(B, -1, H * W)  # Flatten spatial dimensions
        v_flat = v.view(B, -1, H * W)  # Flatten spatial dimensions

        attn_map = torch.bmm(q_flat.transpose(1, 2), k_flat)  # Dot-product attention
        attn_map = torch.softmax(attn_map, dim=-1)  # Apply softmax

        out = torch.bmm(v_flat, attn_map.transpose(1, 2))  # Apply attention to value
        out = out.view(B, C, H, W)  # Reshape back to original dimensions

        # Gating mechanism
        gate = self.gate(x)  # Shape: [B, C, H, W]
        gated_out = out * gate  # Gated axial attention output

        # Combine both channel and spatial attention with gated axial attention output
        return x_c + x_s + gated_out

class LEQCA_Block(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(LEQCA_Block, self).__init__()
        self.in_channels = in_channels
        self.reduction = reduction

        # Global Average Pooling to get the channel-wise global feature
        self.global_pool = nn.AdaptiveAvgPool2d(1)

        # First fully connected layer for latent entropy approximation
        self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False)

        # Second fully connected layer to calculate the attention values
        self.fc2 = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False)

        # Sigmoid to get the attention values between 0 and 1
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch_size, num_channels, h, w = x.size()

        # Step 1: Global Average Pooling
        avg_out = self.global_pool(x)  # Shape: [batch_size, num_channels, 1, 1]

        # Step 2: First fully connected layer with latent entropy
        latent_entropy = F.relu(self.fc1(avg_out))  # Shape: [batch_size, num_channels // reduction, 1, 1]

        # Step 3: Second fully connected layer to get channel weights
        attention_weights = self.fc2(latent_entropy)  # Shape: [batch_size, num_channels, 1, 1]

        # Step 4: Apply sigmoid to the attention weights
        attention_weights = self.sigmoid(attention_weights)

        # Step 5: Multiply input features by attention weights (element-wise)
        out = x * attention_weights.expand_as(x)

        return out


class ComprehensiveAttention(nn.Module):
    def __init__(self, in_channels, reduction=16, kernel_size=7):
        super(ComprehensiveAttention, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x

class SCSEModule(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.cSE = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, 1),
            nn.Sigmoid(),
        )
        self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())

    def forward(self, x):
        return x * self.cSE(x) + x * self.sSE(x)



class ArgMax(nn.Module):
    def __init__(self, dim=None):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        return torch.argmax(x, dim=self.dim)


class Clamp(nn.Module):
    def __init__(self, min=0, max=1):
        super().__init__()
        self.min, self.max = min, max

    def forward(self, x):
        return torch.clamp(x, self.min, self.max)


class Activation(nn.Module):
    def __init__(self, name, **params):
        super().__init__()

        if name is None or name == "identity":
            self.activation = nn.Identity(**params)
        elif name == "sigmoid":
            self.activation = nn.Sigmoid()
        elif name == "softmax2d":
            self.activation = nn.Softmax(dim=1, **params)
        elif name == "softmax":
            self.activation = nn.Softmax(**params)
        elif name == "logsoftmax":
            self.activation = nn.LogSoftmax(**params)
        elif name == "tanh":
            self.activation = nn.Tanh()
        elif name == "argmax":
            self.activation = ArgMax(**params)
        elif name == "argmax2d":
            self.activation = ArgMax(dim=1, **params)
        elif name == "clamp":
            self.activation = Clamp(**params)
        elif callable(name):
            self.activation = name(**params)
        else:
            raise ValueError(
                f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/"
                f"argmax/argmax2d/clamp/None; got {name}"
            )

    def forward(self, x):
        return self.activation(x)


class Attention(nn.Module):
    def __init__(self, name, in_channels=None, out_channels=None, **params):
        super().__init__()

        if name is None:
            self.attention = nn.Identity(**params)
        elif name == "scse":
            self.attention = SCSEModule(**params)
        elif name == "comattn":
            self.attention = ComprehensiveAttention(**params)
        elif name == "sa":
            self.attention = SpatialAttention(**params)
        elif name == "ca":
            self.attention = ChannelAttention(**params)
        elif name == "leqca":
            self.attention = LEQCA_Block(**params)
        elif name == "gaa":
            # Explicitly pass in_channels and out_channels
            if in_channels is None or out_channels is None:
                raise ValueError("in_channels and out_channels must be specified for axial attention")
            # Assuming in_channels = 160
            # Change reduction to 10 or change groups to a compatible value
            # Assuming in_channels = 100
            self.attention = GatedAxialAttentionBlock(**params)

        else:
            raise ValueError("Attention {} is not implemented".format(name))

    def forward(self, x):
        return self.attention(x)


In [None]:
import torch.nn as nn
#####from .modules import Activation


class SegmentationHead(nn.Sequential):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
    ):
        conv2d = nn.Conv2d(
            in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
        )
        upsampling = (
            nn.UpsamplingBilinear2d(scale_factor=upsampling)
            if upsampling > 1
            else nn.Identity()
        )
        activation = Activation(activation)
        super().__init__(conv2d, upsampling, activation)


class ClassificationHead(nn.Sequential):
    def __init__(
        self, in_channels, classes, pooling="avg", dropout=0.2, activation=None
    ):
        if pooling not in ("max", "avg"):
            raise ValueError(
                "Pooling should be one of ('max', 'avg'), got {}.".format(pooling)
            )
        pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1)
        flatten = nn.Flatten()
        dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
        linear = nn.Linear(in_channels, classes, bias=True)
        activation = Activation(activation)
        super().__init__(pool, flatten, dropout, linear, activation)

In [None]:
import torch
from segmentation_models_pytorch.base import initialization as init
#####from . import initialization as init
#####from .hub_mixin import SMPHubMixin
from segmentation_models_pytorch.base.hub_mixin import SMPHubMixin

class SegmentationModel(torch.nn.Module, SMPHubMixin):
    def initialize(self):
        init.initialize_decoder(self.decoder)
        init.initialize_head(self.segmentation_head)
        if self.classification_head is not None:
            init.initialize_head(self.classification_head)

    def check_input_shape(self, x):
        h, w = x.shape[-2:]
        output_stride = self.encoder.output_stride
        if h % output_stride != 0 or w % output_stride != 0:
            new_h = (
                (h // output_stride + 1) * output_stride
                if h % output_stride != 0
                else h
            )
            new_w = (
                (w // output_stride + 1) * output_stride
                if w % output_stride != 0
                else w
            )
            raise RuntimeError(
                f"Wrong input shape height={h}, width={w}. Expected image height and width "
                f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})."
            )

    def forward(self, x):
        """Sequentially pass `x` trough model`s encoder, decoder and heads"""

        self.check_input_shape(x)

        features = self.encoder(x)
        decoder_output = self.decoder(*features)

        masks = self.segmentation_head(decoder_output)

        if self.classification_head is not None:
            labels = self.classification_head(features[-1])
            return masks, labels

        return masks

    @torch.no_grad()
    def predict(self, x):
        """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`

        Args:
            x: 4D torch tensor with shape (batch_size, channels, height, width)

        Return:
            prediction: 4D torch tensor with shape (batch_size, classes, height, width)

        """
        if self.training:
            self.eval()

        x = self.forward(x)

        return x

**Decoder Block Definition and Unet Decoder**

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F

#####from segmentation_models_pytorch.base import modules as md

class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        use_batchnorm=True,
        attention_type=None,
    ):
        super().__init__()
        self.conv1 = Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention1 = Attention(
            attention_type, in_channels=in_channels + skip_channels
        )
        self.conv2 = Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention2 = Attention(attention_type, in_channels=out_channels)

    def forward(self, x, skip=None):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x


class CenterBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, use_batchnorm=True):
        conv1 = Conv2dReLU(
            in_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        conv2 = Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        super().__init__(conv1, conv2)


class UnetDecoder(nn.Module):
    def __init__(
        self,
        encoder_channels,
        decoder_channels,
        n_blocks=5,
        use_batchnorm=True,
        attention_type=None,
        center=False,
    ):
        super().__init__()

        if n_blocks != len(decoder_channels):
            raise ValueError(
                "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
                    n_blocks, len(decoder_channels)
                )
            )

        # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[1:]
        # reverse channels to start from head of encoder
        encoder_channels = encoder_channels[::-1]

        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        in_channels = [head_channels] + list(decoder_channels[:-1])
        skip_channels = list(encoder_channels[1:]) + [0]
        out_channels = decoder_channels

        if center:
            self.center = CenterBlock(
                head_channels, head_channels, use_batchnorm=use_batchnorm
            )
        else:
            self.center = nn.Identity()

        # combine decoder keyword arguments
        kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
        blocks = [
            DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
            for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
        ]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, *features):
        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder

        head = features[0]
        skips = features[1:]

        x = self.center(head)
        for i, decoder_block in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            x = decoder_block(x, skip)

        return x

**Encoder Block with fused attention and Inceptionv4 backbone**

In [None]:
import torch
import torch.nn as nn
import numpy as np
import functools
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from pretrainedmodels.models.inceptionv4 import InceptionV4, pretrained_settings


# ECA Block (Efficient Channel Attention)
class ECA(nn.Module):
    def __init__(self, in_channels, kernel_size=3):
        super(ECA, self).__init__()
        self.in_channels = in_channels
        self.kernel_size = kernel_size

        # Ensure kernel size is odd
        if self.kernel_size % 2 == 0:
            self.kernel_size += 1

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(in_channels, in_channels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, groups=in_channels, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch_size, channels, height, width = x.size()

        # Squeeze operation
        y = self.avg_pool(x).view(batch_size, channels)

        # Channel-wise convolution
        y = y.unsqueeze(2)  # Add a dimension for Conv1d
        y = self.conv(y)
        y = y.view(batch_size, channels, 1)

        # Apply sigmoid activation
        y = self.sigmoid(y)

        # Scale the input
        return x * y.view(batch_size, channels, 1, 1)



# LEQCA Block (Local Equivariant and Quality Channel Attention)
class LEQCA(nn.Module):
    def __init__(self, in_channels):
        super(LEQCA, self).__init__()
        self.local_pool = nn.AdaptiveAvgPool2d(2)  # Keep local pooling
        self.fc1 = nn.Conv2d(in_channels, in_channels // 16, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(in_channels // 16, in_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        scale = self.local_pool(x)
        scale = self.fc1(scale)
        scale = self.relu(scale)
        scale = self.fc2(scale)
        scale = self.sigmoid(scale)
        scale = F.interpolate(scale, size=x.shape[2:], mode='bilinear', align_corners=False)  # Upsample to match input
        return x * scale



# Combined Ensemble Attention Block (ECA + LEQCA)
class EnsembleAttention(nn.Module):
    def __init__(self, in_channels):
        super(EnsembleAttention, self).__init__()
        self.eca_block = ECA(in_channels)
        self.leqca_block = LEQCA(in_channels)

    def forward(self, x):
        eca_out = self.eca_block(x)
        leqca_out = self.leqca_block(x)
        ensemble_out = (eca_out + leqca_out) / 2  # Averaging outputs
        return ensemble_out



# Encoder Mixin
class EncoderMixin:
    _output_stride = 32

    @property
    def out_channels(self):
        return self._out_channels[:self._depth + 1]

    @property
    def output_stride(self):
        return min(self._output_stride, 2 ** self._depth)

    def set_in_channels(self, in_channels, pretrained=True):
        if in_channels == 3:
            return
        self._in_channels = in_channels
        if self._out_channels[0] == 3:
            self._out_channels = tuple([in_channels] + list(self._out_channels)[1:])
        patch_first_conv(self, new_in_channels=in_channels, pretrained=pretrained)

    def make_dilated(self, output_stride):
        if output_stride == 16:
            stage_list, dilation_list = [5], [2]
        elif output_stride == 8:
            stage_list, dilation_list = [4, 5], [2, 4]
        else:
            raise ValueError("Output stride must be 16 or 8.")
        self._output_stride = output_stride
        stages = self.get_stages()
        for stage_idx, dilation_rate in zip(stage_list, dilation_list):
            replace_strides_with_dilation(stages[stage_idx], dilation_rate)


# InceptionV4 Encoder with MSCA + LEQCA Ensemble Attention
class InceptionV4Encoder(InceptionV4, EncoderMixin):
    def __init__(self, stage_idxs, out_channels, depth=5, **kwargs):
        super().__init__(**kwargs)
        self._stage_idxs = stage_idxs
        self._out_channels = out_channels
        self._depth = depth
        self._in_channels = 3
        self.attention_blocks = nn.ModuleList(
            [EnsembleAttention(out_channels[i]) for i in range(1, depth + 1)]
        )
        for m in self.modules():
            if isinstance(m, nn.Conv2d) and m.kernel_size == (3, 3):
                m.padding = (1, 1)
            if isinstance(m, nn.MaxPool2d):
                m.padding = (1, 1)
        del self.last_linear

    def get_stages(self):
        """Defines the stages based on the stage indexes provided."""
        return [
            nn.Identity(),  # First stage is a placeholder (no processing)
            self.features[:self._stage_idxs[0]],  # First block of layers
            self.features[self._stage_idxs[0]:self._stage_idxs[1]],  # Second block of layers
            self.features[self._stage_idxs[1]:self._stage_idxs[2]],  # Third block of layers
            self.features[self._stage_idxs[2]:self._stage_idxs[3]],  # Fourth block of layers
            self.features[self._stage_idxs[3]:],  # Fifth block of layers
        ]

    def forward(self, x):
        features = []
        for i in range(self._depth + 1):
            x = self.get_stages()[i](x)
            if i > 0:
                x = self.attention_blocks[i - 1](x)  # Apply ensemble attention (MSCA + LEQCA)
            features.append(x)
        return features

    def load_state_dict(self, state_dict, **kwargs):
        state_dict.pop("last_linear.bias", None)
        state_dict.pop("last_linear.weight", None)
        super().load_state_dict(state_dict, **kwargs)


# Utility Functions
def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True):
    for module in model.modules():
        if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels:
            break
    weight = module.weight.detach()
    module.in_channels = new_in_channels
    if not pretrained:
        module.weight = nn.parameter.Parameter(
            torch.Tensor(module.out_channels, new_in_channels // module.groups, *module.kernel_size))
        module.reset_parameters()
    elif new_in_channels == 1:
        module.weight = nn.parameter.Parameter(weight.sum(1, keepdim=True))
    else:
        new_weight = torch.Tensor(module.out_channels, new_in_channels // module.groups, *module.kernel_size)
        for i in range(new_in_channels):
            new_weight[:, i] = weight[:, i % default_in_channels]
        new_weight *= (default_in_channels / new_in_channels)
        module.weight = nn.parameter.Parameter(new_weight)


def replace_strides_with_dilation(module, dilation_rate):
    for mod in module.modules():
        if isinstance(mod, nn.Conv2d):
            mod.stride = (1, 1)
            mod.dilation = (dilation_rate, dilation_rate)
            kh, kw = mod.kernel_size
            mod.padding = ((kh // 2) * dilation_rate, (kw // 2) * dilation_rate)
            if hasattr(mod, "static_padding"):
                mod.static_padding = nn.Identity()


# Preprocessing Functions
def preprocess_input(x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs):
    if input_space == "BGR":
        x = x[..., ::-1].copy()
    if input_range is not None and x.max() > 1 and input_range[1] == 1:
        x = x / 255.0
    if mean is not None:
        x -= np.array(mean)
    if std is not None:
        x /= np.array(std)
    return x


# Encoder Retrieval Functions
inceptionv4_encoders = {
    "inceptionv4": {
        "encoder": InceptionV4Encoder,
        "pretrained_settings": pretrained_settings["inceptionv4"],
        "params": {
            "stage_idxs": (3, 5, 9, 15),
            "out_channels": (3, 64, 192, 384, 1024, 1536),
            "num_classes": 1001,
        },
    }
}


def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
    if name not in inceptionv4_encoders:
        raise KeyError(f"Unsupported encoder `{name}`, supported: {list(inceptionv4_encoders.keys())}")

    Encoder = inceptionv4_encoders[name]["encoder"]
    params = inceptionv4_encoders[name]["params"]
    params.update(depth=depth)
    encoder = Encoder(**params)

    if weights:
        settings = inceptionv4_encoders[name]["pretrained_settings"][weights]
        pretrained_dict = model_zoo.load_url(settings["url"])
        model_dict = encoder.state_dict()

        # Filter out attention block weights, which are not present in pre-trained weights
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and "attention_blocks" not in k}
        model_dict.update(pretrained_dict)

        encoder.load_state_dict(model_dict)

    encoder.set_in_channels(in_channels, pretrained=weights is not None)
    if output_stride != 32:
        encoder.make_dilated(output_stride=output_stride)

    return encoder


# Preprocessing functions
def get_preprocessing_params(encoder_name, pretrained="imagenet"):
    settings = inceptionv4_encoders[encoder_name]["pretrained_settings"].get(pretrained)
    return {
        "input_space": settings.get("input_space", "RGB"),
        "input_range": list(settings.get("input_range", [0, 1])),
        "mean": list(settings["mean"]),
        "std": list(settings["std"]),
    }


def get_preprocessing_fn(encoder_name, pretrained="imagenet"):
    params = get_preprocessing_params(encoder_name, pretrained)
    return functools.partial(preprocess_input, **params)


In [None]:
from typing import Optional, Union, List
"""
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
    SegmentationModel,
    SegmentationHead,
    ClassificationHead,
)
"""
#####from .decoder import UnetDecoder

class Unet(SegmentationModel):
    """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder*
    and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial
    resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation*
    for fusing decoder blocks with skip connections.

    Args:
        encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
            to extract features of different spatial resolution
        encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
            two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
            with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
            Default is 5
        encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
            other pretrained weights (see table with available weights for each encoder_name)
        decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
            Length of the list should be the same as **encoder_depth**
        decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
            is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
            Available options are **True, False, "inplace"**
        decoder_attention_type: Attention module used in decoder of the model. Available options are
            **None** and **scse** (https://arxiv.org/abs/1808.08127).
        in_channels: A number of input channels for the model, default is 3 (RGB images)
        classes: A number of classes for output mask (or you can think as a number of channels of output mask)
        activation: An activation function to apply after the final convolution layer.
            Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
                **callable** and **None**.
            Default is **None**
        aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
            on top of encoder if **aux_params** is not **None** (default). Supported params:
                - classes (int): A number of classes
                - pooling (str): One of "max", "avg". Default is "avg"
                - dropout (float): Dropout factor in [0, 1)
                - activation (str): An activation function to apply "sigmoid"/"softmax"
                    (could be **None** to return logits)

    Returns:
        ``torch.nn.Module``: Unet

    .. _Unet:
        https://arxiv.org/abs/1505.04597

    """

    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_use_batchnorm: bool = True,
        decoder_channels: List[int] = (256, 128, 64, 32, 16),
        decoder_attention_type: Optional[str] = None,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, callable]] = None,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )

        self.decoder = UnetDecoder(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
            n_blocks=encoder_depth,
            use_batchnorm=decoder_use_batchnorm,
            center=True if encoder_name.startswith("vgg") else False,
            attention_type=decoder_attention_type,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=3,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params
            )
        else:
            self.classification_head = None

        self.name = "u-{}".format(encoder_name)
        self.initialize()


**Training**

In [None]:
import torch
from torchinfo import summary
#from model import Unet  # Assuming your model is saved in model.py

# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class_size = 1
model = Unet(
    encoder_name="inceptionv4",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",             # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=class_size                      # model output channels (number of classes in your dataset)
).to(device)

# Get the summary
input_size = (1, 3, 512, 512)  # Batch size of 1, 3 channels, 512x512 image
summary(model, input_size=input_size, device=device)

In [None]:
# Dice score implementation
def dice_score(pred, target, smooth=1e-6):
    pred = pred > 0.5  # Threshold predictions
    target = target > 0.5
    intersection = (pred * target).sum().float()
    dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
    return dice

def train(dataloader, model, loss_fn, optimizer, lr_scheduler):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.train()
    epoch_loss = 0
    epoch_iou_score = 0
    epoch_dice_score = 0

    for batch_i, batch in enumerate(dataloader):
        x, y = batch['x'].to(device), batch['y'].to(device)

        optimizer.zero_grad()
        pred = model(x)
        loss = loss_fn(pred, y)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        pred = torch.sigmoid(pred)
        pred = pred.squeeze(dim=1)
        y = y.round().long()

        # Calculate Dice score
        dice = dice_score(pred, y)
        epoch_dice_score += dice.item()

        # Calculate IoU score
        tp, fp, fn, tn = smp.metrics.get_stats(pred, y, mode='binary', threshold=0.5)
        iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro").item()
        epoch_iou_score += iou

        lr_scheduler.step()

    return epoch_loss / num_batches, epoch_dice_score / num_batches, epoch_iou_score / num_batches

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    epoch_loss = 0
    epoch_iou_score = 0
    epoch_dice_score = 0

    with torch.no_grad():
        for batch_i, batch in enumerate(dataloader):
            x, y = batch['x'].to(device), batch['y'].to(device)

            pred = model(x)
            loss = loss_fn(pred, y)

            epoch_loss += loss.item()

            pred = torch.sigmoid(pred)
            pred = pred.squeeze(dim=1)
            y = y.round().long()

            # Calculate Dice score
            dice = dice_score(pred, y)
            epoch_dice_score += dice.item()

            # Calculate IoU score
            tp, fp, fn, tn = smp.metrics.get_stats(pred, y, mode='binary', threshold=0.5)
            iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro").item()
            epoch_iou_score += iou

    return epoch_loss / num_batches, epoch_dice_score / num_batches, epoch_iou_score / num_batches


In [None]:
EPOCHS = 50
logs = {
    'train_loss': [], 'val_loss': [],
    'train_iou_score': [], 'val_iou_score': [],
    'train_dice_score': [], 'val_dice_score': []
}


if os.path.exists('checkpoints') == False:
    os.mkdir("checkpoints")

loss_fn = smp.losses.DiceLoss(mode="binary")
#loss_fn = smp.losses.FocalLoss(mode="binary")

#loss_fn = DiceLoss(mode="binary")
#loss_fn = FocalLoss(mode="binary")

learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size = 100, gamma=0.1)

# Earlystopping
patience = 5
counter = 0
best_loss = np.inf

model.to(device)
for epoch in tqdm(range(EPOCHS)):
    train_loss, train_dice_score, train_iou_score = train(train_loader, model, loss_fn, optimizer, step_lr_scheduler)
    val_loss, val_dice_score, val_iou_score = test(val_loader, model, loss_fn)

    logs['train_loss'].append(train_loss)
    logs['val_loss'].append(val_loss)
    logs['train_dice_score'].append(train_dice_score)
    logs['val_dice_score'].append(val_dice_score)
    logs['train_iou_score'].append(train_iou_score)
    logs['val_iou_score'].append(val_iou_score)

    print(f'EPOCH: {str(epoch+1).zfill(3)} | '
          f'train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f} | '
          f'train_dice_score: {train_dice_score:.3f}, val_dice_score: {val_dice_score:.3f} | '
          f'train_iou_score: {train_iou_score:.3f}, val_iou_score: {val_iou_score:.3f} | '
          f'lr: {optimizer.param_groups[0]["lr"]}')

    # Save model
    torch.save(model.state_dict(), "checkpoints/last.pth")
    if val_loss < best_loss:
        counter = 0
        best_loss = val_loss
        torch.save(model.state_dict(), "checkpoints/best.pth")
    else:
        counter += 1

    # Early stopping
    if counter >= patience:
        print("Early stopping!")
        break


In [None]:
#Plot
plt.figure(figsize=(15,5))
plt.subplot(1,2,1)
plt.plot(logs['train_loss'],label='Train_Loss')
plt.plot(logs['val_loss'],label='Validation_Loss')
plt.title('Train_Loss & Validation_Loss',fontsize=20)
plt.legend()
plt.subplot(1,2,2)
plt.plot(logs['train_iou_score'],label='Train_Iou_Score')
plt.plot(logs['val_iou_score'],label='Validation_Iou_Score')
plt.title('Train_Iou_score & Validation_Iou_score',fontsize=20)
plt.legend()

**Testing**

In [None]:
test_transforms = A.Compose([
    A.Resize(512, 512),
    # ToTensorV2(),
])


In [None]:
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe,transforms_=None):
        self.df = dataframe
        self.transforms_ = transforms_
        self.pre_normalize = v2.Compose([
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.resize = [512, 512]
        self.class_size = 2

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        img = cv2.cvtColor(cv2.imread(self.df.iloc[index]['images']), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.df.iloc[index]['masks'],cv2.IMREAD_GRAYSCALE)
        aug = self.transforms_(image=img, mask=mask)
        img, mask = aug['image'], aug['mask']
        img_view = np.copy(img)
        img = img/255
        # img = self.pre_normalize(img)
        img = torch.tensor(img, dtype=torch.float).permute(2, 0, 1)
        mask_view = np.copy(mask)
        mask = np.where(mask<127, 0, 1).astype(np.int16)
        target = torch.tensor(mask, dtype=torch.long)
        sample = {'x': img, 'y': target, 'img_view':img_view, 'mask_view':mask_view}
        return sample

test_dataset = TestDataset(test_df, test_transforms)
test_loader = DataLoader(test_dataset, batch_size=4)


In [None]:
model.load_state_dict(torch.load('checkpoints/best.pth'))
model.to(device)

def get_metrics(model, dataloader, threshold):
    IoU_score, precision, f1_score, recall, acc, dice_score = 0, 0, 0, 0, 0, 0
    batches = 0
    model.eval()
    with torch.no_grad():
        for batch_i, batch in enumerate(dataloader):
            x, y = batch['x'].to(device), batch['y'].to(device)  # move data to GPU
            pred = model(x)
            pred = pred.squeeze(dim=1)
            pred = torch.sigmoid(pred)
            y = y.round().long()

            # Calculate stats
            tp, fp, fn, tn = smp.metrics.get_stats(pred, y, mode='binary', threshold=threshold)

            # Calculate various metrics
            batch_iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro").item()
            batch_acc = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro").item()
            batch_f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro").item()
            batch_recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro").item()
            batch_precision = smp.metrics.precision(tp, fp, fn, tn, reduction="micro").item()

            # Manually calculate Dice score
            dice = (2 * tp.sum()) / (2 * tp.sum() + fp.sum() + fn.sum())
            batch_dice_score = dice.item()

            # Aggregate the results
            IoU_score += batch_iou_score
            acc += batch_acc
            f1_score += batch_f1_score
            recall += batch_recall
            precision += batch_precision
            dice_score += batch_dice_score
            batches += 1

    # Compute average metrics over all batches
    IoU_score = round(IoU_score / batches, 3)
    precision = round(precision / batches, 3)
    f1_score = round(f1_score / batches, 3)
    recall = round(recall / batches, 3)
    acc = round(acc / batches, 3)
    dice_score = round(dice_score / batches, 3)

    sample = {
        'iou': IoU_score,
        'pre': precision,
        'fi': f1_score,
        're': recall,
        'acc': acc,
        'dice': dice_score
    }
    return sample

# Evaluate the model for different thresholds
threshold_list = [0.3, 0.4, 0.5, 0.6, 0.7]
for threshold in threshold_list:
    sample = get_metrics(model, test_loader, threshold)
    print(f"Threshold: {threshold:.2f} \
    IoU Score: {sample['iou']:.3f} \
    Precision: {sample['pre']:.3f} \
    F1 Score: {sample['fi']:.3f} \
    Recall: {sample['re']:.3f} \
    Accuracy: {sample['acc']:.3f} \
    Dice Score: {sample['dice']:.3f}")


In [None]:
#Plot
model.load_state_dict(torch.load("checkpoints/best.pth"))
model.to(device)
show_imgs = 2
random_list = np.random.choice(len(test_dataset), show_imgs, replace=False)

for i in range(show_imgs):
    idx = random_list[i]
    sample = test_dataset[idx]
    pred = model(sample['x'].to('cuda', dtype=torch.float32).unsqueeze(0))
    pred = torch.sigmoid(pred).squeeze(0).squeeze(0)
    pred = pred.data.cpu().numpy()
    pred = np.where(pred<0.5, 0, 1).astype(np.int16)
    pred_img = Image.fromarray(np.uint8(pred), 'L')

    img_view = sample['img_view']
    img_view = Image.fromarray(img_view, 'RGB')

    mask_view = sample['mask_view']
    mask_view = Image.fromarray(mask_view, 'L')

    f, axarr = plt.subplots(1, 3)
    axarr[0].imshow(img_view)
    axarr[0].set_title('Input')
    axarr[1].imshow(pred_img)
    axarr[1].set_title('pred')
    axarr[2].imshow(mask_view)
    axarr[2].set_title('gt')
    plt.show()

In [None]:
#GradCam visualisation
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms

# Step 1: Load the model (ensure it's defined as per your architecture)
model.load_state_dict(torch.load('best.pth'))
model.eval()

# Step 2: Load and preprocess the input image
def load_image(image_path):
    # Load the image
    image = Image.open(image_path).convert("RGB")  # Ensure it's RGB

    # Define the transformations
    preprocess = transforms.Compose([
        transforms.Resize((512, 512)),  # Resize to match model input
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize
    ])

    # Apply transformations
    image_tensor = preprocess(image)
    image_tensor = image_tensor.unsqueeze(0)  # Add batch dimension
    return image_tensor

# Load your input image (replace with your actual image path)
input_image_path = 'image_path'  # Specify the path to your image
input_tensor = load_image(input_image_path)
input_tensor = input_tensor.to('cuda')

# Step 3: Set up the hook and perform Grad-CAM
def get_activation(layer):
    def hook(model, input, output):
        activation.append(output)
    return hook

# Choose the target layer for Grad-CAM
target_layer = _layer_name_  # Modify this to your architecture

# Register the hook
activation = []
hook = target_layer.register_forward_hook(get_activation(target_layer))

# Forward pass through the model
with torch.no_grad():
    output = model(input_tensor)

# Unregister the hook
hook.remove()

# Now you can access the activation
activation_map = activation[0]  # Get the activation from the hook

# Assuming the output shape is [N, C, H, W]
# For visualization, we need to take the average across the channels
activation_map = activation_map.mean(dim=1, keepdim=True)  # Shape: [N, 1, H, W]
activation_map = F.relu(activation_map)  # ReLU activation

# Normalize the activation map for visualization
activation_map = (activation_map - activation_map.min()) / (activation_map.max() - activation_map.min())
activation_map = activation_map.squeeze().cpu().numpy()  # Convert to numpy for plotting

# Plot the activation map
plt.imshow(activation_map, cmap='inferno')  # or any other colormap
plt.axis('off')  # Turn off axis
#plt.colorbar()  # Show colorbar for reference
#plt.title('Grad-CAM Activation Map')
plt.show()

**With the help of the Segmentation Models Pytorch GitHub repository by Pavel Iakubovskii, the code implementation for this research was partially adapted [@misc{Iakubovskii:2019}]{\url{https://github.com/qubvel/segmentation_models.pytorch}}.**