In [1]:
import numpy as np
from random import randint
import torch.nn as nn
from torch.utils.data import Dataset
import torch
import torchvision as vs
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from torch import optim

from os import listdir
from os.path import splitext
from glob import glob

import cv2
from utils.nn_block import DualConv, DownConv, UpConv, OutputConv
from utils.image_aug import flip, add_gaussian_noise, add_uniform_noise, change_brightness, normalization2

from torchsummary import summary
import logging
from tqdm import tqdm

In [2]:
IMAGES_PATH = '/data/Data/midv500_data/dataset/images_resized/'
MASKS_PATH = '/data/Data/midv500_data/dataset/masks_resized/'
MODEL_CHECKPOINT_PATH = '/data/Data/midv500_data/dataset/checkpoints/'

In [3]:
class UNet(nn.Module):
    
    def __init__(self, n_channels, n_classes):
        
        super().__init__()
        
        self.n_channels = n_channels
        self.n_classes = n_classes
        
        self.inp = DualConv(n_channels, 64)
        
        self.down_conv_1 = DownConv(64, 128)
        self.down_conv_2 = DownConv(128, 256)
        self.down_conv_3 = DownConv(256, 512)
        self.down_conv_4 = DownConv(512, 1024)
        
        self.up_conv_1 = UpConv(1024, 512)
        self.up_conv_2 = UpConv(512, 256)
        self.up_conv_3 = UpConv(256, 128)
        self.up_conv_4 = UpConv(128, 64)
        
        self.op_conv = OutputConv(64, n_classes)
    
    def forward(self, x):
        
        x1 = self.inp(x)
        
        x2 = self.down_conv_1(x1)
        x3 = self.down_conv_2(x2)
        x4 = self.down_conv_3(x3)
        x5 = self.down_conv_4(x4)
        
        x6 = self.up_conv_1(x5, x4)
        x7 = self.up_conv_2(x6, x3)
        x8 = self.up_conv_3(x7, x2)
        x9 = self.up_conv_4(x8, x1)
        
        result = self.op_conv(x9)
        
        return result

In [4]:
class BasicDataset(Dataset):
    def __init__(self, imgs_dir, masks_dir, mask_suffix=''):
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.mask_suffix = mask_suffix
        self.height = 480
        self.width = 360

        self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
                    if not file.startswith('.')]
        logging.info(f'Creating dataset with {len(self.ids)} examples')

    def __len__(self):
        return len(self.ids)

    def preprocess(self, img, mask):
        
        # Augmentation
        # flip {0: vertical, 1: horizontal, 2: both, 3: none}
        flip_num = randint(0, 3)
        img = flip(img, flip_num)
        mask = flip(mask, flip_num)
        
        # Noise Determine {0: Gaussian_noise, 1: uniform_noise
        if randint(0, 1):
            # Gaussian_noise
            gaus_sd, gaus_mean = randint(0, 20), 0
            img = add_gaussian_noise(img, gaus_mean, gaus_sd)
        else:
            # uniform_noise
            l_bound, u_bound = randint(-20, 0), randint(0, 20)
            img = add_uniform_noise(img, l_bound, u_bound)
        
        # Brightness
        pix_add = randint(-20, 20)
        img = change_brightness(img, pix_add)
        
        # Normalize the image
        img = normalization2(img, max=1, min=0)
        
        
        # Normalize mask to only 0 and 1
        mask = mask/255
        # msk_as_np = np.expand_dims(msk_as_np, axis=0)  # add additional dimension
        
        if len(mask.shape) == 2:
            mask = np.expand_dims(mask, axis=2)
            
        # HWC to CHW
        img = img.transpose((2, 0, 1))
        mask = mask.transpose((2, 0, 1))
        
        return img, mask

    def __getitem__(self, i):
        idx = self.ids[i]
        mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*')
        img_file = glob(self.imgs_dir + idx + '.*')

        assert len(mask_file) == 1, \
            f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
        assert len(img_file) == 1, \
            f'Either no image or multiple images found for the ID {idx}: {img_file}'
        
        mask = cv2.imread(mask_file[0])
        img = cv2.imread(img_file[0])
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
        mask = cv2.threshold(mask, 100, 255, cv2.THRESH_BINARY)[1]

        img, mask = self.preprocess(img, mask)

        return {
            'image': torch.from_numpy(img).type(torch.FloatTensor),
            'mask': torch.from_numpy(mask).type(torch.FloatTensor)
        }

