In [None]:
# !pip install -U git+https://github.com/albu/albumentations --no-cache-dir

In [1]:
import torch 
import torchvision
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from matplotlib import pyplot as plt
from PIL import Image
import glob2
import numpy as np
import albumentations as A 
import copy
from albumentations.pytorch import ToTensor
import cv2
from tqdm import tqdm
import os
import subprocess

In [5]:
# from google.colab import drive
# drive.mount('/content/drive')

# ! tar -xf /content/drive/MyDrive/Colab\ Notebooks/Unet/CelebAMask-HQ-mask-anno.tar 
# ! tar -xf /content/drive/MyDrive/Colab\ Notebooks/Unet/CelebA-HQ-img.tar

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## HyperParameter

In [2]:
train_test_split = 0.9
lr =  0.01
batch_size = 16
epochs = 20
images_path = '/content/CelebA-HQ-img/'
anno_path = '/content/CelebAMask-HQ-mask-anno/'
device = 'cuda'

## Data Cleaning

In [None]:
# Data Clearning 
def delete_imgs_for_no_anno(images_path,anno_path):
  image_list =  glob2.glob(images_path+"/*jpg")
  for i,image in enumerate(image_list):
    # image_name = str(i) +".jpg"
    # image = images_path + "/" + image_name
    # print(image)
    # print(i)
    id =  int(image.split("/")[-1].split(".")[0])

    l_eye_name = ''.join([str(id).rjust(5, '0'), '_', 'l_eye', '.png'])
    r_eye_name = ''.join([str(id).rjust(5, '0'), '_', 'r_eye', '.png'])
    u_lip_name = ''.join([str(id).rjust(5, '0'), '_', 'u_lip', '.png'])
    l_lip_name = ''.join([str(id).rjust(5, '0'), '_', 'l_lip', '.png'])   

    anno_sub_folder = int( id / 2000)

    l_eye_path = anno_path +  "/" + str(anno_sub_folder) + "/" + l_eye_name
    r_eye_path = anno_path + "/"  + str(anno_sub_folder) + "/" + r_eye_name
    u_lip_path = anno_path + "/"  + str(anno_sub_folder) + "/" + u_lip_name
    l_lip_path = anno_path + "/"  + str(anno_sub_folder) + "/" + l_lip_name 

    if not (os.path.isfile(l_eye_path) and os.path.isfile(r_eye_path) and os.path.isfile(u_lip_path) and os.path.isfile(l_lip_path) ):
      print("deleting file",image)
      cmd = 'rm -rf ' + image
      # try:
      #   os.remove(image)
      # except:
      #   a=1
      subprocess.run( cmd ,shell=True )

In [None]:
# delete_imgs_for_no_anno(images_path,anno_path)

## Dataset / DataLoader

In [3]:
from torch.utils.data import Dataset, DataLoader

