In [None]:
import os
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.tokenize.treebank import TreebankWordDetokenizer
import spacy
import gensim
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from collections import defaultdict
import time
import pickle
import pandas as pd
import numpy as np
from collections import Counter
import sys
import csv
import matplotlib.pyplot as plt
import pandas as pd
import mpu
from tmtoolkit.topicmod.evaluate import metric_coherence_gensim
import itertools

np.set_printoptions(threshold= sys.maxsize)
np.random.seed(5)

# --------------------------------------------------------------------------------
# train lda
# --------------------------------------------------------------------------------
load_model = True
save_to_model = False
model_name = 'saved_model/model_0.pkl'
topic_number = 5
max_df = 0.8
min_df = 0.01

user_tweets_dic = mpu.io.read('user_tweets_dic.pickle')

train_data = []
for key in sorted(user_tweets_dic.keys()):
	train_data = train_data + user_tweets_dic[key]


if load_model:
	with open(model_name, 'rb') as f:
	    vectorizer, lda_model = pickle.load(f)
	data_vectorized = vectorizer.transform(train_data)
	lda_output = lda_model.transform(data_vectorized)
else:
	vectorizer = CountVectorizer(max_df = max_df, min_df=min_df)
	data_vectorized = vectorizer.fit_transform(train_data)

	lda_model = LatentDirichletAllocation(n_components=topic_number, learning_method='online', random_state=100, max_iter=100)
	lda_output = lda_model.fit_transform(data_vectorized)


# --------------------------------------------------------------------------------
# evaluate lda
# --------------------------------------------------------------------------------

score_perplexity = lda_model.perplexity(data_vectorized)
print("perplexity:", score_perplexity)

vocab = np.array(vectorizer.get_feature_names())
topic_word_distrib = np.array(lda_model.components_)
dtm = vectorizer.transform(train_data)

score_coherence = metric_coherence_gensim(measure='u_mass', topic_word_distrib=topic_word_distrib, 
		vocab=vocab, dtm=dtm, return_mean=True)
print("coherence:", score_coherence)

# --------------------------------------------------------------------------------
# get topic scores for lda and save model if needed
# --------------------------------------------------------------------------------

if save_to_model:
	with open(model_name, 'wb') as f:
	    pickle.dump((vectorizer, lda_model), f)

# Show top n keywords for each topic
def show_topics(vectorizer, lda_model, verbose=True, n_words=20):
	keywords = np.array(vectorizer.get_feature_names())
	topic_keywords = []
	for topic_weights in lda_model.components_:
		top_keyword_locs = (-topic_weights).argsort()[:n_words]
		topic_keywords.append(keywords.take(top_keyword_locs))
	if verbose:
		total = []
		for i in range(0, len(topic_keywords)):
			print("Topic " + str(i), end = " ")
			print(list(topic_keywords[i]))
			total = total + list(topic_keywords[i])
		print(len(list(set(total))))
		exit()
	return topic_keywords

show_topics(vectorizer, lda_model)
dominant_topic = np.argmax(lda_output, axis=1)
print(Counter(dominant_topic))



# --------------------------------------------------------------------------------
# get user topic distribution
# --------------------------------------------------------------------------------
# print('user', 'total_tweets', 'topic_0', 'topic_1', 'topic_2', 'topic_3', 'topic_4')
# usernames = []
# user_topic_count = {}
# for key in sorted(user_tweets_dic.keys()):
# 	# train_data = train_data + user_tweets_dic[key]
# 	usernames = usernames + [key] * len(user_tweets_dic[key])
# 	user_topic_count[key] = {0:0, 1:0, 2:0, 3:0, 4:0}

# for i in range(0, len(usernames)):
# 	user_topic_count[usernames[i]][dominant_topic[i]] += 1


# for key in user_topic_count.keys():
# 	tmp = user_topic_count[key]
# 	s = float(tmp[0] + tmp[1] + tmp[2]+ tmp[3] + tmp[4])
# 	print(key, s, tmp[0]/s, tmp[1]/s, tmp[2]/s, tmp[3]/s, tmp[4]/s)