In [5]:
dataset = BasicDataset(IMAGES_PATH, MASKS_PATH)

In [9]:
def train_model(model, device, img_dir, mask_dir, checkpoint_dir, epochs=20, lr=0.001, val_split=0.20, batch_size=1):
    dataset = BasicDataset(img_dir, mask_dir)
    val_samples = int(len(dataset) * val_split)
    train_samples = len(dataset) - val_samples
    train, val = random_split(dataset, [train_samples, val_samples])
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, drop_last=True)
    
    writer = SummaryWriter(logdir=checkpoint_dir, comment=f'LR_{lr}_BS_{batch_size}')
    global_step = 0
    
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-8)
    criterion = nn.BCEWithLogitsLoss()
    
    training_loss = []
    validation_loss = []

    for epoch in range(1, epochs):
        model.train()

        losses = []
        val_losses = []
        avg_val_loss = np.inf
        
        with tqdm(total=train_samples, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                
                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if model.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)
                
                masks_pred = model(imgs)
                loss = criterion(masks_pred, true_masks)
                losses.append(loss.item())
                writer.add_scalar('Loss/train', sum(losses)/len(losses), global_step)

                pbar.set_postfix(**{'loss': sum(losses)/len(losses)})
                
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(model.parameters(), 0.1)
                optimizer.step()
                
                pbar.update(imgs.shape[0])
                global_step += 1

            val_loss = 0
            for val_batch in val_loader:
                imgs, true_masks = val_batch['image'], val_batch['mask']
                imgs = imgs.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.float32)

                with torch.no_grad():
                    mask_pred = model(imgs)
                
                pred = torch.sigmoid(mask_pred)
                pred = (pred > 0.5).float()
                val_loss += criterion(masks_pred, true_masks).item()
            val_score = val_loss / len(val_loader)
            val_losses.append(val_score)
            avg_val_loss = sum(val_losses) / len(val_losses)
            pbar.set_postfix(**{'loss': sum(losses)/len(losses), 'val_loss': avg_val_loss})
        
        training_loss.append(sum(losses)/len(losses))
        validation_loss.append(avg_val_loss)
        
        if epoch % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': training_loss,
                'val_loss': validation_loss,
                'global_step': global_step
            }, checkpoint_dir + str(epoch) + '_model.pth')
        
    writer.close()

In [7]:
unet = UNet(n_channels=3, n_classes=1)
summary(unet.cuda(), (3, 480, 360))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 480, 360]           1,792
       BatchNorm2d-2         [-1, 64, 480, 360]             128
              ReLU-3         [-1, 64, 480, 360]               0
            Conv2d-4         [-1, 64, 480, 360]          36,928
       BatchNorm2d-5         [-1, 64, 480, 360]             128
              ReLU-6         [-1, 64, 480, 360]               0
          DualConv-7         [-1, 64, 480, 360]               0
         MaxPool2d-8         [-1, 64, 240, 180]               0
            Conv2d-9        [-1, 128, 240, 180]          73,856
      BatchNorm2d-10        [-1, 128, 240, 180]             256
             ReLU-11        [-1, 128, 240, 180]               0
           Conv2d-12        [-1, 128, 240, 180]         147,584
      BatchNorm2d-13        [-1, 128, 240, 180]             256
             ReLU-14        [-1, 128, 2

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)
unet = unet.to(device=device)

Device: cuda


In [None]:
train_model(unet,
            device,
            IMAGES_PATH,
            MASKS_PATH,
            MODEL_CHECKPOINT_PATH)

Epoch 1/20:   6%|▌         | 981/16800 [03:14<53:09,  4.96img/s, loss=0.219]