Python implementation of a causal topic modeling paper.

This program implements the following paper:
<blockquote>
    <p>Hyun Duk Kim, Malu Castellanos, Meichun Hsu, ChengXiang Zhai, Thomas Rietz, and Daniel Diermeier. 2013. Mining causal topics in text data: Iterative topic modeling with time series feedback. In Proceedings of the 22nd ACM international conference on information & knowledge management (CIKM 2013). ACM, New York, NY, USA, 885-890. DOI=10.1145/2505515.2505612</p>
</blockquote>

In [1]:
# Import libraries
import numpy as np
import pandas as pd
import statsmodels.api as sm
from statsmodels.tsa.stattools import grangercausalitytests
from statsmodels.tsa.api import VAR

In [2]:
# Import data
pres_market = pd.read_csv("./data/PRES00_WTA.csv", skipinitialspace=True)
pres_market = pres_market.set_index("Date")
pres_market.index = pd.to_datetime(pres_market.index)

AAMRQ = pd.read_csv("./data/AAMRQ.csv")
AAMRQ = AAMRQ.set_index("Date")
AAMRQ.index = pd.to_datetime(AAMRQ.index)

AAPL = pd.read_csv("./data/AAPL.csv")
AAPL = AAPL.set_index("Date")
AAPL.index = pd.to_datetime(AAPL.index)

In [3]:
pres_market

Unnamed: 0_level_0,Contract,Units,$Volume,LowPrice,HighPrice,AvgPrice,LastPrice
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
2000-05-01,Dem,224,112.043,0.490,0.550,0.500,0.550
2000-05-01,Reform,2,0.067,0.019,0.048,0.034,0.019
2000-05-01,Rep,116,57.95,0.488,0.501,0.500,0.500
2000-05-02,Dem,87,44.369,0.501,0.522,0.510,0.508
2000-05-02,Reform,50,0.196,0.003,0.005,0.004,0.003
...,...,...,...,...,...,...,...
2000-11-09,Reform,2065,2.062,0.000,0.001,0.001,0.000
2000-11-09,Rep,10055,542.973,0.025,0.109,0.054,0.050
2000-11-10,Dem,3454,3328.02,0.950,0.980,0.964,0.969
2000-11-10,Reform,23,0.02,0.000,0.001,0.001,0.000


In [4]:
# follow standard practice in the field and use the “normalized” price of one candidate
# as a forecast probability of the election outcome: (Gore price)/(Gore price + Bush price)
gore_price = pres_market.loc[pres_market["Contract"] == "Dem"]["AvgPrice"]
bush_price = pres_market.loc[pres_market["Contract"] == "Rep"]["AvgPrice"]
pres_market_forcprob = gore_price / (gore_price + bush_price)
pres_market_forcprob

Date
2000-05-01    0.500000
2000-05-02    0.507463
2000-05-03    0.508492
2000-05-04    0.510490
2000-05-05    0.519115
                ...   
2000-11-06    0.270378
2000-11-07    0.330986
2000-11-08    0.806452
2000-11-09    0.945838
2000-11-10    0.958250
Name: AvgPrice, Length: 192, dtype: float64

In [5]:
AAMRQ

Unnamed: 0_level_0,Open,High,Low,Close,Adj Close,Volume
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2000-07-03,26.63,26.63,26.00,26.13,26.13,483100
2000-07-05,27.25,28.88,27.06,28.38,28.38,1840000
2000-07-06,28.44,29.56,27.81,29.00,29.00,1820000
2000-07-07,29.81,29.94,29.13,29.13,29.13,1150000
2000-07-10,29.75,30.13,29.19,30.00,30.00,711800
...,...,...,...,...,...,...
2001-12-24,21.72,21.73,20.77,21.19,21.19,1350000
2001-12-26,21.37,21.74,21.18,21.57,21.57,938900
2001-12-27,21.35,21.79,21.20,21.50,21.50,1190000
2001-12-28,21.60,22.19,21.55,22.00,22.00,853000


In [6]:
AAMRQ_close = AAMRQ["Close"]

In [7]:
AAPL

