In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import os
import pickle

# !pip3 install corextopic
from corextopic import corextopic as ct
from corextopic import vis_topic as vt
import scipy.sparse as ss

from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer

# import pyLDAvis
# import pyLDAvis.gensim_models
# pyLDAvis.enable_notebook()

# import gensim
# from gensim.utils import simple_preprocess
# from gensim.parsing.preprocessing import STOPWORDS
# from gensim.corpora import Dictionary
# def simple_tokenize(text):
#   return [token for token in simple_preprocess(text) if token not in STOPWORDS]

# from nltk.stem import WordNetLemmatizer, SnowballStemmer
# from nltk.stem.porter import *

def flatten_list(l):
  return [item for sublist in l for item in sublist]

def describe_training_documents(list_of_docs):
  print('There are',len(list_of_docs),'documents.')
  document_lengths = list(map(lambda x: len(x.split()),list_of_docs))
  percentile_50 = int(np.percentile(document_lengths,50))
  percentile_95 = int(np.percentile(document_lengths,95))
  print('95% of the documents are below:',percentile_95,'words.')
  plt.axvline(percentile_50, lw=1, color='g')
  plt.axvline(percentile_95, lw=1, color='r', linestyle='--')
  _ = plt.hist(document_lengths, bins=50, range=(0,percentile_95+100))
  print('Solid green line indicates median, dotted red line indicates 95 percentile. Outliers may be cropped.')


## Anchored Correlation Explanation:Topic Modeling with Minimal Domain Knowledge

In [2]:
# Get 20 newsgroups data
newsgroups = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))
documents_train = list(np.load('data/train.npy')) # historical materials 4451 documents
documents_train2 = list(np.load('data/train2.npy'))   # census bureau 4226 documents
df_occsc = pd.read_csv('data/OCC_pairs.csv').rename(columns={'OCC_DES':'Full Occupation'})
assert(df_occsc['Full Occupation'].nunique() == len(df_occsc))
occ_list = list(set(list(df_occsc['Full Occupation'])))

In [3]:
with open("nyt_text_modified.txt",'r') as f:
    nyt_text2 = f.readlines()
describe_training_documents(nyt_text2)

There are 5878 documents.
95% of the documents are below: 67 words.
Solid green line indicates median, dotted red line indicates 95 percentile. Outliers may be cropped.


In [4]:
# include both training and testing dataset into the vectorizer
# but fit the model with the training dataset
# corex model requires them to be the same shape 
documents = []
documents.extend(documents_train)
documents.extend(documents_train2)
documents.extend(nyt_text2)

document_total = documents[:]
document_total.extend(occ_list)

print("In the dataset there are", len(document_total), "textual documents")
print("And this is the first one:\n", documents[0])

In the dataset there are 14991 textual documents
And this is the first one:
 Skip to main content Search UPLOAD SIGN UP | LOG IN BOOKS VIDEO AUDIO SOFTWARE IMAGESABOUT BLOG PROJECTS HELP DONATE  CONTACT JOBS VOLUNTEER PEOPLE Search Metadata Search text contents Search TV news captions Search radio transcripts Search archived websitesAdvanced SearchSign up for freeLog inFull text of "The practical cabinet maker and furniture designer's assistant, with essays on history of furniture, taste in design, color and materials, with full explanation of the canons of good taste in furniture .."See other formats^ 


The topic model assumes input is in the form of a doc-word matrix, where rows are documents and columns are binary counts. We'll vectorize the dataset, take the top 10,000 words, and convert it to a sparse matrix to save on memory usage. Note, we use binary count vectors as input to the CorEx topic model.

### Transform data into a sparse matrix

In [5]:
vectorizer = CountVectorizer(stop_words='english', max_features=10000, binary=True)
doc_word = vectorizer.fit_transform(document_total)
doc_word = ss.csr_matrix(doc_word)

doc_word.shape # n_docs x m_words

(14991, 10000)

In [6]:
# Get words that label the columns (needed to extract readable topics and make anchoring easier)
words = list(np.asarray(vectorizer.get_feature_names()))

not_digit_inds = [ind for ind,word in enumerate(words) if not word.isdigit()]
doc_word = doc_word[:,not_digit_inds]
words = [word for ind,word in enumerate(words) if not word.isdigit()]

doc_word.shape

(14991, 9222)

In [7]:
len(words)

