### Training a corpus-wide RMN with tfidf embeddings

In [1]:
import os
import sys
import pandas as pd

In [2]:
sys.path.append("/home/rocassius/w266_final/scripts/assembly")
sys.path.append("/home/rocassius/w266_final/scripts/modeling")

In [3]:
from document import load_documents
from constant import DOC_PRAYER_PATH, MIN_SESSION, MAX_SESSION, DOC_ALL_PATH
from subject import subject_keywords

sessions = list(range(MIN_SESSION, MAX_SESSION+1))

In [4]:
from helper import *
from rmn import *
from rmn_data_generator import RMN_DataGenerator
from rmn_analyzer import RMN_Analyzer

In [5]:
# load embedding tools
prayer_tools_path = "/home/rocassius/gen-data/tools/prayer_tools"
metadata_dict = load_pickled_object(os.path.join(prayer_tools_path, "metadata_dict"))
tokenizer_dict = load_pickled_object(os.path.join(prayer_tools_path, "tokenizer_dict"))
embedding_matrix = load_pickled_object(os.path.join(prayer_tools_path, "idf_embedding_matrix"))
global_embedding_matrix = load_pickled_object(os.path.join(prayer_tools_path, "embedding_matrix_wg"))
global_tokenizer_dict = load_pickled_object(os.path.join(prayer_tools_path, "tokenizer_dict_wg"))

In [6]:
docs_df = load_documents([66], DOC_PRAYER_PATH)

In [7]:
data_df = docs_df.sample(2347)

In [8]:
docs_df.shape

(35475, 10)

In [9]:
local_models_path = "/home/rocassius/gen-data/models"

In [10]:
rmn = RigidRMN()
rmn.load_rmn("SuaveRanger", local_models_path)
rmn.infer_embedding_matrix = global_embedding_matrix
rmn.infer_tokenizer_dict = global_tokenizer_dict

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [53]:
#=======#====================#
#=*= RMN Analyzer =*=#
#====================#

# Class for analyzing an RMN

import numpy as np
import pandas as pd
from analysis import *

# variable constants
SUB = 'subject'
SPEAK = 'speakerid'
PARTY = 'party'
SESS = 'session'
# party constants
R = 'R'
D = 'D'
# metric constants
JS = 'js'
HH = 'hh'
EN = 'entr'
N_REC = 'n_records'
N_NAN = 'n_nan_preds'
TP = 'topic_use'


