## Import

In [None]:
import os
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np


 ## Dataset

In [None]:
class XRayDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = [f for f in os.listdir(image_dir) if f.endswith('.png')]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])

        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask


def get_transforms():
    return transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

## Plot Image

In [None]:
def plot_images_and_masks(images, masks, filenames):
    num_images = len(images)
    fig, axes = plt.subplots(num_images, 2, figsize=(10, 5 * num_images))

    for i in range(num_images):
        image, mask = images[i], masks[i]
        image_filename, mask_filename = filenames[i]

        axes[i, 0].imshow(image.squeeze(), cmap='gray')
        axes[i, 0].set_title(f'Image: {image_filename}')
        axes[i, 0].axis('off')

        axes[i, 1].imshow(mask.squeeze(), cmap='gray')
        axes[i, 1].set_title(f'Mask: {mask_filename}')
        axes[i, 1].axis('off')

    plt.tight_layout()
    plt.show()

## Test Dataset

In [None]:
def test_dataset():
    image_dir = 'data/train/image'  
    mask_dir = 'data/train/mask'  

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    dataset = XRayDataset(image_dir, mask_dir, transform)

    # Number of samples to display
    num_samples = 2
    images, masks, filenames = [], [], []

    for i in range(num_samples):
        image, mask = dataset[i]
        image_filename = dataset.images[i]
        mask_filename = dataset.images[i]
        print(mask.shape)

        images.append(image.numpy())
        masks.append(mask.numpy())
        filenames.append((image_filename, mask_filename))

    plot_images_and_masks(images, masks, filenames)

test_dataset()

## Model - UNET

In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)
"""" Full assembly of the parts to form the complete network """

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        logits = self.outc(x)
        return logits
    
    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

In [None]:
# """ Parts of the U-Net model """

# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class DoubleConv(nn.Module):
#     """(convolution => [BN] => ReLU) * 2"""

#     def __init__(self, in_channels, out_channels):
#         super().__init__()
#         self.double_conv = nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True)
#         )

#     def forward(self, x):
#         return self.double_conv(x)


# class Down(nn.Module):
#     """Downscaling with maxpool then double conv"""

#     def __init__(self, in_channels, out_channels):
#         super().__init__()
#         self.maxpool_conv = nn.Sequential(
#             nn.MaxPool2d(2),
#             DoubleConv(in_channels, out_channels)
#         )

#     def forward(self, x):
#         return self.maxpool_conv(x)


# class Up(nn.Module):
#     """Upscaling then double conv"""

#     def __init__(self, in_channels, out_channels, bilinear=True):
#         super().__init__()

#         # if bilinear, use the normal convolutions to reduce the number of channels
#         if bilinear:
#             self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
#         else:
#             self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

#         self.conv = DoubleConv(in_channels, out_channels)

#     def forward(self, x1, x2):
#         x1 = self.up(x1)
#         # input is CHW
#         diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
#         diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

#         x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
#                         diffY // 2, diffY - diffY // 2])
#         # if you have padding issues, see
#         # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
#         # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
#         x = torch.cat([x2, x1], dim=1)
#         return self.conv(x)


# class OutConv(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(OutConv, self).__init__()
#         self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

#     def forward(self, x):
#         return self.conv(x)


# """ Full assembly of the parts to form the complete network """

# class UNet(nn.Module):
#     def __init__(self, n_channels, n_classes, bilinear=True):
#         super(UNet, self).__init__()
#         self.n_channels = n_channels
#         self.n_classes = n_classes
#         self.bilinear = bilinear

#         inter_channel = 16

#         self.inc = DoubleConv(n_channels, inter_channel)
#         self.down1 = Down(inter_channel, inter_channel*2)
#         self.down2 = Down(inter_channel*2, inter_channel*4)
#         self.down3 = Down(inter_channel*4, inter_channel*8)
#         self.down4 = Down(inter_channel*8, inter_channel*8)
#         self.up1 = Up(inter_channel*16, inter_channel*4, bilinear)
#         self.up2 = Up(inter_channel*8, inter_channel*2, bilinear)
#         self.up3 = Up(inter_channel*4, inter_channel, bilinear)
#         self.up4 = Up(inter_channel*2, inter_channel, bilinear)
#         self.outc = OutConv(inter_channel, n_classes)

