In [1]:
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset
import torch
import torchvision as vs
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from torch import optim

from os import listdir
from os.path import splitext
from glob import glob

from PIL import Image
from utils.nn_block import DualConv, DownConv, UpConv, OutputConv

from torchsummary import summary
import logging
from tqdm import tqdm

In [2]:
class UNet(nn.Module):
    
    def __init__(self, n_channels, n_classes):
        
        super().__init__()
        
        self.n_channels = n_channels
        self.n_classes = n_classes
        
        self.inp = DualConv(n_channels, 64)
        
        self.down_conv_1 = DownConv(64, 128)
        self.down_conv_2 = DownConv(128, 256)
        self.down_conv_3 = DownConv(256, 512)
        self.down_conv_4 = DownConv(512, 1024)
        
        self.up_conv_1 = UpConv(1024, 512)
        self.up_conv_2 = UpConv(512, 256)
        self.up_conv_3 = UpConv(256, 128)
        self.up_conv_4 = UpConv(128, 64)
        
        self.op_conv = OutputConv(64, n_classes)
    
    def forward(self, x):
        
        x1 = self.inp(x)
        
        x2 = self.down_conv_1(x1)
        x3 = self.down_conv_2(x2)
        x4 = self.down_conv_3(x3)
        x5 = self.down_conv_4(x4)
        
        x6 = self.up_conv_1(x5, x4)
        x7 = self.up_conv_2(x6, x3)
        x8 = self.up_conv_3(x7, x2)
        x9 = self.up_conv_4(x8, x1)
        
        result = self.op_conv(x9)
        
        return result

In [3]:
class BasicDataset(Dataset):
    def __init__(self, imgs_dir, masks_dir, mask_suffix=''):
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.mask_suffix = mask_suffix

        self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
                    if not file.startswith('.')]
        logging.info(f'Creating dataset with {len(self.ids)} examples')

    def __len__(self):
        return len(self.ids)

    @classmethod
    def preprocess(cls, pil_img, newW, newH):
        w, h = pil_img.size
        
        pil_img = pil_img.resize((newW, newH))

        img_nd = np.array(pil_img)

        if len(img_nd.shape) == 2:
            img_nd = np.expand_dims(img_nd, axis=2)

        # HWC to CHW
        img_trans = img_nd.transpose((2, 0, 1))
        if img_trans.max() > 1:
            img_trans = img_trans / 255

        return img_trans

    def __getitem__(self, i):
        idx = self.ids[i]
        mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*')
        img_file = glob(self.imgs_dir + idx + '.*')

        assert len(mask_file) == 1, \
            f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
        assert len(img_file) == 1, \
            f'Either no image or multiple images found for the ID {idx}: {img_file}'
        mask = Image.open(mask_file[0])
        img = Image.open(img_file[0])

        assert img.size == mask.size, \
            f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'

        img = self.preprocess(img, 360, 480)
        mask = self.preprocess(mask, 360, 480)

        return {
            'image': torch.from_numpy(img).type(torch.FloatTensor),
            'mask': torch.from_numpy(mask).type(torch.FloatTensor)
        }

In [4]:
IMAGES_PATH = '/data/Data/midv500_data/dataset/images/'
MASKS_PATH = '/data/Data/midv500_data/dataset/masks/'
MODEL_CHECKPOINT_PATH = '/data/Data/midv500_data/dataset/checkpoints/'

In [5]:
dataset = BasicDataset(IMAGES_PATH, MASKS_PATH, 0.5)

In [6]:
def train_model(model, device, img_dir, mask_dir, epochs=5, lr=0.001, val_split=0.20, batch_size=1):
    dataset = BasicDataset(img_dir, mask_dir)
    val_samples = int(len(dataset) * val_split)
    train_samples = len(dataset) - val_samples
    train, val = random_split(dataset, [train_samples, val_samples])
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, drop_last=True)
    
    writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}')
    global_step = 0
    
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-8)
    criterion = nn.BCEWithLogitsLoss()
    
    for epoch in range(epochs):
        model.train()
        
        losses = []
        val_losses = []
        avg_val_loss = np.inf
        with tqdm(total=train_samples, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == model.n_channels, \
                    f'Network has been defined with {model.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'
                
                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if model.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)
                
                masks_pred = model(imgs)
                loss = criterion(masks_pred, true_masks)
                losses.append(loss.item())
                writer.add_scalar('Loss/train', sum(losses)/len(losses), global_step)

                pbar.set_postfix(**{'loss: ': sum(losses)/len(losses)})
                
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(model.parameters(), 0.1)
                optimizer.step()
                
                pbar.update(imgs.shape[0])
                global_step += 1

            val_loss = 0
            for val_batch in val_loader:
                imgs, true_masks = val_batch['image'], val_batch['mask']
                imgs = imgs.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.float32)

                with torch.no_grad():
                    mask_pred = model(imgs)
                
                pred = torch.sigmoid(mask_pred)
                pred = (pred > 0.5).float()
                val_loss += criterion(masks_pred, true_masks)
            val_score = val_loss / len(val_loader)
            val_losses.append(val_score)
            avg_val_loss = sum(val_losses) / len(val_losses)
            pbar.set_postfix(**{'loss: ': sum(losses)/len(losses), 'val_loss: ': avg_val_loss})
        
    writer.close()

In [7]:
unet = UNet(n_channels=3, n_classes=1)
# summary(unet.cuda(), (3, 480, 360))

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)
unet = unet.to(device=device)

Device: cuda


In [10]:
train_model(unet,
          device,
          IMAGES_PATH,
          MASKS_PATH)