Unnamed: 0_level_0,Open,High,Low,Close,Adj Close,Volume
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2000-07-03,0.930804,0.969866,0.930804,0.952009,0.821251,70828800
2000-07-05,0.950893,0.985491,0.906250,0.921875,0.795256,265216000
2000-07-06,0.937500,0.945313,0.886161,0.925223,0.798145,309545600
2000-07-07,0.939174,0.978795,0.930804,0.972098,0.838581,263603200
2000-07-10,0.965960,1.040179,0.959821,1.020089,0.879981,397796000
...,...,...,...,...,...,...
2001-12-21,0.375179,0.384643,0.371429,0.375000,0.323494,256334400
2001-12-24,0.373214,0.383036,0.373214,0.381429,0.329040,50629600
2001-12-26,0.381250,0.398214,0.377500,0.383750,0.331042,146400800
2001-12-27,0.385357,0.397321,0.385357,0.394107,0.339977,191508800


In [8]:
AAPL_close = AAPL["Close"]

# Granger Test

"Granger tests...measur[e] statistical significance at different time lags using auto regression to identify causal relationships. Let $y_{t}$ and $x_{t}$ be two time series. To see if $x_{t}$ 'Granger causes' $y_{t}$ with maximum $p$ time lag, run the following regression:

$$
y_{t} = a_{0} + a_{1}y_{t−1} + ... + a_{p}y_{t−p} + b_{1}x_{t−1} + ... + b_{p}x_{t−p}
$$

Then, use F-tests to evaluate the significance of the lagged $x$ terms. The coefficients of lagged $x$ terms estimate the impact of $x$ on $y$. We average the $x$ term coefficients, $\frac{\sum_{i=1}^{p}b_{i}}{|b|}$, as an impact value."

In [9]:
close = pd.concat([AAMRQ["Close"], AAPL["Close"]], axis=1, keys=["AAMRQ", "AAPL"])
close = close.rolling(3, center=True, min_periods=2).mean()
close = close.diff()[1:]
close

Unnamed: 0_level_0,AAMRQ,AAPL
Date,Unnamed: 1_level_1,Unnamed: 2_level_1
2000-07-05,0.581667,-0.003906
2000-07-06,1.000000,0.006696
2000-07-07,0.540000,0.032738
2000-07-10,0.146667,0.030506
2000-07-11,0.500000,0.026414
...,...,...
2001-12-24,-0.026667,0.004881
2001-12-26,-0.130000,0.006369
2001-12-27,0.270000,0.006369
2001-12-28,0.243333,0.004524


In [10]:
# Is first column "caused by" second column up to a given lag?
gc_res = grangercausalitytests(close, 5)

ValueError: x contains NaN or inf values.

In [11]:
p_vals = []
for i in range(1, len(gc_res) + 1):
    p_vals.append(gc_res[i][0]['params_ftest'][1])

NameError: name 'gc_res' is not defined

In [None]:
p_vals

In [None]:
np.argmin(p_vals)

In [None]:
np.argmax(np.subtract(1, p_vals))

In [None]:
sig = np.subtract(1, p_vals)

In [None]:
sig

In [None]:
gc_res[1][1][0].summary()

# Import NYTAC

Do not run this code below. It was to run the original data cleaning steps. Running again will delete the NYTAC in storage. It has been commented out for safety.

