# Packages

In [2]:
import time, os, random
from datetime import datetime
import tqdm
from glob import glob
from collections import OrderedDict

import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms as T
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchsummary import summary

import albumentations as A
import segmentation_models_pytorch as smp
import urllib.request

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('torch ver.', torch.__version__)

torch ver. 1.11.0


# Parameters

In [10]:
IMAGE_PATH = '../input/original/' # Path to original image dir
MASK_PATH = '../input/mask/' # Path to mask image dir
PROXY_PIN = '../PIN.txt' # [userID]:[passward]@[proxy server adrress]:[port number]
N_CLASSES = 2 # Number of classes including background, i.e. N_CLASSES=2 for binary segmentation)
BATCH_SIZE = 1
SEED = 19

# Utilities

In [11]:
# Random seed
def seed_everything(seed=SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything()

# Proxy setting (to download encorder weigths of pretrained model) 
def set_proxy(pin=PROXY_PIN):
    with open(pin, 'r') as f:
        proxy_pass = f.read()
        proxies = {'http': 'http://' + proxy_pass, 'https': 'http://' + proxy_pass}
    proxy = urllib.request.ProxyHandler(proxies)
    opener = urllib.request.build_opener(proxy)
    urllib.request.install_opener(opener)
set_proxy()

# Data

In [12]:
names = []
for dirname in glob(IMAGE_PATH + '*.jpg'):
    filename = dirname.split('\\')[-1].split('.')[0]
    names.append(filename)
df = pd.DataFrame({'id': names}, index=np.arange(0, len(names)))
df

Unnamed: 0,id


In [13]:
# Split data
X_train_, X_test = train_test_split(df['id'].values, test_size=0.1, random_state=19)
X_train, X_val = train_test_split(X_train_, test_size=0.2, random_state=19)

print('Train Size   : ', len(X_train))
print('Val Size     : ', len(X_val))
print('Test Size    : ', len(X_test))

ValueError: With n_samples=0, test_size=0.1 and train_size=None, the resulting train set will be empty. Adjust any of the aforementioned parameters.

In [None]:
# Show sample data
samples = [0, 1, 2]
fig, ax = plt.subplots(1, len(samples), figsize=(20, 5))
for i in samples:
    img = cv2.imread(IMAGE_PATH + df['id'][i] + '.jpg')
    mask = cv2.imread(MASK_PATH + df['id'][i] + '_mask.png', cv2.IMREAD_GRAYSCALE)
    # print('Image Size', np.asarray(img).shape)
    # print('Mask Size', np.asarray(mask).shape)
    ax[i].imshow(img)
    ax[i].imshow(mask*(255/np.max(mask)), cmap='jet', alpha=0.3)
    ax[i].axis('off')
    ax[i].set_title(df['id'][i])
fig.suptitle('Original image with mas')
plt.show()

# Dataset

In [None]:
class ImgMaskDataset(Dataset):
    
    def __init__(self, img_path, mask_path, X, mean, std, transform=None, patch=False):
        self.img_path = img_path
        self.mask_path = mask_path
        self.X = X
        self.transform = transform
        self.patches = patch
        self.mean = mean
        self.std = std
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        img = cv2.imread(self.img_path + self.X[idx] + '.jpg')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_path + self.X[idx] + '_mask.png', cv2.IMREAD_GRAYSCALE)
        
        if self.transform is not None:
            aug = self.transform(image=img, mask=mask)
            img = Image.fromarray(aug['image'])
            mask = aug['mask']
        
        if self.transform is None:
            img = Image.fromarray(img)
        
        t = T.Compose([T.ToTensor(), T.Normalize(self.mean, self.std)])
        img = t(img)
        mask = torch.from_numpy(mask).long()
        
        if self.patches:
            img, mask = self.tiles(img, mask)
            
        return img, mask
    
    def tiles(self, img, mask):
        img_patches = img.unfold(1, 512, 512).unfold(2, 768, 768) 
        img_patches  = img_patches.contiguous().view(3,-1, 512, 768)
        img_patches = img_patches.permute(1,0,2,3)
        mask_patches = mask.unfold(0, 512, 512).unfold(1, 768, 768)
        mask_patches = mask_patches.contiguous().view(-1, 512, 768)
        return img_patches, mask_patches
        

