In [7]:
#@title
# How to update the data:
# 1. You may need to copy this colab so you have your own version.
# 2. Update the table name constants below to have the latest data's suffix.
# 3. Update the date variables below to be the last case date included in the data.
# 4. Updates the scatterplot max/min below in chart settings may need to be updated for more cases.
# 5. There are a few checks for the county_fips_mapping that we created due to issues with the CDC's.
#    Instructions are at https://docs.google.com/spreadsheets/d/1AVSSge7BpkbNL4PfumUZpL7hokMLjKUojtamQjNW6f0/edit?resourcekey=0-Abdprx3fy_pXikSCDV2hxw#gid=967935006.
# 6. Many/all of the tables and text are not auto-updated. If you want to do a full updated of
#    the paper including text and tables, a lot of that is done in commented out PrintSummaryStats() statements.

import pandas as pd
import altair as alt
from vega_datasets import data

from google.colab import auth
auth.authenticate_user()

# Turn off the three-dot menu for Altair/Vega charts.
alt.renderers.set_embed_options(actions=False)
pd.options.display.float_format = '{:,.2f}'.format

# Project and table names.
PROJECT_ID = 'msm-secure-data-1b'
CDC_TABLE = '`%s.ndunlap_secure.cdc_restricted_access_20210430`' % PROJECT_ID

# Dates in different formats.
DATE = 'DATE(2021, 04, 15)'
DATE_DISPLAY_NAME = 'Apr 15'

# Set the scatterplot max/min to better handle outliers (CA, Los Angeles).
TOTAL_CASES_SCALE_MAX = 70000

# Chart settings.
SCATTER_HEIGHT = 300
SCATTER_WIDTH = 300
MAP_HEIGHT = 300
MAP_WIDTH = 450
US_STATES_TOPO = alt.topo_feature(data.us_10m.url, 'states')
US_COUNTIES_TOPO = alt.topo_feature(data.us_10m.url+"#", 'counties')

TERRITORIES = ('PR', 'GU', 'VI', 'MP', 'AS')
NYT_TERRITORIES = ('Puerto Rico', 'Guam', 'Virgin Islands', 'Northern Mariana Islands', 'American Samoa')
STATES_TO_FIPS = {'AL': 1, 'AK': 2, 'AZ': 4, 'AR': 5, 'AS': 3, 'CA': 6, 'CO': 8, 'CT': 9, 'DC': 11, 'DE': 10, 'FL': 12, 'GA': 13, 'GU': 14, 'HI': 15, 'ID': 16, 'IL': 17, 'IN': 18, 'IA': 19, 'KS': 20, 'KY': 21, 'LA': 22, 'ME': 23, 'MD': 24, 'MA': 25, 'MI': 26, 'MN': 27, 'MS': 28, 'MO': 29, 'MT': 30, 'NE': 31, 'NV': 32, 'NH': 33, 'NJ': 34, 'NM': 35, 'NY': 36, 'NC': 37, 'ND': 38, 'OH': 39, 'OK': 40, 'OR': 41, 'PA': 42, 'PR': 43, 'RI': 44, 'SC': 45, 'SD': 46, 'TN': 47, 'TX': 48, 'UT': 49, 'VT': 50, 'VA': 51, 'VI': 52, 'WA': 53, 'WV': 54, 'WI': 55, 'WY': 56, 'AS': 60, 'GU': 66, 'MP': 69, 'PR': 72, 'VI': 78, 'NYC': 36}
FIPS_TO_STATES = {STATES_TO_FIPS[key]: key for key in STATES_TO_FIPS}
STATES_TO_ABBREVIATIONS = {'Alabama': 'AL', 'Alaska': 'AK', 'Arizona': 'AZ', 'Arkansas': 'AR', 'California': 'CA', 'Colorado': 'CO', 'Connecticut': 'CT', 'Delaware': 'DE', 'District of Columbia': 'DC', 'Florida': 'FL', 'Georgia': 'GA', 'Hawaii': 'HI', 'Idaho': 'ID', 'Illinois': 'IL', 'Indiana': 'IN', 'Iowa': 'IA', 'Kansas': 'KS', 'Kentucky': 'KY', 'Louisiana': 'LA', 'Maine': 'ME', 'Maryland': 'MD', 'Massachusetts': 'MA', 'Michigan': 'MI', 'Minnesota': 'MN', 'Mississippi': 'MS', 'Missouri': 'MO', 'Montana': 'MT', 'Nebraska': 'NE', 'Nevada': 'NV', 'New Hampshire': 'NH', 'New Jersey': 'NJ', 'New Mexico': 'NM', 'New York': 'NY', 'North Carolina': 'NC', 'North Dakota': 'ND', 'Ohio': 'OH', 'Oklahoma': 'OK', 'Oregon': 'OR', 'Pennsylvania': 'PA', 'Rhode Island': 'RI', 'South Carolina': 'SC', 'South Dakota': 'SD', 'Tennessee': 'TN', 'Texas': 'TX', 'Utah': 'UT', 'Vermont': 'VT', 'Virginia': 'VA', 'Washington': 'WA', 'West Virginia': 'WV', 'Wisconsin': 'WI', 'Wyoming': 'WY'}

