<div align='center'><div align='center'>

## Image Processing

### Image Segmentation

#### Version 1.0, 12th Dec, 2024
</div>

This notebook focuses on segmenting lung regions affected by COVID-19 in computed tomography (CT) scans. Segmentation highlights critical features such as ground-glass opacities and consolidations, which are vital for diagnosing COVID-19. Accurate identification of these regions is essential for assessing the severity of infection and guiding treatment plans.

In this tutorial, we will:
1. Explore the COVID-19 CT dataset.
2. Preprocess the data for segmentation tasks.
3. Build a segmentation model using a U-Net architecture.
4. Train the model and evaluate its performance.

#### Step 1: Explore dataset

The dataset used in this tutorial consists of 2D lung CT scans along with their corresponding segmentation masks. It includes 100 axial CT images and the corresponding annotated masks (0 - "ground glass", 1 - "consolidations", 2 - "lungs other", 3 - "background"), derived from over 40 COVID-19 patients, originally sourced from publicly available JPG images and subsequently converted for this study.

Corresponding segmentation masks contain four channels:

1) Ground glass opacities (GGO)
2) Consolidations
3) Normal lung tissues
4) Background

Each image has dimensions of 520 × 520 pixels. The segmentation masks identify regions of interest within the lungs, including infected areas (such as ground-glass opacities and consolidations) and normal lung tissue. Here we start from exploring the dataset.

In [None]:
import os

# Function to concatenate file chunks
def concatenate_files(output_file, chunk_prefix):
    """
    Concatenates all file chunks with a given prefix into a single output file.

    Parameters:
    - output_file (str): The name of the output file to create.
    - chunk_prefix (str): The common prefix of all the file chunks to merge.
    """
    with open(output_file, 'wb') as outfile:
        # Get all chunk files starting with the prefix and sort them
        chunk_files = sorted([f for f in os.listdir('.') if f.startswith(chunk_prefix)])
        
        for chunk_file in chunk_files:
            with open(chunk_file, 'rb') as infile:
                outfile.write(infile.read())
            print(f"Added {chunk_file} to {output_file}")
    print(f"File {output_file} created successfully!")

# Concatenate the chunks for images_medseg.npy
concatenate_files('images_medseg.npy', 'images_part_')

# Concatenate the chunks for masks_medseg.npy
concatenate_files('masks_medseg.npy', 'masks_part_')

In [2]:
import numpy as np
import os
data_path = './'     # this could be changed if you put the training data to other places
images_medseg = np.load(os.path.join(data_path, 'images_medseg.npy'))
masks_medseg = np.load(os.path.join(data_path, 'masks_medseg.npy'))

**Question-1:** What are the types of the images and masks? You may use print(variable.dtype) to get the type of the variables.
Generally, which type do we need for training? Why?

Now we visualize the data

In [4]:
import matplotlib.pyplot as plt

def visualize(image_batch, mask_batch=None, pred_batch=None, num_samples=8, hot_encode=True):
    num_classes = mask_batch.shape[-1] if mask_batch is not None else 0
    fix, ax = plt.subplots(num_classes + 1, num_samples, figsize=(num_samples * 2, (num_classes + 1) * 2))

    for i in range(num_samples):
        ax_image = ax[0, i] if num_classes > 0 else ax[i]
        if hot_encode:
            ax_image.imshow(image_batch[i,:,:,0], cmap='Greys')
        else:
            ax_image.imshow(image_batch[i,:,:])
        ax_image.set_xticks([])
        ax_image.set_yticks([])

        if mask_batch is not None:
            for j in range(num_classes):
                if pred_batch is None:
                    mask_to_show = mask_batch[i,:,:,j]
                else:
                    mask_to_show = np.zeros(shape=(*mask_batch.shape[1:-1], 3))
                    mask_to_show[..., 0] = pred_batch[i,:,:,j].cpu().numpy() > 0.5
                    mask_to_show[..., 1] = mask_batch[i,:,:,j]
                ax[j + 1, i].imshow(mask_to_show, vmin=0, vmax=1)
                ax[j + 1, i].set_xticks([])
                ax[j + 1, i].set_yticks([])

    plt.tight_layout()
    plt.show()


In [None]:
visualize(images_medseg[10:], masks_medseg[10:])

#### Step 2: Data preprocessing

**Question-2:** What is the range of a normal RGB/Grey image? What is the range for a CT scan?

**Answer-2:** The range for RGB/Grey images is [0,255].
However, for CT scans, it is different. In CT images, we use the Hounsfield Unit (HU) as the standardized measure to quantify the density of tissues. It is derived from the linear attenuation coefficient of tissues compared to water, where water is assigned a value of 0 HU and air is -1000 HU. Dense structures like bone have high positive values (e.g., +1000 HU), while less dense materials like fat have lower values (e.g., -100 HU).
The HU scale enables consistent interpretation of tissue densities across CT scans, making it crucial for medical diagnosis and research. The range for a human CT scan is approximately from -1000 to 2000.

