<a href="https://colab.research.google.com/github/Lexuanthangutc/Courses/blob/main/NB_01_UNet_DogCatDataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Packages

Documentations
- [Torchmetrics](https://lightning.ai/docs/torchmetrics/stable/)
- [Semantic Segmentation Model Pytorch](https://smp.readthedocs.io/en/latest/)
- [Albumentations](https://albumentations.ai/docs/examples/pytorch_classification/ )


In [None]:
!pip3 install torchmetrics
!pip3 install segmentation-models-pytorch
!pip3 install albumentations

In [None]:
# Download the dog and cat segmentation dataset oxford
!wget https://thor.robots.ox.ac.uk/~vgg/data/pets/images.tar.gz
!wget https://thor.robots.ox.ac.uk/~vgg/data/pets/annotations.tar.gz

In [None]:
# extract the file
!tar -xf annotations.tar.gz
!tar -xf images.tar.gz

In [None]:
import numpy  as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchmetrics
from torchmetrics import Dice, JaccardIndex
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2 # np.array -> torch.tensor

import os
from tqdm import tqdm
from glob import glob # read and close image in folder

# 2. Read and Understand the Data

In [None]:

# image_path = "/content/images/Abyssinian_1.jpg"
# image = cv2.imread(image_path)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# plt.subplot(1,2,1)
# plt.imshow(image)
# plt.show()
# print(image.shape)

# mask_path = "/content/annotations/trimaps/Abyssinian_1.png"
# mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
# plt.subplot(1,2,2)
# plt.imshow(mask)
# plt.show()
# print(mask)
# print(mask.shape)


In [None]:
class DogCatDataset(Dataset):
    def __init__(self, root_dir, txt_file, transform=None):
        super().__init__()
        self.root_dir = root_dir
        self.txt_file = txt_file
        self.transform = transform
        self.img_path_lst = []
        with open(self.txt_file) as file_in:
            for line in file_in:
                self.img_path_lst.append(line.split(" ")[0])

    def __len__(self):
        return len(self.img_path_lst)

    def __getitem__(self, idx):
        image_path = os.path.join(self.root_dir, "images", "{}.jpg".format(self.img_path_lst[idx]))
        mask_path = os.path.join(self.root_dir, "annotations", "trimaps", "{}.png".format(self.img_path_lst[idx]))
        # read image
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        #read mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        # foreground -> 1
        # background (2) -> 0
        # Not classified (3) -> 1
        mask[mask == 2] = 0
        mask[mask == 3] = 1
        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']

        return image,mask


In [None]:
trainsize = 256
train_transform = A.Compose([
    A.Resize(width=trainsize, height=trainsize),
    A.HorizontalFlip(),
    A.RandomBrightnessContrast(),
    A.Blur(),
    A.Sharpen(),
    A.RGBShift(),
    A.Cutout(),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
    ToTensorV2(),
])

test_transform = A.Compose([
    A.Resize(width=trainsize, height=trainsize),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
    ToTensorV2(),
])

In [None]:
train_dataset = DogCatDataset("/content", "/content/annotations/trainval.txt", transform = train_transform)
test_dataset = DogCatDataset("/content", "/content/annotations/test.txt", transform = test_transform)
# image,mask = train_dataset.__getitem__(10)
# print(image.shape, mask.shape)
# print(mask.unique())

## UnNomarlize

In [None]:
def tensor_to_np(tensor):
    # Make sure the tensor is on the CPU and convert to NumPy
    return tensor.detach().cpu().numpy()

def np_to_tensor(array):
    # Convert a NumPy array back to PyTorch tensor
    return torch.tensor(array).float()

def inverse_norm(image):
    # Define the inverse transformation using Albumentations
    invTrans = A.Compose([
        A.Normalize(mean=[0., 0., 0.], std=[1/0.229, 1/0.224, 1/0.225], max_pixel_value=1.0),
        A.Normalize(mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.], max_pixel_value=1.0),
    ])

    # Example usage:
    # Assuming 'tensor_image' is your normalized image tensor
    tensor_image_np = tensor_to_np(image)  # Convert tensor to numpy array
    tensor_image_np = np.transpose(tensor_image_np, (1, 2, 0))  # CHW to HWC for Albumentations

    # Apply the inverse transformation
    inv_img_np = invTrans(image=tensor_image_np)['image']
    inv_img_np = np.transpose(inv_img_np, (2, 0, 1))  # HWC back to CHW for PyTorch

    # Convert back to tensor
    inv_img_tensor = np_to_tensor(inv_img_np)
    return inv_img_tensor
# inv_img_tensor = inverse_norm(image)

In [None]:
# plt.subplot(1,2,1)
# plt.imshow(inv_img_tensor.permute(1,2,0))
# plt.subplot(1,2,2)
# plt.imshow(mask)
# plt.show()

# 3. Create model

In [None]:
def unet_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels,out_channels,3,1,1),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, 3,1,1),
        nn.ReLU()
    )
class UNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes
        self.downsample = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2,mode='bilinear')
        self.block_down1 = unet_block(3,64)
        self.block_down2 = unet_block(64, 128)
        self.block_down3 = unet_block(128, 256)
        self.block_down4 = unet_block(256, 512)

        self.block_neck = unet_block(512,1024)

        self.block_up1 = unet_block(1024+512, 512)
        self.block_up2 = unet_block(256+512, 256)
        self.block_up3 = unet_block(128+256, 128)
        self.block_up4 = unet_block(64+128, 64)

        self.conv_cls = nn.Conv2d(64, self.num_classes, 1)

    def forward(self,x):
        x1 = self.block_down1(x)
        x = self.downsample(x1)
        x2 = self.block_down2(x)
        x = self.downsample(x2)
        x3 = self.block_down3(x)
        x = self.downsample(x3)
        x4 = self.block_down4(x)
        x = self.downsample(x4)

        x = self.block_neck(x)

        x = torch.cat([x4,self.upsample(x)], dim = 1)
        x = self.block_up1(x)
        x = torch.cat([x3,self.upsample(x)], dim = 1)
        x = self.block_up2(x)
        x = torch.cat([x2,self.upsample(x)], dim = 1)
        x = self.block_up3(x)
        x = torch.cat([x1,self.upsample(x)], dim = 1)
        x = self.block_up4(x)

        x = self.conv_cls(x)
        return x

# model = UNet(1)
# x = torch.rand(4, 3, trainsize, trainsize)
# print("Input shape =", x.shape)
# y = model(x)
# print("Out shape =", y.shape)

In [None]:
class AverageMetric(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val=0
        self.avg=0
        self.sum=0
        self.count=0

    def update(self, val, n=1):
        self.val = val
        self.sum += val*n
        self.count += n
        self.avg = self.sum / self.count


In [None]:
#device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 8

n_workers = os.cpu_count()
print("number of workers=", n_workers)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers = n_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers = n_workers)

#model
model = UNet(1).to(device)

#loss
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 30

#metrics
dice_fn = torchmetrics.Dice(num_classes=2, average="macro").to(device)
iou_fn = torchmetrics.JaccardIndex(num_classes=2, task="binary", average="macro").to(device)
acc_fn = torchmetrics.Accuracy(num_classes=4, task="binary").to(device)

# metric
acc_metric = AverageMetric()
dice_metric = AverageMetric()
iou_metric = AverageMetric()
train_loss_metric = AverageMetric()

In [None]:
for epoch in range(num_epochs):
    acc_metric.reset()
    dice_metric.reset()
    iou_metric.reset()
    train_loss_metric.reset()

    model.train()
    for batch_id, (x, y) in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        n = x.shape[0]
        x = x.to(device).float()
        y = y.to(device).float()
        yhat = model(x)
        yhat = yhat.squeeze() # B,1,H,W -> B,H,W

        loss = criterion(yhat, y)
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            yhat_mask = yhat.sigmoid().round().long() # -> mask 0,1
            # print(yhat_mask.unique(), y.unique())
            dice_score = dice_fn(yhat_mask,y.long())
            iou_score = iou_fn(yhat_mask, y.long())
            accuracy = acc_fn(yhat_mask, y.long())

            acc_metric.update(accuracy.item(), n)
            dice_metric.update(dice_score.item(), n)
            iou_metric.update(iou_score.item(), n)
            train_loss_metric.update(loss.item(), n)

    print("Epoch {}: train_loss = {}, accuracy = {}, iou_score = {}, dice_score = {}".format(
        epoch, train_loss_metric.avg, acc_metric.avg, iou_metric.avg, dice_metric.avg
    ))

In [None]:
torch.save(model.state_dict(), "/content/model_last.pth")

In [None]:
model.eval()

test_iou_metric = AverageMetric()
test_dice_metric = AverageMetric()
with torch.no_grad():
    for batch_id, (x, y) in enumerate(tqdm(test_loader), start=1):
        optimizer.zero_grad()
        n = x.shape[0]
        x = x.to(device).float()
        y = y.to(device).float()
        yhat = model(x)
        yhat = yhat.squeeze() # B,1,H,W -> B,H,W
        y = y.long()
        yhat_mask = yhat.sigmoid().round().long() # -> mask 0,1
        dice_score = dice_fn(yhat_mask,y)
        iou_score = iou_fn(yhat_mask, y)
        test_dice_metric.update(dice_score.item(), n)
        test_iou_metric.update(iou_score.item(), n)

print("TEST: IoU = {}, dice = {}".format(test_iou_metric.avg, test_dice_metric.avg))

In [None]:
import random
model.eval()
idx = random.randint(0,100)
with torch.no_grad():
    x, y = test_dataset[idx]
    print(x.shape,y.shape) # (C, H , W) - > (B, C, H, W) -> model
    x = x.to(device).float().unsqueeze(0)

    yhat = model(x).squeeze() #(1, 1, H, W) -> (H,W)
    yhat_mask = yhat.sigmoid().round().long() # -> mask 0,1

    inv_img_tensor = inverse_norm(x.squeeze())
    # draw, x, y, yhat_mask
    plt.subplot(1,3,1)
    plt.imshow(inv_img_tensor.permute(1,2,0).cpu())
    plt.subplot(1,3,2)
    plt.imshow(y)
    plt.subplot(1,3,3)
    plt.imshow(yhat_mask.cpu())