In [1]:
num_docs = 50000
num_ideal_points = 1
vocab_size = 100
num_covs = 1
min_words = 200
max_words = 300

In [None]:
import numpy as np
from tqdm import tqdm

def generate_documents(
    num_docs, 
    num_ideal_points, 
    vocab_size, 
    num_covs,
    lambda_, 
    sigma,
    min_words=50, 
    max_words=100, 
    random_seed=42
):

    if random_seed is not None:
        np.random.seed(random_seed)

    if num_covs > 0:
        M_prevalence_covariates = np.zeros((num_docs, num_covs + 1), dtype=int)
        M_prevalence_covariates[:, 0] = 1
        for i in range(num_covs):
            M_prevalence_covariates[:, i + 1] = np.random.randint(2, size=num_docs)     
            
        mean = np.dot(M_prevalence_covariates, lambda_)
        samples = []
        for m in mean:
            sample = np.random.multivariate_normal(m, sigma)
            samples.append(sample)
        true_ideal_points = np.array(samples)

    else:
        M_prevalence_covariates = None
        true_ideal_points = np.random.normal(size=(num_docs, num_ideal_points))
        
    ideal_point_word_matrix = np.random.rand(num_ideal_points, vocab_size)
    
    documents = []

    for i in tqdm(range(num_docs)):
        doc_length = np.random.randint(min_words, max_words+1)
        doc_ideal_points = true_ideal_points[i]

        doc_words = []
        for _ in range(doc_length):
            word_probs = np.dot(doc_ideal_points, ideal_point_word_matrix)
            word_probs = np.exp(word_probs) / np.sum(np.exp(word_probs))  
            word_index = np.random.choice(vocab_size, p=word_probs.flatten())
            doc_words.append('word_' + str(word_index))

        doc_words = ' '.join(doc_words)
        documents.append(doc_words)

    return true_ideal_points, M_prevalence_covariates, documents

if num_covs > 0:
    np.random.seed(1)
    lambda_ = np.array(
        [[0],[1]]
    )
    sqrt_sigma = np.random.rand(num_ideal_points, num_ideal_points)
    sigma = sqrt_sigma * sqrt_sigma.T
    
    true_ideal_points, M_prevalence_covariates, documents = generate_documents(
        num_docs, num_ideal_points, vocab_size, num_covs, lambda_, sigma, min_words, max_words
    )
else:
    true_ideal_points, M_prevalence_covariates, documents = generate_documents(
        num_docs, num_ideal_points, vocab_size, num_covs = 0, lambda_ = None, sigma = None, min_words = min_words, max_words = max_words
    )

 13%|████████████████▏                                                                                                              | 6372/50000 [01:22<08:59, 80.87it/s]

In [None]:
import pandas as pd

df = pd.DataFrame({"doc_clean": documents, "doc": documents})
if num_covs > 0:
    df['cov_0'] = 1
    df['cov_1'] = M_prevalence_covariates[:,1]
df

In [None]:
import sys
sys.path.append('../IdealPointNN/')
from corpus import IdealPointCorpus
from ideal_point_model import IdealPointNN

if num_covs > 0:
    train_dataset = IdealPointCorpus(
        df,
        prevalence = "~ cov_0 + cov_1 - 1",
    )
    m = IdealPointNN(
        train_dataset, 
        n_dims=1,
        update_prior=True,
        encoder_hidden_layers=[],
        decoder_hidden_layers=[],
        log_every_n_epochs = 1,
        print_every_n_batches = 100
    )
else:
    train_dataset = IdealPointCorpus(
        df
    )
    m = IdealPointNN(
        train_dataset, 
        n_dims=1,
        update_prior=False,
        encoder_hidden_layers=[],
        decoder_hidden_layers=[],
        log_every_n_epochs = 1,
        print_every_n_batches = 10
    )

In [None]:
import numpy as np
import matplotlib.pyplot as plt

doc_dims = m.get_doc_dims(train_dataset)

x = doc_dims
y = true_ideal_points

plt.scatter(x, y, label='Data points', s=1)
coefficients = np.polyfit(x.flatten(), y.flatten(), 1)
fit = np.poly1d(coefficients)
plt.plot(x, fit(x), color='red', label='Linear Fit')
plt.xlabel('Estimates')
plt.ylabel('True Value')
plt.title('True vs. Estimated Ideal Points')
plt.legend()
plt.show()

In [None]:
if num_covs > 0:
    print('True lambda: {}'.format(lambda_))
    print('Estimated lambda: {}'.format(m.prior.lambda_))