In [3]:
import numpy as np
import pandas as pd
import dill as pickle
from dpgmm_vi import variational_inference
import os

In [4]:
year_list = [
    '1900_1909',
    '1910_1919',
    '1920_1929',
    '1930_1939',
    '1940_1949',
    '1950_1959',
    '1960_1969',
    '1970_1979',
    '1980_1989',
    '1990_1999',
    '2000_2009',
    '2010_2020'
]

In [5]:
country_list = ['china',
                'north korea',
                'south korea',
                'canada',
                'united kingdom',
                'germany']

In [6]:
concept_list = ['autocracy',
                'autocratic',
                'dictator',
                'dictatorship',
               'authoritarianism',
               'democracy']

In [7]:
def dpgmm_clusters(data, np_seed=23, tf_seed=100, alpha=1.0, T=100, n_iter=10):
    data = np.float32(data[np.newaxis,:,:])
    batch_number = 0
    nonzero_datapoints_batches = np.where(~np.all(data==0, axis=2))
    nonzero_datapoints = nonzero_datapoints_batches[1][np.where(nonzero_datapoints_batches[0] == batch_number)]
    
    inferred_latents = variational_inference(data, alpha=alpha, T=T, n_iter=n_iter, tf_seed=tf_seed, get_elbo=False)
    inferred_zeta = inferred_latents.zeta[batch_number,nonzero_datapoints,:]
    assignments = np.argmax(inferred_zeta, axis=1)
    
    bdf = pd.DataFrame(inferred_latents.zeta[0])
    zero_cols = [ col for col, is_zero in ((bdf == 0).sum() == bdf.shape[0]).items() if is_zero ]
    bdf = bdf.drop(zero_cols, axis=1).copy()
    bdf.replace(0, np.nan, inplace=True)
    
    #TODO! Can 'inferred_latents.nu' replace the cluster mean??
    
    return inferred_latents, assignments, bdf

In [8]:
def plot_dpgmm(inferred_latents, data):
    data = np.float32(data[np.newaxis,:,:])
    nonzero_datapoints_batches = np.where(~np.all(data==0, axis=2))
    batch_number = 0
    nonzero_datapoints = nonzero_datapoints_batches[1][np.where(nonzero_datapoints_batches[0] == batch_number)]
    inferred_zeta = inferred_latents.zeta[batch_number,nonzero_datapoints,:]
    assignments = np.argmax(inferred_zeta, axis=1)
    plt.scatter(data[batch_number,nonzero_datapoints,0],data[batch_number,nonzero_datapoints,1],c=assignments, marker='x')
    plt.scatter(inferred_latents.nu[batch_number,:,0] + 0.01*np.random.randn(T), inferred_latents.nu[batch_number,:,1],marker='o',s=30,color='r')

    #Plot expected standard deviation as diameter of ellipse
    patches = []; 
    diameter = 2*np.sqrt(1./np.divide(inferred_latents.a, inferred_latents.b))

    #Plot marginal cluster probabilities as the transparency of circle
    l1 = inferred_latents.lambda_1[batch_number,:]
    l2 = inferred_latents.lambda_2[batch_number,:]
    beta_means = np.divide(l1,l1 + l2)
    log_beta_means = np.log(beta_means + 1e-30)
    cs = np.concatenate(( [0], np.cumsum( np.log(1-beta_means+1e-30)[:-1]) )) #SBP
    beta_expectation = np.exp(log_beta_means + cs)
    beta_expectation /= (1.*np.sum(beta_expectation))               
    for k in range(100):
        circle = Ellipse((inferred_latents.nu[batch_number,k,0], inferred_latents.nu[batch_number,k,1]), diameter[batch_number,k,0], diameter[batch_number,k,1])
        plt.gca().add_artist(circle)
        circle.set_alpha(beta_expectation[k])   
    plt.gca().set_xlim([-10,10])
    plt.gca().set_ylim([-10,10])
    plt.gca().set_aspect('equal', adjustable='box')
    plt.title('Variational distributions')
    plt.show()