In [2]:
#@title
NYT_STATES_QUERY_STR = ('''
SELECT
  state_name,
  state_fips_code,
  deaths as nyt_cases,
FROM `bigquery-public-data.covid19_nyt.us_states`
WHERE
  date = %s AND
  state_fips_code IS NOT NULL
''')

NYT_STATES_QUERY = NYT_STATES_QUERY_STR % DATE

CDC_STATES_QUERY = ('''
SELECT
  res_state,
  COUNT(*) as cdc_cases
FROM
  %s
WHERE
  death_yn = 'Yes'
GROUP BY
   res_state
''' % CDC_TABLE)

In [3]:
#@title
def FieldAnalysis(project_id, table, field_list, title, calculate_race_ethnicity=False):
  dict = {}
  for field in field_list:
      dict[field] = [0.0, 0.0, 0.0, 0.0]
  unknowns = pd.DataFrame(dict, index=['Unknown', 'Missing', 'NA', 'Known'])
  field_series = []
  value_series = []
  percent_series = []
  cases_series = []
  chart_denominator = 1000

  for field in field_list:
    field_unknowns_query = ('''
    SELECT
      %s,
      count(*) as cases
    FROM
      %s
    WHERE
      death_yn = 'Yes'
    GROUP BY
      %s
    ''')
    if calculate_race_ethnicity and field == 'race_ethnicity_combined':
      field_unknowns_query = ('''
      SELECT
        CASE ethnicity = "Non-Hispanic/Latino"
          WHEN true THEN race
          ELSE ethnicity
        END as %s,
        count(*) as cases
      FROM
        %s
      GROUP BY
        %s
      ''')    
    query = field_unknowns_query % (field, table, field)
    field_unknowns_df = pd.io.gbq.read_gbq(query, project_id=project_id)
    field_unknowns_df.set_index(field, inplace=True)
    field_unknowns_df.index = field_unknowns_df.index.fillna('Null')

    field_display_name = {
        'cdc_case_earliest_dt': 'CDC earliest case date',
        'current_status': 'Case status',
        'res_state': 'State',
        'res_county': 'County',
        'sex': 'Sex',
        'age_group': 'Age',
        'race_ethnicity_combined': 'Race/Ethnicity',
        'race': 'Race',
        'ethnicity': 'Ethnicity',
        'case_month': 'Case month'
    }

    missing_count = 0
    if 'Missing' in field_unknowns_df.index:
      missing_count += field_unknowns_df.loc['Missing'].cases
    if 'Null' in field_unknowns_df.index:
      missing_count += field_unknowns_df.loc['Null'].cases
    if '' in field_unknowns_df.index:
      missing_count += field_unknowns_df.loc[''].cases
    if 'OTH' in field_unknowns_df.index:
      missing_count += field_unknowns_df.loc['OTH'].cases
    if 'nul' in field_unknowns_df.index:
      missing_count += field_unknowns_df.loc['nul'].cases
    unknowns.loc['Missing', field] = missing_count

    if 'Unknown' in field_unknowns_df.index:
      unknowns.loc['Unknown', field] = field_unknowns_df.loc['Unknown'].cases
    if 'NA' in field_unknowns_df.index:
      unknowns.loc['NA', field] = field_unknowns_df.loc['NA'].cases
    unknowns.loc['Known', field] = field_unknowns_df.cases.sum() - (
        unknowns.loc['Missing', field] +
        unknowns.loc['Unknown', field] +
        unknowns.loc['NA', field])
    field_series.extend([field_display_name.get(field, field)] * 4)
    value_series.extend(['Known', 'Suppressed', 'Unknown', 'Missing'])
    percent_series.extend([unknowns.loc['Known', field] / field_unknowns_df.cases.sum(),
                           unknowns.loc['NA', field] / field_unknowns_df.cases.sum(),
                           unknowns.loc['Unknown', field] / field_unknowns_df.cases.sum(),
                           unknowns.loc['Missing', field] / field_unknowns_df.cases.sum()])
    cases_series.extend([unknowns.loc['Known', field] / chart_denominator,
                           unknowns.loc['NA', field] / chart_denominator,
                           unknowns.loc['Unknown', field] / chart_denominator,
                           unknowns.loc['Missing', field] / chart_denominator])
    bars = pd.DataFrame.from_dict({'field': field_series,
                               'value': value_series,
                               'percent': percent_series,
                               'cases': cases_series})
  return alt.Chart(bars).mark_bar().encode(
      x=alt.X('percent', axis=alt.Axis(format='%'), title=''),
      y=alt.Y('field', sort='x', title='Field'),
      color=alt.Color('value', scale=alt.Scale(scheme='category20'), title='Value'),
      order=alt.Order('field:N'),
      tooltip=[
                  alt.Tooltip('field:N', title='Field'),
                  alt.Tooltip('value:N', title='Value'),
                  alt.Tooltip('percent:Q', format=',.0%', title='Percent'),
                  alt.Tooltip('cases:Q', format=',.2f', title='Deaths in group (thousands)'),
      ]
  ).properties(title=title)

