In [None]:
%matplotlib inline

#### Requirements

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import argparse
import logging
import os
import sys

# Change this path to wherever you installed the Pytorch-UNet module
sys.path.append('C:/Users/Groh/Documents/GitHub/Pytorch-UNet')
os.environ['KMP_DUPLICATE_LIB_OK']='True'

import numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm

import cv2
from PIL import Image

from eval import eval_net
from unet import UNet

from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split

#### Assing absolute paths of image and mask folders

In [None]:
# Change this path to the respective image and mask folders
dir_img = 'C:/Users/Groh/Documents/GitHub/Pytorch-UNet/data/images/'
dir_mask = 'C:/Users/Groh/Documents/GitHub/Pytorch-UNet/data/masks/'
dir_checkpoint = 'C:/Users/Groh/Documents/GitHub/Pytorch-UNet/checkpoints/'

#### Prepare masks for multi instance segmentation

In [None]:
num_files = len([f for f in os.listdir(dir_img)if os.path.isfile(os.path.join(dir_img, f))])

for i in range(num_files):
    name = os.listdir(dir_img)[i]
    file = dir_mask+name
    
    # If a mask is missing, create a new, empty mask
    if not os.path.isfile(file):
        img = Image.open(dir_img+name)
        width, height = img.size

        img_new = Image.new('L', (width, height))
        img_new.save(file, "PNG")
    
    # Convert all images to 8-bit gray scales
    img_grayscale = Image.open(file).convert('L')
    img_grayscale.save(file)
    
    
    # Check, whether more than background is visible in the mask
    image_gray = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
    unique = np.unique(image_gray.flatten())
    if len(unique) > 1:
        # For every pixel change gray-scale value according to categorical value
        img = Image.open(file)
        # Create pixel map
        pixels = img.load()
        for ii in range(img.size[0]): 
            for j in range(img.size[1]):
                if pixels[ii,j] == 52: # Atraum. Pinzette
                    pixels[ii,j] = 1
                elif pixels[ii,j] == 113: # Nadelhalter
                    pixels[ii,j] = 2
#                     elif pixels[i,j] = 52:
#                         pixels[i,j] = 1
#                     elif pixels[i,j] = 52:
#                         pixels[i,j] = 1
#                     elif pixels[i,j] = 52:
#                         pixels[i,j] = 1
#                     elif pixels[i,j] = 52:
#                         pixels[i,j] = 1
#                     elif pixels[i,j] = 52:
#                         pixels[i,j] = 1

    
    ### Only for binary segmentation
    #img_grayscale = Image.open(file)
    #thresh = 10
    #fn = lambda x : 255 if x > thresh else 0
    #r = img_grayscale.convert('L').point(fn, mode='1')
    #r.save(file)
    

#### Train network

In [None]:
def train_net(net,
              device,
              epochs=5,
              batch_size=1,
              lr=0.001,
              val_percent=0.1,
              save_cp=True,
              img_scale=0.5):

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)

    writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
    global_step = 0

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
    ''')

    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
    
    # Use Binary Cross Entropy as loss function, if more than one class is used
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    # Use a plain Sigmoid and Binary Cross Entropy as loss function
    else:
        criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % (n_train // (10 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
                        writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
                    val_score = eval_net(net, val_loader, device)
                    scheduler.step(val_score)
                    writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

                    if net.n_classes > 1:
                        logging.info('Validation cross entropy: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step)
                    else:
                        logging.info('Validation Dice Coeff: {}'.format(val_score))
                        writer.add_scalar('Dice/test', val_score, global_step)

                    writer.add_images('images', imgs, global_step)
                    if net.n_classes == 1:
                        writer.add_images('masks/true', true_masks, global_step)
                        writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
        
        # Save checkpoint after every epoch
        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()


#### Configure training parameters and device

In [None]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

# Automatically uses a GPU, if it is available to torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')
logging.info(f'Using device {device}')

# Change here to adapt to your data
# n_channels = 3 for RGB images, < 3 for grayscale images
# n_classes is the number of probabilities you want to get per pixel
#   - For 1 class and background, use n_classes=1
#   - For 2 classes, use n_classes=1
#   - For N > 2 classes, use n_classes=N
net = UNet(n_channels=3, n_classes=1, bilinear=True)
logging.info(f'Network:\n'
     f'\t{net.n_channels} input channels\n'
     f'\t{net.n_classes} output channels (classes)\n'
     f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
net.to(device=device)

#### Call the training function

In [None]:
try:
    train_net(net=net,
              device=device)
except KeyboardInterrupt:
    torch.save(net.state_dict(), 'INTERRUPTED.pth')
    logging.info('Saved interrupt')
    try:
        sys.exit(0)
    except SystemExit:
        os._exit(0)