Examples of point and feature base clustering approaches, specifically the distance thresholding and mean shift clustering algorithms.

In [1]:
import torch
import numpy as np
import math

import operator
from torch import exp, sqrt

In [2]:
# Definition of threshold method

def threshold_cluster(data, threshold):
  data.cuda()
  distance_pred = torch.cdist(data, data)

  ones = torch.ones(distance_pred.shape).cuda()
  zeros = torch.zeros(distance_pred.shape).cuda()

  distance_pred = torch.where((distance_pred < threshold).cuda(), ones, zeros).cpu().detach().numpy()

  cluster_assignments = np.zeros(distance_pred.shape[0], dtype=np.int64)
  next_label = 1

  for i in range(distance_pred.shape[0]):
      if cluster_assignments[i] == 0:
        cluster_assignments[i] = next_label
        next_label += 1

        ind = np.where(distance_pred[i] == 1)

        for j in ind[0]:
            cluster_assignments[j] = cluster_assignments[i]

  return cluster_assignments 

In [3]:
# Definition of mean shift method

def distance_batch(a, b):
    return sqrt(((a[None,:] - b[:,None]) ** 2).sum(2))

def gaussian(dist, bandwidth):
  return exp(-0.5 * ((dist / bandwidth))**2) / (bandwidth * math.sqrt(2 * math.pi))

# sourced from https://colab.research.google.com/github/sotte/pytorch_tutorial/blob/master/notebooks/mean_shift_clustering.ipynb#scrollTo=g0rJs_0BeVSB
# Assigns a likely cluster center for each input point
def meanshift_torch(data, batch_size=500, window_size=0.5, steps=10):
    n = len(data)
    X = data.cuda()

    for _ in range(steps):
        for i in range(0, n, batch_size):
            s = slice(i, min(n, i + batch_size))
            weight = gaussian(distance_batch(X, X[s]), window_size)
            num = (weight[:, :, None] * X).sum(dim=1)
            X[s] = num / weight.sum(1)[:, None]

    return X

# Applies mean shift clustering to predict cluster centers, then uses 
# thresholding to return cluster assignments for each point
def meanshift_cluster(data, batch_size=10, threshold=1, window_size=1, steps=15):
  X = meanshift_torch(data, batch_size, window_size, steps)
  cluster_assignments = threshold_cluster(X, threshold)
  return cluster_assignments


In [4]:
# Example of mean-shift cluster usage

# Define some test points
test = [[1,2,3,4], [1.1, 2.2, 3.3,4], [7, 7, 7, 8], [7.1, 7.1, 7.5, 8], [12, 12, 12,11]]

print("Raw cluster: ", meanshift_torch(torch.tensor(test)))
print("Cluster Assign", meanshift_cluster(torch.tensor(test)))

Raw cluster:  tensor([[ 1.0500,  2.1000,  3.1500,  4.0000],
        [ 1.0500,  2.1000,  3.1500,  4.0000],
        [ 7.0500,  7.0500,  7.2500,  8.0000],
        [ 7.0500,  7.0500,  7.2500,  8.0000],
        [12.0000, 12.0000, 12.0000, 11.0000]], device='cuda:0')
Cluster Assign [1 1 2 2 3]


In [5]:
# Apply each clustering to the DGCNN output for our trained model
# We first load the model and a test sample

sys.path.append("/work/murph186/repos")
sys.path.append("/work/murph186/repos/TreePartNet/")

# Configuration
model_name = "SorghumPartNetInstance"
version = 0
model_checkpoint_path = f"/space/ariyanzarei/sorghum_segmentation/models/model_checkpoints/{model_name}/lightning_logs/version_{version}/checkpoints/epoch=9-step=9379.ckpt" 
test_dataset_path = "/space/ariyanzarei/sorghum_segmentation/dataset/2022-03-10/sorghum__labeled_test.hdf5"
test_index = 2

from SorghumPartNet.models.nn_models import SorghumPartNetInstance 
model = SorghumPartNetInstance.load_from_checkpoint(model_checkpoint_path).cuda()

from SorghumPartNet.train_and_inference.predict_and_visualize import load_test_data
import torch

test_points,_,_,plant_index,leaf_index = load_test_data(test_dataset_path, test_index)

print(f"Model input shape: {test_points.shape}")

pred_instance_features = model(torch.unsqueeze(test_points,dim=0).cuda()).detach()
pred_instance_features = torch.squeeze(pred_instance_features)

print("Feature Vectors Shape: ", pred_instance_features.shape)


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
Model input shape: torch.Size([8000, 3])
Feature Vectors Shape:  torch.Size([8000, 256])