def CreateNYTStateDataframe(query):
  nyt_states_df = pd.io.gbq.read_gbq(query, project_id=PROJECT_ID)
  for territory in NYT_TERRITORIES:
    nyt_states_df = nyt_states_df[nyt_states_df.state_name != territory]
  nyt_states_df['state_fips_code'] = nyt_states_df.state_fips_code.astype(int)
  nyt_states_df.set_index('state_fips_code', inplace=True)
  return nyt_states_df

def CreateCDCStateDataframe(query):
  states_df = pd.io.gbq.read_gbq(query, project_id=PROJECT_ID)
  for state in ('Unknown', 'NA', 'Missing', 'OCONUS'):
    states_df = states_df[states_df.res_state != state]
  states_df.rename(columns={'res_state': 'state'}, inplace=True)
  states_df['state_fips_code'] = states_df.state
  states_df = states_df.replace(to_replace={'state_fips_code': STATES_TO_FIPS})
  states_df['state_fips_code'] = states_df.state_fips_code.astype(int)
  states_df.set_index('state_fips_code', inplace=True)
  return states_df

def CreateScatterPlot(
    chart_df, fields_dict, title, scale_max, height, width, geo, metric_type):
  
  geo_field = 'state'
  geo_field_display_name = 'State'
  if geo == 'county':
    geo_field = 'state_county'
    geo_field_display_name = 'County'

  if metric_type == 'ratio':
    scale_scheme = 'blueorange'
    scale_reverse = True
    scale_domain = [0, 2]
    legend_format = '.1f'
    axis_format = ',.0f'
  elif metric_type == 'percent':
    scale_scheme = 'redyellowblue'
    scale_reverse = False
    scale_domain = [0, 1]
    legend_format = '.0%'
    axis_format = '.0%'

  tooltips = [alt.Tooltip(geo_field + ':N', title=geo_field_display_name)]
  for field in ('y', 'x', 'percent'):
    tooltips.append(alt.Tooltip(
        fields_dict[field]['name'] + ':Q',
        format=fields_dict[field]['format'],
        title=fields_dict[field]['title'],
    ))
  plot = alt.Chart(chart_df).mark_circle(size=60).encode(
      alt.X(fields_dict['x']['name'] + ':Q', axis=alt.Axis(title=fields_dict['x']['title'], format=axis_format),
          scale=alt.Scale(domain=(0, scale_max))
      ),
      alt.Y(fields_dict['y']['name'] + ':Q', axis=alt.Axis(title=fields_dict['y']['title'], format=axis_format),
          scale=alt.Scale(domain=(0, scale_max))
      ),
      color=alt.Color(fields_dict['percent']['name'],
                      type='quantitative',
                      scale=alt.Scale(scheme=scale_scheme,
                                      reverse=scale_reverse,
                                      domain=scale_domain,
                                      clamp=True),
                      legend=alt.Legend(format=legend_format),
                      title=metric_type.capitalize()),
      tooltip=tooltips,
  ).properties(
      height=height,
      width=width,
  )
  if metric_type == 'ratio':
    plot.interactive()

  line = pd.DataFrame({
      'x': [0, scale_max],
      'y': [0, scale_max],
  })

  if metric_type == 'ratio':
    line_plot = alt.Chart(line).mark_line(color='black').encode(
        x='x',
        y='y',
    )
  elif metric_type == 'percent':
    line_plot = (
        alt.Chart(pd.DataFrame({'x': [.5]})).mark_rule().encode(y='x') +
        alt.Chart(pd.DataFrame({'y': [.5]})).mark_rule().encode(x='y')
    )
  # Add interative for concatenating due to https://github.com/altair-viz/altair/issues/2010.
  scatter = (plot + line_plot).properties(
      title=title,
      height=height,
      width=width,
  ).interactive()
  return scatter

