# Train UNET

In [1]:
!pip install scikit-learn



## Third Pary Code

In [2]:
# The model is copied from previous_work/models/unet/UNet.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class UNetDoubleConvBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()
        self.double_conv_block = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding=1,
            ),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding=1,
            ),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.double_conv_block(x)


class UNetDownwardLayer(torch.nn.Module):
    def __init__(self, *, in_channels, out_channels) -> None:
        super().__init__()
        self.conv_block = UNetDoubleConvBlock(
            in_channels=in_channels,
            out_channels=out_channels,
        )
        self.down_sample = nn.MaxPool2d(kernel_size=2)

    def forward(self, x):
        out_forward = self.conv_block(x)
        out_downward = self.down_sample(out_forward)
        return out_forward, out_downward


class UNetUpwardLayer(torch.nn.Module):
    def __init__(self, *, num_features) -> None:
        super().__init__()
        self.conv_block = UNetDoubleConvBlock(
            in_channels=num_features * 2,
            out_channels=num_features,
        )
        self.up_sample_conv = nn.ConvTranspose2d(
            in_channels=num_features * 2,
            out_channels=num_features,
            kernel_size=2,
            stride=2,
        )

    def forward(self, x_from_lower_layer, x_from_encoder_forward):
        x_upsampled = self.up_sample_conv(x_from_lower_layer)
        x = torch.cat(
            (x_upsampled, x_from_encoder_forward),
            dim=1,
        )
        return self.conv_block(x)


class UNet(torch.nn.Module):
    def __init__(
        self, *, in_size=(512, 512), in_channels=1, out_channels=1, init_features=64
    ) -> None:
        super().__init__()
        self.encode_down_layer1 = UNetDownwardLayer(
            in_channels=in_channels,
            out_channels=init_features,
        )
        self.encode_down_layer2 = UNetDownwardLayer(
            in_channels=init_features,
            out_channels=init_features * 2,
        )
        self.encode_down_layer3 = UNetDownwardLayer(
            in_channels=init_features * 2,
            out_channels=init_features * 4,
        )
        self.encode_down_layer4 = UNetDownwardLayer(
            in_channels=init_features * 4,
            out_channels=init_features * 8,
        )
        self.bottom_layer = UNetDoubleConvBlock(
            in_channels=init_features * 8,
            out_channels=init_features * 16,
        )
        self.decode_up_layer4 = UNetUpwardLayer(num_features=init_features * 8)
        self.decode_up_layer3 = UNetUpwardLayer(num_features=init_features * 4)
        self.decode_up_layer2 = UNetUpwardLayer(num_features=init_features * 2)
        self.decode_up_layer1 = UNetUpwardLayer(num_features=init_features)
        self.output_conv = nn.Conv2d(
            in_channels=init_features,
            out_channels=out_channels,
            kernel_size=1,
        )

        self.out_channels = out_channels
        if out_channels == 1:
            self.out_layer_func = nn.Sigmoid()
        else:
            self.out_layer_func = nn.Identity()

    def forward(self, x):
        x_encode_forward1, x_encode_downward1 = self.encode_down_layer1(x)
        x_encode_forward2, x_encode_downward2 = self.encode_down_layer2(
            x_encode_downward1
        )
        x_encode_forward3, x_encode_downward3 = self.encode_down_layer3(
            x_encode_downward2
        )
        x_encode_forward4, x_encode_downward4 = self.encode_down_layer4(
            x_encode_downward3
        )
        x_out_bottom = self.bottom_layer(x_encode_downward4)
        x_decode_upward4 = self.decode_up_layer4(x_out_bottom, x_encode_forward4)
        x_decode_upward3 = self.decode_up_layer3(x_decode_upward4, x_encode_forward3)
        x_decode_upward2 = self.decode_up_layer2(x_decode_upward3, x_encode_forward2)
        x_out_decode = self.decode_up_layer1(x_decode_upward2, x_encode_forward1)
        x_out_result = self.out_layer_func(self.output_conv(x_out_decode))
        return x_out_result

In [3]:
# The following code is taken from previous_work/train_unet.py
import random
import json
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import transforms, InterpolationMode
import torchvision.transforms.functional as ttf
from PIL import Image

class SegmentationDataset(Dataset):
    def __init__(self, dataset_dir: str):
        self.dataset_dir = dataset_dir
        with open(os.path.join(dataset_dir, "image_names.json"), "r") as f:
            self.image_names = json.load(f)

    def __getitem__(self, index) -> tuple[Image.Image, Image.Image]:
        image_name = self.image_names[index]
        image = Image.open(os.path.join(self.dataset_dir, "xrays", image_name)).convert("L")
        mask = Image.open(os.path.join(self.dataset_dir, "masks", image_name))

        return image, mask

    def __len__(self) -> int:
        return len(self.image_names)


class Preload(Dataset):
    """
    wrap a dataset to preload all items eagerly
    """

    def __init__(self, dataset):
        self.dataset = dataset
        self.data = []
        for i in range(len(dataset)):
            self.data.append(dataset[i])

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.dataset)


