In [7]:
import os
import re
import json
import sys
import argparse
import logging

from tqdm import tqdm
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio

from gensim.models import CoherenceModel
from gensim.models.wrappers import LdaMallet

from IPython.display import display, Markdown

from ldamodel import LdaModel

logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(levelname)s: %(message)s',
                    level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger()
gensim_logger = logging.getLogger('gensim')
gensim_logger.setLevel(logging.WARNING)

pio.templates.default = "plotly_white"

DATASET_DIR = "data/COVID-19-Tweets-geo"

FIG_SAVE_DIR = "docs/lda_vis"
PAPER_FIG_DIR = 'paper/coling2020/Figures'


In [3]:
# load data

def load_corpus(dataset_dir):
    corpus = []
    for month_dir in os.listdir(dataset_dir):
        month_path = os.path.join(dataset_dir, month_dir)
        if not os.path.isdir(month_path):
            continue
        data_files = os.listdir(month_path)
        for filename in tqdm(data_files, total=len(data_files), desc=month_dir):
            path = os.path.join(month_path, filename)
            if re.match(r'coronavirus-tweet-annotated-2020-\d\d-\d\d-\d\d.jsonl', filename):
                with open(path) as f:
                    for line in f:
                        tweet = json.loads(line)
                        corpus.append(tweet['candidates'])
    return corpus

logger.info(f'Loading data from {DATASET_DIR}')

corpus = load_corpus(DATASET_DIR)

logger.info(f'{len(corpus)} tweets loaded')

2020-07-01 18:16:12 INFO: Loading data from data/COVID-19-Tweets-geo


2020-01: 100%|██████████| 968/968 [00:00<00:00, 2360.36it/s]
2020-02: 100%|██████████| 2624/2624 [00:00<00:00, 2712.91it/s]
2020-03: 100%|██████████| 2972/2972 [00:01<00:00, 1735.52it/s]
2020-04: 100%|██████████| 2880/2880 [00:01<00:00, 1756.37it/s]

2020-07-01 18:16:16 INFO: 498852 tweets loaded





In [4]:
TOPIC_NUMS = [5, 10, 15, 20, 25, 30, 50, 100, 150, 200]
TOPN = 20

mallet_lda_model = "dump/topics-{}-mallet-iter2000-6.2/lda.model"

scores = []

pbar = tqdm(TOPIC_NUMS, total=len(TOPIC_NUMS))

for topic_num in pbar:
    pbar.set_description(f'topic_num={topic_num}')

    mallet_model_path =  mallet_lda_model.format(topic_num)
    mallet_model = LdaMallet.load(mallet_model_path)
    mallet_score = CoherenceModel(model=mallet_model,
                                  texts=corpus,
                                  coherence='c_v',
                                  topn=TOPN).get_coherence()
    scores.append({'num_topics': topic_num,
                   'mallet': mallet_score
                  })


topic_num=200: 100%|██████████| 10/10 [03:45<00:00, 22.51s/it]


In [5]:
df_score = pd.DataFrame(scores)
display(Markdown(df_score.to_markdown(showindex=False)))

|   num_topics |   mallet |
|-------------:|---------:|
|            5 | 0.349906 |
|           10 | 0.42745  |
|           15 | 0.468312 |
|           20 | 0.514389 |
|           25 | 0.52172  |
|           30 | 0.533517 |
|           50 | 0.547147 |
|          100 | 0.542604 |
|          150 | 0.528016 |
|          200 | 0.504825 |

In [11]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=df_score['num_topics'],
                         y=df_score['mallet'],
                         mode='lines+markers',
                         name='score'))
fig.update_layout(title="Topic Coherence Score vs. Number of Topics",
                  xaxis_title="number of topics",
                  yaxis_title="Coherence（Cv）",
                  width=300, height=180)
fig.write_html(os.path.join(FIG_SAVE_DIR, 'topics_coherence.html'))
fig.write_image(os.path.join(FIG_SAVE_DIR, 'topics_coherence.svg'))

fig.update_layout(title="",
                  margin={
                    'l': 20,
                    'r': 0,
                    'b': 0,
                    't': 5,
                    'pad': 5
                  })
fig.update_yaxes(range=[0.3, 0.6])

fig.write_image(os.path.join(PAPER_FIG_DIR, 'topics_coherence.pdf'))
fig.show()