In [11]:
def run_dpgmm(emb_dict, term_list, year_list, term_type):
    min_size_dict = {}
    for term in term_list:
        min_size_dict[term] = {}
        for year in year_list:
            try:
                data = np.vstack(emb_dict[term][year])
                data.size
                try:
                    inferred_latents, assignments, df = dpgmm_clusters(data)
                    values, counts = np.unique(assignments, return_counts=True)
                    outname = term + '_' + year + '.csv'
                    outdir = './crs_cluster_csvs/bgmm/' + term_type + '/' + term
                    print(outdir)
                    if not os.path.exists(outdir):
                        os.makedirs(outdir)
                    fullname = os.path.join(outdir,outname)
                    df.to_csv(fullname, index=False)
                    outname2 = term + '_' + year + '_latents.pkl'
                    outdir2 = './crs_bgmm_latents/' + term_type + '/' + term
                    if not os.path.exists(outdir2):
                        os.makedirs(outdir2)
                    fullname2 = os.path.join(outdir2, outname2)
                    latents_dict = {'inferred_latents': inferred_latents, 'assignments': assignments, 'min_size': np.min(counts)}
                    with open(fullname2, 'wb') as handle:
                        pickle.dump(latents_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
                    del latents_dict,inferred_latents,assignments,df
                except:
                    print(term + ' ' + year + ' ' + 'failed.')
                    continue
            except:
                print(term + ' ' + year + ' ' + 'empty.')
                continue

In [12]:
# LOAD COUNTRY EMBEDS

with open('./crs_embeds/country_embeds.pkl', 'rb') as handle:
    country_embeddings = pickle.load(handle)

In [13]:
run_dpgmm(country_embeddings, country_list, year_list, 'countries')

Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))
./crs_cluster_csvs/bgmm/countries/china
./crs_cluster_csvs/bgmm/countries/china
./crs_cluster_csvs/bgmm/countries/china
./crs_cluster_csvs/bgmm/countries/china
./crs_cluster_csvs/bgmm/countries/china
./crs_cluster_csvs/bgmm/countries/china
./crs_cluster_csvs/bgmm/countries/china
./crs_cluster_csvs/bgmm/countries/china
./crs_cluster_csvs/bgmm/countries/china
./crs_cluster_csvs/bgmm/countries/china
./crs_cluster_csvs/bgmm/countries/china
./crs_cluster_csvs/bgmm/countries/china
north korea 1900_1909 empty.
north korea 1910_1919 empty.
north korea 1920_1929 empty.
north korea 1930_1939 empty.
./crs_cluster_csvs/bgmm/countries/north korea
./crs_cluster_csvs/bgmm/countries/north korea
./crs_cluster_csvs/bgmm/countries/north korea
./crs_cluster_csvs/b

In [14]:
del country_embeddings

In [15]:
# LOAD CONCEPT EMBEDS

with open('./crs_embeds/concept_embeds.pkl', 'rb') as handle:
    concept_embeddings = pickle.load(handle)

In [16]:
run_dpgmm(concept_embeddings, concept_list, year_list, 'concepts')

./crs_cluster_csvs/bgmm/concepts/autocracy
./crs_cluster_csvs/bgmm/concepts/autocracy
./crs_cluster_csvs/bgmm/concepts/autocracy
./crs_cluster_csvs/bgmm/concepts/autocracy
./crs_cluster_csvs/bgmm/concepts/autocracy
./crs_cluster_csvs/bgmm/concepts/autocracy
./crs_cluster_csvs/bgmm/concepts/autocracy
./crs_cluster_csvs/bgmm/concepts/autocracy
./crs_cluster_csvs/bgmm/concepts/autocracy
./crs_cluster_csvs/bgmm/concepts/autocracy
./crs_cluster_csvs/bgmm/concepts/autocracy
./crs_cluster_csvs/bgmm/concepts/autocracy
./crs_cluster_csvs/bgmm/concepts/autocratic
./crs_cluster_csvs/bgmm/concepts/autocratic
./crs_cluster_csvs/bgmm/concepts/autocratic
./crs_cluster_csvs/bgmm/concepts/autocratic
./crs_cluster_csvs/bgmm/concepts/autocratic
./crs_cluster_csvs/bgmm/concepts/autocratic
./crs_cluster_csvs/bgmm/concepts/autocratic
./crs_cluster_csvs/bgmm/concepts/autocratic
./crs_cluster_csvs/bgmm/concepts/autocratic
./crs_cluster_csvs/bgmm/concepts/autocratic
./crs_cluster_csvs/bgmm/concepts/autocratic


In [17]:
del concept_embeddings