In [12]:
"""
import os
import shutil
import tarfile
import xml.etree.ElementTree as ET

tars = []
for root, dirs, files in os.walk("./data/nyt_corpus/data"):
    if dirs:
        delete = dirs.copy()
        delete[:] = [x for x in dirs if x not in ['2000', '2001', '2002', '2003']]
        dirs[:] = [x for x in dirs if x in ['2000', '2001', '2002', '2003']]
        for name in delete:
            subdir = os.path.join(root, name)
            with os.scandir(subdir) as it:
                for entry in it:
                    os.remove(entry)
            os.rmdir(subdir)
    if files:
        if os.path.basename(root) == '2003':
            delete = files.copy()
            delete = [x for x in files if x not in ['01.tgz', '02.tgz', '03.tgz']]
            files[:] = [x for x in files if x in ['01.tgz', '02.tgz', '03.tgz']]
            for name in delete:
                os.remove(os.path.join(root, name))
        for file in files:
            tars.append(os.path.join(root, file))

for file_path in tars:
    tar = tarfile.open(file_path)
    tar.extractall(path=os.path.dirname(file_path))
    tar.close()
    os.remove(file_path)

# collect articles for 2000 Presidential Election
with os.scandir("./data/nyt_corpus/data/2000") as it:
    for entry in it:
        if os.path.basename(entry) in ['05', '06', '07', '08', '09', '10']:
            shutil.copytree(entry, os.path.join("./data/nyt_corpus/data/election/2000", os.path.basename(entry)))

# collect articles for Stock Time Series, AAMRQ vs. AAPL
with os.scandir("./data/nyt_corpus/data/2000") as it:
    for entry in it:
        if os.path.basename(entry) in ['07', '08', '09', '10', '11', '12']:
            shutil.copytree(entry, os.path.join("./data/nyt_corpus/data/stock/2000", os.path.basename(entry)))
with os.scandir("./data/nyt_corpus/data/2001") as it:
    for entry in it:
        shutil.copytree(entry, os.path.join("./data/nyt_corpus/data/stock/2001", os.path.basename(entry)))

# collect articles for Iraq War
with os.scandir("./data/nyt_corpus/data/2002") as it:
    for entry in it:
        shutil.copytree(entry, os.path.join("./data/nyt_corpus/data/war/2002", os.path.basename(entry)))
with os.scandir("./data/nyt_corpus/data/2003") as it:
    for entry in it:
        if os.path.basename(entry) in ['01', '02', '03']:
            shutil.copytree(entry, os.path.join("./data/nyt_corpus/data/war/2003", os.path.basename(entry)))

# remove unused directories
for year in ['2000', '2001', '2002', '2003']:
    shutil.rmtree(os.path.join("./data/nyt_corpus/data", os.path.basename(year)))

# initialize list of documents to delete
delete = []

# delete documents that do not contain "Bush" and "Gore" or do not contain document bodies
for root, dirs, files in os.walk("./data/nyt_corpus/data/election"):
    if files:
        for name in files:
            tree = ET.parse(os.path.join(root, name))
            tree_root = tree.getroot()
            element = tree_root.find('./body/body.content/block[@class="full_text"]')
            if element:
                keep = 0
                Bush = 0
                Gore = 0
                for para in element.findall('p'):
                    para_list = nltk.word_tokenize(para.text)
                    if 'Bush' in para_list:
                        Bush = 1
                    if 'Gore' in para_list:
                        Gore = 1
                    keep = Bush * Gore
                if not keep:
                    delete.append(os.path.join(root, name))
            else:
                delete.append(os.path.join(root, name))

# delete documents that do not contain document bodies
for root, dirs, files in os.walk("./data/nyt_corpus/data/stock"):
    if files:
        for name in files:
            tree = ET.parse(os.path.join(root, name))
            tree_root = tree.getroot()
            element = tree_root.find('./body/body.content/block[@class="full_text"]')
            if not element:
                delete.append(os.path.join(root, name))

# delete documents that do not contain "Iraq" or do not contain document bodies
for root, dirs, files in os.walk("./data/nyt_corpus/data/war"):
    if files:
        for name in files:
            tree = ET.parse(os.path.join(root, name))
            tree_root = tree.getroot()
            element = tree_root.find('./body/body.content/block[@class="full_text"]')
            if element:
                keep = 0
                for para in element.findall('p'):
                    para_list = nltk.word_tokenize(para.text)
                    if 'Iraq' in para_list:
                        keep = 1
                if not keep:
                    delete.append(os.path.join(root, name))
            else:
                delete.append(os.path.join(root, name))

# delete the unneeded documents
for name in delete:
    os.remove(name)

# delete empty directories
for root, dirs, files in os.walk("./data/nyt_corpus/data"):
    if not dirs and not files:
        os.rmdir(root)

# consolidate xml files into text files
# one text file contains the documents from the date in the file name
# documents are stored one document per line
for root, dirs, files in os.walk("./data/nyt_corpus/data"):
    if files:
        yyyy = os.path.basename(os.path.dirname(os.path.dirname(root)))
        mm = os.path.basename(os.path.dirname(root))
        dd = os.path.basename(root)
        file_name = yyyy + "-" + mm + "-" + dd + ".txt"
        base_folder = os.path.basename(os.path.dirname(os.path.dirname(os.path.dirname(root))))
        directory = os.path.join("./data/nyt_corpus/data", base_folder)
        f = open(os.path.join(directory, file_name), "w")
        for name in files:
            tree = ET.parse(os.path.join(root, name))
            tree_root = tree.getroot()
            element = tree_root.find('./body/body.content/block[@class="full_text"]')
            paragraphs = []
            for para in element.findall('p'):
                paragraphs.append(para.text)
            f.write(" ".join(paragraphs) + "\n")
        f.close()

# remove unused directories
for subdir in ["election/2000", "stock/2000", "stock/2001", "war/2002", "war/2003"]:
    shutil.rmtree(os.path.join("./data/nyt_corpus/data", subdir))
"""

