In [None]:
from __future__ import division, print_function
import demics as dmc
import cv2
import numpy as np
from matplotlib import pyplot as plt
% matplotlib notebook
import logging
import sys
logging.basicConfig(
        stream=sys.stdout,
        level=logging.INFO,
        format='%(asctime)s %(name)s [%(levelname)s]:%(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')

In [None]:
s_train = 1
s_test = 2
train_file = "../data/section_%s.tif"%(str(s_train).zfill(4))
test_file = "../data/section_%s.tif"%(str(s_test).zfill(4)) 
coords_file = "../data/section_%s.{}.txt"%(str(s_train).zfill(4))
patchsize = 30

In [None]:
# prepare training data for L=2 different classes
I = cv2.imread( train_file, 0 )
labels = ['bg','vessels']
training_patches = []
training_labels = []
for l in labels:
    coords = np.loadtxt(coords_file.format(l))
    patches_l = dmc.extract_patches( I, coords, patchsize )
    training_patches.extend( patches_l )
    training_labels.extend( [labels.index(l)]*len(patches_l) )

In [None]:
# see the vessel patches
vesselpatches = [p for i,p 
                 in enumerate(training_patches) 
                 if training_labels[i]==labels.index("vessels")]
numpatches = len(vesselpatches)
f,axs = plt.subplots( 5, numpatches//5+1, sharex=True,sharey=True)
for i,p in enumerate(vesselpatches):
    axs[i%5,i//5].imshow(p,cmap='gray')
    axs[i%5,i//5].axis('off')
plt.show()

In [None]:
# Training: train the classifier
classifier = dmc.TextureClassifier()
param_grid = [{'C': 10.**np.arange(-3, 4), 'kernel': ['linear']},
              {'C': 10.**np.arange(-3, 4), 'gamma': 10.**np.arange(-4, 3), 'kernel':['rbf']}]
classifier.grid_search(param_grid, n_jobs=-1, cv=3, refit=True)
classifier.train(np.array(training_patches), np.array(training_labels), num_augmentations=5)

## Optionally save and reload trained classifier / parameters

In [None]:
classifier.save(filename="trained_classifier.p", overwrite=True)

In [None]:
classifier = dmc.TextureClassifier.load("trained_classifier.p")

## Extract and predict image patches ...

In [None]:
# collect patches for prediction
I = cv2.imread(test_file, 0)
gridcoords = dmc.grid_coordinates(I.shape, 20, patchsize//2)
patches = dmc.extract_patches(I, gridcoords, patchsize)

In [None]:
# prediction: predict labels for N image patches
labels, scores = classifier.predict(patches)
# labels is an Nx1 vector, with elements 0 <= e <=L-1; scores is an NxL array

In [None]:
# Plot detected vessels
detections = np.array(gridcoords)[np.where(labels==1)[0]]
plt.figure()
plt.imshow(I, 'gray')
plt.plot(detections[:,0], detections[:,1], "r.")
plt.show()

## ... or detect multi-scale features with detect()

In [None]:
I = cv2.imread(test_file, 0)
detections = classifier.detect(I, 1, gridsize=20)

In [None]:
# Plot detected vessels
plt.figure()
plt.imshow(I, 'gray')
plt.plot(detections.x,detections.y, "r.")
plt.show()

In [None]:
### optional: Drop coordinates with only one detection (in just one scale)
detections = detections[detections.duplicated(subset=["x","y"], keep=False)]
plt.figure()
plt.imshow(I, 'gray')
plt.plot(detections.x,detections.y, "r.")
plt.show()

## Cluster detections with MeanShift

In [None]:
from sklearn.cluster import MeanShift
bandwidth = 180/5   # 36
ms = MeanShift(bandwidth=65, bin_seeding=True, cluster_all=False)
ms.fit(detections[["x","y","size"]])

In [None]:
# Plot detected vessels and cluster centers
plt.figure()
plt.imshow(I, 'gray')
plt.plot(detections.x,detections.y, "r.")
plt.plot(ms.cluster_centers_[:,0],
         ms.cluster_centers_[:,1], "b.")
plt.show()

## Cluster detections with DBSCAN

In [None]:
from sklearn.cluster import DBSCAN
dbscan = DBSCAN(eps=23, min_samples=5)   # 17 / 5
db_labels = dbscan.fit_predict(detections[["x","y"]])
XY_center = np.array([[detections.x.values[np.where(db_labels == l)].mean(), 
                       detections.y.values[np.where(db_labels == l)].mean()] 
                      for l in np.unique(db_labels) 
                      if l != -1])

In [None]:
# Plot detected vessels and cluster centers
plt.figure()
plt.imshow(I, 'gray')
plt.plot(detections.x,detections.y, "r.")
plt.plot(XY_center[:,0],
         XY_center[:,1], "b.")
plt.show()