In [1]:
import skimage as ski
import skimage.morphology as mp
from skimage import io, feature,filters
from skimage import img_as_float
from skimage.color import rgb2hsv,rgb2gray,hsv2rgb
from skimage.filters.edges import convolve
from skimage.morphology import disk
import warnings
from matplotlib import pylab as plt
warnings.simplefilter("ignore")
import os
from skimage.morphology import flood_fill
import numpy as np
import random
from sklearn.preprocessing import StandardScaler    
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from sklearn.model_selection import KFold
from torchvision.transforms import ToTensor

In [47]:
def readFile(path):
    img=io.imread(path)
    return img
def convert2gray(img):
    return(img_as_float(rgb2gray(img)))
def preProcess(img):
    img=convert2gray(img)
    img=filters.gaussian(img,sigma=5) 
    img=img**0.4
    return img
def process(img):
    img=filters.sobel(img)
    MIN = np.min(img)
    MAX = np.max(img)
    img = (img - MIN) / (MAX - MIN)
    img[img[:,:] > 1] = 1
    img[img[:,:] < 0] = 0
    img=img*(img>np.percentile(img,80))
    img=(img>0)*1.0

    return img
def postProcess(img,mask):
    img=mp.dilation(img,selem=disk(10))
    img=mp.erosion(img,selem=disk(15))
    mask=convert2gray(mask)
    for i in range(20):
        mask=mp.erosion(mask)
    img=img*mask

    return img
def createResult(base,img):
    for b,i in zip(base,img):
        for c,j in zip(b,i):
            if j==1:
                c[0]=255
                c[1]=255
                c[2]=255

    return base
def wholeProcess(base,mask):
    img=preProcess(base)
    img=process(img)
    img=postProcess(img,mask)
    base=createResult(base,img)
    return base, img
def getFileNames(path):
    return os.listdir(path)
def takeFewExamples(amount):
    images=getFileNames(os.getcwd()+"\\images")
    chosenImages=random.sample(images,amount)
    names=[]
    imagesPaths=[]
    masksPaths=[]
    manualPaths=[]
    for i in chosenImages:
        names.append(i[:-4])
    for i in names:
        imagesPaths.append(os.getcwd()+"\\images\\"+i+".jpg")
        masksPaths.append(os.getcwd()+"\\mask\\"+i+"_mask.tif")
        manualPaths.append(os.getcwd()+"\\manual1\\"+i+".tif")
    return imagesPaths, masksPaths, manualPaths
def takeAll():
    images=getFileNames(os.getcwd()+"\\images")
    names=[]
    imagesPaths=[]
    masksPaths=[]
    manualPaths=[]
    for i in images:
        names.append(i[:-4])
    for i in names:
        imagesPaths.append(os.getcwd()+"\\images\\"+i+".jpg")
        masksPaths.append(os.getcwd()+"\\mask\\"+i+"_mask.tif")
        manualPaths.append(os.getcwd()+"\\manual1\\"+i+".tif")
    return imagesPaths, masksPaths, manualPaths
def showImages(images):
    for i in range(1,len(images)+1):
        showImg(images[i-1])

def showImg(img):
    fig=plt.figure(figsize=(20,10))
    plt.imshow(img)
def countStatistics(image, manual1,masks):
    total=0
    TP=0
    FP=0
    FN=0
    TN=0
    for img,man,mask in zip(image,manual1,masks):
        for i,j,k in zip(img,man,mask):
           
            if k==1:
                if j==i==0:
                    TN+=1
                if j==i==1:
                    TP+=1
                if i==0 and j==1:
                    FN+=1
                if i==1 and j==0:
                    FP+=1
                total+=1
    accuracy=(TP+TN)/total
    sensitivity=TP/(TP+FN)
    precision=TP/(FP+TP)
    specificity=TN/(FP+TN)
    return accuracy,sensitivity,precision,specificity
def printStatistics(tup):
    print("Accuracy: "+str(tup[0]))
    print("Sensitivity: "+str(tup[1]))
    print("Precision: "+str(tup[2]))
    print("Specificity: "+str(tup[3]))
                    
def divideImg(size=5,amount=200):   #returns data and target in lists
    imagesPaths,masksPaths,manualPaths=takeAll()
    half=size//2
    transform=ToTensor()
    allPositive=[]
    allNegative=[]
    for image,maskP,manualP in zip(imagesPaths,masksPaths,manualPaths):
        img=img_as_float(readFile(image))

        mask=convert2gray(readFile(maskP))
        manual=convert2gray(readFile(manualP))
        positive=[]
        negative=[]
        while (len(positive)<amount) or (len(negative)<amount):
            x=random.randint(half,len(img)-half-1)
            y=random.randint(half,len(img[0])-half-1)
            if mask[x][y]==1:
                if manual[x][y]==1:
                    if len(positive)<amount:

                        positive.append(img[x-half:x+half+1,y-half:y+half+1])
                else:
                    if len(negative)<amount:
                        negative.append(img[x-half:x+half+1,y-half:y+half+1])
        allPositive+=positive
        allNegative+=negative
        labels=np.ones(len(allPositive)+len(allNegative))
        labels[len(allPositive):]=0
        allImgs=allPositive+allNegative
    return allImgs,labels

def prepareDataset(x,y):
    kfold=KFold(5,True,1)
    data=TensorDataset(torch.Tensor(x).permute(0, 3, 1, 2),torch.Tensor(y))    #permute for changing from NHWC to NCHW
    return kfold,data