# Set Up Corpus

In [10]:
from causal_topic_mining import ITMTF

[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/ubuntu/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package brown to /home/ubuntu/nltk_data...
[nltk_data]   Package brown is already up-to-date!
[nltk_data] Downloading package names to /home/ubuntu/nltk_data...
[nltk_data]   Package names is already up-to-date!
[nltk_data] Downloading package wordnet to /home/ubuntu/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/ubuntu/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package universal_tagset to
[nltk_data]     /home/ubuntu/nltk_data...
[nltk_data]   Package universal_tagset is already up-to-date!


In [11]:
itmtf = ITMTF("./data/nyt_corpus/data/test", pres_market_forcprob)
itmtf.build_corpus()
itmtf.build_vocabulary()
itmtf.process(number_of_topics = 30, max_plsa_iter = 1, epsilon = 0.001, mu = 1000, itmtf_iter = 5)

Initial PLSA: 100%|██████████| 1/1 [01:43<00:00, 103.22s/it]
  pos_rows = len(wc[wc["Topic"] == topic][wc["Impact_Value"] >= 0])
  neg_rows = len(wc[wc["Topic"] == topic][wc["Impact_Value"] < 0])
  wc = wc.drop(wc[wc["Topic"] == topic][wc["Impact_Value"] >= 0].index)

PLSA at end of ITMTF iter.:   0%|          | 0/1 [00:00<?, ?it/s][A
PLSA at end of ITMTF iter.: 100%|██████████| 1/1 [01:42<00:00, 102.01s/it][A
  pos_rows = len(wc[wc["Topic"] == topic][wc["Impact_Value"] >= 0])
  neg_rows = len(wc[wc["Topic"] == topic][wc["Impact_Value"] < 0])
ITMTF Loop:  20%|██        | 1/5 [04:31<18:07, 272.00s/it]
PLSA at end of ITMTF iter.:   0%|          | 0/1 [00:00<?, ?it/s][A
PLSA at end of ITMTF iter.: 100%|██████████| 1/1 [01:42<00:00, 102.96s/it][A
ITMTF Loop:  40%|████      | 2/5 [09:09<13:40, 273.62s/it]
PLSA at end of ITMTF iter.:   0%|          | 0/1 [00:00<?, ?it/s][A
PLSA at end of ITMTF iter.: 100%|██████████| 1/1 [01:40<00:00, 100.91s/it][A
ITMTF Loop:  60%|██████    | 3/5 [13:

In [12]:
itmtf.average_entropy

[-0.16827301978900808, -0.16587823563530482, -0.185860533573547, nan, nan]

In [13]:
itmtf.average_topic_purity

[83.1726980210992, 83.41217643646952, 81.41394664264529, nan, nan]

In [14]:
itmtf.average_causality_confidence

[0.9482155299452735,
 0.9429624014181818,
 0.9338831798020352,
 0.9194746035568545,
 0.9014388100851503]