class TransformedDataset(Dataset):
    def __init__(
        self,
        dataset: Dataset,
        flip: float = None,
        crop: float = None,
        rotate: list = None,
    ):
        self.dataset = dataset
        self.flip = flip
        self.crop = crop
        self.rotate = rotate

    def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
        image, mask = self.dataset[index]
        image, mask = TransformedDataset.data_transform(image, mask, self.flip, self.crop, self.rotate)
        return image, mask

    def __len__(self) -> int:
        return len(self.dataset)

    @staticmethod
    def data_transform(
        image: Image.Image, mask: Image.Image = None, flip: float = None, crop: float = None, rotate: list = None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        convert PIL Image to torch Tensor and do some augmentation
        @param image: PIL Image
        @param mask: PIL Image
        @param flip: float, 0.0 ~ 1.0, probability of flip
        @param crop: float, 0.0 ~ 1.0, probability of crop
        @param rotate: list, [min_angle, max_angle], in degree
        """
        dummy_mask = mask if mask is not None else Image.new("L", image.size)
        # resize
        image = image.resize((256, 256), Image.BILINEAR)
        dummy_mask = dummy_mask.resize((256, 256), Image.NEAREST)

        # to tensor
        image = ttf.to_tensor(image)  # shape(1, 256, 256)
        dummy_mask = torch.from_numpy(np.array(dummy_mask)).long().unsqueeze(0)  # shape(1, 256, 256)

        # normalize
        image = ttf.normalize(image, [0.458], [0.173])

        # flip
        if flip is not None and random.random() < flip:
            image = ttf.hflip(image)
            dummy_mask = ttf.hflip(dummy_mask)

        # crop
        if crop is not None and random.random() < crop:
            size = random.randint(128, 225)
            i, j, h, w = transforms.RandomCrop.get_params(image, output_size=(size, size))
            image = ttf.crop(image, i, j, h, w)
            dummy_mask = ttf.crop(dummy_mask, i, j, h, w)

            # resize
            image = ttf.resize(image, (256, 256), InterpolationMode.BILINEAR, antialias=True)
            dummy_mask = ttf.resize(dummy_mask, (256, 256), InterpolationMode.NEAREST, antialias=True)

        # rotate
        if rotate is not None and random.random() < 0.1:
            angle = random.randint(rotate[0], rotate[1])
            image = ttf.rotate(image, angle)
            dummy_mask = ttf.rotate(dummy_mask, angle)

        dummy_mask = dummy_mask.squeeze(0)
        return image, dummy_mask

## Own Code

In [4]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(), 'previous_work'))

import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader, random_split

from sklearn.metrics import average_precision_score

# Set your parameters
dataset_dir = os.path.join(os.getcwd(), 'dentex_dataset/segmentation/enumeration32')
train_ratio = 0.8 
batch_size = 32 
seed = 42

# Load and transform dataset
dataset = Preload(SegmentationDataset(dataset_dir))
dataset = TransformedDataset(dataset, flip=0.1, crop=0.1, rotate=[-10, 10])

# Split dataset into training and validation sets
train_size = int(len(dataset) * train_ratio)
validation_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(
    dataset, [train_size, validation_size], generator=torch.Generator().manual_seed(seed)
)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
model = UNet()  # Define your Unet model

In [5]:
def train_unet(model, dataloader, num_epochs=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize criterion, optimizer
    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()  # Binary Cross Entropy Loss for binary segmentation
    optimizer = torch.optim.Adam(model.parameters())
    
    for epoch in range(num_epochs):
        model.train()
        
        running_loss = 0.0
        all_labels = []
        all_predictions = []

        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device).float()  # Convert labels to float

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels.unsqueeze(1))

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)

            # Store labels and predictions for AP and mAP calculation
            all_labels.append(labels.detach().cpu().numpy())
            all_predictions.append(outputs.detach().cpu().numpy())

        epoch_loss = running_loss / len(dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}')

        # Calculate AP and mAP
        all_labels = np.concatenate(all_labels).reshape(-1, 1)  # Reshape all_labels to 2D array
        all_predictions = np.concatenate(all_predictions).reshape(-1, 1)  # Reshape all_predictions to 2D array
        AP = average_precision_score(all_labels, all_predictions)
        mAP = AP.mean()
        print(f'AP: {AP}, mAP: {mAP}')

In [6]:
train_unet(model, train_loader)

Epoch 1/10, Loss: -1.356386776226176
AP: 0.654060233541386, mAP: 0.654060233541386
Epoch 2/10, Loss: -1.800070093257194
AP: 0.6079126306182415, mAP: 0.6079126306182415
Epoch 3/10, Loss: -1.8083518826247014
AP: 0.6093541867953793, mAP: 0.6093541867953793
Epoch 4/10, Loss: -1.7950889315890965
AP: 0.6126072678348209, mAP: 0.6126072678348209
Epoch 5/10, Loss: -1.7957886923750868
AP: 0.612290863889115, mAP: 0.612290863889115
Epoch 6/10, Loss: -1.8032530222781449
AP: 0.6119226733580009, mAP: 0.6119226733580009
Epoch 7/10, Loss: -1.8433690116232502
AP: 0.6098901879759592, mAP: 0.6098901879759592
Epoch 8/10, Loss: -1.8754253364887898
AP: 0.6068890911472004, mAP: 0.6068890911472004
Epoch 9/10, Loss: -1.890907586937071
AP: 0.6063768459250577, mAP: 0.6063768459250577
Epoch 10/10, Loss: -1.8106625647951
AP: 0.6152941931367462, mAP: 0.6152941931367462