class RMN_Analyzer(object):
    """Class for Analyzing an RMN with respect to a dataset
    """
    
    def __init__(self, rmn, df):
        """
        Args:
        - rmn: (RMN) the RMN to be used for analysis
        - df : (DataFrame) the dataframe to analyze
        """
        
        'Initialization'
        self.rmn = rmn
        self.df = df.reset_index(drop=True)
        self.topic_preds = None
        self.y_preds = None
        
    @property
    def index(self):
        return self.df.index
         
        
    def predict_topics(self, use_generator=True):
        """Computes the topic predictions for all observations
        """
        self.topic_preds = self.rmn.predict_topics(self.df, use_generator)
        
    
    def predict_y(self, use_generator=True):
        """Computes the sentence vector predictions for all observations
        """
        self.y_preds = self.rmn.predict_y(self.df, use_generator)
        
        
    def sample_indices(self, indices, n):
        """Returns a SRR of the indices provided
        """
        return np.random.choice(indices, n, replace=True)

    
    def bool_subset(self, col, value):
        """
        Returns a boolean vector for each observation in the
        dataframe indicating whether it meets the col == value condition
        """
        assert col in self.df.columns
        return self.df[col] == value
    
    
    def bool_index(self, conditions):
        """
        Returns a boolean vector for each observation in the
        dataframe indicating whether it meets all conditions
        
        Args:
        - conditions: (dict) dictionary of conditions
        
        Returns: 
        - pandas series of booleans indicating where all 
          of the conditions hold
        """
        # initialize bool index
        bool_index = (pd.Series(True)
                      .repeat(self.index.shape[0])
                      .reset_index(drop=True))
        
        for col, val in conditions.items():
            bool_index = bool_index & self.bool_subset(col, val)
            
        return bool_index
    
    
    def cond_index(self, conditions):
        """Returns indices of records meeting the conditions
        """
        return self.index[self.bool_index(conditions)]
    
    
    def n_records(self, conditions={}):
        """Returns the number of records meetings the conditions
        """
        return len(self.cond_index(conditions))
    
    
    def n_nan_preds(self, conditions={}):
        """Returns the number of records which have nan predictions
        """
        cond_index = self.cond_index(conditions)
        return np.isnan(self.topic_preds[cond_index].sum(axis=-1)).sum()
    
    
    def compute_JS(self, index_A, index_B, base=2):
        """
        Computes the mean pair-wise JS divergence and associated CI
        between indices in index_A and indices in index_B
        """
        p_A = self.topic_preds[index_A]
        p_B = self.topic_preds[index_B]
        js_list = [jensenshannon(p, q, base) for p, q in zip(p_A, p_B)]
        
        return mean_CI(js_list)
        
        
    def compute_HH(self, index):
        """
        Computes the mean HH index and associated CI between
        indices in index_A and indices in index_B
        """
        p = self.topic_preds[index]
        hh_list = [hh_index(q) for q in p]
        
        return mean_CI(hh_list)
    
    
    def topic_use_RD_js(self, conditions={}):
        """Returns the JS divergence of the R and D topic use distributions
        """
        R_topic_use = self.topic_use({**conditions, **{PARTY: R}})
        D_topic_use = self.topic_use({**conditions, **{PARTY: D}})
        
        return jensenshannon(R_topic_use, D_topic_use)
    
    
    def topic_use_hh(self, conditions={}):
        """Returns the HH-index of the RD topic use distributions
        """
        return hh_index(self.topic_use(conditions))
    
    
    def inter_party_js(self, conditions, n):
        """
        Returns the estimated inter party JS divergence and a CI.
        
        Computes the inter party JS divergence between 
        Republicans and Democrats on a given subject
        
        Args:
        - subject: (str) subject to examine
        - n      : (int) sample size
        
        Returns: a numpy array of length 3, where
        - 0 is the mean divergence point estimate:
        - 1 is the lower bound of a 95% CI
        - 2 is the upper bound of a 95% CI
        """
        # ensure that the topic predictions exist
        if self.topic_preds is None:
            self.predict_topics()
        
        # find R and D indicies on the subject
        index_R = self.cond_index({**conditions, **{PARTY: R}})
        index_D = self.cond_index({**conditions, **{PARTY: D}})
        
        # return None if indices are insufficient
        if len(index_R)==0 or len(index_D)==0:
            return None
        
        # sample 
        samp_index_R = self.sample_indices(index_R, n)
        samp_index_D = self.sample_indices(index_D, n)
    
        return self.compute_JS(samp_index_R, samp_index_D)
    
    
    def group_js(self, conditions, n):
        """
        Returns the estimated mean JS divergence and a CI
        
        Estimates the average JS divergence between any two documents of
        a group defined by the conditions. A document by speaker _i_ is 
        never compared to another document by speaker _i_.
        
        
        Args:
        - conditions: (dict) dictionary of conditions
        - n         : (int) sample size
        
        Returns: a numpy array of length 3, where index...
        - 0 is the mean divergence point estimate:
        - 1 is the lower bound of a 95% CI
        - 2 is the upper bound of a 95% CI
        """
        # ensure that the topic predictions exist
        if self.topic_preds is None:
            self.predict_topics()
        
        # find indicies of party on the subject
        cond_index = self.cond_index(conditions)
        
        # Return none if there are fewer than 2 speakers
        if self.df.loc[cond_index][SPEAK].nunique() < 2:
            return None
        
        # Sample index pairs
        index_AB = []
        while len(index_AB) < n:
            a_b = self.sample_indices(cond_index, n=2)
            # include samples whose speakers are different
            if self.df.loc[a_b][SPEAK].nunique() == 2:
                index_AB.append(a_b)
        
        index_AB = np.asarray(index_AB)
        assert index_AB.shape == (n, 2)
        
        # get indices for each group
        index_A, index_B = index_AB[:,0], index_AB[:,1]
        
        return self.compute_JS(index_A, index_B)
    
    
    def group_hh(self, conditions={}, n=None):
        """
        Returns the estimated mean HH index and a CI
        
        Estimates the average Herfindahl–Hirschman Index 
        of all records meetings the conditons.
        
        Args:
        - subject: (str) subject to examine
        - party  : (str) party of interest
        - n      : (int) sample size
        
        Returns: a numpy array of length 3, where index...
        - 0 is the mean index point estimate:
        - 1 is the lower bound of a 95% CI
        - 2 is the upper bound of a 95% CI
        """
        # ensure that the topic predictions exist
        if self.topic_preds is None:
            self.predict_topics()
        
        # indicies meeting the conditions
        cond_index = self.cond_index(conditions)
        
        # return None if indices are insufficient
        if len(cond_index)==0:
            return None
        
        if n is None:
            return self.compute_HH(cond_index)
        else:
            samp_index = self.sample_indices(cond_index, n)
            return self.compute_HH(samp_index)
        
        
    def analyze_subset(self, conditions, n):
        """
        Returns a dictionary of analysis metrics for the subset 
        of records defined by the conditions.
        
        Note: It is recommended conditions be on subject
        
        Args:
        - conditions: (dict) dictionary of conditions
        - n         : (int) sample size for estimation of metrics
        
        for the entire dataset and for each subject the following are computed:
        - n_records, n_records_R
        - n_records_D
        - js
        - js_R
        - js_D
        - js_RD
        - hh
        - hh_R
        - hh_D
        
        Returns: a dictionary of metrics
        """
        # R and D added conditions
        conditions_R = {**conditions, **{PARTY: R}}
        conditions_D = {**conditions, **{PARTY: D}}
        
        # annotation tags
        _R = '_' + R
        _D = '_' + D
        _RD = _R + D
        _TP = '_' + TP
        
        metrics = {
            # n records in data
            N_REC:    self.n_records(conditions),
            N_REC+_R: self.n_records(conditions_R),
            N_REC+_D: self.n_records(conditions_D),
            N_NAN+_R: self.n_nan_preds(conditions_R),
            N_NAN+_D: self.n_nan_preds(conditions_D),
            # JS divergence data
            JS:     self.group_js(conditions, n),
            JS+_R:  self.group_js(conditions_R, n),
            JS+_D:  self.group_js(conditions_D, n),
            JS+_RD: self.inter_party_js(conditions, n),
            # HH index data
            HH:    self.group_hh(conditions, n),
            HH+_R: self.group_hh(conditions_R, n),
            HH+_D: self.group_hh(conditions_D, n),
            # Topic Use Metrics
            HH+_TP:    self.topic_use_hh(conditions),
            HH+_TP+_R: self.topic_use_hh(conditions_R),
            HH+_TP+_D: self.topic_use_hh(conditions_D),
            JS+_TP:    self.topic_use_RD_js(conditions),
        }
        
        return metrics
    
        
    def analyze(self, subjects, n):
        """
        Returns a dictionary of analysis metrics at the subject level
        and at the session level (assuming self.df is the data of a
        single session).
        
        Args:
        - subjects: (array-like) list of subjects
        - n       : (int) sample size for estimation of metrics
        
        Returns: a dictionary of metrics
        """
        # analyze entire session dataset
        dataset_metrics = self.analyze_subset(conditions={}, n=n)
        
        # analyze by subject
        subject_metrics = {}
        for s in subjects:
            subject_metrics[s] = self.analyze_subset({SUB: s}, n)
        
        metrics = {'dataset' : dataset_metrics, 
                   'subjects': subject_metrics}
        
        return metrics
    
    
    def shannon_entropy(self, conditions={}):
        """Returns the Shannon Entropy of topic predictions meeting conditions
        """
        # ensure that the topic predictions exist
        if self.topic_preds is None:
            self.predict_topics()
        
        return shannon_entropy(self.topic_preds[self.cond_index(conditions)])
    
    
    def mean_entropy(self, conditions={}):
        """Returns the mean entropy of topic predictions meeting condiditons
        """
        return np.nanmean(self.shannon_entropy(conditions))
        
    
    def first_topic_counts(self, conditions={}):
        """
        Returns a leaderboard of topics and how many times they 
        are the primary topic associated with a document.
        """
        if self.topic_preds is None:
            self.predict_topics()
           
        cond_index = self.cond_index(conditions)
        topic_counts = pd.Series(np.argmax(self.topic_preds[cond_index], axis=-1)).value_counts()
        
        return topic_counts
    
    
    def topic_use(self, conditions={}):
        """
        Returns a leaderboard of topics based on the percentage of 
        total weight given to them in all of the documents
        """
        cond_index = self.cond_index(conditions)
        topic_sums = pd.Series(np.nansum(self.topic_preds[cond_index], axis=0))
        topic_use = topic_sums.sort_values(ascending=False) / topic_sums.sum()
        
        return topic_use
    
    
    def primary_topics(self, conditions={}, k=5):
        """Returns top k most prominent topics for documents
        """
        cond_index = self.cond_index(conditions)
        primary_topics = np.flip(np.argsort(self.topic_preds[cond_index]))[:,:k]
        
        return primary_topics
    

