# Interpretable Clustering with Latent Dirichlet Allocation

In [1]:
import numpy as np
import pandas as pd
import pyLDAvis.gensim
from gensim import corpora
from gensim.models import LdaModel
from sklearn.metrics import normalized_mutual_info_score as NMI
pyLDAvis.enable_notebook()

scipy.sparse.sparsetools is a private module for scipy.sparse, and should not be used.
  _deprecated()


In this tutorial,

I will show that the combination of LDA(Latent Dirichlet Allocation) model + pyLDAVis

can be powerful method for interpretable clustering.

## 1. Load Data

In [2]:
data = pd.read_csv('iris.csv')
dummy_data = data.copy()

In [3]:
class_var = 'species'
input_vars = [var for var in data.columns if var != class_var]

In [4]:
data[class_var].unique()

array(['setosa', 'versicolor', 'virginica'], dtype=object)

#### Perform clustering with #cluster = 3

In [5]:
n_cluster = 3

## 2. Discretize Continuous Variables

In [6]:
for var in input_vars:
    dummy_data[var] = pd.cut(data[var], bins=3)

In [7]:
dummy_data.head(3)

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species
0,"(4.296, 5.5]","(2.8, 3.6]","(0.994, 2.967]","(0.0976, 0.9]",setosa
1,"(4.296, 5.5]","(2.8, 3.6]","(0.994, 2.967]","(0.0976, 0.9]",setosa
2,"(4.296, 5.5]","(2.8, 3.6]","(0.994, 2.967]","(0.0976, 0.9]",setosa


## 3. Preprocess data to be input format of gensim LDA model

#### Convert Each Row to Token List

In [8]:
rows = []
for _, row in dummy_data.iterrows():
    rows.append(['{}_{}'.format(k,v) for k,v in row.items() if k != class_var])

In [9]:
rows[0]

['sepal_length_(4.296, 5.5]',
 'sepal_width_(2.8, 3.6]',
 'petal_length_(0.994, 2.967]',
 'petal_width_(0.0976, 0.9]']

#### Prepare dictionary for values of variables & Convert Structured Data into Corpus Format

In [10]:
dictionary = corpora.Dictionary(rows)
row_var_matrix = [dictionary.doc2bow(row) for row in rows]

## 4. Perform clustering with LDA model & Interpreting Clusters

In [11]:
lda = LdaModel(row_var_matrix, num_topics=n_cluster, id2word = dictionary, passes=50)

In [12]:
clustering_result = pyLDAvis.gensim.prepare(lda, row_var_matrix, dictionary)

#### With PyLdaVis, we can explore each cluster interactively!

In [13]:
clustering_result

#### To enhance interpretability, merge adjacent intervals

In [14]:
def merge_adjacent_intervals(intervals_for_var):
    
    def get_begin_and_end_points(interval):
        return [float(x) for x in interval.split('_(')[1].replace(']','').split(', ')]
    
    var_name = intervals_for_var[0].split('_(')[0]
    intervals_for_var = sorted([get_begin_and_end_points(i) for i in intervals_for_var])
    
    prev_num_intervals = len(intervals_for_var)
    
    while(True):
        
        for idx, interval in enumerate(intervals_for_var):
            if idx < len(intervals_for_var) - 1:
                curr_interval = intervals_for_var[idx]
                next_interval = intervals_for_var[idx+1]
                if curr_interval[1] == next_interval[0]:
                    merged_interval = [curr_interval[0], next_interval[1]]
                    intervals_for_var[idx+1] = merged_interval
                    intervals_for_var.remove(curr_interval)
                
        curr_num_intervals = len(intervals_for_var)
        if prev_num_intervals > curr_num_intervals:
            prev_num_intervals = curr_num_intervals
        else:
            break
    
    intervals_for_var = ['{}_({}, {}]'.format(var_name, i[0], i[1]) for i in intervals_for_var]
    return intervals_for_var

def postprocess_intervals(var_names, intervals):
    intervals_by_var = []
    for var in var_names:
        intervals_for_var = [x for x in intervals if var == x[:len(var)]]
        if len(intervals_for_var) >= 2:
            intervals_for_var = merge_adjacent_intervals(intervals_for_var)
        intervals_by_var.append(intervals_for_var)
    return intervals_by_var

#### Now, we can describe clusters in terms of intervals of variables

In [15]:
cutoff = 0.1 # P(value of variable | cluster) (Same as P(word | topic) in LDA Model)
for i in range(n_cluster):
    variables = postprocess_intervals(input_vars,
                                       [dictionary.id2token[x[0]] for x in lda.get_topic_terms(i) if x[1] > cutoff])
    print('{} th cluster'.format(i), variables)
    print('')

0 th cluster [['sepal_length_(5.5, 6.7]'], ['sepal_width_(1.998, 2.8]'], ['petal_length_(2.967, 4.933]'], ['petal_width_(0.9, 1.7]']]

1 th cluster [['sepal_length_(4.296, 5.5]'], ['sepal_width_(2.8, 3.6]'], ['petal_length_(0.994, 2.967]'], ['petal_width_(0.0976, 0.9]']]

2 th cluster [['sepal_length_(5.5, 7.9]'], ['sepal_width_(2.8, 3.6]'], ['petal_length_(4.933, 6.9]'], ['petal_width_(1.7, 2.5]']]



## 5. Clustering performance evaluation

#### Get cluster(=topic) label of each row

In [16]:
def get_cluster_label(lda, row):
    return np.argmax([x[1] for x in lda.get_document_topics(row)])

In [17]:
cluster_labels = [get_cluster_label(lda, row) for row in row_var_matrix]
NMI(data.species, cluster_labels)

0.8243868199386184

#### Compare NMI with K-means & Agglomerative clustering

In [18]:
from sklearn.cluster import AgglomerativeClustering, KMeans

In [19]:
kmeans = KMeans(n_cluster)
cluster_labels_kmeans = kmeans.fit_predict(data.drop(class_var, axis=1))
print(NMI(data.species, cluster_labels_kmeans))

0.758205727819


In [20]:
agg = AgglomerativeClustering(n_cluster)
cluster_labels_agg = agg.fit_predict(data.drop(class_var, axis=1))
print(NMI(data.species, cluster_labels_agg))

0.770140990573


#### Try with scaled input

In [21]:
scaled_data = data.drop(class_var, axis=1)
scaled_data = (scaled_data - scaled_data.mean(axis=0)) / scaled_data.std()

In [22]:
cluster_labels_kmeans_scaled = kmeans.fit_predict(scaled_data)
print(NMI(data.species, cluster_labels_kmeans_scaled))

0.652558221252


In [23]:
cluster_labels_agg_scaled = agg.fit_predict(data.drop(class_var, axis=1))
print(NMI(data.species, cluster_labels_agg_scaled))

0.770140990573


Clustering performance of proposed method outperforms K-means clustering,

and is able to provide explanation for each cluster at the same time.