# Using Naive Bayes to Predict What State You're From (If You Live in India)

In [1]:
%matplotlib inline
from math import prod
import numpy as np
import pandas as pd
import requests
import ipywidgets as wg
from matplotlib import pyplot as plt
from matplotlib.ticker import PercentFormatter

base_url = 'http://digital-library.census.ihsn.org/index.php/api/tables/'
idx = pd.IndexSlice

## Using the India Census API

I have some preemptive metadata in a csv file. The India Census API uses codes for many of the different fields, some of which is provided alongside code labels with each request, but others are common and inferred through this metadata.

In [2]:
# Create maps for state, district, 
# subdistrict, and town-village
# 0 means ALL
sdsdtv = pd.read_csv('data/PC11_TV_DIR.csv')

# States map
states = sdsdtv.loc[ (sdsdtv['District Code'] == 0) \
                    & (sdsdtv['Sub District Code'] == 0) \
                    & (sdsdtv['Town-Village Code'] == 0) ]
states = states[[
    'State Code', 
    'Town-Village Name'
]]
states = states.set_index('State Code')
states = states['Town-Village Name']
states.at[0] = 'ALL'
states = states.sort_index()
states = states.rename('State Name')

# Districts map
districts = sdsdtv.loc[ (sdsdtv['Sub District Code'] == 0) \
                    & (sdsdtv['Town-Village Code'] == 0) ]
districts = districts[[ 'District Code', 'Town-Village Name' ]]
districts = districts.drop_duplicates(subset=['District Code'])
districts = districts.set_index('District Code')
districts = districts['Town-Village Name']
districts.at[0] = 'All'
districts = districts.sort_index()
districts = districts.rename('District Name')

# Sub Districts map
sub_districts = sdsdtv.loc[ (sdsdtv['Town-Village Code'] == 0) ]
sub_districts = sub_districts[[
    'Sub District Code',
    'Town-Village Name'
]]
sub_districts = sub_districts.drop_duplicates(subset=['Sub District Code'])
sub_districts = sub_districts.set_index('Sub District Code')
sub_districts = sub_districts['Town-Village Name']
sub_districts.at[0] = 'All'
sub_districts = sub_districts.sort_index()
sub_districts = sub_districts.rename('Sub District Name')

# Town/Villages map
town_villages = sdsdtv[[ 'Town-Village Code', 'Town-Village Name' ]]
town_villages = town_villages.drop_duplicates(subset=['Town-Village Code'])
town_villages = town_villages.set_index('Town-Village Code')
town_villages = town_villages['Town-Village Name']
town_villages.at[0] = 'All'
town_villages = town_villages.sort_index()
town_villages = town_villages.rename('Sub District Name')

I created a simple interface to help me pull relevant data from the API. That way I'm not pulling and saving a large dataset to do work with

In [3]:
def get_dataset(table, **kwargs):
    """
    Call the India Census API with args
    to retrieve filtered data
    
    :param table:    Census API table to query
    :param **kwargs: Query parameters to send to API
    
    :returns: data retrieved from API as pandas DataFrame
    """
    # Common features in all tables
    common_features = {
        'state': states,
        'district': districts,
        'subdistrict': sub_districts,
        'town': town_villages
    }
    
    # Get info specific to dataset (features)
    year = kwargs.pop('year', '2011')
    info_url = base_url + f'/info/{year}/{table}'
    info = requests.get(info_url).json()
    features = info['result']['result_']['features']
    feature_map = { 
        feature['feature_name'] : pd.Series(
            data=[c['label'] for c in feature['code_list']],
            index=[c['code'] for c in feature['code_list']])
        for feature in features }
    feature_map |= common_features
    
    # Get data
    data_url = base_url + f'/data/{year}/{table}'
    query_string = '?' + '&'.join(f'{k}={v}' for k,v in kwargs.items()) if kwargs else ''
    body = requests.get(data_url+query_string).json()
    
    # Map codes to features in dataset
    df = pd.DataFrame(body['data'])
    for feature, labels in feature_map.items():
        if feature in df:
            df[feature] = df[feature].map(lambda r: labels.at[r])
        
    # Return dataframe
    return df

## Getting Prior Probabilities

Get total population of India and population of states

In [10]:
# Total population
total_population = get_dataset('PC11_C01', 
                               state='0', 
                               urbrur='0', 
                               geo_level='0',
                               sex='0',
                               religion='0',
                               fields='value')
total_population = total_population['value'].at[0]
print('Total Population:', total_population)

state_populations = get_dataset('PC11_C01',
                                state='1-35',
                                urbrur='0',
                                geo_level='1',
                                sex='0',
                                religion='0',
                                fields='state,value')
state_populations = state_populations.set_index('state')['value']
state_populations = state_populations.rename('state_prior')
state_populations

Total Population: 1210854977


