In [None]:
# adapted from blogpost:
# https://shuaiw.github.io/2016/12/22/topic-modeling-and-tsne-visualzation.html

In [1]:
import os
import argparse
import time
from sklearn.decomposition import LatentDirichletAllocation
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.datasets import fetch_20newsgroups
from sklearn.manifold import TSNE
import bokeh.plotting as bp
from bokeh.plotting import save
from bokeh.models import HoverTool

In [26]:
from bokeh.io import output_notebook, push_notebook, show
output_notebook()

In [2]:
from glob import glob

In [3]:
import json

## Load the article text

In [4]:
articles = {}
for path in glob('articles/*.json'):
    with open(path, 'r') as f:
        data = json.load(f)
        articles[data['title'].lstrip(':')] = data['paragraph_text']

In [84]:
len(articles)

2975

In [83]:
print articles['Pelican (band)'][:310]

Pelican is a post-metal quartet from Chicago, Illinois. Established in 2000, the band stems from their native post-metal scene and is known for their atmospheric and, with no vocalist in the group, entirely instrumental style. They have released four studio albums and four EPs and gained television exposure.



In [6]:
ids = articles.keys()
texts = articles.values()

## Set our LDA Parameters

In [None]:
n_topics = 20
max_iter = 10
n_top_words = 5

In [7]:
cvectorizer = CountVectorizer(min_df=3, stop_words='english')
cvz = cvectorizer.fit_transform(texts)

lda_model = LatentDirichletAllocation(n_topics=n_topics, max_iter=max_iter)
X_topics = lda_model.fit_transform(cvz)



In [85]:
X_topics[0]

array([ 0.00053763,  0.00053763,  0.26137394,  0.00053763,  0.00053763,
        0.00053763,  0.03298554,  0.00053763,  0.00053763,  0.00053763,
        0.42949392,  0.00053763,  0.00053763,  0.00053763,  0.02792428,
        0.00053763,  0.00053763,  0.05652962,  0.00053763,  0.18416581])

## Sample some words for each topic

In [31]:
threshold = 0.0
_idx = np.amax(X_topics, axis=1) > threshold
_topics = X_topics[_idx]

num_example = len(_topics)

# find the most probable topic for each news
_lda_keys = []
for i in xrange(_topics.shape[0]):
    _lda_keys += _topics[i].argmax(),

topic_summaries = []
topic_word = lda_model.components_
vocab = cvectorizer.get_feature_names()
for i, topic_dist in enumerate(topic_word):
    topic_words = np.array(vocab)[np.argsort(topic_dist)][:-(n_top_words+1):-1]
    summary = ' '.join(topic_words)
    topic_summaries.append(summary)
    print i, summary

0 school university students college campus
1 team league club season football
2 soviet al israel iran saudi
3 philippines philippine manila archers peat
4 aleppo istanbul orient turkish lithuania
5 german brazilian brazil bundesliga rio
6 church canada canadian st catholic
7 air aircraft squadron force wing
8 court black aclu law supreme
9 police service district officers department
10 party government members national united
11 epp jacksonville thai thailand sdp
12 army division war forces regiment
13 la spanish del spain scottish
14 texas alabama state yard southern
15 band album music released group
16 company services business new million
17 metal armenian thrash death armenia
18 station radio new channel fm
19 india pakistan indian wales cardiff


## Use T-SNE to reduce the dimensionality of our LDA vectors

In [13]:
tsne_model = TSNE(n_components=2, verbose=1, random_state=0, angle=.99, init='pca')
tsne_lda = tsne_model.fit_transform(_topics[:num_example])

[t-SNE] Computing pairwise distances...
[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Computed conditional probabilities for sample 1000 / 2975
[t-SNE] Computed conditional probabilities for sample 2000 / 2975
[t-SNE] Computed conditional probabilities for sample 2975 / 2975
[t-SNE] Mean sigma: 0.029860
[t-SNE] KL divergence after 100 iterations with early exaggeration: 1.182378
[t-SNE] Error after 325 iterations: 1.182378


In [66]:
def get_spaced_colors(n):
    max_value = 16581375 #255**3
    interval = int(max_value / n)
    colors = [hex(I)[2:].zfill(6) for I in range(0, max_value, interval)]
    return [(int(i[:2], 16), int(i[2:4], 16), int(i[4:], 16)) for i in colors]

colormap = []
for rgb in get_spaced_colors(n_topics):
    colormap.append('#' + ''.join('{0:02x}'.format(part) for part in rgb))
colormap = np.array(colormap)

## Make a pretty picture!

In [71]:
title = "t-SNE visualization of LDA model"

plot_lda = bp.figure(plot_width=1000, plot_height=700,
                   title=title,
                   tools="pan,wheel_zoom,box_zoom,reset,hover,previewsave",
                   x_axis_type=None, y_axis_type=None, min_border=1)

plot_lda.scatter(x=tsne_lda[:, 0], y=tsne_lda[:, 1],
               color=colormap[_lda_keys][:num_example],
               source=bp.ColumnDataSource({
                 "content": ids[:num_example],
                 "topic_key": _lda_keys[:num_example]
                 }))

# randomly choose a news (in a topic) coordinate as the crucial words coordinate
topic_coord = np.empty((X_topics.shape[1], 2)) * np.nan
for topic_num in _lda_keys:
    if not np.isnan(topic_coord).any():
        break
    topic_coord[topic_num] = tsne_lda[_lda_keys.index(topic_num)]

# plot crucial words
for i in xrange(X_topics.shape[1]):
    if not np.isnan(topic_coord[i,:]).any():
        plot_lda.text(topic_coord[i, 0], topic_coord[i, 1], [topic_summaries[i]])

# hover tools
hover = plot_lda.select(dict(type=HoverTool))
hover.tooltips = {"content": "@content - topic: @topic_key"}

try:
    handle = show(plot_lda, notebook_handle=True)
except:
    pass

Supplying a user-defined data source AND iterable values to glyph methods is deprecated.

See https://github.com/bokeh/bokeh/issues/2056 for more information.

  warn(message)
Supplying a user-defined data source AND iterable values to glyph methods is deprecated.

See https://github.com/bokeh/bokeh/issues/2056 for more information.

  warn(message)
Supplying a user-defined data source AND iterable values to glyph methods is deprecated.

See https://github.com/bokeh/bokeh/issues/2056 for more information.

  warn(message)
Supplying a user-defined data source AND iterable values to glyph methods is deprecated.

See https://github.com/bokeh/bokeh/issues/2056 for more information.

  warn(message)