In [None]:
mean =[0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

# Transformer
t_train = A.Compose([
    A.Resize(704, 1056, interpolation=cv2.INTER_NEAREST),
    A.HorizontalFlip(),
    A.VerticalFlip(),
    A.GridDistortion(p=0.2),
    A.RandomBrightnessContrast((0, 0.5),(0, 0.5)),
    A.GaussNoise()
    ])

t_val = A.Compose([
    A.Resize(704, 1056, interpolation=cv2.INTER_NEAREST),
    A.HorizontalFlip(),
    A.GridDistortion(p=0.2)
    ])

# Dataset
train_set = ImgMaskDataset(IMAGE_PATH, MASK_PATH, X_train, mean, std, t_train, patch=False)
val_set = ImgMaskDataset(IMAGE_PATH, MASK_PATH, X_val, mean, std, t_val, patch=False)

# Dataloader
train_loader = DataLoader(train_set, BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, BATCH_SIZE, shuffle=True)

# Model

In [None]:
model = smp.Unet('efficientnet-b2', encoder_weights='imagenet', classes=N_CLASSES, activation=None, encoder_depth=5, decoder_channels=[256, 128, 64, 32, 16])
# model = smp.Unet('resnet34', encoder_weights='imagenet', classes=N_CLASSES, activation=None, encoder_depth=5, decoder_channels=[256, 128, 64, 32, 16])
# model = smp.Unet('mobilenet_v2', encoder_weights=None, classes=N_CLASSES, activation=None, encoder_depth=5, decoder_channels=[256, 128, 64, 32, 16])

In [None]:
print('Total params: ', sum(p.numel() for p in model.parameters()))
print('Trainable params:', sum(p.numel() for p in model.parameters() if p.requires_grad))

# Train

In [None]:
# Metrics
def pixel_accuracy(output, mask):
    with torch.no_grad():
        output = torch.argmax(F.softmax(output, dim=1), dim=1)
        correct = torch.eq(output, mask).int()
        accuracy = float(correct.sum()) / float(correct.numel())
    return accuracy

# IoU averaged over classes 
def meanIoU(pred_mask, mask, smooth=1e-10, n_classes=N_CLASSES):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)
        iou_per_class = []
        for c in range(n_classes):
            pred_label = (pred_mask==c)
            true_label = (mask==c)

            if true_label.long().sum().item() == 0: # no exists label
                iou_per_class.append(np.nan)
            else:
                intersect = torch.logical_and(pred_label, true_label).sum().float().item()
                union = torch.logical_or(pred_label, true_label).sum().float().item()
                iou = (intersect + smooth) / (union + smooth)
                iou_per_class.append(iou)
        
        return np.nanmean(iou_per_class)

# Dice coeffient average over classes 
def meanDice(pred_mask, mask, smooth=1e-10, n_classes=N_CLASSES):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)
        dice_per_class = []
        for c in range(n_classes):
            pred_label = (pred_mask==c)
            true_label = (mask==c)

            if true_label.long().sum().item() == 0: # no exists label
                dice_per_class.append(np.nan)
            else:
                intersect = torch.logical_and(pred_label, true_label).sum().float().item()
                left = torch.sum(pred_label)
                right = torch.sum(true_label)
                dice = (2. * intersect + smooth) / (left + right + smooth)
                dice_per_class.append(dice)

        return np.nanmean(dice_per_class)


In [None]:
# Learning rate
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