state
JAMMU & KASHMIR               12541302
HIMACHAL PRADESH               6864602
PUNJAB                        27743338
CHANDIGARH                     1055450
UTTARAKHAND                   10086292
HARYANA                       25351462
NCT OF DELHI                  16787941
RAJASTHAN                     68548437
UTTAR PRADESH                199812341
BIHAR                        104099452
SIKKIM                          610577
ARUNACHAL PRADESH              1383727
NAGALAND                       1978502
MANIPUR                        2855794
MIZORAM                        1097206
TRIPURA                        3673917
MEGHALAYA                      2966889
ASSAM                         31205576
WEST BENGAL                   91276115
JHARKHAND                     32988134
ODISHA                        41974218
CHHATTISGARH                  25545198
MADHYA PRADESH                72626809
GUJARAT                       60439692
DAMAN & DIU                     243247
DADRA & NAGAR HAVEL

Get relative population of states (as fraction of total population). This will be our prior probability

In [11]:
# Get state populations
state_priors = state_populations / total_population
state_priors = state_priors.rename('state_prior')
state_priors

state
JAMMU & KASHMIR              0.010357
HIMACHAL PRADESH             0.005669
PUNJAB                       0.022912
CHANDIGARH                   0.000872
UTTARAKHAND                  0.008330
HARYANA                      0.020937
NCT OF DELHI                 0.013865
RAJASTHAN                    0.056612
UTTAR PRADESH                0.165018
BIHAR                        0.085972
SIKKIM                       0.000504
ARUNACHAL PRADESH            0.001143
NAGALAND                     0.001634
MANIPUR                      0.002358
MIZORAM                      0.000906
TRIPURA                      0.003034
MEGHALAYA                    0.002450
ASSAM                        0.025772
WEST BENGAL                  0.075382
JHARKHAND                    0.027244
ODISHA                       0.034665
CHHATTISGARH                 0.021097
MADHYA PRADESH               0.059980
GUJARAT                      0.049915
DAMAN & DIU                  0.000201
DADRA & NAGAR HAVELI         0.000284
MAHARA

## Computing Likelihoods for Given Evidence Fields

### Odds based on Religion

Get total religious makeup in country as well as for each state

In [21]:
# Relative religions in the country
religion_evidence = get_dataset('PC11_C01',
                                state='0',
                                sex='0',
                                urbrur='0',
                                religion='1-6',
                                fields='religion,value')
religion_evidence = religion_evidence.set_index('religion')['value']
religion_evidence = religion_evidence.rename('religion')
religion_evidence /= total_population
religion_evidence

religion
Hindu        0.797996
Muslim       0.142251
Christian    0.022975
Sikh         0.017205
Buddhist     0.006973
Jain         0.003677
Name: religion, dtype: float64

In [20]:
# Likelihood of given religion per state
religion_likelihood_given_state = get_dataset('PC11_C01', 
                                              sex='0', 
                                              urbrur='0', 
                                              religion='1-6', 
                                              state='1-35', 
                                              district='0',
                                              fields='religion,state,value',
                                              limit=210)
religion_likelihood_given_state = religion_likelihood_given_state.set_index(['religion', 'state'])['value']
religion_likelihood_given_state /= state_populations
religion_likelihood_given_state

religion  state                    
Hindu     JAMMU & KASHMIR              0.284394
          HIMACHAL PRADESH             0.951660
          PUNJAB                       0.384890
          CHANDIGARH                   0.807782
          UTTARAKHAND                  0.829704
                                         ...   
Jain      LAKSHADWEEP                  0.000171
          KERALA                       0.000134
          TAMIL NADU                   0.001237
          PUDUCHERRY                   0.001122
          ANDAMAN & NICOBAR ISLANDS    0.000081
Name: value, Length: 210, dtype: float64

## Building and Testing Model

The naive bayesian model

In [15]:
def prediction_model(religion):
    return state_priors * prod([
        religion_likelihood_given_state[religion] / religion_evidence[religion]
    ])

Here's an interactive tool used to predict what state you're from

In [16]:
n_results = 5

religion_select = wg.Dropdown(options=['Hindu','Muslim','Christian','Sikh','Jain'],
                              value='Hindu',
                              description='Religion:',
                              disabled=False)
form = wg.VBox([ religion_select ])
graph_output = wg.Output()

def display_results(evt=None):
    state_posterior = prediction_model(
        religion=religion_select.value)
    top_n = state_posterior.sort_values()[-n_results:]
    graph_output.clear_output()
    with graph_output:
        xticks = PercentFormatter(xmax=1, decimals=0)
        ax = top_n.plot.barh(xlabel='')
        ax.xaxis.set_major_formatter(xticks)
        plt.show();

display(wg.HBox([ form, graph_output ]))
display_results()
religion_select.observe(display_results, 'value')

HBox(children=(VBox(children=(Dropdown(description='Religion:', options=('Hindu', 'Muslim', 'Christian', 'Sikh…