class mask_dataset(Dataset):
  def __init__(self, images_path, anno_path, transform, data_type, train_test_split, image_list):
    self.images_path = images_path
    self.anno_path = anno_path
    self.transform = transform
    self.imagelist = image_list
    self.data_type = data_type
    self.train_test_split = train_test_split

  def combine_mask_images(self, img1_path, img2_path):
    img1_mask = cv2.imread(img1_path ,cv2.IMREAD_GRAYSCALE)
    img2_mask = cv2.imread(img2_path ,cv2.IMREAD_GRAYSCALE)
    mask = img2_mask > 0.5
    comb_mask = img1_mask.copy()
    comb_mask[mask] = img2_mask[mask]
    return comb_mask


  def __getitem__(self,id):
    idx = int(self.imagelist[id].split("/")[-1].split(".")[0])

    if (self.data_type == 'test'):
      id = id + int( train_test_split * len(self.imagelist) )
      idx = int(self.imagelist[id].split("/")[-1].split(".")[0])

    image_name = str(idx) +".jpg"
    l_eye_name = ''.join([str(idx).rjust(5, '0'), '_', 'l_eye', '.png'])
    r_eye_name = ''.join([str(idx).rjust(5, '0'), '_', 'r_eye', '.png'])
    u_lip_name = ''.join([str(idx).rjust(5, '0'), '_', 'u_lip', '.png'])
    l_lip_name = ''.join([str(idx).rjust(5, '0'), '_', 'l_lip', '.png'])

    anno_sub_folder = int( idx / 2000)

    image_path = self.images_path + "/" + image_name
    # print("image_path",image_path)
    l_eye_path = self.anno_path +  "/" + str(anno_sub_folder) + "/" + l_eye_name
    r_eye_path = self.anno_path + "/"  + str(anno_sub_folder) + "/" + r_eye_name
    u_lip_path = self.anno_path + "/"  + str(anno_sub_folder) + "/" + u_lip_name
    l_lip_path = self.anno_path + "/"  + str(anno_sub_folder) + "/" + l_lip_name

    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (512, 512), interpolation = cv2.INTER_NEAREST)

    # l_eye_mask = cv2.imread(l_eye_path,cv2.IMREAD_GRAYSCALE)
    # r_eye_mask = cv2.imread(r_eye_path,cv2.IMREAD_GRAYSCALE)
    # u_lip_mask = cv2.imread(u_lip_path,cv2.IMREAD_GRAYSCALE)
    # l_lip_mask = cv2.imread(l_lip_path,cv2.IMREAD_GRAYSCALE)
    # masks = [l_eye_mask,r_eye_mask,u_lip_mask,l_lip_mask]

    # print("image_path:",image_path)
    # print("l_eye_path:",l_eye_path)
    eye_mask = self.combine_mask_images(l_eye_path,r_eye_path)
    lip_mask = self.combine_mask_images(u_lip_path,l_lip_path)
    masks = [eye_mask,lip_mask]

    transformed = self.transform(image=image, masks=masks) # one face, several  masks
    transformed_image = transformed['image']
    transformed_masks = transformed['masks']
    transformed_masks = torch.stack([ torch.from_numpy(transformed_masks[0]), torch.from_numpy(transformed_masks[1]) ]) # (2,512,512)
    # transformed_masks = torch.stack([ torch.from_numpy(transformed_masks[0]), torch.from_numpy(transformed_masks[1]),
    #                                  torch.from_numpy(transformed_masks[2]), torch.from_numpy(transformed_masks[3])]) # (4,512,512)
    transformed_masks = transformed_masks / 255.0
    return transformed_image,transformed_masks

  def __len__(self):
    total_image = len(self.imagelist)
    if (self.data_type == 'train'):
      num_dataset = int(self.train_test_split * total_image) 
    else:
      num_dataset = int ((1.0-self.train_test_split)*total_image)
    return num_dataset

In [4]:
train_transform = A.Compose(
    [
        # A.Resize(256, 256),
        A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=15, p=0.2),
        # A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.2),
        ToTensor(),
    ]
)

val_transform = A.Compose(
    [
        A.Resize(512, 512),
        ToTensor(),
    ]
)





In [None]:
# image_list = glob2.glob(images_path+"/*jpg")

# import pickle

# output = open('image_list.pkl', 'wb')
# pickle.dump(image_list, output)
# output.close()

In [5]:
import pickle
file = open('image_list.pkl', 'rb')
image_list = pickle.load(file)
file.close()


In [6]:
train_dataset = mask_dataset(images_path, anno_path, train_transform,'train', train_test_split, image_list)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = mask_dataset(images_path, anno_path, val_transform,'test', train_test_split, image_list)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)


In [11]:
def visualize_augmentations(dataset, idx=0, samples=5):
    dataset = copy.deepcopy(dataset)
    figure, ax = plt.subplots(nrows=samples, ncols=3, figsize=(18, 84))
    for i in range(samples):
        tensor_image, tensor_masks = dataset[idx+i]
        ax[i, 0].imshow(tensor_image.permute(1, 2, 0)) # (3,512,512) --> (512,512,3)
        ax[i, 2].imshow(tensor_masks[0].squeeze(), interpolation="nearest")  # (1,512,512) to (512,512)
        ax[i, 1].imshow(tensor_masks[1].squeeze(), interpolation="nearest") 
        # ax[i, 3].imshow(tensor_masks[2].squeeze(), interpolation="nearest")  # (1,512,512) to (512,512)
        # ax[i, 4].imshow(tensor_masks[3].squeeze(), interpolation="nearest") 
        ax[i, 0].set_title("Augmented image")
        ax[i, 2].set_title("Augmented eye")
        ax[i, 1].set_title("Augmented lips")
        # ax[i, 3].set_title("Augmented upper lips")
        # ax[i, 4].set_title("Augmented lower lips")
        ax[i, 0].set_axis_off()
        ax[i, 1].set_axis_off()
    plt.tight_layout()
    plt.show()

In [None]:
# visualize_augmentations(train_dataset,idx=0,samples=20)
# Refer README.md for visualize_augmentations output

## UNet Model

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


# class OutConv(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(OutConv, self).__init__()
#         self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
#         self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=1)

#     def forward(self, x):
#         return self.conv2(self.conv1(x))