# Train
def fit(epochs, model, train_loader, val_loader, criterion, optimizer, scheduler, patch=False):
    run_date = str(datetime.now().strftime("%Y%m%d-%H%M%S"))                 
    model_path = f'../model/model_{run_date}.pt' # trained model saved in
    torch.cuda.empty_cache() # memory freeing
    train_losses, val_losses  = [], []
    train_iou, val_iou = [], []
    train_dice, val_dice = [], []
    train_acc, val_acc = [], []
    lrs = []
    min_loss = np.inf
    decrease = 0
    not_improve = 0
    model.to(device)
    
    # Training
    fit_start = time.time()
    # # Loop per epochs
    for e in range(epochs):
        model.train()
        epoch_start = time.time()
        train_loss = 0
        iou_score = 0
        dice_score = 0
        accuracy = 0
        # # Loop per batch
        with tqdm(train_loader) as pbar:
            for i, data in enumerate(pbar):
                pbar.set_description('[Epoch {:d} train]'.format(e + 1))
                # Load input
                image_tiles, mask_tiles = data
                if patch:
                    b, n_tiles, c, h, w = image_tiles.size()
                    image_tiles = image_tiles.view(-1, c, h, w)
                    mask_tiles = mask_tiles.view(-1, h, w)
                image = image_tiles.to(device)
                mask = mask_tiles.to(device)
                # Forward
                output = model(image)
                loss = criterion(output, mask)
                accuracy += pixel_accuracy(output, mask) 
                iou_score += meanIoU(output, mask)
                dice_score += meanDice(output, mask)
                # Backward
                loss.backward() 
                optimizer.step() # update weigth
                optimizer.zero_grad() # reset gradient
                lrs.append(get_lr(optimizer))
                scheduler.step()
                train_loss += loss.item()
                pbar.set_postfix(OrderedDict(dice=dice_score/(i+1)))
        
        # Validation
        model.eval()
        val_loss = 0
        val_accuracy = 0
        val_iou_score = 0
        val_dice_score = 0
        with torch.no_grad():
            with tqdm(val_loader) as pbar:
                for i, data in enumerate(pbar):
                    pbar.set_description('[Epoch {:d} valid]'.format(e + 1))
                    # Load input
                    image_tiles, mask_tiles = data
                    if patch:
                        b, n_tiles, c, h, w = image_tiles.size()
                        image_tiles = image_tiles.view(-1, c, h, w)
                        mask_tiles = mask_tiles.view(-1, h, w)
                    image = image_tiles.to(device)
                    mask = mask_tiles.to(device)
                    # Forward
                    output = model(image)
                    loss = criterion(output, mask)
                    val_loss += loss.item()
                    val_iou_score += meanIoU(output, mask)
                    val_dice_score += meanDice(output, mask)
                    val_accuracy += pixel_accuracy(output, mask)
                    pbar.set_postfix(OrderedDict(dice=val_dice_score/(i+1)))
        
        # Metrics averaged in batch
        train_losses.append(train_loss/len(train_loader))
        val_losses.append(val_loss/len(val_loader))
        
        # Save model if loss updated
        if min_loss > (val_loss / len(val_loader)):
            # print('Loss decreasing...{:.3f} >> {:.3f}'.format(min_loss, (val_loss/len(val_loader))))
            min_loss = val_loss / len(val_loader)
            best_dice = val_dice_score / len(val_loader)
            decrease += 1
            not_improve = 0
            if decrease >= 3:
                # print('Saving model...')
                torch.save(model, model_path)
                
        # Early stopping if loss not updated 3 times in succession
        else:
            not_improve += 1
            print(f'Loss Not decrease for {not_improve} time')
            if not_improve == 3:
                print('Stop training since loss is not decreased for 3 times in succession')
                break
            
        # Score
        train_iou.append(iou_score / len(train_loader))
        val_iou.append(val_iou_score / len(val_loader))
        train_dice.append(dice_score / len(train_loader))
        val_dice.append(val_dice_score / len(val_loader))
        train_acc.append(accuracy / len(train_loader))
        val_acc.append(val_accuracy / len(val_loader))
        print('Epoch:{}/{} |'.format(e+1, epochs),
        'Train Loss: {:.3f} |'.format(train_loss/len(train_loader)),
        'Val Loss: {:.3f} |'.format(val_loss/len(val_loader)),
        'Train Dice: {:.3f} |'.format(dice_score/len(train_loader)),
        'Val Dice: {:.3f} |'.format(val_dice_score/len(val_loader)),
        'Train Acc: {:.3f} |'.format(accuracy/len(train_loader)),
        'Val Acc: {:.3f} |'.format(val_accuracy/len(val_loader)),
#        'Train mIoU: {:.3f} |'.format(iou_score/len(train_loader)),
#        'Val mIoU: {:.3f} |'.format(val_iou_score/len(val_loader)),
        'Time: {:.2f} min.'.format((time.time()-epoch_start)/60))
    
    history = {'train_loss': train_losses, 'val_loss': val_losses,
               'train_miou': train_iou, 'val_miou': val_iou,
               'train_mdice': train_dice, 'val_mdice': val_dice,
               'train_acc' : train_acc, 'val_acc': val_acc,
               'lrs': lrs}
    print('Total time: {:.2f} min.' .format((time.time() - fit_start)/60))
    best_model_path = model_path.split('.pt')[0] + '_dice-{:.3f}'.format(best_dice) + '.pt'
    os.rename(model_path, best_model_path)
    if decrease<3:
        print('Not model saved since training was not proceeding.')
    else:
        print(f'Model saved in {best_model_path}'

    return history, best_model_path


In [None]:
max_lr = 1e-3
epochs = 20
weight_decay = 1e-4
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, steps_per_epoch=len(train_loader))
history, model_path = fit(epochs, model, train_loader, val_loader, criterion, optimizer, scheduler)

# Result