In [54]:
analyzer = RMN_Analyzer(rmn, data_df)

In [55]:
analyzer.predict_topics()



In [56]:
# nn = rmn.inspect_topics([1,2,3])

In [57]:
analyzer.topic_preds[0,2] = np.nan
analyzer.topic_preds[2,44] = np.nan
analyzer.topic_preds[944,44] = np.nan
analyzer.topic_preds[944,1] = np.nan
analyzer.topic_preds[900,1] = np.nan
analyzer.topic_preds[200,5] = np.nan
analyzer.topic_preds[245,5] = np.nan

In [61]:
analyzer.analyze_subset(conditions={'subject':'immigration'}, n=10000)

{'n_records': 35,
 'n_records_R': 21,
 'n_records_D': 14,
 'n_nan_preds_R': 0,
 'n_nan_preds_D': 0,
 'js': {'mean': 0.9070792027697758,
  'lower': 0.9050941816396604,
  'upper': 0.9090642238998908},
 'js_R': {'mean': 0.9075859307599854,
  'lower': 0.9055981708919111,
  'upper': 0.9095736906280597},
 'js_D': {'mean': 0.9077800437708343,
  'lower': 0.905785994976378,
  'upper': 0.9097740925652904},
 'js_RD': {'mean': 0.9027487772244022,
  'lower': 0.9007128726488136,
  'upper': 0.9047846817999907},
 'hh': {'mean': 0.41508641839027405,
  'lower': 0.41015806765690727,
  'upper': 0.4200147644640477},
 'hh_R': {'mean': 0.384602814912796,
  'lower': 0.3805210227801077,
  'upper': 0.38868461184514835},
 'hh_D': {'mean': 0.4594230055809021,
  'lower': 0.45356966759467127,
  'upper': 0.465276305891037},
 'hh_topic_use': 0.03803202,
 'hh_topic_use_R': 0.04300683,
 'hh_topic_use_D': 0.05627721,
 'js_topic_use': 0.10515827823173647}