9222

In [8]:
# Train the CorEx topic model with 50 topics
topic_model = ct.Corex(n_hidden=50, words=words, max_iter=200, verbose=True, seed=1)
topic_model.fit(doc_word[:len(documents)], words=words) 

corex, rep size: 50
word counts [5. 7. 4. ... 4. 4. 4.]
[0.013 0.014 0.013 0.013 0.015 0.01  0.013 0.013 0.01  0.011 0.013 0.015 0.016 0.013 0.012 0.012 0.012 0.015 0.013 0.009 0.011 0.012 0.013 0.011 0.023 0.014 0.014 0.012 0.014 0.013 0.011 0.014 0.011
 0.014 0.015 0.012 0.015 0.015 0.011 0.012 0.015 0.014 0.014 0.014 0.012 0.011 0.013 0.018 0.011 0.013]
[0.281 0.427 0.288 0.266 0.483 0.265 0.264 0.399 0.255 0.259 0.314 0.436 0.408 0.297 0.352 0.309 0.274 0.392 0.248 0.241 0.458 0.307 0.471 0.284 0.878 0.314 0.38  0.316 0.353 0.412 0.289 0.437 0.288
 0.31  0.422 0.293 0.345 0.385 0.253 0.255 0.423 0.373 0.307 0.308 0.287 0.245 0.377 0.561 0.288 0.31 ]
[0.66  1.579 1.044 0.78  2.11  0.648 0.57  1.526 0.537 0.68  1.023 1.515 1.301 0.718 1.61  0.994 0.964 1.316 0.607 0.626 1.981 1.068 2.134 0.9   2.714 1.069 1.505 1.101 1.343 1.949 0.715 1.321 1.015
 0.901 1.436 0.664 0.879 1.272 0.658 0.769 1.368 1.204 0.756 0.914 0.702 0.535 1.197 1.835 1.006 0.991]
