In [1]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation

# 모토사이클, 야구, 그래픽스, 윈도우즈, 중동, 기독교, 의학, 우주 주제를 추출. 
cats = ['rec.motorcycles', 'rec.sport.baseball', 'comp.graphics', 'comp.windows.x',
        'talk.politics.mideast', 'soc.religion.christian', 'sci.electronics', 'sci.med'  ]

# 위에서 cats 변수로 기재된 category만 추출. featch_20newsgroups( )의 categories에 cats 입력
news_df= fetch_20newsgroups(subset='all',remove=('headers', 'footers', 'quotes'), 
                            categories=cats, random_state=0)

#LDA 는 Count기반의 Vectorizer만 적용합니다.  
count_vect = CountVectorizer(max_df=0.95, max_features=1000, min_df=2, stop_words='english', ngram_range=(1,2))
feat_vect = count_vect.fit_transform(news_df.data)
print('CountVectorizer Shape:', feat_vect.shape)

CountVectorizer Shape: (7862, 1000)


In [2]:
lda = LatentDirichletAllocation(n_components=8, random_state=0)
lda.fit(feat_vect)

LatentDirichletAllocation(n_components=8, random_state=0)

In [3]:
print(lda.components_.shape)
lda.components_

(8, 1000)


array([[3.60992018e+01, 1.35626798e+02, 2.15751867e+01, ...,
        3.02911688e+01, 8.66830093e+01, 6.79285199e+01],
       [1.25199920e-01, 1.44401815e+01, 1.25045596e-01, ...,
        1.81506995e+02, 1.25097844e-01, 9.39593286e+01],
       [3.34762663e+02, 1.25176265e-01, 1.46743299e+02, ...,
        1.25105772e-01, 3.63689741e+01, 1.25025218e-01],
       ...,
       [3.60204965e+01, 2.08640688e+01, 4.29606813e+00, ...,
        1.45056650e+01, 8.33854413e+00, 1.55690009e+01],
       [1.25128711e-01, 1.25247756e-01, 1.25005143e-01, ...,
        9.17278769e+01, 1.25177668e-01, 3.74575887e+01],
       [5.49258690e+01, 4.47009532e+00, 9.88524814e+00, ...,
        4.87048440e+01, 1.25034678e-01, 1.25074632e-01]])

In [4]:
def display_topics(model, feature_names, no_top_words):
    for topic_index, topic in enumerate(model.components_):
        print('Topic #',topic_index)

        # components_ array에서 가장 값이 큰 순으로 정렬했을 때, 그 값의 array index를 반환. 
        topic_word_indexes = topic.argsort()[::-1]
        top_indexes=topic_word_indexes[:no_top_words]
        
        # top_indexes대상인 index별로 feature_names에 해당하는 word feature 추출 후 join으로 concat
        feature_concat = ' '.join([feature_names[i] for i in top_indexes])                
        print(feature_concat)

# CountVectorizer객체내의 전체 word들의 명칭을 get_features_names( )를 통해 추출
feature_names = count_vect.get_feature_names()

# Topic별 가장 연관도가 높은 word를 15개만 추출 (모토사이클, 야구, 그래픽스, 윈도우즈, 중동, 기독교, 의학, 우주 주제)
display_topics(lda, feature_names, 15)


Topic # 0
year 10 game medical health team 12 20 disease cancer 1993 games years patients good
Topic # 1
don just like know people said think time ve didn right going say ll way
Topic # 2
image file jpeg program gif images output format files color entry 00 use bit 03
Topic # 3
like know don think use does just good time book read information people used post
Topic # 4
armenian israel armenians jews turkish people israeli jewish government war dos dos turkey arab armenia 000
Topic # 5
edu com available graphics ftp data pub motif mail widget software mit information version sun
Topic # 6
god people jesus church believe christ does christian say think christians bible faith sin life
Topic # 7
use dos thanks windows using window does display help like problem server need know run


In [5]:
doc_topics = lda.transform(feat_vect)
print(doc_topics.shape)

(7862, 8)


In [6]:
doc_topics

array([[0.01389701, 0.01394362, 0.01389104, ..., 0.01389205, 0.01393501,
        0.43424401],
       [0.27750436, 0.18151826, 0.0021208 , ..., 0.00212102, 0.00212113,
        0.00212125],
       [0.00544459, 0.22166575, 0.00544539, ..., 0.00544168, 0.00544182,
        0.74567512],
       ...,
       [0.35721917, 0.56773159, 0.01250428, ..., 0.01250179, 0.0125106 ,
        0.01251549],
       [0.00962015, 0.00962299, 0.00962142, ..., 0.00962064, 0.00961908,
        0.6037673 ],
       [0.06250258, 0.06251314, 0.06250191, ..., 0.06252018, 0.56240126,
        0.06251295]])

In [9]:
def get_filename_list(newsdata):
    filename_ls = []
    
    for f in newsdata.filenames:
        # print(f)
        fnm_tmp = f.split('/')[-2:]
        fnm = '.'.join(fnm_tmp)
        filename_ls.append(fnm)
        
    return filename_ls

filename_ls = get_filename_list(news_df)

In [11]:
import pandas as pd

topic_nms = ['Topic #' + str(i) for i in range(8)]
doc_topic_df = pd.DataFrame(data=doc_topics, columns=topic_nms, index=filename_ls)
doc_topic_df

Unnamed: 0,Topic #0,Topic #1,Topic #2,Topic #3,Topic #4,Topic #5,Topic #6,Topic #7
soc.religion.christian.20630,0.013897,0.013944,0.013891,0.482218,0.013979,0.013892,0.013935,0.434244
sci.med.59422,0.277504,0.181518,0.002121,0.530372,0.002121,0.002121,0.002121,0.002121
comp.graphics.38765,0.005445,0.221666,0.005445,0.005445,0.005440,0.005442,0.005442,0.745675
comp.graphics.38810,0.005439,0.005441,0.005449,0.578959,0.005440,0.388387,0.005442,0.005442
sci.med.59449,0.006584,0.552000,0.006587,0.408485,0.006585,0.006585,0.006588,0.006585
...,...,...,...,...,...,...,...,...
comp.windows.x.68298,0.008935,0.008942,0.008951,0.157459,0.008932,0.201530,0.008936,0.596316
soc.religion.christian.20723,0.002360,0.208372,0.002359,0.002361,0.002360,0.002359,0.777467,0.002361
rec.sport.baseball.102656,0.357219,0.567732,0.012504,0.012512,0.012505,0.012502,0.012511,0.012515
sci.electronics.53606,0.009620,0.009623,0.009621,0.338511,0.009618,0.009621,0.009619,0.603767