Use the following function to visualize the histogram of the 10th image.

In [None]:
def visualize_histogram(image):
    plt.figure(figsize=(10, 6))
    plt.hist(image.flatten(), bins=100, color='blue', alpha=0.7)
    plt.title('CT Image Histogram')
    plt.xlabel('Pixel Intensity')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()
visualize_histogram(images_medseg[10])

#### Window shift and normalization
From the histogram above, you will find most pixels range between [-1000, -700] and [-200,200]. In computed tomography (CT), a window transform refers to the process of adjusting the display settings to focus on specific ranges of Hounsfield Units (HU), which correspond to different tissue densities. This adjustment enhances the visibility of certain tissues or structures in the CT images by mapping the HU range to the grayscale display.

**Key Concepts of Window Transform**

1) Window Width (WW): Determines the range of HU values displayed in the image. Pixels with HU values outside this range are set to the minimum (black) or maximum (white) grayscale values. A narrow window width enhances contrast but may lose detail in areas outside the range.
2) Window Level (WL): Specifies the center of the HU range (or window). Adjusting the window level shifts the range of displayed HU values, allowing focus on different tissue types.

**Common Window Settings in CT**
Lung: WW ~ 1500 HU, WL ~ -600 HU.
Bone: WW ~ 2000 HU, WL ~ 500 HU.
Soft Tissue Window: WW ~ 400 HU, WL ~ 40 HU.

**Question-3**: complete the following code for window shift with WW ~ 1500 HU, WL ~ -600 HU, and rescale the image to [0,1]

In [23]:
def window_shift_and_normalize(ct_image, window_center, window_width):
    """
    Apply window shift and normalization to a CT image.

    Parameters:
    - ct_image (numpy array): The input CT image (HU values).
    - window_center (int): The center of the window (e.g., 40 for soft tissue).
    - window_width (int): The width of the window (e.g., 400 for soft tissue).

    Returns:
    - normalized_image (numpy array): The processed image with values normalized to [0, 1].
    """
    # Step 1: Calculate the window bounds (replace the placeholders)
    window_min = ... # Lower bound of the window
    window_max = ... # Upper bound of the window

    # Step 2: Clip the image to the window range (replace the placeholder)
    shifted_image = np.clip(..., ..., ...)

    # Step 3: Normalize the image to [0, 1] (replace the placeholder)
    normalized_image = (... - ...) / (... - ...)

    return normalized_image

You should get the following results:

In [None]:
image_10th, masks_10th = images_medseg[10], masks_medseg[10]
image_trans = window_shift_and_normalize(image_10th,-600, 1500)
plt.figure()
plt.imshow(image_trans*255,cmap='Greys')
plt.show()

#### Visualize segmentation mask
Here you need to merge the one-hot mask to the segmentation mask for visualization to make sure everything goes right.

**Question-4:** How to convert the onehot mask to the segmentation mask?

```
seg_mask10 =
plt.figure()
plt.imshow(seg_mask10)
plt.show()
```


You should get something like this

In [None]:
seg_mask10 =
plt.figure()
plt.imshow()
plt.show()

#### Step 3: Build your Dataloader
Now it's time to build your own Pytorch dataloader. The DataLoader in PyTorch is a utility that provides an efficient way to load and preprocess data for training or inference in deep learning models. It works by combining a dataset and a sampler, allowing you to easily iterate over data in batches, shuffle data, and use multiprocessing to speed up data loading.

**Question-5:**
Build your own Pytorch dataloader to load the data based on the following example:
```python
from torch.utils.data import DataLoader, Dataset

class Covid19_Dataloader(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        return image, mask
```

Please note that you need to
1. add window adjusting function into the class 'Covid19_Dataloader' for window shifting
2. add randomflip for data augmentation and set the flipping prob to 0.45.

#### Step 4: Build your neural network
Now, you are ready to build a neural network for training. Here we will use a simple UNet as an example.
You are encouraged to build your own.

In [48]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import random

