## Loading data

In [None]:
#!pip3 install segmentation_models_pytorch

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp

from os import path
from matplotlib import gridspec

In [None]:
print(smp.__version__)

In [None]:
ROOT_DIR = 'root_dir'
DATASET_NAME = 'baseline'
DATASET_PART = 'test'

# images, labels (optional)
DIR_IMAGES = path.join(ROOT_DIR, DATASET_NAME, DATASET_PART, 'images/')
DIR_LABELS = None

print('in:', DIR_IMAGES)

In [None]:
# keyworded visualization: title_name_1=image1, title_name_2=image2, ...
def visualize(**images):
    n = len(images)
    plt.figure(figsize=(16, 10))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i+1)
        plt.title(' '.join(name.split('_')).title()) # keyword as title
        plt.imshow(image)
    plt.show()

# non-keyworded images: image1, image2, ...
def visualize_grid(*images):
    n, cols = len(images), 4
    rows = int(np.ceil(n / cols))
    gs = gridspec.GridSpec(rows, cols)
    fig = plt.figure(figsize=(16, 4*rows))
    fig.tight_layout()
    for i in range(n):
        ax = fig.add_subplot(gs[i])
        ax.imshow(images[i])
        ax.axis('off')    
    plt.show()

## Dataloader

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

In [None]:
class Dataset(BaseDataset):
    
    CLASSES = ['background', 'wound']
    
    def __init__(self, images_dir, masks_dir, classes=None, augmentation=None, preprocessing=None):
        self.ids = os.listdir(images_dir)
        self.images_fps = [path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = None if masks_dir == None else [path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv.imread(self.images_fps[i])
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
        mask = None
        
        if self.masks_fps != None: 
            mask = cv.imread(self.masks_fps[i], 0)
        
            # extract certain classes from mask (e.g. wound)
            masks = [(mask == v) for v in self.class_values]
            mask = np.stack(masks, axis=-1).astype('float')
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image)#, mask=mask)
            image = sample['image']
            mask = None if mask == None else sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image)#, mask=mask)
            image = sample['image']
            mask = None if mask == None else sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.ids)

In [None]:
# inspect data
pred_dataset = Dataset(DIR_IMAGES, DIR_LABELS, classes=['wound'])
image, mask = pred_dataset[0] # get some sample

if mask is None:
    print('No mask for image available')
    visualize(image=image)    
else:
    visualize(image=image, wound_mask=mask.squeeze())
    
# save image height/width
IMAGE_HEIGHT, IMAGE_WIDTH, _ = image.shape
assert (IMAGE_HEIGHT % 32) + (IMAGE_WIDTH % 32) == 0, 'image height/width must be divisible by 32'

## Predict and visualize

In [None]:
import torch
import numpy as np
import os
import albumentations as albu
import segmentation_models_pytorch as smp
from PIL import Image

In [None]:
def get_validation_augmentation():
    test_transform = [
        #albu.PadIfNeeded(IMAGE_HEIGHT, IMAGE_WIDTH) # padding for %32=0
    ]    
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):    
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor),
        #albu.Lambda(image=to_tensor, mask=to_tensor),
    ]    
    return albu.Compose(_transform)

In [None]:
DIR_MODELS = 'models_dir'
!ls $DIR_MODELS | grep tt

In [None]:
MODEL_NAMES = [
    
    # id15, currently best performing base model ensemble
    #'training___baseline___fpn_se_resnext101_32x4d_imagenet_sigmoid___adam_lr1e-04_lrd15e-05_lrd21e-05___medium___bs24_e150_ed1100_ed2135___sameconf1',
    #'training___baseline___fpn_se_resnext101_32x4d_imagenet_sigmoid___adam_lr1e-04_lrd15e-05_lrd21e-05___medium___bs24_e150_ed1100_ed2135___sameconf2',
    #'training___baseline___fpn_se_resnext101_32x4d_imagenet_sigmoid___adam_lr1e-04_lrd15e-05_lrd21e-05___medium___bs24_e150_ed1100_ed2135___sameconf3',
    #'training___baseline___fpn_se_resnext101_32x4d_imagenet_sigmoid___adam_lr1e-04_lrd15e-05_lrd21e-05___medium___bs24_e150_ed1100_ed2135___sameconf4',
    #'training___baseline___fpn_se_resnext101_32x4d_imagenet_sigmoid___adam_lr1e-04_lrd15e-05_lrd21e-05___medium___bs24_e150_ed1100_ed2135___sameconf5'
    
    # id22 (test id1)
    #'training__baseline+tsynq95id15_4k__fpn_se_resnext101_32x4d_imagenet_sigmoid__adam_lr1e-04_lrd11e-05__medium__bs24_e150_ed1120__sameconf1',
    #'training__baseline+tsynq95id15_4k__fpn_se_resnext101_32x4d_imagenet_sigmoid__adam_lr1e-04_lrd11e-05__medium__bs24_e150_ed1120__sameconf2',
    #'training__baseline+tsynq95id15_4k__fpn_se_resnext101_32x4d_imagenet_sigmoid__adam_lr1e-04_lrd11e-05__medium__bs24_e150_ed1120__sameconf3',
    #'training__baseline+tsynq95id15_4k__fpn_se_resnext101_32x4d_imagenet_sigmoid__adam_lr1e-04_lrd11e-05__medium__bs24_e150_ed1120__sameconf4',
    #'training__baseline+tsynq95id15_4k__fpn_se_resnext101_32x4d_imagenet_sigmoid__adam_lr1e-04_lrd11e-05__medium__bs24_e150_ed1120__sameconf5',
    
    # id30 (test id2)
    'training___baseline+ttsynq95id15_2k___fpn_se_resnext101_32x4d_imagenet_sigmoid___adam_lr1e-04_lrd11e-05___medium___bs24_e150_ed1120___sameconf1',
    'training___baseline+ttsynq95id15_2k___fpn_se_resnext101_32x4d_imagenet_sigmoid___adam_lr1e-04_lrd11e-05___medium___bs24_e150_ed1120___sameconf2',
    'training___baseline+ttsynq95id15_2k___fpn_se_resnext101_32x4d_imagenet_sigmoid___adam_lr1e-04_lrd11e-05___medium___bs24_e150_ed1120___sameconf3',
    'training___baseline+ttsynq95id15_2k___fpn_se_resnext101_32x4d_imagenet_sigmoid___adam_lr1e-04_lrd11e-05___medium___bs24_e150_ed1120___sameconf4',
    'training___baseline+ttsynq95id15_2k___fpn_se_resnext101_32x4d_imagenet_sigmoid___adam_lr1e-04_lrd11e-05___medium___bs24_e150_ed1120___sameconf5'

]