class light_capicity_Conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(light_capicity_Conv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [8]:
class high_capicity_conv(nn.Module):
  def __init__(self,in_channels,out_channels):
    super(high_capicity_conv, self).__init__()
    self.capicity_conv = nn.Sequential(
            DoubleConv(in_channels, in_channels),
            DoubleConv(in_channels, in_channels),
            DoubleConv(in_channels, in_channels)
        )
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

  def forward(self, x):
    x = self.capicity_conv(x)
    x = self.conv(x)
    return x

In [9]:
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 16)
        self.down1 = Down(16, 32)
        self.down2 = Down(32, 64)
        self.down3 = Down(64, 128)
        factor = 2 if bilinear else 1
        self.down4 = Down(128, 256 // factor)
        self.up1 = Up(256, 128 // factor, bilinear)
        self.up2 = Up(128, 64 // factor, bilinear)
        self.up3 = Up(64, 32 // factor, bilinear)
        self.up4 = Up(32, 16, bilinear)
        # self.outc = OutConv(16, n_classes)
        self.eyes = high_capicity_conv(16,1)
        self.lips = high_capicity_conv(16,1)


    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        eyes = self.eyes(x)
        lips = self.lips(x)
        return eyes,lips
        # logits = self.outc(x)
        # return logits

In [10]:
unet = UNet(n_channels=3, n_classes=2)
unet.to(device)

from torchsummary import summary
summary(unet, (3, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 16, 512, 512]             448
       BatchNorm2d-2         [-1, 16, 512, 512]              32
              ReLU-3         [-1, 16, 512, 512]               0
            Conv2d-4         [-1, 16, 512, 512]           2,320
       BatchNorm2d-5         [-1, 16, 512, 512]              32
              ReLU-6         [-1, 16, 512, 512]               0
        DoubleConv-7         [-1, 16, 512, 512]               0
         MaxPool2d-8         [-1, 16, 256, 256]               0
            Conv2d-9         [-1, 32, 256, 256]           4,640
      BatchNorm2d-10         [-1, 32, 256, 256]              64
             ReLU-11         [-1, 32, 256, 256]               0
           Conv2d-12         [-1, 32, 256, 256]           9,248
      BatchNorm2d-13         [-1, 32, 256, 256]              64
             ReLU-14         [-1, 32, 2

## Training Model

In [10]:
def train(train_loader, model, criterion, optimizer, epoch, device, weights):
    model.train()
    stream = tqdm(train_loader)
    for i, (images, target) in enumerate(stream, start=1):
        images = images.to(device)
        eyes,lips = model(images)
        target = target.to(device, dtype=torch.float32)
        target_channel_first = target.permute(1,0,2,3) 

        loss_eyes  = criterion(eyes.squeeze(), target_channel_first[0])
        loss_lips = criterion(lips.squeeze(), target_channel_first[1])

        loss = loss_eyes * weights[0] + loss_lips * weights[1] 
        loss = loss / (sum(weights))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        stream.set_description(
            "Epoch: {epoch}. Train.  Loss:{loss}    ".format(epoch=epoch,loss={loss.item()})
        )

def validate(val_loader, model, criterion, epoch, device):
    model.eval()
    stream = tqdm(val_loader)
    with torch.no_grad():
        for i, (images, target) in enumerate(stream, start=1):
            images = images.to(device)
            eyes,lips = model(images)
            target = target.to(device, dtype=torch.float32)
            target_channel_first = target.permute(1,0,2,3) 

            loss_eyes  = criterion(eyes.squeeze(), target_channel_first[0])
            loss_lips = criterion(lips.squeeze(), target_channel_first[1])

            loss = loss_eyes * 1 + loss_lips * 1
            loss = loss / 2
            # metric_monitor.update("Loss", loss.item())
            stream.set_description(
                "Epoch: {epoch}. Validation. Val Loss:{loss} ".format(epoch=epoch,loss={loss.item()})
            )

In [11]:
class SoftDiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(SoftDiceLoss, self).__init__()

    def forward(self, logits, targets):
        smooth = 1
        num = targets.size(0)
        """
       I am assuming the model does not have sigmoid layer in the end. if that is the case, change torch.sigmoid(logits) to simply logits
        """
        probs = torch.sigmoid(logits)
        # print("probs max",torch.max(probs))
        # print("probs min",torch.min(probs))
        m1 = probs.view(num, -1)
        m2 = targets.view(num, -1)
        intersection = (m1 * m2)

        score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
        score = 1 - score.sum() / num
        return score

In [14]:
criterion = SoftDiceLoss().to(device) 
optimizer = torch.optim.SGD(unet.parameters(), lr=lr)
weights = [3.,1.]

for epoch in range(1, epochs + 1):
  train(train_dataloader, unet, criterion, optimizer, epoch, device, weights)
  torch.save(unet,"/content/drive/MyDrive/CelebAMask/unet_epoch"+str(epoch)+".pt")
  if(epoch==5):
    validate(val_dataloader, unet, criterion, epoch, device)

Epoch: 1. Train.  Loss:{0.9591552019119263}    : 100%|██████████| 1631/1631 [58:25<00:00,  2.15s/it]


FileNotFoundError: ignored

In [16]:
torch.save(unet,"/content/drive/MyDrive/Colab Notebooks/Unet/unet_epoch"+str(epoch)+".pt")

In [17]:
criterion = SoftDiceLoss().to(device) 
optimizer = torch.optim.SGD(unet.parameters(), lr=lr)
weights = [3.,1.]

for epoch in range(2, epochs + 1):
  train(train_dataloader, unet, criterion, optimizer, epoch, device, weights)
  torch.save(unet,"/content/drive/MyDrive/Colab\ Notebooks/Unet/unet_epoch"+str(epoch)+".pt")
  if(epoch==5):
    validate(val_dataloader, unet, criterion, epoch, device)

Epoch: 2. Train.  Loss:{0.8400422930717468}    : 100%|██████████| 1631/1631 [57:40<00:00,  2.12s/it]


FileNotFoundError: ignored

In [19]:
criterion = SoftDiceLoss().to(device) 
optimizer = torch.optim.SGD(unet.parameters(), lr=lr)
weights = [3.,1.]

for epoch in range(3, epochs + 1):
  train(train_dataloader, unet, criterion, optimizer, epoch, device, weights)
  torch.save(unet,"/content/drive/MyDrive/Colab Notebooks/Unet/unet_epoch"+str(epoch)+".pt")
  if(epoch==5):
    validate(val_dataloader, unet, criterion, epoch, device)

Epoch: 3. Train.  Loss:{0.2073861062526703}    : 100%|██████████| 1631/1631 [57:44<00:00,  2.12s/it]
Epoch: 4. Train.  Loss:{0.11479796469211578}    : 100%|██████████| 1631/1631 [57:37<00:00,  2.12s/it]
Epoch: 5. Train.  Loss:{0.08407671749591827}    : 100%|██████████| 1631/1631 [57:19<00:00,  2.11s/it]
  0%|          | 0/182 [00:00<?, ?it/s]

TypeError: ignored

In [14]:
validate(val_dataloader, unet, criterion, epoch, device)

Epoch: 5. Validation. Val Loss:{0.11279505491256714} : 100%|██████████| 182/182 [02:34<00:00,  1.18it/s]


## Evaluation metrics

In [11]:
class Validation_SoftDiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(Validation_SoftDiceLoss, self).__init__()

    def forward(self, logits, targets):
        smooth = 1
        num = targets.size(0)
        """
       I am assuming the model does not have sigmoid layer in the end. if that is the case, change torch.sigmoid(logits) to simply logits
        """
        probs = torch.sigmoid(logits)
        probs[probs>0.5] = 1.0
        probs[probs<=0.5] = 0.0
        # print("probs shapre",probs.shape)
        # print("target shape",targets.shape)
        m1 = probs.view(num, -1)
        m2 = targets.view(num, -1)
        intersection = (m1 * m2)

        score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
        dice_score = score.sum() / num
        return dice_score

def validate(val_loader, model, criterion, epoch, device):
    model.eval()
    stream = tqdm(val_loader)
    total_dice_score = 0
    with torch.no_grad():
        for i, (images, target) in enumerate(stream, start=1):
            images = images.to(device)
            target = target.to(device, dtype=torch.float32)
            eyes,lips = model(images)
            target_channel_first = target.permute(1,0,2,3)

            dice_score_eyes  = criterion(eyes.squeeze(), target_channel_first[0])
            dice_score_lips = criterion(lips.squeeze(), target_channel_first[1])
            dice_score = dice_score_eyes * 1  + dice_score_lips * 1 
            dice_score = dice_score / 2

            # loss = criterion(output, target)
            total_dice_score += dice_score
            # metric_monitor.update("Loss", loss.item())
            stream.set_description(
                "Epoch: {epoch}. Validation. Val dice_score:{dice_score} ".format(epoch=epoch,dice_score={dice_score.item()})
            )
        avg_dice_score = total_dice_score/(i+1)
    return avg_dice_score

In [15]:
epoch = 1
criterion  = Validation_SoftDiceLoss().to(device)
dice_score = validate(val_dataloader, unet, criterion, epoch, device)
print("dice_score on test dataset:", dice_score)

Epoch: 1. Validation. Val dice_score:{0.8940052390098572} : 100%|██████████| 182/182 [02:26<00:00,  1.24it/s]

dice_score on test dataset: tensor(0.8884, device='cuda:0')





In [17]:
print("dice_score on test dataset:", dice_score.item())

dice_score on test dataset: 0.8884453177452087
