In [None]:
import pyro
import pyro.distributions as dist
from pyro.distributions import constraints
from pyro.infer import SVI, JitTrace_ELBO
from pyro.optim import Adam
import pandas as pd
import torch
from tqdm import tqdm
import us
import seaborn as sns
import matplotlib.pyplot as plt 
import numpy as np
from scipy.optimize import root
pyro.enable_validation()

## Load and graph data

In [None]:
voter_survey = pd.read_stata('../data/election_results/pew_research_center_june_elect_wknd_data.dta',
               columns=['state', 'ideo'])
voter_survey['state'] = voter_survey['state'].map(us.states.lookup)
voter_survey = voter_survey.dropna()
voter_survey['state'] = voter_survey['state'].map(lambda st: st.abbr)
voter_survey

In [None]:
election_results = pd.read_csv('../data/election_results/2008ElectionResult.csv', usecols=['state', 'vote_Obama_pct'], sep=r'\s*,\s*')
election_results['state'] = election_results['state'].map(lambda st: us.states.lookup(st).abbr)
election_results

In [None]:
voter_survey_total = voter_survey.groupby('state').count().rename(columns={'ideo': 'total'})
voter_survey_very_liberal = voter_survey[voter_survey['ideo'] == 'very liberal'].groupby('state').count().rename(columns={'ideo': 'very_liberal'})
voter_survey_stat = pd.merge(voter_survey_total, voter_survey_very_liberal, on='state')
voter_survey_stat['very_liberal_pct'] = 100 * voter_survey_stat['very_liberal'] / voter_survey_stat['total']
voter_survey_stat = voter_survey_stat.round(1)

In [None]:
combined = pd.merge(election_results, voter_survey_stat, on='state')
combined

In [None]:
p1 = sns.scatterplot('very_liberal_pct', 'vote_Obama_pct', data=combined)
for line in range(0,combined.shape[0]):
     p1.text(combined['very_liberal_pct'][line]+0.01, combined['vote_Obama_pct'][line], 
     combined['state'][line], horizontalalignment='left', 
     size='medium', color='black', weight='semibold')
plt.xlim(0, 12.)
plt.ylim(30, 70)
plt.xlabel('very liberal (%)')
plt.ylabel('Obama vote (%)')
plt.show()

## Calculate Prior Parameters

In [None]:
total = torch.from_numpy(combined['total'].to_numpy())
very_liberal = torch.from_numpy(combined['very_liberal'].to_numpy())
(total, very_liberal)

In [None]:
very_liberal_mean = (combined['very_liberal_pct']  / 100.0).mean()
very_liberal_var = (combined['very_liberal_pct'] / 100.0).var() 

In [None]:
conc1 = very_liberal_mean * ((very_liberal_mean) * (1 - very_liberal_mean) / very_liberal_var - 1)
conc0 = conc1 * (1 - very_liberal_mean) / very_liberal_mean
(conc1, conc0)

## Inference

In [None]:
def model(total, very_liberal):
    with pyro.plate('data', total.size(0)):
        pass # TODO: Write Model

In [None]:
def guide(total, very_liberal):
    with pyro.plate('data', total.size(0)):
        pass # TODO: Write Guide

In [None]:
svi = SVI(model, guide, Adam({'lr': 0.1}), JitTrace_ELBO(3))
pbar = tqdm(range(1000))
for i in pbar:
    loss = svi.step(total, very_liberal)
    pbar.set_description(f"Loss: {loss}")

In [None]:
posterior_mean = ... # TODO: Replace ... with posterior mean
posterior_mean

In [None]:
combined['very_liberal_posterior_pct'] = posterior_mean * 100.0
combined = combined.round(1)
combined

In [None]:
p1 = sns.scatterplot('very_liberal_posterior_pct', 'vote_Obama_pct', data=combined)
for line in range(0,combined.shape[0]):
     p1.text(combined['very_liberal_posterior_pct'][line]+0.01, combined['vote_Obama_pct'][line], 
     combined['state'][line], horizontalalignment='left', 
     size='medium', color='black', weight='semibold')
plt.xlim(0, 12.)
plt.ylim(30, 70)
plt.xlabel('very liberal posterior (%)')
plt.ylabel('Obama vote (%)')
plt.show()