[0.208 0.964 0.305 0.318 0.544 0.31

[0.031 0.341 0.024 0.148 0.152 0.211 0.038 0.034 0.193 0.43  0.152 0.357 0.073 0.063 0.214 0.385 0.097 0.336 0.017 0.015 0.348 0.058 0.143 0.108 0.046 0.605 0.401 0.674 0.063 1.266 0.018 0.283 0.013
 0.021 0.088 0.007 0.209 0.007 0.234 0.006 1.278 0.227 0.1   0.045 0.329 0.009 0.133 0.14  0.051 0.024]
[0.031 0.341 0.025 0.148 0.153 0.214 0.038 0.034 0.197 0.442 0.151 0.359 0.073 0.063 0.215 0.479 0.097 0.34  0.017 0.015 0.346 0.059 0.143 0.108 0.047 0.609 0.403 0.675 0.061 1.248 0.019 0.298 0.015
 0.022 0.088 0.021 0.206 0.021 0.246 0.015 1.281 0.23  0.1   0.046 0.331 0.024 0.132 0.141 0.052 0.024]
[0.032 0.343 0.025 0.148 0.157 0.219 0.038 0.034 0.197 0.445 0.152 0.359 0.074 0.063 0.204 0.5   0.098 0.342 0.016 0.015 0.348 0.059 0.135 0.109 0.047 0.611 0.4   0.675 0.057 1.238 0.018 0.319 0.017
 0.021 0.092 0.007 0.206 0.007 0.25  0.005 1.274 0.235 0.101 0.047 0.333 0.009 0.133 0.141 0.052 0.023]
[0.032 0.347 0.025 0.15  0.159 0.22  0.038 0.034 0.197 0.447 0.156 0.361 0.074 0.063 0.195 

[0.033 0.365 0.028 0.159 0.176 0.23  0.041 0.037 0.235 0.498 0.192 0.383 0.071 0.069 0.178 0.5   0.105 0.363 0.038 0.018 0.379 0.067 0.145 0.147 0.054 0.641 0.403 0.705 0.077 1.438 0.034 0.389 0.057
 0.017 0.111 0.009 0.213 0.008 0.291 0.009 1.324 0.262 0.116 0.06  0.364 0.011 0.13  0.152 0.061 0.023]
[0.033 0.356 0.028 0.16  0.176 0.231 0.041 0.037 0.236 0.497 0.193 0.384 0.071 0.069 0.179 0.5   0.106 0.361 0.039 0.018 0.379 0.067 0.145 0.15  0.054 0.641 0.403 0.705 0.089 1.438 0.035 0.391 0.058
 0.017 0.111 0.009 0.217 0.008 0.291 0.009 1.323 0.262 0.116 0.06  0.364 0.011 0.129 0.152 0.061 0.023]
[0.033 0.356 0.027 0.161 0.177 0.231 0.041 0.037 0.235 0.497 0.194 0.384 0.072 0.069 0.179 0.5   0.106 0.361 0.043 0.018 0.38  0.067 0.145 0.153 0.054 0.639 0.403 0.705 0.098 1.439 0.035 0.392 0.06
 0.017 0.112 0.009 0.217 0.008 0.291 0.009 1.322 0.262 0.116 0.06  0.366 0.011 0.129 0.153 0.062 0.023]
[0.034 0.377 0.027 0.161 0.177 0.231 0.041 0.037 0.235 0.497 0.195 0.386 0.072 0.069 0.179 0

[0.037 0.475 0.028 0.176 0.179 0.25  0.041 0.039 0.234 0.513 0.201 0.407 0.054 0.069 0.178 0.496 0.109 0.351 0.085 0.023 0.393 0.069 0.145 0.269 0.056 0.647 0.401 0.719 0.172 1.46  0.04  0.418 0.102
 0.017 0.11  0.01  0.22  0.008 0.291 0.013 1.291 0.259 0.118 0.066 0.37  0.011 0.123 0.151 0.068 0.022]
[0.037 0.474 0.028 0.175 0.179 0.252 0.041 0.039 0.234 0.513 0.201 0.407 0.054 0.07  0.178 0.496 0.109 0.354 0.085 0.023 0.394 0.069 0.145 0.271 0.056 0.647 0.401 0.72  0.172 1.456 0.04  0.418 0.102
 0.017 0.11  0.01  0.22  0.008 0.291 0.013 1.29  0.259 0.118 0.066 0.37  0.011 0.128 0.149 0.068 0.022]
[0.037 0.476 0.028 0.175 0.179 0.253 0.041 0.039 0.234 0.513 0.201 0.407 0.054 0.07  0.178 0.496 0.109 0.354 0.086 0.023 0.394 0.069 0.145 0.272 0.056 0.647 0.401 0.72  0.172 1.454 0.041 0.418 0.102
 0.017 0.11  0.01  0.22  0.008 0.292 0.013 1.29  0.259 0.118 0.066 0.37  0.011 0.131 0.149 0.069 0.022]
[0.037 0.485 0.028 0.175 0.18  0.254 0.041 0.039 0.234 0.514 0.201 0.407 0.054 0.07  0.178 

[0.038 0.561 0.029 0.179 0.18  0.257 0.041 0.042 0.234 0.529 0.202 0.417 0.057 0.072 0.182 0.496 0.111 0.354 0.111 0.026 0.403 0.069 0.147 0.275 0.058 0.615 0.41  0.721 0.172 1.458 0.046 0.423 0.107
 0.016 0.114 0.011 0.224 0.008 0.293 0.014 1.285 0.264 0.119 0.068 0.369 0.012 0.131 0.151 0.072 0.024]
[0.038 0.561 0.029 0.179 0.18  0.257 0.041 0.042 0.234 0.529 0.202 0.417 0.057 0.072 0.182 0.496 0.11  0.354 0.111 0.026 0.403 0.069 0.15  0.275 0.058 0.615 0.41  0.721 0.172 1.458 0.047 0.423 0.107
 0.016 0.114 0.011 0.224 0.008 0.293 0.014 1.285 0.264 0.119 0.069 0.369 0.012 0.131 0.151 0.072 0.024]
[0.038 0.561 0.029 0.179 0.182 0.257 0.041 0.042 0.234 0.529 0.202 0.417 0.057 0.072 0.182 0.496 0.11  0.354 0.111 0.026 0.403 0.069 0.151 0.275 0.058 0.615 0.41  0.721 0.172 1.458 0.047 0.423 0.107
 0.016 0.114 0.011 0.224 0.008 0.293 0.014 1.285 0.264 0.119 0.069 0.369 0.012 0.131 0.151 0.073 0.025]
[0.038 0.56  0.029 0.178 0.183 0.257 0.041 0.042 0.234 0.529 0.202 0.417 0.057 0.072 0.182 

[0.038 0.557 0.028 0.181 0.183 0.258 0.041 0.041 0.242 0.532 0.205 0.436 0.058 0.073 0.182 0.496 0.112 0.353 0.113 0.028 0.411 0.069 0.155 0.274 0.06  0.617 0.409 0.723 0.172 1.462 0.05  0.422 0.107
 0.016 0.115 0.011 0.225 0.008 0.293 0.017 1.298 0.265 0.121 0.069 0.37  0.012 0.124 0.152 0.078 0.025]
[0.038 0.557 0.028 0.181 0.183 0.258 0.041 0.041 0.242 0.532 0.205 0.438 0.059 0.073 0.182 0.496 0.112 0.353 0.113 0.028 0.411 0.069 0.155 0.274 0.06  0.617 0.409 0.723 0.172 1.462 0.05  0.422 0.107
 0.015 0.115 0.011 0.225 0.008 0.293 0.017 1.298 0.265 0.121 0.069 0.37  0.012 0.125 0.152 0.078 0.025]
[0.038 0.557 0.028 0.181 0.183 0.258 0.041 0.041 0.242 0.533 0.205 0.438 0.059 0.073 0.182 0.496 0.112 0.353 0.113 0.028 0.411 0.069 0.155 0.274 0.06  0.617 0.41  0.723 0.172 1.462 0.05  0.422 0.107
 0.016 0.115 0.011 0.225 0.008 0.293 0.017 1.298 0.265 0.121 0.069 0.37  0.012 0.125 0.152 0.078 0.025]
[0.038 0.558 0.028 0.181 0.183 0.258 0.041 0.041 0.242 0.533 0.205 0.438 0.059 0.073 0.182 

[0.039 0.559 0.029 0.182 0.183 0.258 0.043 0.043 0.243 0.533 0.207 0.439 0.06  0.073 0.182 0.496 0.112 0.354 0.114 0.028 0.415 0.07  0.155 0.275 0.06  0.619 0.41  0.726 0.172 1.466 0.058 0.422 0.105
 0.016 0.115 0.012 0.224 0.008 0.293 0.019 1.298 0.265 0.121 0.068 0.371 0.012 0.125 0.151 0.079 0.025]
[0.039 0.559 0.029 0.182 0.183 0.258 0.043 0.043 0.243 0.533 0.207 0.439 0.06  0.073 0.182 0.496 0.112 0.354 0.114 0.028 0.415 0.07  0.155 0.275 0.06  0.619 0.41  0.726 0.172 1.466 0.058 0.422 0.105
 0.016 0.115 0.012 0.224 0.008 0.293 0.019 1.298 0.265 0.121 0.068 0.371 0.012 0.125 0.151 0.079 0.025]
[0.039 0.559 0.029 0.182 0.183 0.258 0.043 0.043 0.243 0.533 0.207 0.439 0.06  0.073 0.182 0.496 0.112 0.354 0.114 0.028 0.415 0.07  0.155 0.275 0.06  0.619 0.41  0.726 0.172 1.466 0.058 0.422 0.105
 0.016 0.115 0.012 0.224 0.008 0.293 0.019 1.298 0.265 0.121 0.068 0.371 0.012 0.125 0.151 0.079 0.025]
[0.039 0.559 0.029 0.182 0.183 0.258 0.043 0.043 0.243 0.533 0.207 0.439 0.06  0.073 0.182 

[0.039 0.56  0.029 0.182 0.184 0.259 0.044 0.042 0.241 0.533 0.207 0.439 0.06  0.073 0.182 0.503 0.111 0.353 0.117 0.028 0.431 0.072 0.155 0.275 0.061 0.619 0.41  0.726 0.172 1.465 0.059 0.423 0.105
 0.016 0.115 0.012 0.224 0.009 0.293 0.023 1.299 0.267 0.133 0.068 0.37  0.012 0.127 0.151 0.077 0.026]
[0.039 0.559 0.029 0.182 0.184 0.259 0.044 0.042 0.241 0.533 0.207 0.44  0.06  0.074 0.182 0.503 0.111 0.353 0.117 0.028 0.433 0.073 0.155 0.275 0.061 0.619 0.41  0.726 0.172 1.465 0.059 0.423 0.105
 0.016 0.115 0.012 0.224 0.009 0.293 0.023 1.299 0.267 0.133 0.068 0.37  0.012 0.127 0.151 0.076 0.026]
[0.038 0.559 0.029 0.182 0.184 0.259 0.044 0.042 0.241 0.533 0.207 0.441 0.06  0.074 0.182 0.503 0.111 0.353 0.117 0.028 0.433 0.073 0.155 0.275 0.061 0.619 0.41  0.726 0.172 1.465 0.059 0.423 0.105
 0.016 0.115 0.012 0.224 0.009 0.293 0.023 1.299 0.268 0.133 0.068 0.37  0.012 0.127 0.151 0.076 0.026]
[0.038 0.559 0.029 0.182 0.184 0.259 0.044 0.042 0.241 0.533 0.207 0.441 0.06  0.074 0.182 

<corextopic.corextopic.Corex at 0x1a2a13eeb8>

In [9]:
# Print a single topic from CorEx topic model
topic_model.get_topics(topic=0, n_words=10)

[('products', 0.15927212257342793, 1.0),
 ('industry', 0.1532496058484271, 1.0),
 ('value', 0.15260675538381344, 1.0),
 ('establishments', 0.14943021247390673, 1.0),
 ('total', 0.13742491094420908, 1.0),
 ('cent', 0.11353625541728218, 1.0),
 ('table', 0.10781344655704732, 1.0),
 ('reported', 0.10692745508750423, 1.0),
 ('statistics', 0.08780133184558349, 1.0),
 ('manufacture', 0.07388261926017761, 1.0)]

In [10]:
# Print all topics from the CorEx topic model
topics = topic_model.get_topics()

In [11]:
for topic_n,topic in enumerate(topics):
    # w: word, mi: mutual information, s: sign
    topic = [(w,mi,s) if s > 0 else ('~'+w,mi,s) for w,mi,s in topic]
    # Unpack the info about the topic
    topic_words,mis,signs = zip(*topic)
    # Print topic
    topic_str = str(topic_n)+': '+', '.join(topic_words)
    print(topic_str)

0: products, industry, value, establishments, total, cent, table, reported, statistics, manufacture
1: water, wood, piece, dry, inch, surface, add, glue, cut, mold
2: govt, repts, com, urges, pres, ct, natl, gen, conf, amer
3: furniture, style, chairs, design, decoration, carved, designs, examples, ornament, gothic
4: fig, ends, legs, pieces, rails, concrete, width, square, chair, placed
5: union, workers, local, members, international, locals, organization, trade, employers, unions
6: earners, wage, number, average, employed, prevailing, employees, nearest, representative, hours
7: today, yesterday, american, president, national, association, committee, street, washington, mr
8: work, time, best, quite, possible, working, present, true, matter, does
9: bo, lesson, vegetables, girls, cooking, teacher, food, foods, cooked, minutes
10: mills, goods, horsepower, primary, rolling, woolen, cotton, worsted, silk, purchased
11: age, occupations, gainful, sex, years, persons, gainfully, female

In [12]:
results = topic_model.predict(doc_word[len(documents):])
results.shape

(436, 50)

In [13]:
def get_predict_result(results):
    pairs=[]
    for i in range(len(results)):
        for j in range(len(results[i])):
            if results[i][j]==True:
                pairs.append([i,j])
    return pairs

def count_topics(pairs):
    available={}
    for a,b in pairs:
        if b not in available.keys():
            available[b]=1
        else:
            available[b]+=1
    return available 

def get_topic_content(pairs,topic):
    result=[]
    for a,b in pairs:
        if b == topic:
            result.append(occ_list[a])
    return result 
    

### Total Corelation and Model Selection

In [14]:
topic_model.tc 

12.396312012950295

In [15]:
topic_model.tcs.shape

(50,)

In [16]:
print(np.sum(topic_model.tcs))
print(topic_model.tc)

12.396312012950295
12.396312012950295


In [17]:
#Selecting number of topics:Choosing from the data visualised
plt.figure(figsize=(10,5))
plt.bar(range(topic_model.tcs.shape[0]), topic_model.tcs, color='#4e79a7', width=0.5)
plt.xlabel('Topic', fontsize=16)
plt.ylabel('Total Correlation (nats)', fontsize=16);
plt.savefig('Distribution of TCs for each topic', dpi=600)

### Pointwise Document TC

In [18]:
topic_model.log_z.shape # n_docs x k_topics

(14555, 50)

In [19]:
print(np.mean(topic_model.log_z, axis=0)) #The pointwise total correlations in log_z represent the correlations within an individual document explained by a particular topic. These correlations have been used to measure how "surprising" documents are with respect to given topics
print(topic_model.tcs)

[1.465 1.299 0.727 0.619 0.559 0.534 0.503 0.446 0.441 0.423 0.41  0.37  0.353 0.293 0.275 0.268 0.259 0.241 0.224 0.207 0.185 0.182 0.182 0.172 0.155 0.151 0.132 0.127 0.117 0.115 0.111 0.105 0.074
 0.073 0.07  0.068 0.061 0.06  0.06  0.044 0.042 0.038 0.029 0.028 0.026 0.023 0.016 0.012 0.012 0.009]
[1.465 1.299 0.727 0.619 0.559 0.534 0.503 0.446 0.441 0.423 0.41  0.37  0.353 0.293 0.275 0.268 0.259 0.241 0.224 0.207 0.185 0.182 0.182 0.172 0.155 0.151 0.132 0.127 0.117 0.115 0.111 0.105 0.074
 0.073 0.07  0.068 0.061 0.06  0.06  0.044 0.042 0.038 0.029 0.028 0.026 0.023 0.016 0.012 0.012 0.009]


### Introducing Anchoring in the semi-supervised topic mode

CorEx is a discriminative model, whereas LDA is a generative model. This means that while LDA outputs a probability distribution over each document, CorEx instead estimates the probability a document belongs to a topic given that document's words. As a result, the probabilities across topics for a given document do not have to add up to 1. The estimated probabilities of topics for each document can be accessed through log_p_y_given_x or p_y_given_x.

Hierarchical Topic Models
The labels attribute gives the binary topic expressions for each document and each topic. We can use this output as input to another CorEx topic model to get latent representations of the topics themselves. This yields a hierarchical CorEx topic model. Like the first layer of the topic model, one can determine the number of latent variables to add in higher layers through examination of the topic TCs.

Anchored CorEx is an extension of CorEx that allows the "anchoring" of words to topics. When anchoring a word to a topic, CorEx is trying to maximize the mutual information between that word and the anchored topic. So, anchoring provides a way to guide the topic model towards specific subsets of words that the user would like to explore.

1. Anchoring a single set of words to a single topic. This can help promote a topic that did not naturally emerge when running an unsupervised instance of the CorEx topic model. For example, one might anchor words like "snow," "cold," and "avalanche" to a topic if one suspects there should be a snow avalanche topic within a set of disaster relief articles.

2. Anchoring single sets of words to multiple topics. This can help find different aspects of a topic that may be discussed in several different contexts. For example, one might anchor "protest" to three topics and "riot" to three other topics to understand different framings that arise from tweets about political protests.

3. Anchoring different sets of words to multiple topics. This can help enforce topic separability if there appear to be chimera topics. For example, one might anchor "mountain," "Bernese," and "dog" to one topic and "mountain," "rocky," and "colorado" to another topic to help separate topics that merge discussion of Bernese Mountain Dogs and the Rocky Mountains.

In [20]:
# to automatically generate anchor words: for each label in a data set, 
# we find the words that have the highest mutual information with the label.
# we took a very simple to automatically generate the anchor words to create a semi-supervised model

anchor_words=[]
for n,topic in enumerate(topics):
    topic_words,_,_ = zip(*topic)
    anchor_words.append(list(topic_words[:3]))

anchor_words

[['products', 'industry', 'value'],
 ['water', 'wood', 'piece'],
 ['govt', 'repts', 'com'],
 ['furniture', 'style', 'chairs'],
 ['fig', 'ends', 'legs'],
 ['union', 'workers', 'local'],
 ['earners', 'wage', 'number'],
 ['today', 'yesterday', 'american'],
 ['work', 'time', 'best'],
 ['bo', 'lesson', 'vegetables'],
 ['mills', 'goods', 'horsepower'],
 ['age', 'occupations', 'gainful'],
 ['precinct', 'patrolmen', 'annum'],
 ['york', 'new', 'pennsylvania'],
 ['market', 'electric', 'exchange'],
 ['great', 'century', 'people'],
 ['institute', 'technology', 'avenue'],
 ['north', 'south', 'carolina'],
 ['art', 'exhibition', 'museum'],
 ['used', 'iron', 'steel'],
 ['law', 'court', 'war'],
 ['dr', 'school', 'university'],
 ['individual', 'operations', 'avoid'],
 ['classification', 'according', 'classified'],
 ['proportion', 'females', 'tho'],
 ['good', 'make', 'appearance'],
 ['employment', 'minimum', 'maximum'],
 ['mind', 'borne', 'shop'],
 ['circulation', 'apparatus', 'newspapers'],
 ['states', 

In [21]:
len(anchor_words)

50

In [22]:
# Anchor 'nasa' and 'space' to first topic, 'sports' and 'stadium' to second topic, so on...
#anchor_words = [['industry', 'manufacture','worker'], ['professional','skilled','technology'], ['politics', 'government'], ['domestic','service']]

anchored_topic_model = ct.Corex(n_hidden=50, seed=2) # note different seed, does it matter?
anchored_topic_model.fit(doc_word[:len(documents)], words=words, anchors=anchor_words, anchor_strength=6) # anchor_strength pretty high

doc_word.shape[1]



9222

In [23]:
for n in range(len(anchor_words)):
    topic_words,_,_ = zip(*anchored_topic_model.get_topics(topic=n))
    print('{}: '.format(n) + ','.join(topic_words))

0: industry,products,value,manufacture,materials,added,branches,primarily,valued,duplication
1: wood,water,piece,dry,add,boil,brush,stain,surface,hot
2: govt,com,repts,pres,urges,natl,conf,ussr,ct,gen
3: furniture,style,chairs,design,decoration,carved,designs,examples,ivory,gothic
4: fig,ends,legs,inch,mold,sides,cut,inches,pieces,cast
5: workers,union,local,members,international,locals,trade,shops,employers,strike
6: number,wage,earners,employed,average,employees,prevailing,gives,hours,december
7: today,yesterday,american,announced,ap,afternoon,died,aug,company,kindred
8: work,time,best,doing,performed,overtime,paid,complete,begin,carvers
9: bo,lesson,vegetables,cooking,girls,teacher,food,milk,foods,cooked
10: mills,goods,horsepower,primary,cotton,rolling,worsted,woolen,silk,purchased
11: age,occupations,gainful,sex,engaged,years,persons,occupational,gainfully,status
12: precinct,annum,patrolmen,assignments,duty,division,appointed,18th,patrolman,effect
13: new,york,pennsylvania,jersey

In [24]:
print(np.sum(anchored_topic_model.tcs))
print(anchored_topic_model.tc)

55.25746662529172
55.25746662529172


In [25]:
len(documents), 13025

(14555, 13025)

In [28]:
doc_word.shape

(14991, 9222)

In [29]:
results=topic_model.predict(doc_word[13025:])
results.shape

(1966, 50)

In [30]:
pairs=get_predict_result(results)
print(count_topics(pairs))
print(get_topic_content(pairs,17))

{7: 899, 15: 151, 25: 104, 36: 115, 2: 361, 20: 429, 48: 172, 21: 318, 30: 273, 16: 319, 41: 193, 47: 198, 40: 154, 33: 60, 43: 117, 29: 60, 49: 207, 26: 108, 37: 77, 44: 39, 34: 355, 8: 147, 39: 105, 46: 74, 35: 35, 38: 117, 13: 127, 18: 198, 5: 101, 3: 33, 32: 46, 24: 8, 42: 24, 27: 34, 17: 37, 45: 20, 1: 3, 4: 1, 9: 14, 0: 18, 22: 3, 28: 30, 31: 9, 19: 38, 10: 32, 14: 24, 6: 1, 11: 62, 23: 3, 12: 4}


IndexError: list index out of range

In [None]:
def get_topic_list(pairs):
    result={}
    for a,b in pairs:
        if b not in result.keys():
            result[b]=[a]
        else:
            result[b].append(a)
    return result

def save_topic(pairs):
    result_dic=get_topic_list(pairs)
    with open("train_result.txt",'w') as f:
        for m,n in result_dic.items():
            topic_words,_,_ = zip(*anchored_topic_model.get_topics(topic=m))
            title=str(m)+":"+(','.join(topic_words))
            f.write(title+'\n')
            for file in n:
                occ_title=occ_list[file]
                f.write(occ_title+'\n')
            f.write("\n \n \n")

save_topic(pairs)