In [64]:
analyzer.analyze_subset(conditions={}, n=10000)

{'n_records': 2347,
 'n_records_R': 1382,
 'n_records_D': 959,
 'n_nan_preds_R': 3,
 'n_nan_preds_D': 3,
 'js': {'mean': 0.8971323229504494,
  'lower': 0.8947437974957849,
  'upper': 0.8995208484051139},
 'js_R': {'mean': 0.8992934233777085,
  'lower': 0.8968493795483355,
  'upper': 0.9017374672070818},
 'js_D': {'mean': 0.8942565858023946,
  'lower': 0.8917554643188438,
  'upper': 0.896757707285946},
 'js_RD': {'mean': 0.8957261537283714,
  'lower': 0.8932374861091186,
  'upper': 0.8982148213476242},
 'hh': {'mean': 0.36306512355804443,
  'lower': 0.3583856193909551,
  'upper': 0.36774460049679786},
 'hh_R': {'mean': 0.36951836943626404,
  'lower': 0.36474630403541963,
  'upper': 0.37429049701362194},
 'hh_D': {'mean': 0.35585999488830566,
  'lower': 0.3512379868835622,
  'upper': 0.36048197847332086},
 'hh_topic_use': 0.023277294,
 'hh_topic_use_R': 0.023165543,
 'hh_topic_use_D': 0.023929607,
 'js_topic_use': 0.026178849637814673}