def prepareLoaders(trainData,testData,batchSize):
    trainLoader=DataLoader(dataset=trainData,batch_size=batchSize,shuffle=True)
    testLoader=DataLoader(dataset=testData,batch_size=len(testData))
    return trainLoader,testLoader
def prepareModel():
    layers=[]
    layers.append(nn.Conv2d(3,12,3,stride=1,padding=1))
    layers.append(nn.ReLU())
    layers.append(nn.Conv2d(12,24,3,stride=1,padding=1))
    layers.append(nn.ReLU())
    layers.append(nn.MaxPool2d(kernel_size=2))
    layers.append(nn.Flatten())
    layers.append(nn.Linear(24*4,2))
    return layers
    
def compute_acc(logits, expected):
    pred = logits.argmax(dim=1)
    return (pred == expected).type(torch.float).mean()
def train(kfold,data,layers):
    max_epoch = 10000
    no_improvement = 5
    batchSize=200
    models=[]
    accs=[]
    for i in range(5):   #k-fold-cross-validation
        model=nn.Sequential(*layers)
        cost=torch.nn.CrossEntropyLoss()
        opt=optim.Adam(model.parameters())
        sets=next(kfold.split(data),None)
        trainLoader,testLoader=prepareLoaders(TensorDataset(data[sets[0]][0],data[sets[0]][1]),
                                              TensorDataset(data[sets[1]][0],data[sets[1]][1]),
                                               batchSize)
        train_loss = []
        validation_acc = []
        best_model = None
        best_acc = None
        best_epoch = None


        for n_epoch in range(max_epoch):
            model.train()
            epoch_loss = []
            for X_batch, y_batch in trainLoader:
                opt.zero_grad()
                logits = model(X_batch)
                loss = cost(logits, y_batch.long())
                loss.backward()
                opt.step()        
                epoch_loss.append(loss.detach())
            train_loss.append(torch.tensor(epoch_loss).mean())
            model.eval()
            X, y = next(iter(testLoader))
            logits = model(X)
            acc = compute_acc(logits, y).detach()
            validation_acc.append(acc)
            if best_acc is None or acc > best_acc:
                print("New best epoch ", n_epoch, "acc", acc)
                best_acc = acc
                best_model = model.state_dict()
                best_epoch = n_epoch
            if best_epoch + no_improvement <= n_epoch:
                print("No improvement for", no_improvement, "epochs")
                break
        
        model.load_state_dict(best_model)
        models.append(model)
        accs.append(best_acc)
    return accs,models
def chooseBestModel(models,accs):
    return models[accs.index(max(accs))]
def getTestFragment():
    half=2
    imagesPaths,masksPaths,manualPaths=takeFewExamples(1)
    for image,maskP,manualP in zip(imagesPaths,masksPaths,manualPaths):
        b=readFile(image)
        mask=readFile(maskP)
        manual=convert2gray(readFile(manualP))
        base=b.copy()
        newBase,res=wholeProcess(base.copy(),mask.copy())
        small=[]
        man=manual[30*len(manual)//100:len(manual)*40//100,30*len(manual[0])//100:len(manual[0])*40//100]
        img=base[30*len(base)//100:len(base)*40//100,30*len(base[0])//100:len(base[0])*40//100]
        for i in range(30*len(base)//100,len(base)*40//100):
            for j in range(30*len(base[0])//100,40*len(base[0])//100):
                small.append(base[i-half:i+half+1,j-half:j+half+1])
        showImg(img)
        return small,img,man
def checkModel(model,small,img,man):
        data=TensorDataset(torch.Tensor(small).permute(0, 3, 1, 2))
        
        loader=DataLoader(data,batch_size=len(data))
        X=next(iter(loader))
        correct=0
        logits = model(X[0])
        binY=logits.argmax(dim=1)
        print(binY)
        print(man)
        for i in range(len(img)):
            for j in range(len(img[0])):
                if binY[i*len(img)+j]==man[i][j]:
                    correct+=1
                if binY[i*len(img)+j]==1:
                    img[i][j][0]=255
                    img[i][j][1]=255
                    img[i][j][2]=255
        showImg(img)
        print(correct/len(img)*len(img[0]))
        print(len(small))

In [3]:
def normalMethod():
    imagesPaths,masksPaths,manualPaths=takeFewExamples(1)
    for image,maskP,manualP in zip(imagesPaths,masksPaths,manualPaths):
        base=readFile(image)
        mask=readFile(maskP)
        manual=readFile(manualP)
        newBase,res=wholeProcess(base.copy(),mask.copy())
        showImages([base,newBase,res])
        printStatistics(countStatistics(res,convert2gray(manual),convert2gray(mask)))


    

In [None]:
normalMethod()

In [None]:
allImgs,labels=divideImg()
kfold,data=prepareDataset(allImgs,labels)

In [39]:
layers=prepareModel()
accs,models=train(kfold,data,layers)

New best epoch  0 acc tensor(0.6325)
New best epoch  1 acc tensor(0.6622)
New best epoch  3 acc tensor(0.6644)
New best epoch  4 acc tensor(0.6717)
New best epoch  5 acc tensor(0.6731)
No improvement for 5 epochs


In [40]:
model=chooseBestModel(models,accs)

In [None]:
small,img,man=getTestFragment()

In [None]:
checkModel(model,np.array(small),img,man)