#     def forward(self, x):
#         x1 = self.inc(x)
#         x2 = self.down1(x1)
#         x3 = self.down2(x2)
#         x4 = self.down3(x3)
#         x5 = self.down4(x4) # 1/16
#         x = self.up1(x5, x4)
#         x = self.up2(x, x3)
#         x = self.up3(x, x2)
#         x = self.up4(x, x1)
#         logits = self.outc(x)
#         return logits



# if __name__ == "__main__":
#     unet = UNet(n_channels=1, n_classes=2)
#     aa = torch.ones((2, 1, 128, 128))
#     bb = unet(aa)
#     print (bb.shape)


In [None]:
# def test_unet():
#     # Step 1: Instantiate the model
#     model = UNET(1,1)

#     # Step 2: Generate a sample input
#     sample_input = torch.randn(1, 1, 256, 256)  # Batch size of 1, 1 channel, 256x256 image

#     # Step 3: Forward pass
#     output = model(sample_input)

#     # Step 4: Check the output
#     print(f"Input shape: {sample_input.shape}")
#     print(f"Output shape: {output.shape}")
#     assert output.shape == sample_input.shape, "The output shape is incorrect!"

#     # Optionally visualize the input and output
#     input_image = sample_input.squeeze().detach().numpy()
#     output_image = output.squeeze().detach().numpy()

#     fig, axes = plt.subplots(1, 2, figsize=(12, 6))
#     axes[0].imshow(input_image, cmap='gray')
#     axes[0].set_title('Input Image')
#     axes[0].axis('off')

#     axes[1].imshow(output_image, cmap='gray')
#     axes[1].set_title('Output Image')
#     axes[1].axis('off')

#     plt.tight_layout()
#     plt.show()

# test_unet()


In [None]:
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets):
        smooth = 1e-5

        intersection = torch.sum(inputs * targets)
        dice_coeff = (2. * intersection + smooth) / (torch.sum(inputs) + torch.sum(targets) + smooth)

        return 1. - dice_coeff

# def dice_coefficient(inputs, targets):
#     smooth = 1e-5
#     intersection = torch.sum(inputs * targets)
#     dice = (2. * intersection + smooth) / (torch.sum(inputs) + torch.sum(targets) + smooth)
#     return dice

# class DiceLoss(nn.Module):
#     def __init__(self):
#         super(DiceLoss, self).__init__()

#     def forward(self, inputs, targets):
#         dice = dice_coefficient(inputs, targets)
#         return 1. - dice


## Train Model