In [24]:
analyzer.mean_entropy({'subject':'trade'})

2.250809

In [28]:
analyzer.n_nan_preds()

6

In [29]:
analyzer.topic_use()

5     0.039877
14    0.037185
8     0.034378
1     0.033188
15    0.033075
45    0.030230
0     0.029634
44    0.029541
2     0.028929
11    0.028855
28    0.027529
34    0.026888
38    0.024861
24    0.023889
33    0.023650
49    0.022767
9     0.022456
25    0.021869
31    0.021347
12    0.021307
46    0.020716
17    0.020505
3     0.020373
37    0.020315
32    0.019827
26    0.019779
39    0.019517
40    0.018955
47    0.017728
20    0.017339
6     0.017262
22    0.017211
36    0.017140
7     0.017084
21    0.015750
4     0.015567
30    0.015072
35    0.014653
18    0.013638
13    0.013507
10    0.011416
48    0.011272
29    0.011057
27    0.009990
42    0.009428
41    0.009170
23    0.007621
16    0.006873
19    0.006244
43    0.003536
dtype: float32

In [30]:
analyzer.first_topic_counts()

14    104
5     100
1      96
8      89
15     89
11     79
45     72
0      68
34     68
44     67
24     63
2      58
9      57
38     54
28     53
25     52
33     52
37     52
26     51
31     51
39     49
46     49
32     47
12     47
49     47
17     47
3      46
47     43
40     43
20     41
22     41
7      40
6      39
21     37
36     36
30     33
35     33
4      31
10     31
48     27
13     27
29     23
41     23
18     19
42     19
23     15
27     13
19     11
16     10
43      5
dtype: int64

In [31]:
analyzer.primary_topics()

array([[20, 45, 48, 26, 35],
       [42, 44, 20, 22,  4],
       [37, 10, 17, 39, 14],
       ...,
       [44, 36, 28, 31,  2],
       [39, 33,  8, 14, 34],
       [ 2,  5, 26,  7, 12]])

In [32]:
analyzer.analyze_subset(conditions={}, n =100)

{'n_records': 2347,
 'n_records_R': 1382,
 'n_records_D': 959,
 'js': {'mean': 0.8929648605424753,
  'lower': 0.8622081332272088,
  'upper': 0.9237215878577419},
 'js_R': {'mean': 0.8918299181276494,
  'lower': 0.8616113843204268,
  'upper': 0.9220484519348716},
 'js_D': {'mean': 0.8911728389681157,
  'lower': 0.8629463573982648,
  'upper': 0.9193993205379667},
 'js_RD': {'mean': 0.8984762036662614,
  'lower': 0.86969344624948,
  'upper': 0.9272589610830427},
 'hh': {'mean': 0.3929731249809265,
  'lower': 0.3409234876789598,
  'upper': 0.44502267406801893},
 'hh_R': {'mean': 0.3422833979129791,
  'lower': 0.3015297350782954,
  'upper': 0.38303707460574277},
 'hh_D': {'mean': 0.412963330745697,
  'lower': 0.361676728024675,
  'upper': 0.4642500619147285}}