def CreateMap(
    chart_df, fields_dict, title, scale_max, height, width, geo, metric_type):
  
  geo_field = 'state'
  geo_field_display_name = 'State'
  fips_code = 'state_fips_code'
  topo_feature = US_STATES_TOPO
  if geo == 'county':
    geo_field = 'state_county'
    geo_field_display_name = 'County'
    fips_code = 'county_fips'
    topo_feature = US_COUNTIES_TOPO

  if metric_type == 'ratio':
    scale_scheme = 'blueorange'
    scale_reverse = True
    scale_domain = [0, 2]
    legend_format = '.1f'
  elif metric_type == 'percent':
    scale_scheme = 'redyellowblue'
    scale_reverse = False
    scale_domain = [0, 1]
    legend_format = '.0%'

  highlight = alt.selection_single(on='mouseover', fields=['id', fips_code], empty='none')
  tooltips = [alt.Tooltip(geo_field + ':N', title=geo_field_display_name)]
  for field in ('y', 'x', 'percent'):
    tooltips.append(alt.Tooltip(
        fields_dict[field]['name'] + ':Q',
        format=fields_dict[field]['format'],
        title=fields_dict[field]['title'],
    ))

  field_names = [geo_field]
  field_names.extend([fields_dict[field]['name'] for field in fields_dict])
  plot = alt.Chart(topo_feature).mark_geoshape(
        stroke='white',
        strokeOpacity=.2,
        strokeWidth=1
    ).project(
      type='albersUsa'
    ).transform_lookup(
        lookup='id',
        from_=alt.LookupData(chart_df, fips_code, field_names)
    ).encode(
        alt.Color(fields_dict['percent']['name'],
                  type='quantitative',  
                  legend=alt.Legend(format=legend_format),
                  scale=alt.Scale(scheme=scale_scheme,
                                  reverse=scale_reverse,
                                  domain=scale_domain,
                                  clamp=True,
                                  ),
                  title=metric_type.capitalize()),
         tooltip=tooltips
    ).add_selection(
        highlight,
    )

  states_outline = alt.Chart(US_STATES_TOPO).mark_geoshape(stroke='white', strokeWidth=1.5, fillOpacity=0, fill='white').project(
        type='albersUsa'
  )

  states_fill = alt.Chart(US_STATES_TOPO).mark_geoshape(
        fill='silver',
        stroke='white'
  ).project('albersUsa')

  layered_map = alt.layer(states_fill, plot, states_outline).properties(
        height=height,
        width=width,
        title=title,
  )
  return layered_map

