In [None]:
## This script identifies the prototype using K-Means over deep features

In [None]:
# import packages
import os, cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
## packages for clustering
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import *

In [None]:
# load data and check dimensions
## load shadow-free images
img = cv2.imread('../output/images/shadow_free.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) ## convert the color channel to RGB
## load LFDP ground-based labels
df = pd.read_csv('../data/labels/LFDP_labels.csv', index_col=0)
## load CNN features
feat_cnn = pd.read_csv('../output/features/feat_cnn.csv', header=None).to_numpy()

In [None]:
# filter the ground labels
THRESH_DIAM = 20 ## diameter threshold
df = df[df.ALIVE == 'A'] ## remove dead trees
df = df[df.DIAM > THRESH_DIAM] ## remove small trees
df = df[df.pix_1 < 9600] 
df = df[df.pix_2 < 15000] ## remove boundary points
df.index = range(df.shape[0])

In [None]:
# KMeans clustering over the features
n_c = 25
time_start = datetime.now()
print('Start:', time_start)
## transform data
X = StandardScaler().fit_transform(feat_cnn)
## run minibatch kmeans over training set
kmeans = KMeans(n_clusters=n_c, random_state=2020, max_iter=2000, tol=0)
kmeans = kmeans.fit(X)
## save labels and centers as .csv
labels = kmeans.predict(X)
labels_pd = pd.DataFrame(labels)
lab_path = '../output/KMeans/labels_' + str(n_c) + '.csv'
labels_pd.to_csv(lab_path)
centers = kmeans.cluster_centers_
centers_pd = pd.DataFrame(centers)
cent_path = '../output/KMeans/centers_' + str(n_c) + '.csv'
centers_pd.to_csv(cent_path)
print('Time:', datetime.now() - time_start)    

In [None]:
n_c = 25
labels = pd.read_csv('../output/KMeans/labels_' + str(n_c) + '.csv', index_col=0).to_numpy().flatten()

In [None]:
# calculate species relevance for each cluster
## species relevance = number of labels / number of image patches
rs = int(np.sqrt(img.shape[0] * img.shape[1] / feat_cnn.shape[0]))
nrow = int(img.shape[0] / rs)
ncol = int(img.shape[1] / rs)
SPECIES = 'PREMON'
sp_dict = {'PREMON': ['PREMON', 'ROYBOR'], 'CECSCH': ['CECSCH'], 'MANBID': ['MANBID']}
sp_list = sp_dict[SPECIES]
num_sp = np.zeros(n_c)
size_cluster = np.array([sum(labels == i) for i in range(n_c)])
## loop over the label data
for i in range(df.shape[0]):
    loc_1 = df['pix_2'][i] // rs
    loc_2 = df['pix_1'][i] // rs
    location = loc_1 * ncol + loc_2
    if df.SPECIES[i] in sp_list:
        num_sp[labels[location]] += 1
## calculate relevance
ratio_sp = num_sp / size_cluster
index_sort = np.array(sorted(range(len(ratio_sp)), key=lambda k: ratio_sp[k]), dtype=str)
ratio_sp_sorted = np.sort(ratio_sp)
plt.figure(figsize=(8, 3))
plt.bar(index_sort, ratio_sp_sorted)
plt.xlabel('Cluster ID', fontsize=12)
plt.ylabel('Target Relevance', fontsize=12)
plt.savefig('../output/figs/relevance_' + SPECIES + '.png', bbox_inches='tight', dpi=300) 
plt.show()
plt.figure(figsize=(8, 3))
plt.bar(index_sort, size_cluster[index_sort.astype(int)], color='green')
plt.xlabel('Cluster ID', fontsize=12)
plt.ylabel('Number of Patches', fontsize=12)
plt.show()
print('In total, species relevance is', num_sp.sum()/size_cluster.sum())
## save the palm relevance as relavance.csv
pd.DataFrame(ratio_sp).to_csv('../output/relevance/relevance_' + SPECIES + '_' + str(THRESH_DIAM) + '_' + str(n_c) + '.csv')

In [None]:
# visualize cluster results
rs = 100
## choose clusters
cluster_p = 8
cluster_n = 18
num_col = 96
## compare images visually
img_p = np.zeros([5*rs, 5*rs, 3], dtype=np.uint8)
img_n = np.zeros([5*rs, 5*rs, 3], dtype=np.uint8)
## insert image patches
for i in range(25):
    ## get locations
    loc_p = np.where(labels == cluster_p)[0][i]
    loc_n = np.where(labels == cluster_n)[0][i]
    ## insert image patch for palm
    rr = loc_p // num_col
    cc = loc_p - rr * num_col
    r_1 = rr * rs
    r_2 = r_1 + rs
    c_1 = cc * rs
    c_2 = c_1 + rs
    rr = i // 5
    cc = i - rr * 5
    img_patch = img[r_1:r_2, c_1:c_2].copy()
    img_patch[:2] = 255
    img_patch[-2:] = 255
    img_patch[:,:2] = 255
    img_patch[:,-2:] = 255
    img_p[(rr*rs):(rr*rs+rs), (cc*rs):(cc*rs+rs)] = img_patch
    ## insert image patch for non-palm
    rr = loc_n // num_col
    cc = loc_n - rr * num_col
    r_1 = rr * rs
    r_2 = r_1 + rs
    c_1 = cc * rs
    c_2 = c_1 + rs
    rr = i // 5
    cc = i - rr * 5
    img_patch = img[r_1:r_2, c_1:c_2].copy()
    img_patch[:2] = 255
    img_patch[-2:] = 255
    img_patch[:,:2] = 255
    img_patch[:,-2:] = 255
    img_n[(rr*rs):(rr*rs+rs), (cc*rs):(cc*rs+rs)] = img_patch
plt.figure(figsize=(12,12)) 
plt.subplot(121),plt.imshow(img_p)
plt.xticks([]), plt.yticks([])
plt.subplot(122),plt.imshow(img_n)
plt.xticks([]), plt.yticks([])
plt.show()