[![Fixel Algorithms](https://fixelalgorithms.co/images/CCExt.png)](https://fixelalgorithms.gitlab.io)

# Deep Learning Methods

## Deep Learning - Image to Image - Image Segmentation with U-Net

> Notebook by:
> - Royi Avital RoyiAvital@fixelalgorithms.com

## Revision History

| Version | Date       | User        |Content / Changes                                                   |
|---------|------------|-------------|--------------------------------------------------------------------|
| 1.0.002 | 02/02/2026 | Royi Avital | Expanded the information on the Inverted Residual Block            |
| 1.0.001 | 01/02/2026 | Royi Avital | Simplified the classification head                                 |
| 1.0.000 | 21/01/2026 | Royi Avital | First version                                                      |

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/FixelAlgorithmsTeam/FixelCourses/blob/master/AIProgram/2024_02/0099DeepLearningObjectDetection.ipynb)

In [None]:
# Import Packages

# General Tools
import numpy as np
import scipy as sp
import pandas as pd

# Image Processing and Computer Vision
import skimage as ski

# Machine Learning
from sklearn.model_selection import train_test_split

# Deep Learning
import torch
import torch.nn            as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

import torchinfo

from torchmetrics.functional.segmentation import mean_iou
from torchmetrics.functional.classification import multiclass_f1_score

import torchvista

import torchvision
from torchvision.io import decode_image
from torchvision.transforms import v2 as TorchVisionTrns

# Miscellaneous
import os
from platform import python_version
import random
import shutil
from zipfile import ZipFile

# Typing
from typing import Callable, Dict, Generator, List, Optional, Self, Set, Tuple, Union
from numpy.typing import NDArray
from torch import Tensor

# Visualization
import matplotlib.pyplot as plt

# Jupyter
from IPython import get_ipython

## Notations

* <font color='red'>(**?**)</font> Question to answer interactively.
* <font color='blue'>(**!**)</font> Simple task to add code for the notebook.
* <font color='green'>(**@**)</font> Optional / Extra self practice.
* <font color='brown'>(**#**)</font> Note / Useful resource / Food for thought.

Code Notations:

```python
someVar    = 2; #<! Notation for a variable
vVector    = np.random.rand(4) #<! Notation for 1D array
mMatrix    = np.random.rand(4, 3) #<! Notation for 2D array
tTensor    = np.random.rand(4, 3, 2, 3) #<! Notation for nD array (Tensor)
tuTuple    = (1, 2, 3) #<! Notation for a tuple
lList      = [1, 2, 3] #<! Notation for a list
dDict      = {1: 3, 2: 2, 3: 1} #<! Notation for a dictionary
oObj       = MyClass() #<! Notation for an object
dfData     = pd.DataFrame() #<! Notation for a data frame
dsData     = pd.Series() #<! Notation for a series
hObj       = plt.Axes() #<! Notation for an object / handler / function handler
```

### Code Exercise

 - Single line fill

```python
valToFill = ???
```

 - Multi Line to Fill (At least one)

```python
# You need to start writing
?????
```

 - Section to Fill

```python
#===========================Fill This===========================#
# 1. Explanation about what to do.
# !! Remarks to follow / take under consideration.
mX = ???

?????
#===============================================================#
```

In [None]:
# Configuration
# %matplotlib inline

seedNum = 512
np.random.seed(seedNum)
random.seed(seedNum)

# Matplotlib default color palette
lMatPltLibclr = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
# sns.set_theme() #>! Apply SeaBorn theme

runInGoogleColab = 'google.colab' in str(get_ipython())

# Improve performance by benchmarking
torch.backends.cudnn.benchmark = True

# Reproducibility (Per PyTorch Version on the same device)
# torch.manual_seed(seedNum)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark     = False #<! Makes things slower

In [None]:
# Constants

FIG_SIZE_DEF    = (8, 8)
ELM_SIZE_DEF    = 50
CLASS_COLOR     = ('b', 'r')
EDGE_COLOR      = 'k'
MARKER_SIZE_DEF = 10
LINE_WIDTH_DEF  = 2

PROJECT_NAME     = 'FixelCourses'
DATA_FOLDER_NAME = 'DataSets'
BASE_FOLDER_PATH = os.getcwd()[:(len(os.getcwd()) - (os.getcwd()[::-1].lower().find(PROJECT_NAME.lower()[::-1])))]
DATA_FOLDER_PATH = os.path.join(BASE_FOLDER_PATH, DATA_FOLDER_NAME)

TENSOR_BOARD_BASE   = 'TB'

In [None]:
# Download Auxiliary Modules for Google Colab
if runInGoogleColab:
    !wget https://raw.githubusercontent.com/FixelAlgorithmsTeam/FixelCourses/master/AIProgram/2024_02/DataManipulation.py
    !wget https://raw.githubusercontent.com/FixelAlgorithmsTeam/FixelCourses/master/AIProgram/2024_02/DataVisualization.py
    !wget https://raw.githubusercontent.com/FixelAlgorithmsTeam/FixelCourses/master/AIProgram/2024_02/DeepLearningPyTorch.py

In [None]:
# Courses Packages

from DataManipulation import DownloadKaggleDataset
from DataVisualization import PlotLabelsHistogram
from DeepLearningPyTorch import GenDataLoaders, GetBatch, TrainModel

In [None]:
# General Auxiliary Functions

def TensorImageNumpy( tZ: Tensor ) -> NDArray:
    """
    Converts a PyTorch Tensor to a Numpy Array.
    """
    mZ = tZ.squeeze()
    mX = mZ.detach().cpu().numpy()

    return mX

def TensorImgNumpy( tI: Tensor ) -> NDArray:
    """
    Converts a PyTorch Tensor Image to a Numpy Array Image.
    """
    
    mI = TensorImageNumpy(tI.permute(1, 2, 0))

    return mI

## Image to Image Models

_Image to Image_ models (Also called `Pix2Pix` / _Image to Image Translation_) are models which transform the input image into an output image.  
There many applications of such models:

 - Feature Extraction    
   Extract edges, corners, Mask, etc...
 - Color Adjustment  
   Apply RGB to Gray / Gray to RGB transformations.  
   Adjust White Balance / Tonal Curve.
 - Styling  
   Style transfer, style application.
 - Modality Transformation  
   RGB to IR / RGB to SAR and vice versa.  
   Image to Map.
 - Deconvolution  
   Super Resolution, Deblurring, Denoising.

The applications drove many developments in the Deep Learning field:

 - Architectures  
   U-Net, ViT.
 - Training Approach / Loss  
   Variational Auto Encoder (VAE), Generative Adversarial Network (GAN), Diffusion Models.

### Image Segmentation

_Image Segmentation_ is one of the most popular applications in the _Image to Image_ paradigm.  
It is composed of 3 types of segmentation:

| Type     	| Example                              	| Properties                                                                     |
|----------	|--------------------------------------	|------------------------------------------------------------------------------- |
| Semantic 	| ![](https://i.imgur.com/cnerHbN.png) 	| Classify each pixel by the class it represents                                 |
| Instance 	| ![](https://i.imgur.com/qbAoMk3.png) 	| Classify pixels of countable objects, each object by its own label             |
| Panoptic 	| ![](https://i.imgur.com/J99Rzpu.png) 	| Classify instances of countable objects and pixels of uncountable objects 	 |

This notebook demonstrates:
 - Working with real world data.
 - Augmentation of images and masks.
 - Building a model for _Image Segmentation_ based on U-Net.
 - Training a model with a composed objective.
 - Loading optimal weights for inference.

</br>

* <font color='brown'>(**#**)</font> The U-Net model was original created in the context of medical imaging: [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597).

In [None]:
# Parameters

# Data
kggleUser       = 'girish17019'
kaggleDataset   = 'mobile-phone-defect-segmentation-dataset'
tmpFileName     = 'TMP.zip'
dCls            = {0: 'None', 1: 'Scratch', 2: 'Stain', 3: 'Oil'} #<! Defect Classes
lCls            = list(dCls.keys())
numCls          = len(lCls) #<! Number of classes
tuImgSize       = (320, 576) #<! Factor of 32 and similar aspect ratio to original images (1080, 1920)

# Model
numFiltersBase = 16
weightSeg      = 3.0
weightCls      = 1.0
segThr         = 0.5

# Training
trainSampleRatio = 0.9 #<! Ratio of training images to total images
valSampleRatio   = 1 - trainSampleRatio #<! Ratio of validation images to total images

batchSize   = 8
numWorkers  = 0 #<! Number of workers
numEpochs   = 35

# Optimizer
ηOpt        = 1e-4        #<! Optimizer learning rate (Has no effect if using scheduler)
tuβ         = (0.9, 0.99) #<! Betas for ADAM like Optimizer
weightDecay = 5e-5        #<! Optimizer weight decay
ηSch        = 7.5e-5      #<! Scheduler learning rate

# Visualization
numImg = 3

## Generate / Load Data

The data is the [Mobile Phone Screen Surface Defect Segmentation Dataset (MSD)](https://github.com/jianzhang96/MSD).  

* <font color='brown'>(**#**)</font> The data is downloaded by the Kaggle version of the dataset: [Kaggle - Mobile Phone Screen Surface Defect Segmentation Dataset](https://www.kaggle.com/datasets/girish17019/mobile-phone-defect-segmentation-dataset).

In [None]:
# Download and Parse the Dataset

DownloadKaggleDataset(kggleUser, kaggleDataset, tmpFileName)

In [None]:
# Extract Files
# Will create:
# - `FixelCourses/DataSets/MSD/Images` - Contains all images.
# - `FixelCourses/DataSets/MSD/Masks` - Contains all masks.
# Files are: `<Class>_<ImgIdxCls>.png`

datasetFolderPath     = os.path.join(DATA_FOLDER_PATH, 'MSD')
imgDatasetFolderPath  = os.path.join(datasetFolderPath, 'Images')
maskDatasetFolderPath = os.path.join(datasetFolderPath, 'Masks')

# Delete existing folders if any
if os.path.exists(datasetFolderPath):
    shutil.rmtree(datasetFolderPath)

os.makedirs(imgDatasetFolderPath, exist_ok = True)
os.makedirs(maskDatasetFolderPath, exist_ok = True)

with ZipFile(tmpFileName, 'r') as zipFile:
    lFiles = zipFile.namelist()
    for file in lFiles:
        fileName = os.path.basename(file)
        # Make the class explicit in the name
        fileName = fileName.replace('Scr', 'Scratch')
        fileName = fileName.replace('Sta', 'Stain')
        if file.startswith('oil/') or file.startswith('scratch/') or file.startswith('stain/'):
            # Using `extract()` keeps the folder structure
            with open(os.path.join(imgDatasetFolderPath, fileName), 'wb') as hFile:
                hFile.write(zipFile.read(file))
        elif file.startswith('good/'):
            # Do not have masks, create an empty mask
            fileName = 'None_' + fileName
            with open(os.path.join(imgDatasetFolderPath, fileName), 'wb') as hFile:
                hFile.write(zipFile.read(file))
            
            mI = ski.io.imread(os.path.join(imgDatasetFolderPath, fileName))
            mI = np.zeros_like(mI)
            ski.io.imsave(os.path.join(maskDatasetFolderPath, fileName), mI, check_contrast = False)
        elif file.startswith('ground_truth_1/') or file.startswith('ground_truth_2/'):
            with open(os.path.join(maskDatasetFolderPath, fileName), 'wb') as hFile:
                hFile.write(zipFile.read(file))

* <font color='red'>(**?**)</font> Go through files using the OS's image viewer. Specifically the mask images. What can you say about the classes per image?

<!-- Each image mask contain only a single class. Hence predicting the mask class can be done in global manner and not in a per pixel manner. -->

### DataSet

Generate a `DataSet` class as a loader of the data.

* <font color='brown'>(**#**)</font> Since each image contains a single class, the mask can be represented as a binary mask for the defect area and a global class label by a classification head.
* <font color='brown'>(**#**)</font> There images with no defects. Hence a `None` class should be added.

In [None]:
# The Dataset Class

class MSDDataset(Dataset):
    def __init__( self, imgFolderPath: str, maskFolderPath: str, dCls: Dict[str, int], /, *, hTrns: Optional[Callable] = None, lImgFormats: List = ['jpg', 'jpeg', 'png'] ) -> None:
        """
        Mobile Phone Defect Segmentation Dataset.

        Parameters
        ----------
        imgFolderPath : str
            Path to the folder containing images.
        maskFolderPath : str
            Path to the folder containing masks.
        hTrns : Optional[Callable], optional
            Transform to be applied on a sample, by default None. Must be TorchVision v2 transforms compatible.
        """
        super().__init__()

        lImgFiles = os.listdir(imgFolderPath)
        lImgFiles = [ff for ff in lImgFiles if (ff.split('.')[-1].lower() in lImgFormats) and (os.path.isfile(os.path.join(maskFolderPath, ff.split('.')[0] + '.png')))]
        lMskFiles = [ff.split('.')[0] + '.png' for ff in lImgFiles] 
        numFiles  = len(lImgFiles)

        dClsIdx = {v: k for k, v in dCls.items()}

        self._imgFolderPath  = imgFolderPath
        self._maskFolderPath = maskFolderPath
        self._dCls           = dCls
        self._dClsIdx        = dClsIdx
        self._hTrns          = hTrns

        # Must be after setting `self._dCls`
        lImgCls   = [self._ParseCls(ff) for ff in lImgFiles]

        self._lImgFiles = lImgFiles
        self._lMskFiles = lMskFiles
        self._lImgCls   = lImgCls
        self._numFiles  = numFiles
    
    def __len__( self ) -> int:
        """
        Returns the number of samples in the dataset.
        """
        
        return self._numFiles
    
    def __getitem__( self, idx: int ) -> Tuple[Tensor, Tensor]:
        """
        Returns the sample at index `idx`.

        Parameters
        ----------
        idx : int
            Index of the sample to be fetched.

        Returns
        -------
        Tuple[Tensor, int]
            A tuple containing the image tensor and the class label.
        """
        imgFile = self._lImgFiles[idx]
        mskFile = self._lMskFiles[idx]

        imgPath = os.path.join(self._imgFolderPath, imgFile)
        mskPath = os.path.join(self._maskFolderPath, mskFile)

        # TorchVision's `decode_image()` returns a tensor (C x H x W)
        tI = decode_image(imgPath, mode = 'RGB') #<! Guarantees 3 channels
        tM = decode_image(mskPath, mode = 'RGB')   #<! Single channel
        clsLbl = self._lImgCls[idx]

        # Binary Mask where defect area is 1, else 0
        tB = (tM.sum(dim = 0) > 0).to(torch.long) #<! (H x W)
        tM = clsLbl * tB #<! (H x W)

        tI = torchvision.tv_tensors.Image(tI, dtype = torch.uint8)
        tM = torchvision.tv_tensors.Mask(tM, dtype = torch.long)

        if self._hTrns:
            tI, tM = self._hTrns(tI, tM)

        return tI, tM
    
    def _ParseCls( self, fileName: str ) -> int:
        """
        Parses the class label from the file name.
        Parameters
        ----------
        fileName : str
            File name of the image.
        Returns
        -------
        int
            Class label.
        """
        clsStr = fileName.split('_')[0]
        clsLbl = self._dClsIdx[clsStr]

        return clsLbl
    
    def SetTransforms( self, hTrns: Callable ) -> None:
        """
        Sets the transforms to be applied on a sample.

        Parameters
        ----------
        hTrns : Callable
            Transform to be applied on a sample. Must be TorchVision v2 transforms compatible.
        """
        self._hTrns = hTrns
    
    def GetClasses( self ) -> Dict[str, int]:
        """
        Returns the class dictionary.

        Returns
        -------
        Dict[str, int]
            Class dictionary.
        """
        
        return self._dCls.copy()
    
    def GetLabels( self ) -> List[int]:
        """
        Returns the list of labels.

        Returns
        -------
        List[int]
            List of labels.
        """
        
        return self._lImgCls.copy()

In [None]:
# The Dataset

dsData     = MSDDataset(imgDatasetFolderPath, maskDatasetFolderPath, dCls, hTrns = None)
numSamples = len(dsData)

print(f'Number of samples in the dataset: {numSamples}')

### Plot the Data

In [None]:
# Plot Random Samples from the Dataset

rndIdx = random.randint(0, numSamples - 1)
tI, tM = dsData[rndIdx]

tI = TensorImgNumpy(tI)
tM = TensorImageNumpy(tM)

hF, vHa = plt.subplots(1, 2, figsize = (12, 6))
vHa = vHa.flat

hA = vHa[0]
hA.imshow(tI)
hA.set_title('Image')
hA.axis('off');

hA = vHa[1]
hA.imshow(tM, vmin = 0, vmax = numCls - 1, cmap = 'jet')
hA.set_title(f'Mask, Class: {dCls[tM.max()]}')
hA.axis('off');

In [None]:
# Plot Labels Histogram

lLablels = dsData.GetLabels()

hF, hA = plt.subplots(figsize = FIG_SIZE_DEF)
hA = PlotLabelsHistogram(lLablels, hA, lClass = list(dCls.values()))

* <font color='brown'>(**#**)</font> In most cases _Image Segmentation_ deals with imbalanced classification.  
  One of the effective ways to deal with it is using the Focal Loss in the binary case ([Extension to Multi Class](https://discuss.pytorch.org/t/61289)) or approaches like [Class-Balanced Loss Based on Effective Number of Samples](https://arxiv.org/abs/1901.05555).

In [None]:
# Loader Transform

oTrns = TorchVisionTrns.Compose([
    # TorchVisionTrns.ToImage(),
    TorchVisionTrns.Resize(tuImgSize),
    TorchVisionTrns.ToDtype(torch.float, scale = True),
    TorchVisionTrns.RandomChoice([
            TorchVisionTrns.RandomGrayscale(p = 1.0),
            TorchVisionTrns.RandomHorizontalFlip(p = 1.0),
            TorchVisionTrns.RandomVerticalFlip(p = 1.0),
            TorchVisionTrns.RandomRotation(degrees = 10),
            TorchVisionTrns.RGB(), #<! Identity for RGB Images
        ], p = [0.15, 0.15, 0.15, 0.15, 0.40]),
])

dsData.SetTransforms(oTrns)

* <font color='blue'>(**!**)</font> Add color related augmenation.

In [None]:
# Plot Random Samples from the Dataset

rndIdx = random.randint(0, numSamples - 1)
tI, tM = dsData[rndIdx]

tI = TensorImgNumpy(tI)
tM = TensorImageNumpy(tM)

hF, vHa = plt.subplots(1, 2, figsize = (12, 6))
vHa = vHa.flat

hA = vHa[0]
hA.imshow(tI)
hA.set_title('Image')
hA.axis('off');

hA = vHa[1]
hA.imshow(tM, vmin = 0, vmax = numCls - 1, cmap = 'jet')
hA.set_title(f'Mask, Class: {dCls[tM.max()]}')
hA.axis('off');

In [None]:
# Create Training and Validation Datasets

vIdxTrain, vIdxVal = train_test_split(np.arange(numSamples), test_size = valSampleRatio, train_size = trainSampleRatio, random_state = seedNum, shuffle = True, stratify = lLablels)

dsTrain = torch.utils.data.Subset(dsData, vIdxTrain)
dsVal   = torch.utils.data.Subset(dsData, vIdxVal)

print(f'The training data set contains  : {len(dsTrain):4d} samples.')
print(f'The validation data set contains: {len(dsVal):4d} samples.')

* <font color='brown'>(**#**)</font> One could use negative values for the bounding box. The model will extrapolate the object dimensions.

In [None]:
# Data Loader

dlTrain, dlVal = GenDataLoaders(dsTrain, dsVal, batchSize, numWorkers = numWorkers, persWork = False)

* <font color='red'>(**?**)</font> Why are lists used instead of arrays for the labels and the bounding boxes?

In [None]:
# Element of the Data Set / Data Sample

tX, valY = dsTrain[0]

print(f'The features shape: {tX.shape}')
print(f'The features type : {tX.dtype}')
print(f'The labels shape  : {tM.shape}')
print(f'The labels type   : {tM.dtype}')

In [None]:
# Element of the Dataloader

tX, tM = GetBatch(dlTrain)

print(f'The batch of features shape: {tX.shape}')
print(f'The batch of features type : {tX.dtype}')
print(f'The batch of labels shape  : {tM.shape}')
print(f'The batch of labels type   : {tM.dtype}')

* <font color='brown'>(**#**)</font> Since the labels are in the same contiguous container as the bounding box parameters, their type is `Float`.
* <font color='brown'>(**#**)</font> The bounding box is using absolute values. In practice it is commonly normalized to the image dimensions.

## The Model

The U-Net Models rely heavily on the Transposed Convolution operator.



### The Transposed Convolution Operator

The _Transposed Convolution Operator_ can be used as a learned upscaling operator.  
The layers was crucial to achieve high quality upscaling for 1D and 2D signals.  
As opposed to interpolation, the effect is not pre defined but learned from data.

![](https://i.imgur.com/a9KBYKC.png)
<!-- ![](https://i.postimg.cc/Qxjjwrhz/Diagrams-Transposed-Convolution.png) -->

* <font color='brown'>(**#**)</font> The Transposed Convolution operator is given by the [`ConvTranspose2d`](https://docs.pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html) layer.
* <font color='brown'>(**#**)</font> There are 2 other views of the Transposed Convolution operator:
  - If the Convolution Operator can be represented by a matrix $\boldsymbol{W}$ such that: $\boldsymbol{z} = \boldsymbol{W} \boldsymbol{x}$ then the Transposed Convolution is given by $\boldsymbol{x} = \boldsymbol{W}^{\top} \boldsymbol{z}$.
  - The Transposed Convolution can be achieved by 2 steps (Similar to Signal Processing like approach):
    - Upsample (Insert zeros) between the data samples.
    - Apply _convolution_ as an _Low Pass Filter_ (LPF).
* <font color='brown'>(**#**)</font> In depth discussions on convolutions layers: [Distill - Augustus Odena, Vincent Dumoulin - Deconvolution and Checkerboard Artifacts](https://distill.pub/2016/deconv-checkerboard/), [StackExchange Cross Validated - The Equivalence of Upsample Layer and Transposed Convolution Layer](https://stats.stackexchange.com/questions/252810).
* <font color='brown'>(**#**)</font> Convolution Layer Visualization: [GitHub - Convolution Arithmetic](https://github.com/vdumoulin/conv_arithmetic), [Vincent Dumoulin, Francesco Visin - A guide to Convolution Arithmetic for Deep Learning](https://arxiv.org/abs/1603.07285).

### Inverted Residual Block

The concept of Inverted Residual Block was introduced in the papers on EfficientNet ([EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946)) and [MobileNet](https://en.wikipedia.org/wiki/MobileNet).

While classic _Residual Block_ has the bottleneck in the middle: Wide -> Narrow -> Wide, the _Inverted Residual Block_ starts narrow: Narrow -> Wide -> Narrow.

![](https://i.imgur.com/g1IhUaH.png)
<!-- ![](https://i.postimg.cc/cH1zXvKx/Diagrams-Inverted-Residual-Block.png) -->

The _Inverted Residual Block_ also uses Depthwise Convolution to reduce the number of parameters and removes the last activation layer to save computation time.

In [MobileNet v4](https://arxiv.org/abs/2404.10518) it was generalized by the _Universal Inverted Residual Block_ (UIB):

<img src="https://i.imgur.com/ARDgVh2.png" width=720 >
<!-- <img src="https://i.postimg.cc/sXppQmfP/image.png" width=720 > -->

In [None]:
# Inverted Residual Block

class InvertedResidualBlock(nn.Module):
    """
    Modern block structure: Expand -> Depth Wise Convolution -> Project.
    Includes a skip connection if input and output shapes match.
    """
    def __init__(self, numChnlIn: int, numChnlOut: int, expFctr: int = 4, strideSize: int = 1):
        super().__init__()
        
        self.strideSize = strideSize
        self.enableSkip = (strideSize == 1 and numChnlIn == numChnlOut)
        hiddenDim       = numChnlIn * expFctr

        self.oBlock = nn.Sequential(
            # Expansion of Channels (Using 1x1 Convolution)
            nn.Conv2d(numChnlIn, hiddenDim, 1, bias = False),
            nn.BatchNorm2d(hiddenDim),
            nn.SiLU(),
            
            # Depth Wise Convolution (3x3)
            nn.Conv2d(hiddenDim, hiddenDim, 3, stride = strideSize, padding = 1, groups = hiddenDim, bias = False),
            nn.BatchNorm2d(hiddenDim),
            nn.SiLU(),
            
            # Projection (Using 1x1 Convolution) - Linear bottleneck (No activation at end)
            nn.Conv2d(hiddenDim, numChnlOut, 1, bias = False),
            nn.BatchNorm2d(numChnlOut),
        )

    def forward(self, tX: Tensor) -> Tensor:
        
        if self.enableSkip:
            return tX + self.oBlock(tX)
        else:
            return self.oBlock(tX)

In [None]:
# Depthwise Separable Convolution Block

class DepthwiseSeparableConv(nn.Module):
    """
    The building block of efficient networks.
    Splits a standard convolution into:
    1. Depthwise: Spatial filtering (lightweight)
    2. Pointwise: Channel mixing (1x1 conv)
    """
    def __init__(self, numChnlIn: int, numChnlOut: int, strideSize: int = 1):
        super().__init__()
        
        self.strideSize = strideSize
        
        # Depthwise Convolution (3x3) (Each kernel per channel)
        self.oBlock001 = nn.Sequential(
            nn.Conv2d(numChnlIn, numChnlIn, kernel_size = 3, padding = 1,  stride = strideSize, groups = numChnlIn, bias = False),
            nn.BatchNorm2d(numChnlIn),
            nn.SiLU(), #<! Modern activation (Swish)
        )
        # Pointwise Convolution (1x1) Projection over channels
        self.oBlock002 = nn.Sequential(
            nn.Conv2d(numChnlIn, numChnlOut, kernel_size = 1, bias = False),
            nn.BatchNorm2d(numChnlOut),
            nn.SiLU(), #<! Modern activation (Swish)
        )

    def forward(self, tX: Tensor) -> Tensor:
        
        tX = self.oBlock001(tX)
        tX = self.oBlock002(tX)
        
        return tX

### Optimizing the Model for the Case

Since in the data above there is no case of 2 different non background in the same image the task can be separated into 2 simpler tasks:
 1. Segment the support of the defect.
 2. Classify the type of the defect.

In order to achieve it, the model has 2 heads:
 - Per Pixel Binary Classification Head  
   Per pixel calculates the probability of defect.
 - Global Multi Class classification Head  
   Estimate the probability of the class of the defect.

Both heads output the _logits_ of the probabilities.

In [None]:
# UNet: Encoder and Decoder Blocks

class µSegmentor(nn.Module):
    def __init__(self, numChnlIn: int, numCls: int, numFiltersBase: int = 32, *, useConvTrns: bool = False):
        super().__init__()

        # Encoder (Downsampling)
        # Standard Feature Extractor block
        self.oFeatExt = nn.Sequential(
            nn.Conv2d(numChnlIn, numFiltersBase, 3, padding = 1, stride = 2, bias = False),
            nn.BatchNorm2d(numFiltersBase),
            nn.SiLU()
        )

        # Encoder Stages (using Inverted Residuals)
        # Note: Use `stride = 2` in the first block of a stage to downsample
        self.oEnc001 = InvertedResidualBlock(numFiltersBase    , numFiltersBase * 2, strideSize = 2) #<! H/4
        self.oEnc002 = InvertedResidualBlock(numFiltersBase * 2, numFiltersBase * 4, strideSize = 2) #<! H/8
        self.oEnc003 = InvertedResidualBlock(numFiltersBase * 4, numFiltersBase * 8, strideSize = 2) #<! H/16

        # Embedding / Bottleneck
        self.oEmbed = InvertedResidualBlock(numFiltersBase * 8, numFiltersBase * 16, strideSize = 1) # H/16

        # Decoder (Upsampling)
        # Use simple Bilinear Upsampling + 1x1 Conv to reduce channels / Transposed Convolution
        
        # Process output of `oEnc002`
        if useConvTrns:
            self.oUp003  = nn.Sequential(
                nn.ConvTranspose2d(numFiltersBase * 16, numFiltersBase * 16, kernel_size = 3, stride = 2, padding = 1, output_padding = 1, bias = False),
                nn.BatchNorm2d(numFiltersBase * 16),
                nn.SiLU(),
            )
        else:
            self.oUp003  = nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False)

        self.oDec003 = nn.Sequential(
            nn.Conv2d(numFiltersBase * 16 + numFiltersBase * 4, numFiltersBase * 8, 1),
            InvertedResidualBlock(numFiltersBase * 8, numFiltersBase * 8),
        )

        # Process output of `oEnc001`
        if useConvTrns:
            self.oUp002  = nn.Sequential(
                nn.ConvTranspose2d(numFiltersBase * 8, numFiltersBase * 8, kernel_size = 3, stride = 2, padding = 1, output_padding = 1, bias = False),
                nn.BatchNorm2d(numFiltersBase * 8),
                nn.SiLU(),
            )
        else:
            self.oUp002 = nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False)
        self.oDec002 = nn.Sequential(
            nn.Conv2d(numFiltersBase * 8 + numFiltersBase * 2, numFiltersBase * 4, 1),
            InvertedResidualBlock(numFiltersBase * 4, numFiltersBase * 4),
        )

        # Process output of `oFeatExt`
        if useConvTrns:
            self.oUp001  = nn.Sequential(
                nn.ConvTranspose2d(numFiltersBase * 4, numFiltersBase * 4, kernel_size = 3, stride = 2, padding = 1, output_padding = 1, bias = False),
                nn.BatchNorm2d(numFiltersBase * 4),
                nn.SiLU(),
            )
        else:
            self.oUp001 = nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False)
        self.oDec001 = nn.Sequential(
            nn.Conv2d(numFiltersBase * 4 + numFiltersBase, numFiltersBase * 2, 1),
            InvertedResidualBlock(numFiltersBase * 2, numFiltersBase * 2),
        )
        
        # Segmentation Mask Head
        # Final upsample to restore original resolution
        if useConvTrns:
            self.oHeadMask = nn.Sequential(
                nn.ConvTranspose2d(numFiltersBase * 2, numFiltersBase * 2, kernel_size = 3, stride = 2, padding = 1, output_padding = 1, bias = False),
                nn.BatchNorm2d(numFiltersBase * 2),
                nn.SiLU(),
                nn.Conv2d(numFiltersBase * 2, 1, 1), #<! Classify any class
            )
        else:
            self.oHeadMask = nn.Sequential(
                nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False),
                # nn.Conv2d(numFiltersFeatExt * 2, numCls, 1), #<! Classify per class
                nn.Conv2d(numFiltersBase * 2, 1, 1), #<! Classify any class
            )

        # Classification Head
        self.oHeadCls = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(numFiltersBase, numCls),
        )

    def forward( self, tX: Tensor ) -> Tensor:
        # tX: [B, 3, H, W]
        # H and W must be even numbers
        # If H / E are divisible by 16 no interpolation will be needed during skip connections
        
        # Encoder
        tX0 = self.oFeatExt(tX) #<! H/2
        tX1 = self.oEnc001(tX0) #<! H/4
        tX2 = self.oEnc002(tX1) #<! H/8
        tX3 = self.oEnc003(tX2) #<! H/16
        
        # Lowe Dim Embedding / Bottleneck
        tEm = self.oEmbed(tX3) #<! H/16
        
        # Decoder
        # D3: Upsample Bridge, concat with `oEnc002`
        tD3 = self.oUp003(tEm)
        # Handle potential padding issues if H/W aren't perfect powers of 2
        if tD3.shape[2:] != tX2.shape[2:]:
            tD3 = F.interpolate(tD3, size = tX2.shape[2:], mode = 'bilinear', align_corners = False)
        tD3 = torch.cat([tD3, tX2], dim = 1)
        tD3 = self.oDec003(tD3)
        
        # D2: Upsample D3, concat with `oEnc001`
        tD2 = self.oUp002(tD3)
        if tD2.shape[2:] != tX1.shape[2:]:
            tD2 = F.interpolate(tD2, size = tX1.shape[2:], mode = 'bilinear', align_corners = False)
        tD2 = torch.cat([tD2, tX1], dim = 1)
        tD2 = self.oDec002(tD2)
        
        # D1: Upsample D2, concat with `oFeatExt`
        tD1 = self.oUp001(tD2)
        if tD1.shape[2:] != tX0.shape[2:]:
            tD1 = F.interpolate(tD1, size = tX0.shape[2:], mode = 'bilinear', align_corners = False)
        tD1 = torch.cat([tD1, tX0], dim = 1) # Concatenating with H/2 features
        tD1 = self.oDec001(tD1)
        
        # Final Output
        tO = self.oHeadMask(tD1) #<! Segmentation Mask Head
        tY = self.oHeadCls(tEm)  #<! Classification Head
            
        return tO, tY

In [None]:
# Model

oModel = µSegmentor(numChnlIn = tX.shape[1], numCls = numCls, numFiltersBase = numFiltersBase, useConvTrns = False) #<! Use `useConvTrns = False` for better run time

In [None]:
# Run device

runDevice = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')) #<! The 1st CUDA device

In [None]:
# Model Summary

torchinfo.summary(oModel, tX.shape, col_names = ['input_size', 'output_size', 'num_params', 'kernel_size', 'trainable'], device = runDevice, row_settings = ['depth', 'var_names'])

In [None]:
# Model Architecture

torchvista.trace_model(oModel, tX.to(runDevice))

## Train the Model



In [None]:
# Loss Class

class SegmentationClsLoss(nn.Module):
    def __init__( self, weightSeg: float = 1.0, weightCls: float = 1.0 ) -> None:
        """
        Combined Segmentation and Classification Loss.

        Parameters
        ----------
        weightSeg : float, optional
            Weight for the segmentation loss, by default 1.0
        weightCls : float, optional
            Weight for the classification loss, by default 1.0
        """
        super().__init__()
        
        self.weightSeg = weightSeg
        self.weightCls = weightCls
        
        self.oSegLoss = nn.BCEWithLogitsLoss()
        self.oClsLoss = nn.CrossEntropyLoss()
    
    def forward( self, tuZ: Tuple[Tensor, Tensor], tMTgt: Tensor ) -> Tensor:
        """
        Computes the combined loss.

        Parameters
        ----------
        tPredMask : Tensor
            Predicted segmentation mask (B x 1 x H x W).
        tTrueMask : Tensor
            True segmentation mask (B x H x W).
        tPredCls : Tensor
            Predicted class logits (B x numCls).
        tTrueCls : Tensor
            True class labels (B).

        Returns
        -------
        Tensor
            Combined loss.
        """

        tM   = tuZ[0]
        mCls = tuZ[1]
        tM   = tM.squeeze(1)

        # Extract class from the mask
        vClsTgt = tMTgt.view(tMTgt.shape[0], -1).amax(dim = 1) #<! (B,)
        # Extract the binary mask
        tMTgt   = tMTgt.not_equal(0).to(torch.float32) #<! (B, H, W)
        
        segLoss = self.oSegLoss(tM, tMTgt)
        clsLoss = self.oClsLoss(mCls, vClsTgt)
        
        valLoss = self.weightSeg * segLoss + self.weightCls * clsLoss
        
        return valLoss

In [None]:
# Score Class

class SegmentationClsScore(nn.Module):
    def __init__( self, segThr: float, weightSeg: float = 1.0, weightCls: float = 1.0 ) -> None:
        """
        Combined Segmentation and Classification Accuracy.

        """
        super().__init__()
        
        self.segThr    = segThr
        # Score Weights are normalized so the score is in [0, 1]
        self.weightSeg = weightSeg / (weightSeg + weightCls)
        self.weightCls = weightCls / (weightSeg + weightCls)
    
    def forward( self, tuZ: Tuple[Tensor, Tensor], tMTgt: Tensor ) -> Tuple[Tensor, Tensor]:
        """
        Computes the combined accuracy.

        Parameters
        ----------
        tPredMask : Tensor
            Predicted segmentation mask (B x 1 x H x W).
        tTrueMask : Tensor
            True segmentation mask (B x H x W).
        tPredCls : Tensor
            Predicted class logits (B x numCls).
        tTrueCls : Tensor
            True class labels (B).

        Returns
        -------
        Tuple[Tensor, Tensor]
            Segmentation accuracy and classification accuracy.
        """

        tM   = tuZ[0]
        mCls = tuZ[1]
        tB   = (torch.sigmoid(tM.squeeze(1)) > self.segThr).to(torch.long) #<! (B, H, W)
        vCls = mCls.argmax(dim = 1) #<! (B,), No need for SoftMax as argmax is invariant to monotonic transforms

        # Extract class from the mask
        vClsTgt = tMTgt.view(tMTgt.shape[0], -1).amax(dim = 1) #<! (B,)
        # Extract the binary mask
        tBTgt   = tMTgt.not_equal(0).to(torch.long) #<! (B, H, W)

        segScore = mean_iou(tB, tBTgt, num_classes = 2, include_background = True, per_class = False, input_format = 'index').mean() #<! Mean IoU over the batch
        clsScore = multiclass_f1_score(vCls, vClsTgt, num_classes = numCls, average = 'macro', top_k = 1, multidim_average = 'global', ignore_index = None, validate_args = False, zero_division = 0) #<! Setting Macro so each class has the same weight
        
        valScore = self.weightSeg * segScore + self.weightCls * clsScore
        
        return valScore

In [None]:
# Loss and Score

hL = SegmentationClsLoss(weightSeg, weightCls)
hS = SegmentationClsScore(segThr = segThr, weightSeg = weightSeg, weightCls = weightCls)
hL = hL.to(runDevice)
hS = hS.to(runDevice)

In [None]:
# Optimizer Related

oOpt = torch.optim.AdamW(oModel.parameters(), lr = ηOpt, betas = tuβ, weight_decay = weightDecay) #<! Define optimizer
oSch = torch.optim.lr_scheduler.OneCycleLR(oOpt, max_lr = ηSch, total_steps = numEpochs)

In [None]:
# Training Loop

oModel = oModel.to(runDevice)
_, lTrainLoss, lTrainScore, lValLoss, lValScore, lLearnRate = TrainModel(oModel, dlTrain, dlVal, oOpt, numEpochs, hL, hS, oSch = oSch)

In [None]:
# Plot Training Phase

hF, vHa = plt.subplots(nrows = 1, ncols = 3, figsize = (18, 5))
vHa = np.ravel(vHa)

hA = vHa[0]
hA.plot(lTrainLoss, lw = 2, label = 'Train')
hA.plot(lValLoss, lw = 2, label = 'Validation')
hA.set_title(f'Loss')
hA.set_xlabel('Epoch')
hA.set_ylabel('Loss')
hA.legend()

hA = vHa[1]
hA.plot(lTrainScore, lw = 2, label = 'Train')
hA.plot(lValScore, lw = 2, label = 'Validation')
hA.set_title('Score')
hA.set_xlabel('Epoch')
hA.set_ylabel('Score')
hA.legend()

hA = vHa[2]
hA.plot(lLearnRate, lw = 2)
hA.set_title('Learn Rate Scheduler')
hA.set_xlabel('Epoch')
hA.set_ylabel('Learn Rate');

In [None]:
# Load the Best Model

oModel.load_state_dict(torch.load('BestModel.pt')['Model'])

In [None]:
# Process Samples through the Trained Model

rndIdx = random.randint(0, len(dsVal) - 1)
tX, tY = dsVal[rndIdx]

oModel.eval()
with torch.inference_mode():
    tXb = tX.unsqueeze(0).to(runDevice) #<! Add batch dimension
    tuZb = oModel(tXb)
    tMb = tuZb[0].squeeze(0).cpu() #<! Remove batch dimension
    tYb = tuZb[1].squeeze(0).cpu()

    tMHat = torch.sigmoid(tMb) > segThr
    tMHat = tMHat.to(torch.long)
    valY  = tYb.argmax(dim = 0).item()

tI    = TensorImgNumpy(tX)
tM    = TensorImageNumpy(tY)
tMHat = TensorImageNumpy(tMHat)

hF, vHa = plt.subplots(1, 3, figsize = (12, 6))
vHa = vHa.flat

hA = vHa[0]
hA.imshow(tI)
hA.set_title('Image')
hA.axis('off');

hA = vHa[1]
hA.imshow(tM, vmin = 0, vmax = numCls - 1, cmap = 'jet')
hA.set_title(f'Mask, Class: {dCls[tM.max()]}')
hA.axis('off');

hA = vHa[2]
hA.imshow(tMHat, vmin = 0, vmax = 1, cmap = 'gray')
hA.set_title(f'Predicted Mask, Class: {dCls[valY]}')
hA.axis('off');

* <font color='red'>(**?**)</font> Which post process operation can be done to improve results?
* <font color='green'>(**@**)</font> Add full analysis of the model performance on the validation set: Confusion Matrix for the classification head, Precision-Recall Curve for the segmentation head, etc.
* <font color='green'>(**@**)</font> Plot the worst / best cases IoU wise. See the effect of the _segmentation threshold_.
* <font color='green'>(**@**)</font> Improve the loss function based on [Loss functions for image segmentation](https://github.com/JunMa11/SegLossOdyssey) and [Losses Used in Segmentation Task](https://github.com/Nacriema/Loss-Functions-For-Semantic-Segmentation).