In [None]:
import os
import random
import csv
import torch
import numpy as np
import cv2
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader
import torch.nn as nn
from collections import defaultdict
import torchvision
import albumentations as albu
from albumentations.pytorch.transforms import ToTensor
import torch.nn.functional as F
from sklearn.model_selection import StratifiedKFold
from torch.utils.data.sampler import Sampler
from Losses import dice_metric
from PIL import Image
#import tqdm
try:
    get_ipython().__class_._name__
    from tqdm.notebook import tqdm
except:
    from tqdm import tqdm

In [None]:
IMG_SIZE         = 1600
ENCODER_WEIGHTS  = 'imagenet'
CLASSES          = 1
ENCODER = 'type of model encodr'
MODEL_PATH = 'Path to the desired model'

In [None]:
class Dataset():
    def __init__(self, PATH_TO_CSV_FILE, transform=None):
        
        self.transform = transform
        
        names = list(pd.read_csv(PATH_TO_CSV_FILE,index_col=0)['0'])
        PATH_TO_IMAGES = 'PATH_TO_IMAGES'
        PATH_TO_MASKS = 'PATH_TO_MASKS'
        self.imgs = [PATH_TO_IMAGES+name for name in names]
        self.masks = [PATH_TO_MASKS+name for name in names]
        
    def __getitem__(self, idx):
        
        img_path = self.imgs[idx]
        mask_path = self.masks[idx]

        image = Image.open(img_path).convert("RGB")
        image = cv2.resize(np.array(image), (IMG_SIZE,IMG_SIZE), interpolation = cv2.INTER_AREA)
        mask = Image.open(mask_path)
        mask = cv2.resize(cv2.cvtColor(np.array(mask), cv2.COLOR_BGR2GRAY), 
                            (IMG_SIZE,IMG_SIZE), interpolation = cv2.INTER_AREA)
        
        if self.transform:
            sample = {"image": image, "mask": mask}
            sample = self.transform(**sample)
            image = sample['image']
            mask = sample['mask']
        
        return {
            'image': image, 
            'mask' : mask
        }
        
    def __len__(self):
        return len(self.imgs)

In [None]:
# Dataset
TFMS = albu.Compose([
    
    ToTensor(),
])
dataset = Dataset(PATH_TO_CSV_FILE = 'test_imgs.csv',transform=TFMS) 
# sanity check
image, mask = dataset[20]['image'], dataset[20]['mask']
image.shape, mask.shape

In [None]:
# Dataloader
dataloader = DataLoader(dataset, 1, 
                              shuffle=True , 
                              num_workers=4)
# sanity check
images, masks = next(iter(dataloader))['image'], next(iter(dataloader))['mask']
images.shape, masks.shape

In [None]:
class AverageMeter:
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def acc_metric(input, target):
    inp = torch.where(input>0.5, torch.tensor(1, device='cuda'), torch.tensor(0, device='cuda'))
    acc = (inp.squeeze(1) == target).float().mean()
    return acc
    
def iou_metric(outputs: torch.Tensor, labels: torch.Tensor):
    # You can comment out this line if you are passing tensors of equal shape
    # But if you are passing output from UNet or something it will most probably
    # be with the BATCH x 1 x H x W shape
    #outputs = outputs.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W
    SMOOTH = 1e-6
    intersection = (outputs.type("torch.cuda.IntTensor") 
     & labels.type("torch.cuda.IntTensor")).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = (outputs.type("torch.cuda.IntTensor") |
             labels.type("torch.cuda.IntTensor")).float().sum((1, 2))         # Will be zzero if both are 0
    
    iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
    
    thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
    
    return thresholded.mean()  # Or thresholded.mean() if you are interested in average across the batch
    
def dice_metric(probability, truth, threshold=0.5, reduction='none'):
    batch_size = len(truth)
    with torch.no_grad():
        probability = probability.view(batch_size, -1)
        truth = truth.view(batch_size, -1)
        assert(probability.shape == truth.shape)

        p = (probability > threshold).float()
        t = (truth > 0.5).float()

        t_sum = t.sum(-1)
        p_sum = p.sum(-1)
        neg_index = torch.nonzero(t_sum == 0)
        pos_index = torch.nonzero(t_sum >= 1)

        dice_neg = (p_sum == 0).float()
        dice_pos = 2 * (p*t).sum(-1)/((p+t).sum(-1))

        dice_neg = dice_neg[neg_index]
        dice_pos = dice_pos[pos_index]
        dice = torch.cat([dice_pos, dice_neg])

        num_neg = len(neg_index)
        num_pos = len(pos_index)

    return dice

def evaluate(valid_loader, model, device='cuda', metric=dice_metric):
    losses = AverageMeter()
    model = model.to(device)
    model.eval()
    tk0 = tqdm(valid_loader, total=len(valid_loader))
    with torch.no_grad():
        for b_idx, data in enumerate(tk0):
            for key, value in data.items():
                data[key] = value.to(device)
            out   = model(data['image'])
            out   = torch.sigmoid(out)
            score = metric(out, data['mask']).cpu()
            losses.update(score.mean().item(), valid_loader.batch_size)
            tk0.set_postfix(metric_score=losses.avg)
    return losses.avg

In [None]:
model = None
torch.cuda.empty_cache()
model = smp.Unet(
    encoder_name= ENCODER, 
    encoder_weights= ENCODER_WEIGHTS, 
    classes= CLASSES, 
    activation= None,
)
model.load_state_dict(torch.load(MODEL_PATH))
dice = evaluate(dataloader, model, metric=dice_metric)
acc = evaluate(dataloader, model, metric=acc_metric)
iou =  evaluate(dataloader, model, metric=iou_metric)

print("Iou : "+str(np.round(iou*100,2))+"\%,  Dice score :"+str(np.round(dice*100,2))+"\%, Accuracy : "+str(np.round(acc*100,2))+"\%")

In [None]:
for i in range(len(dataset)):
# testing model
  image = dataset[i]['image'].reshape([1, 3, 1600, 1600])
 
  image = image.to('cuda')

  out = model(image).to('cpu').detach().squeeze()
  image=image[0,:,:,:].to('cpu')

  binary = np.zeros((out.shape[0],out.shape[1],3))
  binary[out>0.5] = [1,1,1]
  binary[out<0.5] = [0,0,0]
  

  f, axarr = plt.subplots(1,2,figsize=(16, 16))
  axarr[0].imshow(image.numpy().transpose(1,2,0))
  axarr[2].imshow(binary)
 
  
  plt.show()