In [1]:
import pandas as pd 
df = pd.read_csv('../data/us_congressional_record/us_congress_speeches_processed.csv')
df = df.sample(n=50000, random_state = 42)
df

In [2]:
import sys
sys.path.append('../gtm/')
from corpus import GTMCorpus
from gtm import GTM

# Create a GTMCorpus object
train_dataset = GTMCorpus(
    df, 
    prevalence = "~ party + C(speech_year)",  
    content = "~ party + C(speech_year)" 
)

train_dataset.M_prevalence_covariates.shape

In [3]:
train_dataset.M_prevalence_covariates

In [4]:
train_dataset.prevalence_colnames

In [5]:
# Train the model
tm = GTM(
    train_dataset, 
    n_topics=50,
    doc_topic_prior='dirichlet', # logistic_normal
    alpha=0.02,
    prevalence_covariates_regularization=0.1,
    update_prior=True,
    encoder_hidden_layers=[], # structure of the encoder neural net
    decoder_hidden_layers=[300], # structure of the decoder neural net
    num_epochs=100,
    print_every=10000,
    log_every=1,
    w_prior=None,
    batch_size=250
)

In [17]:
# Assess the quality of the learned word embeddings 
# Top 8 closest words to a specific word

import torch
import torch.nn.functional as F

specific_word = 'patriot'

word_id = [i for i,w in enumerate(train_dataset.vocab) if w == specific_word][0]

words = tm.AutoEncoder.decoder['dec_1'].weight.T

logit = torch.matmul(words.T[word_id], words)

beta = F.softmax(logit)

tm.AutoEncoder.eval()
topic_words = []
vals, indices = torch.topk(beta, 8)
vals = vals.cpu().tolist()
indices = indices.cpu().tolist()
[tm.id2token[idx] for idx in indices]

In [18]:
dfc = tm.estimate_effect(train_dataset, n_samples=10, topic_ids=None)

In [30]:
dfc[dfc['topic'] == 49]

In [None]:
import statsmodels.api as sm
Y = tm.get_doc_topic_distribution(train_dataset)
X = train_dataset.M_prevalence_covariates
model = sm.OLS(Y[:,43],X)
results = model.fit()
covs = train_dataset.prevalence_colnames
pd.DataFrame([covs, results.params])

In [None]:
tm.get_top_docs(train_dataset, topic_id = 47)

In [None]:
ldavis_format = tm.get_ldavis_data_format(train_dataset)
import pyLDAvis
gtm_vis_data = pyLDAvis.prepare(**ldavis_format, sort_topics=False)
pyLDAvis.display(gtm_vis_data)

In [None]:
tm.plot_wordcloud(topic_id = 18)