In [6]:
# Application of thresholding

thresh_cluster_assign = threshold_cluster(pred_instance_features, 5)
thresh_cluster_labels, cluster_counts = np.unique(thresh_cluster_assign, return_counts=True)

print("Thresh clusters:", thresh_cluster_labels, cluster_counts)

count_threshold = 10

thresh_filtered_clusters = thresh_cluster_labels[cluster_counts > count_threshold]

print("Total found: ", thresh_filtered_clusters, thresh_filtered_clusters.shape)

Thresh clusters: [  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18
  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36
  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54
  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72
  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90
  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107 108
 109] [5501    4   93    6   13    7    8    6   66  109  130    3   16   59
   72   25    4   72   24   32   38   42    5    4   47   15  106   60
   69  109    4   72   70   19   43   19   80   26   17    2    5   12
   16   75    8    7   36   52   18   24    8   12   30    4   93   29
    4   10    1   22   27    5    2   60   16   10    8   16   27    3
    1    5    2    1   11    2    1    4    3   27    7   16    2   15
   10    3   24    3    6    1    6   25    1    2    4    1   15   10
    8    3    4   12    1    3    2    7  

In [7]:
# Applicaiton of Mean-shift
meanshift_cluster_assign = meanshift_cluster(pred_instance_features, threshold=0.6, steps=40, window_size=1.5)

meanshift_cluster_labels, cluster_counts = np.unique(meanshift_cluster_assign, return_counts=True)

print("Meanshift clusters:", meanshift_cluster_labels, cluster_counts)

count_threshold = 20

meanshift_filtered_clusters = meanshift_cluster_labels[cluster_counts > count_threshold]

print("Total found: ", meanshift_filtered_clusters, meanshift_filtered_clusters.shape)

Meanshift clusters: [ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39] [5502  113  184  256  168   57   95  196  220   16  115   82   72   17
   81   70  115   50   18  109   33  155  129   43   26    4   29    4
    5    6   16    2    1    4    1    1    1    3    1]
Total found:  [ 1  2  3  4  5  6  7  8  9 11 12 13 15 16 17 18 20 21 22 23 24 25 27] (23,)


In [8]:
# now lets generate the feature vector with the model and compare the result of meanshift to the existing clustering approach

import k3d

def plot_results(cluster_assign, filtered_clusters):
  colors = [0xe41a1c,0x377eb8,0x4daf4a,0x984ea3,0xff7f00,0xffff33,0xa65628,0xf781bf,0x999999]

  plot = k3d.plot(name='points')
  filtered_list = []
  for i, c in enumerate(np.unique(cluster_assign)):

    if c in filtered_clusters:
      cluster_points = test_points[cluster_assign == c]

      color = colors[i % len(colors)] 
      plt_points = k3d.points(positions=cluster_points, point_size=0.01, color=color, name=f"class {c}")
      plot += plt_points
    else:
      filtered_list.append(test_points[cluster_assign == c])

  filtered_points = np.concatenate(filtered_list) 
  plt_points = k3d.points(positions=filtered_points, point_size=0.01, color=0xe0e0e0, name="other")
  plot += plt_points

  plot.display()



In [9]:
plot_results(thresh_cluster_assign, thresh_filtered_clusters)

Output()

In [10]:
plot_results(meanshift_cluster_assign, meanshift_filtered_clusters)

Output()

In [11]:
# Here we attempt to compare relevant metrics between the clusters derived 
# from the threshodl method and the mean-shift method

def get_test_labels(leaf_index, plant_index):
  return list(set(list(zip(leaf_index.detach().cpu().numpy(), plant_index.detach().cpu().numpy()))))

# This is needed because leaf indicies can be duplicated with multiple plants in a single
# sample
label_lookup = get_test_labels(leaf_index, plant_index)
test_labels = [ label_lookup.index((leaf_index[i], plant_index[i])) for i in range(len(leaf_index))]


from sklearn.metrics.cluster import rand_score, adjusted_rand_score, contingency_matrix

threshold_rand_score = adjusted_rand_score(test_labels, thresh_cluster_assign)
meanshift_rand_score = adjusted_rand_score(test_labels, meanshift_cluster_assign)

print("Threshold Rand Score: ", threshold_rand_score)
print("Meanshift Rand Score: ", meanshift_rand_score)

print(contingency_matrix(test_labels, meanshift_cluster_assign))

Threshold Rand Score:  0.9888163742917845
Meanshift Rand Score:  0.9864319659188563
[[ 0  0  1 ...  0  0  0]
 [ 0  0 82 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]
 ...
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  1]]