In [None]:
def train_model(train_image_dir, train_mask_dir, valid_image_dir, valid_mask_dir, epochs=10, batch_size=20, lr=1e-5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print(('cuda' if torch.cuda.is_available() else 'cpu'))

    train_dataset = XRayDataset(train_image_dir, train_mask_dir, transform=get_transforms())
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    valid_dataset = XRayDataset(valid_image_dir, valid_mask_dir, transform=get_transforms())
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)  # Ensure no shuffling

    model = UNet(1, 1).to(device)
    # model = UNet(n_channels=1, n_classes=1).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    # criterion = DiceLoss()
    criterion = torch.nn.BCEWithLogitsLoss()
    # sigmoid = torch.nn.Sigmoid()

    train_loss_values = []
    valid_loss_values = []
    valid_iou_values = []

    for epoch in range(epochs):
        model.train()
        epoch_train_loss = 0

        for batch_idx, (images, masks) in enumerate(train_loader):
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            # print(outputs)

            # Ensure the dimensions match
            outputs = F.interpolate(outputs, size=masks.shape[2:], mode='bilinear', align_corners=False)

            # Convert masks to float and ensure they are between 0 and 1
            masks = masks.float() / 255.0

            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item()

            # Print loss each iteration
            # print(f'Epoch {epoch + 1}, Iteration {batch_idx + 1}, Train Loss: {loss.item()}')

        train_loss_values.append(epoch_train_loss / len(train_loader))
        
        # Print loss each epoch
        print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {epoch_train_loss / len(train_loader)}')
        
        # Validation step
        model.eval()  # Evaluation mode
        epoch_valid_loss = 0
        epoch_valid_iou = 0

        with torch.no_grad():
            for images_val, masks_val in valid_loader:
                images_val, masks_val = images_val.to(device), masks_val.to(device)

                outputs_val = model(images_val)
                outputs_val = F.interpolate(outputs_val, size=masks_val.shape[2:], mode='bilinear', align_corners=False)

                masks_val = masks_val.float() / 255.0

                valid_loss = criterion(outputs_val, masks_val)
                epoch_valid_loss += valid_loss.item()

                # Calculate Mean IoU using logits
                iou = compute_mean_iou(outputs_val, masks_val)
                # print("IOU = ", iou)
                epoch_valid_iou += iou
                
        valid_loss_values.append(epoch_valid_loss / len(valid_loader))
        valid_iou_values.append(epoch_valid_iou / len(valid_loader))
        
        print(f'Epoch {epoch + 1}/{epochs}, Validation Loss: {epoch_valid_loss / len(valid_loader)}, Validation IoU: {epoch_valid_iou / len(valid_loader)}')
        print("=====================================================================")

    torch.save(model.state_dict(), 'unet_model.pth')

    # Plot the training and validation loss curves
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(1, epochs + 1), train_loss_values, marker='o', label='Training Loss')
    plt.plot(range(1, epochs + 1), valid_loss_values, marker='o', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(range(1, epochs + 1), valid_iou_values, marker='o', label='Validation IoU')
    plt.title('Validation IoU')
    plt.xlabel('Epoch')
    plt.ylabel('Mean IoU')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

def compute_mean_iou(outputs, masks):
    sigmoid = torch.nn.Sigmoid()
    outputs = sigmoid(outputs)  
    predicted_masks = (outputs > 0.5).float()

    intersection = torch.logical_and(predicted_masks, masks).sum().float()
    union = torch.logical_or(predicted_masks, masks).sum().float() + 1e-10  # Avoid division by zero

    iou = (intersection / union).mean()
    return iou.item()

# Example usage
print("Start training")
train_model('data/train/image', 'data/train/mask', 'data/valid/image', 'data/valid/mask')
print("Finished Training")

## Test Model Prediction

In [None]:
def test_model_prediction(model, image_path, mask_path, transform):
    model.eval()

    image = Image.open(image_path).convert('L')
    mask = Image.open(mask_path).convert('L')

    image_tensor = transform(image).unsqueeze(0)
    mask_tensor = transform(mask).unsqueeze(0)

    with torch.no_grad():
        output = model(image_tensor)

    predicted_mask = (output > 0.5).float()


    image_np = image_tensor.squeeze().cpu().numpy() 
    mask_np = mask_tensor.squeeze().cpu().numpy() 
    predicted_mask_np = predicted_mask.squeeze().cpu().numpy()

    # Calculate IoU
    iou = compute_iou(predicted_mask_np, mask_np)

    # Visualize input image, ground truth mask, and predicted mask
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(image_np, cmap='gray')
    axes[0].set_title('Input Image')
    axes[0].axis('off')

    axes[1].imshow(mask_np) 
    axes[1].set_title('Ground Truth Mask')
    axes[1].axis('off')

    axes[2].imshow(predicted_mask_np)
    axes[2].set_title(f'Predicted Mask (IoU: {iou:.4f})')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

    return iou

def compute_iou(predicted_mask, mask):
    predicted_mask = predicted_mask > 0.5
    mask = mask > 0.5

    intersection = np.logical_and(predicted_mask, mask).sum()
    union = np.logical_or(predicted_mask, mask).sum()

    iou = intersection / (union + 1e-10)

    return iou

def main():
    image_path = 'data/test/image/covid_1579.png'
    mask_path = 'data/test/mask/covid_1579.png'

    # Obtain the transformation pipeline
    transform = get_transforms()

    # Load model and weights
    model = UNet(n_channels=1, n_classes=1)
    model.load_state_dict(torch.load('unet_model.pth'))
    model.eval()

    # Test the model on a single image and mask
    iou = test_model_prediction(model, image_path, mask_path, transform)
    print(f"IOU for the image: {iou:.4f}")


main()