class Covid19_Dataloader(Dataset):
    def __init__(self, images, masks, flip_prob=0.5):
        self.images = images
        self.masks = masks
        self.flip_prob = flip_prob
        self.to_Tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.images)

    def window_shift_and_normalize(self, ct_image, window_center, window_width):
        window_min = window_center - (window_width / 2)
        window_max = window_center + (window_width / 2)

        shifted_image = np.clip(ct_image, window_min, window_max)

        normalized_image = (shifted_image - window_min) / (window_max - window_min)

        return normalized_image

    def random_flip(self, image, mask):
        """Randomly flip both image and mask horizontally and/or vertically."""
        if random.random() < self.flip_prob:
            # Horizontal flip
            image = torch.flip(image, dims=[2])
            mask = torch.flip(mask, dims=[2])
        if random.random() < self.flip_prob:
            # Vertical flip
            image = torch.flip(image, dims=[1])
            mask = torch.flip(mask, dims=[1])
        return image, mask

    def __getitem__(self, idx):
        image = self.images[idx]
        image = self.window_shift_and_normalize(image, window_center=-600, window_width=1500)
        mask = self.masks[idx]
        image_T, mask_T = self.to_Tensor(image.astype(np.float32)), self.to_Tensor(mask.astype(np.float32))
        image, mask = self.random_flip(image_T, mask_T)
        return image, mask


**FCN**

In [None]:
class FCN(nn.Module):
    def __init__(self, num_classes=2):
        super(FCN, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, num_classes, kernel_size=3, padding=1)
        )

        self.upsample = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        x = self.upsample(x)
        return x

**UNet**

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)


class UNet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(UNet, self).__init__()

        self.conv1 = DoubleConv(in_ch, 32)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(32, 64)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(64, 128)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(128, 256)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(256, 512)
        self.up6 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv6 = DoubleConv(512, 256)
        self.up7 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv7 = DoubleConv(256, 128)
        self.up8 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv8 = DoubleConv(128, 64)
        self.up9 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.conv9 = DoubleConv(64, 32)
        self.conv10 = nn.Conv2d(32, out_ch, 1)

    def forward(self, x):
        #print(x.shape)
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        #print(p1.shape)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        #print(p2.shape)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        #print(p3.shape)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        #print(p4.shape)
        c5 = self.conv5(p4)
        up_6 = self.up6(c5)
        merge6 = torch.cat([up_6, c4], dim=1)
        c6 = self.conv6(merge6)
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9, c1], dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)
        out = nn.Sigmoid()(c10)
        return out

#### the Attention-UNet
![jupyter](./attention.png)
Schematic of the attention gate (AG). Input features $x^l$ are scaled with attention coefficients ($\alpha$) computed in AG. Spatial regions are selected by analysing both the activations and contextual information provided by the gating signal ($g$) which is collected from a coarser scale. Grid resampling of attention coefficients is done using trilinear interpolation.


The attention gate is formulated as:
$$
q_{\text{att}}^l = \psi^T \left( \sigma_1 \left( W_x^T x_i^l + W_g^T g_i + b_g \right) \right) + b_\psi,
$$

$$
\alpha_i^l = \sigma_2 \left( q_{\text{att}}^l \left( x_i^l, g_i; \Theta_{\text{att}} \right) \right),
$$
where
$$
\sigma_2(x_{i,c}) = \frac{1}{1 + \exp(-x_{i,c})}
$$

This corresponds to the sigmoid activation function. AG is characterized by a set of parameters $\Theta_{\text{att}}$ containing: linear transformations: $W_x \in \mathbb{R}^{F_l \times F_{\text{int}}}$, $W_g \in \mathbb{R}^{F_g \times F_{\text{int}}}$, $\psi \in \mathbb{R}^{F_{\text{int}} \times 1}$ and bias terms $b_\psi \in \mathbb{R}$, $b_g \in \mathbb{R}^{F_{\text{int}}}$. The linear transformations are computed using channel-wise $1 \times 1 \times 1$ convolutions for the input tensors.

**Question-6**: write the attention block
```python
class Attention_block(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        # 1x1 conv for signal g
        g1 =
        # 1x1 conv for signal x^l
        x1 =
        # concat + relu
        psi =
        # get attention map
        psi =
        return x * psi
```

Now we can get the attention UNet based on the attention block

In [None]:
class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x


class AttentionUNet(nn.Module):
    def __init__(self, img_ch=3, output_ch=1):
        super(AttentionUNet, self).__init__()

        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)
        self.Conv2 = conv_block(ch_in=64, ch_out=128)
        self.Conv3 = conv_block(ch_in=128, ch_out=256)
        self.Conv4 = conv_block(ch_in=256, ch_out=512)
        self.Conv5 = conv_block(ch_in=512, ch_out=1024)

        self.Up5 = up_conv(ch_in=1024, ch_out=512)
        self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512, ch_out=256)
        self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)

        self.Up3 = up_conv(ch_in=256, ch_out=128)
        self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)

        self.Up2 = up_conv(ch_in=128, ch_out=64)
        self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5, x=x4)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4, x=x3)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3, x=x2)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2, x=x1)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)
        d1 = self.sigmoid(d1)

        return d1