def CreateScatterPlotAndMap(
    chart_df, fields_dict, title, total_cases_scale_max, scatter_height, scatter_width, map_width, geo, metric_type):
  scatter = CreateScatterPlot(
    chart_df, fields_dict, title, total_cases_scale_max, scatter_height, scatter_width, geo, metric_type)
  map = CreateMap(
    chart_df, fields_dict, title, total_cases_scale_max, scatter_height, map_width, geo, metric_type)
  return (scatter | map).configure_view(
       strokeWidth=0,
   ).configure_mark(
       stroke='grey'
   ).configure_legend(
       gradientLength=scatter_height - 50
   )

def PrintSummaryStats(chart_df, field='percent'):
  plus_minus_15_df = chart_df[chart_df[field] >= .85]
  plus_minus_15_df = plus_minus_15_df[plus_minus_15_df[field] <= 1.15]
  print('between +/-15%: ', len(plus_minus_15_df), round(len(plus_minus_15_df) / len(chart_df), 2))
  plus_minus_50_df = chart_df[chart_df[field] >= .50]
  plus_minus_50_df = plus_minus_50_df[plus_minus_50_df[field] <= 1.50]
  print('between +/-50%: ', len(plus_minus_50_df), round(len(plus_minus_50_df) / len(chart_df), 2))
  print('< than .50: ', len(chart_df[chart_df[field] < .5]))
  print('> than 1.50: ', len(chart_df[chart_df[field] > 1.5]))
  print(chart_df[field].describe())

In [4]:
#@title
cdc_states_df = CreateCDCStateDataframe(CDC_STATES_QUERY)
nyt_states_df = CreateNYTStateDataframe(NYT_STATES_QUERY)

cdc_nyt_states_df = nyt_states_df.join(cdc_states_df, on="state_fips_code", how='left', lsuffix='', rsuffix='_right')
cdc_nyt_states_df.reset_index(inplace=True)
# Fix the states that are missing from the CDC data.
cdc_nyt_states_df.fillna(0, inplace=True)
cdc_nyt_states_df.state = cdc_nyt_states_df.state_name
cdc_nyt_states_df = cdc_nyt_states_df.replace(
  to_replace={'state': STATES_TO_ABBREVIATIONS})

cdc_nyt_states_df['percent'] = round(cdc_nyt_states_df.cdc_cases / cdc_nyt_states_df.nyt_cases, 4)

In [5]:
#@title
cdc_nyt_state_fields_dict = {
    'x': {'name': 'nyt_cases', 'format': ',', 'title': 'NYT deaths'},
    'y': {'name': 'cdc_cases', 'format': ',', 'title': 'CDC deaths'},
    'percent': {'name': 'percent', 'format': '.2f', 'title': 'Ratio of CDC to NYT'},
}
cdc_nyt_state_title = 'Ratio of CDC to NYT Deaths by State up to %s' % DATE_DISPLAY_NAME

CreateScatterPlotAndMap(
    cdc_nyt_states_df, cdc_nyt_state_fields_dict, cdc_nyt_state_title, TOTAL_CASES_SCALE_MAX, SCATTER_HEIGHT, SCATTER_WIDTH, MAP_WIDTH, 'state', 'ratio'
).display()
#PrintSummaryStats(cdc_nyt_states_df)

In [8]:
cdc_nyt_states_df.sort_values(by='percent')

Unnamed: 0,state_fips_code,state_name,nyt_cases,state,cdc_cases,percent
7,15,Hawaii,470,HI,0.0,0.0
27,31,Nebraska,2327,NE,0.0,0.0
24,24,Maryland,8512,MD,0.0,0.0
5,48,Texas,49604,TX,8.0,0.0
44,46,South Dakota,1949,SD,1.0,0.0
37,35,New Mexico,3999,NM,3.0,0.0
47,54,West Virginia,2772,WV,3.0,0.0
26,29,Missouri,9106,MO,39.0,0.0
21,10,Delaware,1595,DE,29.0,0.02
18,56,Wyoming,703,WY,21.0,0.03
