In [1]:
!rm nmf.cpython*
!ln -s ../installer/nmf.cpython-36m-x86_64-linux-gnu.so  nmf.cpython-36m-x86_64-linux-gnu.so


In [None]:
import nmf

In [3]:
import pandas as pd;
import numpy as np;
import scipy as sp;
import sklearn;
import sys;
from nltk.corpus import stopwords;
import nltk;
from gensim.models import ldamodel
import gensim.corpora;
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer;
#from sklearn.decomposition import NMF;
from sklearn.preprocessing import normalize;
import pickle;

 ### Read in all the headlines

In [59]:
data = pd.read_csv('archive/abcnews-date-text.csv', error_bad_lines=False)
print(data.shape)
# We only need the Headlines text column from the data
data_text = data[['headline_text']][0:10000]

(1186018, 2)


In [60]:
data_text.shape

(10000, 1)

### Create train and test data

In [61]:
#remove stopwords
STOPWORDS = set(stopwords.words('english'))
def remove_stopwords(text):
    """custom function to remove the stopwords"""
    return " ".join([word for word in str(text).split() if word not in STOPWORDS])

data_text.loc[:,"headline_text"]  = data_text["headline_text"].apply(lambda text: remove_stopwords(text))

In [63]:
vectorizer = CountVectorizer(analyzer='word', max_features=5000)
x_counts = vectorizer.fit_transform(data_text["headline_text"].to_numpy())

In [64]:
transformer = TfidfTransformer(smooth_idf=False);
x_tfidf = transformer.fit_transform(x_counts);

In [65]:
xtfidf_norm = normalize(x_tfidf, norm='l1', axis=1)

In [66]:
# TODO: add a csr reader in python to deal with extremely big data
xtfidf_norm = xtfidf_norm.toarray()

### Perform the NMF training to compute W and H

In [67]:
%%time 
W, H = nmf.play(xtfidf_norm, 10, 1000000, 42, 0 )

Conversion of X into memoryview
Creating W array with 100000 elements
Creating H array with 50000 elements
Calling C
CPU times: user 8min 41s, sys: 5.1 s, total: 8min 46s
Wall time: 8min 38s


### Obtain the 10 topics 

In [68]:
def get_nmf_topics(W, H, vectorizer, num_topics, n_top_words):
    #the word ids obtained need to be reverse-mapped to the words so we can print the topic names.
    feat_names = vectorizer.get_feature_names()
    
    word_dict = {};
    for i in range(num_topics):
        #for each topic, obtain the largest values, and add the words they map to into the dictionary.
        words_ids = H[i,:].argsort()[:-20 - 1:-1] # select the i-th row, which is the i-th component
        words = [feat_names[key] for key in words_ids]
        word_dict['Topic # ' + '{:02d}'.format(i+1)] = words;
    
    return pd.DataFrame(word_dict);

get_nmf_topics(W, H, vectorizer, 10, 20)

Unnamed: 0,Topic # 01,Topic # 02,Topic # 03,Topic # 04,Topic # 05,Topic # 06,Topic # 07,Topic # 08,Topic # 09,Topic # 10
0,us,man,police,world,baghdad,govt,iraq,war,council,new
1,troops,charged,probe,cup,explosions,nsw,says,anti,water,resolution
2,iraqi,court,death,australia,rock,vic,un,protest,rain,un
3,turkey,murder,search,win,blasts,fire,bush,protesters,plan,ceo
4,forces,stabbing,investigate,takes,iraqi,qld,pm,howard,restrictions,zealand
5,korea,face,missing,final,raids,urged,missiles,protests,security,launched
6,killed,hospital,fatal,claims,missing,sa,howard,rally,may,high
7,military,dies,car,england,coalition,hospital,set,pm,funds,home
8,north,charge,crash,championship,tanks,wa,blair,students,farmers,work
9,fire,jailed,station,miss,forces,nt,resolution,march,boost,president


### Print out some examples for a topic

In [80]:
# Topic 0 
for i, val in enumerate(W[:,0],  0):
    if val >0.10: # this is a treshold value that can be computed from W[:,0].max()
        print(i)

428
440
1446
4789
5719
6376
6565
6785
7017
7874
7927
9015
9016
9439
9604
9623


In [102]:
data.iloc[5719]["headline_text"]

'turkey reconsidering disallowing us troops'

In [98]:
# Topic 4
for i, val in enumerate(W[:,3],  0):
    if val >0.15:
        print(i)

404
2932
4136
4608
6609
6762
6813
7263
9873


In [103]:
data.iloc[404]["headline_text"]

'socceroos creep up world rankings'

In [104]:
data.iloc[2932]["headline_text"]

'suspension stomps on flavells world cup bid'

In [88]:
# Topic 5 
for i, val in enumerate(W[:,4],  0):
    if val >0.15:
        print(i)

1902
5723
6188
6415
6456
6505
6602
6614
6622
6714
6863
7095
7513
7628
7629
7963
9519
9621
9622


In [105]:
data.iloc[7095]["headline_text"]

'explosions heard in baghdad'