#### Training
Now you have your dataloader ready, your network ready, and it's time to train your model!

In [58]:
from sklearn.model_selection import train_test_split

model = UNet()
X_train, X_val, y_train, y_val = train_test_split(images_medseg, masks_medseg, test_size=0.2, random_state=42)

train_dataset = Covid19_Dataloader(X_train, y_train, flip_prob=0.45)
val_dataset = Covid19_Dataloader(X_val, y_val,  flip_prob=0.45)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
epochs = 30

In [None]:
train_losses = []
for epoch in range(epochs):
    model.train()
    train_loss = 0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    average_loss = train_loss / len(train_loader)
    train_losses.append(average_loss)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {average_loss:.4f}")

plt.figure(figsize=(10, 6))
plt.plot(range(1, epochs + 1), train_losses, marker='o', linestyle='-')
plt.title('Training Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()

#### Step 5: Evaluation
Now we have completed the training part. However, we don't know the model's performance in addition to the training losses.

Intersection over Union (IoU), also known as the Jaccard Index, is a metric commonly used to evaluate the performance of image segmentation models. It measures the overlap between the predicted segmentation mask and the ground truth mask.
IoU is particularly useful in assessing how well a model predicts regions of interest, such as objects or tissues, in an image.

**IoU Calculation Formula**
The IoU is calculated as the ratio of the intersection to the union of the predicted and ground truth masks:

$$
IoU = \frac{\text{Intersection}}{\text{Union}}
$$

Where:
- **Intersection**: The area of overlap between the predicted mask and the ground truth mask.
- **Union**: The total area covered by the predicted mask and the ground truth mask (including the overlap).

This can also be expressed mathematically as:

$$
IoU = \frac{|P \cap T|}{|P \cup T|} = \frac{|P \cap T|}{|P| + |T| - |P \cap T|}
$$

Where:
- \(P\): Predicted mask.
- \(T\): Ground truth mask.
- $|P \cap T|$: Number of pixels in the intersection of \(P\) and \(T\).
- $|P \cup T|$: Number of pixels in the union of \(P\) and \(T\).

**Example Use Case in Binary Segmentation**
For binary segmentation tasks, the predicted mask is usually a binary image obtained by applying a threshold to the model's output. The ground truth mask is also a binary image. The IoU is then calculated as:

1. Compute the intersection: Sum of all pixels where both the predicted and ground truth masks have value 1.
2. Compute the union: Sum of all pixels where either the predicted or ground truth mask has value 1.
3. Divide the intersection by the union.



**Question-7:** write a function that evaluates the multi-class iou score of the prediction
```python
def mIoU(pred_mask, mask, smooth=1e-10, n_classes=4):
    with torch.no_grad():
        # 1. use softmax as the activation function to convert pred logits to probs

        # 2. calculate the iou score for each class and average
        iou_per_class = []
        for clas in range(0, n_classes):
            # 3. caculate the ious score of current class
            intersect =
            union =

            iou = /
            iou_per_class.append(iou)

        return np.nanmean(iou_per_class)
```

In [60]:
import torch.nn.functional as F
def mIoU(pred_mask, mask, smooth=1e-10, n_classes=4):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)

        iou_per_class = []
        for clas in range(0, n_classes): #loop per pixel class
            true_class = pred_mask == clas
            true_label = mask == clas

            if true_label.long().sum().item() == 0: #no exist label in this loop
                iou_per_class.append(np.nan)
            else:
                intersect = torch.logical_and(true_class, true_label).sum().float().item()
                union = torch.logical_or(true_class, true_label).sum().float().item()

                iou = (intersect + smooth) / (union +smooth)
                iou_per_class.append(iou)
        return np.nanmean(iou_per_class)

In [None]:
val_ious = []
train_losses = []
for epoch in range(epochs):
    # Training Loop
    model.train()
    train_loss = 0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_losses.append(train_loss / len(train_loader))

    # Validation Loop
    model.eval()
    iou_scores = []
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            iou = mIoU(outputs, torch.argmax(masks,dim=1))
            iou_scores.append(iou)

    average_iou = np.mean(iou_scores)
    val_ious.append(average_iou)

    print(f"Epoch {epoch+1}/{epochs}, Loss: {train_losses[-1]:.4f}, IoU: {average_iou:.4f}")

#### Step-6 Improve the performance
So here comes the question, what shall we do to improve the performance?
1. Use more powerful architectures (UNet++)
2. Use different loss functions (you may try dice coefficient loss)
3. Use more data (download radiopedia data for further experiments https://www.kaggle.com/competitions/covid-segmentation/data?select=images_radiopedia.npy)