In [None]:
def plot_loss(history):
    plt.plot( history['train_loss'], label='train', marker='o')
    plt.plot(history['val_loss'], label='val', marker='o')
    plt.title('Loss per epoch'); plt.ylabel('loss');
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()


def plot_iou(history):
    plt.plot(history['train_miou'], label='train_mIoU', marker='*')
    plt.plot(history['val_miou'], label='val_mIoU',  marker='*')
    plt.title('Score per epoch'); plt.ylabel('mean IoU')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()

def plot_dice(history):
    plt.plot(history['train_mdice'], label='train_mdice', marker='*')
    plt.plot(history['val_mdice'], label='val_mdice',  marker='*')
    plt.title('Score per epoch'); plt.ylabel('mean dice')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()

def plot_acc(history):
    plt.plot(history['train_acc'], label='train_accuracy', marker='*')
    plt.plot(history['val_acc'], label='val_accuracy',  marker='*')
    plt.title('Accuracy per epoch'); plt.ylabel('Accuracy')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()

plot_loss(history)
plot_dice(history)
plot_acc(history)

# Evaluation

In [None]:
# Test dataset
class TestDataset(Dataset):

    def __init__(self, img_path, mask_path, X, transform=None):
        self.img_path = img_path
        self.mask_path = mask_path
        self.X = X
        self.transform = transform
    
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        img = cv2.imread(self.img_path + self.X[idx] + '.jpg')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_path + self.X[idx] + '_mask.png', cv2.IMREAD_GRAYSCALE)

        if self.transform is not None:
            aug = self.transform(image=img, mask=mask)
            img = Image.fromarray(aug['image'])
            mask = aug['mask']

        if self.transform is None:
            img = Image.fromarray(img)
        
        mask = torch.from_numpy(mask).long()

        return img, mask

In [None]:
# Score
def predict_dice(model, image, mask, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    model.eval()
    model.to(device)
    t = T.Compose([T.ToTensor(), T.Normalize(mean, std)])
    image = t(image)
    image = image.to(device)
    mask = mask.to(device)
    with torch.no_grad():
        image = image.unsqueeze(0)
        mask = mask.unsqueeze(0)
        output = model(image)
        score = meanDice(output, mask)
        masked = torch.argmax(output, dim=1)
        masked = masked.cpu().squeeze(0)
    return masked, score

def predict_acc(model, image, mask, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    model.eval()
    model.to(device)
    t = T.Compose([T.ToTensor(), T.Normalize(mean, std)])
    image = t(image)
    image=image.to(device)
    mask = mask.to(device)
    with torch.no_grad():
        image = image.unsqueeze(0)
        mask = mask.unsqueeze(0)
        output = model(image)
        acc = pixel_accuracy(output, mask)
        masked = torch.argmax(output, dim=1)
        masked = masked.cpu().squeeze(0)
    return masked, acc

In [None]:
idx = 0
# Load test data
t_test = A.Resize(768, 1152, interpolation=cv2.INTER_NEAREST)
test_set = TestDataset(IMAGE_PATH, MASK_PATH, X_test, transform=t_test)
image, mask = test_set[idx]

# Load model
model = torch.load(model_path)
# predict
pred_mask, score = predict_dice(model, image, mask)

In [None]:
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1,4, figsize=(20,10))
ax1.imshow(image)
ax1.set_title(f'Original: {X_test[idx]}')

ax2.imshow(image)
ax2.imshow(mask, alpha=0.3, cmap='jet')
ax2.set_title('Origina with Ground truth')
ax2.set_axis_off()

ax3.imshow(mask, cmap='jet')
ax3.set_title('Ground truth')
ax3.set_axis_off()

ax4.imshow(pred_mask, cmap='jet')
ax4.set_title('Predict (Dice coff: {:.3f})'.format(score))
ax4.set_axis_off()

In [None]:
def dice_score(model, test_set):
    score_dice = []
    for i in tqdm(range(len(test_set))):
        img, mask = test_set[i]
        pred_mask, score = predict_dice(model, img, mask)
        score_dice.append(score)
    return score_dice

def acc_score(model, test_set):
    accuracy = []
    for i in tqdm(range(len(test_set))):
        img, mask = test_set[i]
        pred_mask, acc = predict_acc(model, img, mask)
        accuracy.append(acc)
    return accuracy

## Average score for all test set

In [None]:
dice = dice_score(model, test_set)
print('Test Set mean dice {:.3f}'.format(np.mean(dice)))
acc = acc_score(model, test_set)
print('Test set mean accuracy {:.3f}'.format(np.mean(acc)))