In [33]:
analyzer.topic_preds[0]

array([4.5464397e-03, 5.9141582e-03,           nan, 3.3592423e-03,
       4.1585718e-03, 4.6430629e-02, 4.1875597e-03, 1.9036815e-02,
       9.0436758e-03, 7.0622428e-03, 8.2960709e-05, 3.4936439e-04,
       1.9014528e-02, 1.0048404e-02, 8.9517375e-03, 3.3230886e-03,
       6.1772432e-04, 4.3901094e-04, 9.3819568e-04, 3.9070428e-04,
       1.1351042e-03, 1.2289655e-03, 8.3495874e-04, 1.4976026e-04,
       6.4706677e-03, 1.0228754e-02, 2.6048886e-02, 3.7255071e-03,
       1.2589485e-02, 3.1999624e-04, 9.3538733e-04, 4.6058921e-03,
       1.9303185e-04, 4.7712303e-03, 1.5027004e-02, 8.6307497e-04,
       1.9048254e-03, 6.8330963e-05, 1.1928496e-03, 3.1921931e-04,
       1.1857332e-02, 1.4633985e-03, 9.5607882e-04, 2.4262549e-04,
       5.8718733e-03, 1.1712279e-02, 4.0127989e-03, 2.5536839e-03,
       3.2589883e-03, 4.3265073e-04], dtype=float32)

In [34]:
analyzer.topic_preds[1].round(3)

array([0.001, 0.   , 0.001, 0.   , 0.   , 0.001, 0.   , 0.004, 0.018,
       0.   , 0.001, 0.   , 0.002, 0.   , 0.014, 0.   , 0.   , 0.   ,
       0.002, 0.001, 0.   , 0.004, 0.   , 0.   , 0.003, 0.004, 0.001,
       0.001, 0.002, 0.   , 0.   , 0.   , 0.001, 0.15 , 0.011, 0.   ,
       0.001, 0.001, 0.004, 0.76 , 0.001, 0.   , 0.   , 0.   , 0.   ,
       0.   , 0.001, 0.005, 0.   , 0.001], dtype=float32)

In [35]:
analyzer.compute_JS(index_A = [1, 0, 2, 55], index_B = [0, 1, 2, 3])

{'mean': 0.22723024165165603,
 'lower': -0.495917801341507,
 'upper': 0.9503782846448191}

In [36]:
analyzer.topic_use_RD_js()

0.02473030078269329

In [37]:
analyzer.analyze_subset(conditions={}, n=200)

{'n_records': 2347,
 'n_records_R': 1382,
 'n_records_D': 959,
 'js': {'mean': 0.9029303002275355,
  'lower': 0.8860208753153576,
  'upper': 0.9198397251397135},
 'js_R': {'mean': 0.9040589519613792,
  'lower': 0.8913284254532201,
  'upper': 0.9167894784695385},
 'js_D': {'mean': 0.9081885013289568,
  'lower': 0.8957064633162953,
  'upper': 0.9206705393416182},
 'js_RD': {'mean': 0.9084532652511502,
  'lower': 0.8944668450011393,
  'upper': 0.922439685501161},
 'hh': {'mean': 0.4090844988822937,
  'lower': 0.37325187809312105,
  'upper': 0.44491703528147303},
 'hh_R': {'mean': 0.4018021523952484,
  'lower': 0.3672640527670878,
  'upper': 0.43634022505230724},
 'hh_D': {'mean': 0.35837921500205994,
  'lower': 0.3235949378712012,
  'upper': 0.39316348460783224}}

In [288]:
analyzer.group_hh

{'mean': nan, 'lower': nan, 'upper': nan}

In [38]:
analyzer.n_nan_preds()

6