In [None]:
from sklearn.decomposition import NMF
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.preprocessing import normalize
from scipy.sparse import dok_matrix
from stop_words import get_stop_words
import numpy as np
import json

In [None]:
with open("twit_new.json", "r") as f:
    data = json.load(f)

In [None]:
twitTexts = set()
voca = set()
for twit in data:
    body = twit['body']
    # 단어가 3개 이상인 트윗만 취급
    body = body.split()
    if len(body) > 3:
        voca.update(body)
        body = " ".join(body)
        twitTexts.add(body)

del data
twitTexts = list(twitTexts)

In [None]:
# stopwords 제거
stopwords = set(get_stop_words('en'))
stopwords.update(['via', 'will', 'just'])
voca = {v for v in voca if len(v) > 2}  # 단어길이가 3 이상인 경우만 취급
voca = list(voca - stopwords)
voca_id = {w: i for i, w in enumerate(voca)}  # 단어 인덱싱

In [None]:
# term-document matrix를 sparse matrix로 생성
tdm = dok_matrix((len(twitTexts), len(voca)), dtype=np.float32)
print(tdm.shape)

In [None]:
# term-document matrix
for i, twit in enumerate(twitTexts):
    for word in twit.split():
        try:
            tdm[i, voca_id[word]] += 1
        except:
            # stopwords
            continue

In [None]:
# 각 document별로 l2-normalize
tdm_ = normalize(tdm)

In [None]:
# NMF
K = 10
nmf = NMF(n_components=K, init='nndsvd')
W = nmf.fit_transform(tdm_)
H = nmf.components_

In [None]:
# 각 토픽별 키워드 출력
for k in range(K):
    print(f"{k}th topic")
    for index in H[k].argsort()[::-1][:20]:
        print(voca[index], end=" ")
    print("\n")

In [None]:
# tfidf를 통한 nmf
tfidf = TfidfTransformer()
tdm_ = tfidf.fit_transform(tdm)

In [None]:
K = 10
nmf = NMF(n_components=K, init='nndsvd')
W = nmf.fit_transform(tdm_)
H = nmf.components_

In [None]:
for k in range(K):
    print(f"{k}th topic")
    for index in H[k].argsort()[::-1][:20]:
        print(voca[index], end=" ")
    print("\n")

In [None]:
W_ = W.T
for k in range(K):
    print(f"{k}th topic")
    for index in W_[k].argsort()[::-1][:5]:
        print(twitTexts[index])
    print()