In [None]:
ENCODER = 'se_resnext101_32x4d'
WEIGHTS = 'imagenet'
CLASSES = ['wound']
DEVICE = 'cuda'
VISUALIZE = False

DIR_PREDS = path.join(
    'output_dir', 
    DATASET_NAME, DATASET_PART, MODEL_NAMES[0] + '__5x5_ensemble'
)
print('out:', DIR_PREDS)

In [None]:
# regular last model ensemble
#models = [torch.load(path.join(DIR_MODELS, mn, 'last_model.pth')) for mn in MODEL_NAMES]

# polyak last model ensemble
models = []
epochs = 5
for mn in MODEL_NAMES:
    for e in range(0, epochs):
        models.append(torch.load(path.join(DIR_MODELS, mn, 'last_model-{}.pth'.format(e))))

# for otherwise composed ensembles
#models = [
    #torch.load(path.join(DIR_MODELS, MODEL_NAMES[0], 'best_model_dice.pth')),
    #torch.load(path.join(DIR_MODELS, MODEL_NAMES[1], 'best_model_dice.pth')),
    #torch.load(path.join(DIR_MODELS, MODEL_NAMES[2], 'best_model_dice_fold3.pth')),
    #torch.load(path.join(DIR_MODELS, MODEL_NAMES[3], 'best_model_dice_fold4.pth')),
    #torch.load(path.join(DIR_MODELS, MODEL_NAMES[4], 'best_model_dice_fold5.pth'))
#]

# IMPORTANT: MUST HAVE INSTALLED SMP VERSION UNDER WHICH MODELS WHERE CREATED TO LOAD THEM VIA TORCH

In [None]:
# test dataset for image visualization (w/o transformations)
pred_dataset_vis = Dataset(DIR_IMAGES, DIR_LABELS, classes=CLASSES)

# test dataset for inference (w/ transformations)
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, WEIGHTS)
pred_dataset = Dataset(
    DIR_IMAGES, DIR_LABELS, classes=CLASSES,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn)
)

In [None]:
SAMPLES   = range(0, len(pred_dataset))
#SAMPLES   = range(0, 10)
VISUALIZE = False
WRITE     = True
THRESHOLD = 0.5
TTA       = False


# create dir if not existing
if WRITE and not path.isdir(DIR_PREDS):
    # add inference parameters to prediction name
    out = DIR_PREDS + f'__t{int(THRESHOLD*100)}_tta{1 if TTA else 0}'
    
    print('out:', out)
    os.makedirs(out)  
    
# load model, predict
for i in SAMPLES:
    image_name = pred_dataset.ids[i]    
    image_vis = pred_dataset_vis[i][0].astype('uint8')
    image, gt_mask = pred_dataset[i]    
    image_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    
    pr_prob = None
    with torch.no_grad(): # eval mode, do not compute gradients
        
        # prediction 
        pr_probs = [m.predict(image_tensor).squeeze(0).cpu() for m in models] # standard torch.nn.Module.forward syntax        
        if TTA:
            for flip_dim in [(2,), (3,), (2,3)]: # tta
                pr_probs_tta =[torch.flip(m(torch.flip(image_tensor, flip_dim)), flip_dim).squeeze(0).cpu() for m in models]
                for p in pr_probs_tta:
                    pr_probs.append(p)
        pr_prob = torch.stack(pr_probs, 0).mean(0).squeeze(0).cpu()        
    
   
    # threshold-based prob map cut
    pr_mask = (pr_prob > THRESHOLD).numpy().astype(np.uint8) # 0 = background, 1 = wound

    # visualize
    if VISUALIZE:
        if gt_mask is None:
            visualize(image=image_vis, probability_map=pr_prob, predicted_mask=pr_mask)
        else:
            gt_mask = gt_mask.squeeze()
            visualize(image=image_vis, ground_truth_mask=gt_mask, predicted_mask=pr_mask)
    
    # write
    if WRITE:
        image_path = path.join(out, image_name)
        im = Image.fromarray(pr_mask).convert("L") # grayscale
        im.save(image_path)
        print('Saved:', image_path)
    