In [1]:
import matplotlib.pyplot as plt
import matplotlib.pylab as plt
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import euclidean_distances, cosine_distances
from pprint import pprint
from collections import Counter
from random import randrange
import numpy as np
import pandas as pd
import math

In [2]:
def evaluation_metrics(pred_labels, true_labels=None):
    if true_labels is not None:
        N = len(pred_labels)

        cluster_labels = {}
        for i in range(len(pred_labels)):
            cluster_labels.setdefault(pred_labels[i], []).append(true_labels[i])

        cluster_labels.pop('Noise', None)
        K = len(cluster_labels)

        # Store list of labels as a Counter
        for key,value in cluster_labels.items():
            cluster_labels[key] = Counter(value)

        # Calculate purity
        purity = 0
        for cluster in cluster_labels:
            purity += max(cluster_labels[cluster].values())

        purity /= N

        # Calculate gini index
        gini_index = 0
        for key,value in cluster_labels.items():
            gini = 0
            for k,v in value.items():
                gini += (v / sum(cluster_labels[key].values())) ** 2
            gini_index += 1 - gini

        gini_index /= K if K != 0 else 1

        # Final result
        print('Purity -', round(purity, 4), 'Gini Index -', round(gini_index, 4))

    print('No. of clusters -', len(Counter(pred_labels)))
    print(Counter(pred_labels), '\n')

In [3]:
# k - index of current data point in data
# e - epsilon
def find_neighbors(k, e, distance_matrix):
    N = []      # Neighbors
    
    for i in range(len(distance_matrix[k])):
        if distance_matrix[k][i] <= e and i != k:   # Return neighbors within distance e, except for the point itself
            N.append(i)

    return N

In [4]:
# e - epsilon
# min_pts - min points
def dbscan(data, e, min_pts, labels=None):
    
    distance_matrix = euclidean_distances(data)

    clusters = []
    for i in range(data.shape[0]):
        clusters.append(math.nan)

    c = 0   # Cluster label
    for i in range(data.shape[0]):

        # Skip if already assigned a cluster
        if not pd.isnull(clusters[i]):
            continue

        S = find_neighbors(i, e, distance_matrix)

        # Density check - label Noise if no. of neighbors less than min_pts
        if len(S) < min_pts:
            clusters[i] = 'Noise'
            continue

        # Next cluster label
        c = c + 1

        # Add point to the new cluster
        clusters[i] = c

        # Process every point in neighborhood except the point itself
        for j in S:
            j = int(j)
            if j != i:

                # Change noise point to border point 
                if clusters[j] == 'Noise':
                    clusters[j] = c

                # Skip if already assigned a cluster
                if not pd.isnull([clusters[j]]):
                    continue

                # Add neighbor to the current cluster
                clusters[j] = c

                # Get neighbors
                N = find_neighbors(j, e, distance_matrix)

                # Density check - add new neighbors to seed set if no. of neighbors greater than min_pts
                if len(N) >= min_pts:
                    for k in N:
                        if int(k) != i:
                            S.append(k)

    # Evaluate results
    print('epsilon -', e, 'min_pts -', min_pts)
    evaluation_metrics(clusters, labels)

In [5]:
# Fetch data
ng_all = fetch_20newsgroups(subset='all')

# Data and labels
ng_data = ng_all.data

ng_labels = []
for i in range(len(ng_data)):
    ng_labels.append(ng_all.target_names[ng_all.target[i]])

print(len(ng_data))
print(len(ng_labels))

18846
18846


In [6]:
# Converting text to vectors
tfidf = TfidfVectorizer(stop_words='english')
vect_ng_all = tfidf.fit_transform(ng_all.data)
print(vect_ng_all.shape)

(18846, 173451)


In [7]:
for p in [1,3,5]:
    for e in [1, 2]:
        dbscan(data=vect_ng_all, e=e, min_pts=p, labels=ng_labels)
    print('------------------------------')

epsilon - 1 min_pts - 1
Purity - 0.5425 Gini Index - 0.0351
No. of clusters - 2749
Counter({'Noise': 7869, 33: 173, 5: 165, 15: 157, 123: 133, 176: 101, 2: 85, 106: 73, 43: 65, 234: 59, 4: 55, 3: 53, 136: 50, 179: 35, 13: 33, 461: 33, 662: 33, 40: 32, 67: 32, 433: 32, 52: 31, 273: 31, 304: 28, 373: 26, 516: 26, 555: 26, 220: 25, 258: 25, 25: 24, 57: 24, 95: 24, 240: 23, 370: 23, 376: 23, 609: 23, 83: 22, 339: 22, 277: 21, 851: 21, 572: 20, 624: 20, 771: 20, 6: 19, 274: 19, 291: 19, 295: 19, 319: 19, 463: 19, 773: 19, 921: 19, 243: 18, 355: 18, 412: 18, 510: 18, 828: 18, 19: 17, 169: 17, 219: 17, 246: 17, 270: 17, 278: 17, 398: 17, 496: 17, 521: 17, 568: 17, 762: 17, 998: 17, 93: 16, 99: 16, 196: 16, 213: 16, 730: 16, 31: 15, 186: 15, 188: 15, 394: 15, 514: 15, 524: 15, 718: 15, 859: 15, 44: 14, 205: 14, 214: 14, 430: 14, 827: 14, 847: 14, 1056: 14, 1268: 14, 1372: 14, 35: 13, 78: 13, 111: 13, 122: 13, 236: 13, 459: 13, 634: 13, 883: 13, 1081: 13, 8: 12, 70: 12, 142: 12, 206: 12, 247: 1

epsilon - 2 min_pts - 1
Purity - 0.053 Gini Index - 0.9495
No. of clusters - 1
Counter({1: 18846}) 

------------------------------
epsilon - 1 min_pts - 3
Purity - 0.2743 Gini Index - 0.0787
No. of clusters - 682
Counter({'Noise': 13153, 3: 103, 47: 102, 59: 87, 2: 82, 104: 80, 73: 72, 10: 62, 23: 62, 5: 56, 62: 52, 39: 50, 35: 48, 57: 43, 78: 34, 60: 32, 18: 30, 191: 29, 155: 28, 105: 27, 227: 27, 76: 25, 234: 25, 6: 24, 156: 24, 95: 23, 32: 22, 113: 22, 29: 21, 217: 21, 404: 21, 94: 20, 96: 20, 257: 20, 223: 20, 99: 19, 101: 19, 117: 19, 120: 19, 136: 19, 146: 19, 224: 19, 238: 18, 128: 18, 82: 18, 423: 18, 180: 18, 242: 18, 25: 17, 44: 17, 228: 17, 127: 17, 173: 17, 169: 17, 259: 17, 301: 17, 19: 16, 359: 16, 488: 16, 75: 16, 115: 16, 158: 16, 306: 16, 291: 15, 34: 15, 55: 15, 147: 15, 425: 15, 135: 15, 161: 15, 212: 15, 125: 14, 253: 14, 313: 14, 383: 14, 561: 13, 151: 13, 31: 13, 318: 13, 174: 13, 182: 13, 245: 13, 314: 13, 410: 12, 49: 12, 74: 12, 88: 12, 263: 12, 123: 12, 145: 

------------------------------
