When people are infected with COVID-19, even if they don’t have any symptom, the virus can still spread with their feces as well as saliva, and eventually enter the wastewater system. This allows wastewater surveillance to serve as an early warning that COVID-19 is going to spread in the community. When the virus concentration in the wastewater starts to rise, the health department can take early action to prevent the spread of COVID-19. 

In [4]:
import pandas as pd
import numpy as np
from statistics import mean

import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

from sklearn.metrics import mean_squared_error
from sklearn.linear_model import LinearRegression

from datetime import timedelta

In [None]:
!gdown --id 1d-6NzhezzAioyh8Tb0jy988eVbUGcNmY
!gdown --id 1nvHZoi4DYXnVLbAxs0XRwCcTW11_Ooj-

Downloading...
From: https://drive.google.com/uc?id=1d-6NzhezzAioyh8Tb0jy988eVbUGcNmY
To: /content/us-counties.csv
100% 105M/105M [00:00<00:00, 226MB/s] 
Downloading...
From: https://drive.google.com/uc?id=1nvHZoi4DYXnVLbAxs0XRwCcTW11_Ooj-
To: /content/wastewater_by_county.csv
100% 260k/260k [00:00<00:00, 109MB/s]


In [None]:
water_by_county_import = pd.read_csv('wastewater_by_county.csv')
cases_by_county_import = pd.read_csv('us-counties.csv')

In [None]:
state_dict = {'Alabama':'AL', 'Alaska':'AK', 'Arizona':'AZ', 'Arkansas':'AR', 'California':'CA', 'Colorado':'CO',
'Connecticut':'CT', 'Delaware':'DE', 'District of Columbia':'DC', 'Florida':'FL', 'Georgia':'GA', 'Guam':'GU',
'Hawaii':'HI', 'Idaho':'ID', 'Illinois':'IL', 'Indiana':'IN', 'Iowa':'IA', 'Kansas':'KS', 'Kentucky':'KY',
'Louisiana':'LA', 'Maine':'ME', 'Maryland':'MD', 'Massachusetts':'MA', 'Michigan':'MI', 'Minnesota':'MN',
'Mississippi':'MS', 'Missouri':'MO', 'Montana':'MT', 'Nebraska':'NE', 'Nevada':'NV', 'New Hampshire':'NH',
'New Jersey':'NJ', 'New Mexico':'NM', 'New York':'NY', 'North Carolina':'NC', 'North Dakota':'ND',
'Northern Mariana Islands':'MP', 'Ohio':'OH', 'Oklahoma':'OK', 'Oregon':'OR', 'Pennsylvania':'PA',
'Puerto Rico':'PR', 'Rhode Island':'RI', 'South Carolina':'SC', 'South Dakota':'SD', 'Tennessee':'TN', 
'Texas':'TX', 'Utah':'UT', 'Vermont':'VT', 'Virgin Islands':'VI', 'Virginia':'VA', 'Washington':'WA',
'West Virginia':'WV', 'Wisconsin':'WI', 'Wyoming':'WY'}

### **By State**

In [None]:
def time_series_plot_by_state(state_abr):
  state_abr = str(state_abr)
  if state_abr not in water_by_county_import.state.unique().tolist():
    return("Please select another state (in abbreviation)")

  water_by_county = water_by_county_import.drop([0]) # drop 2020-01-01
  water_by_county.rename(columns={'sampling_week':'date', 'effective_concentration_rolling_average':'concentration'}, inplace=True)
  water_USA = water_by_county[['date', 'concentration']].groupby(by=['date'], as_index=False).mean().sort_values(by='date')
  water_one_state = water_by_county[water_by_county['state'] == state_abr].reset_index(drop=True)
  water_one_state = water_one_state[['date', 'concentration']].groupby(by=['date'], as_index=False).mean().sort_values(by='date')

  # imputation using average concentration
  date_USA = water_USA['date'].tolist()
  date_one_state = water_one_state['date'].tolist()
  date_missing = list(set(date_USA) - set(date_one_state))
  water_USA.query('date in @date_missing')
  water_one_state = pd.concat([water_one_state, water_USA.query('date in @date_missing')])

  # turn weekly to daily data
  water_one_state['date'] = pd.to_datetime(water_one_state.date, format='%Y/%m/%d')  
  water_one_state = water_one_state.set_index('date').resample('D').ffill().reset_index()
  water_one_state['date'] = (pd.to_datetime(water_one_state['date']) - timedelta(3))
  water_one_state['date'] = water_one_state['date'].astype('str')
  

  # cases
  cases_by_county = cases_by_county_import.loc[cases_by_county_import['fips'].notnull(),:]
  cases_one_state = cases_by_county[cases_by_county['state'].map(state_dict) == state_abr].reset_index(drop=True)
  state = cases_one_state['state'][0]
  cases_one_state = cases_one_state[['date', 'cases']].groupby(by=['date'], as_index=False).sum().sort_values(by='date')

  cases_one_state['increased_cases'] = cases_one_state.cases.diff()
  cases_one_state['increased_cases'] = cases_one_state['increased_cases'].fillna(0)
  cases_one_state = cases_one_state.mask(cases_one_state['increased_cases'] < 0, 0)

  cases_one_state = cases_one_state[['date', 'increased_cases']]
  cases_one_state['increased_cases'] = cases_one_state['increased_cases'].astype(int) 

  # merge two data frames
  water_cases_one_state = water_one_state.merge(cases_one_state, on='date', how='inner')

  # time series plot
  fig = make_subplots(specs=[[{"secondary_y": True}]])
  fig.add_trace(
      go.Scatter(x = water_cases_one_state['date'], y = water_cases_one_state['concentration'],
                name='Concentration Rolling Average', marker_color='blue'), secondary_y=False
                
  )
  fig.add_trace(
      go.Scatter(x = water_cases_one_state['date'], y = water_cases_one_state['increased_cases'], 
                name='Increased Cases') , secondary_y=True
  )

  dt_all = pd.date_range(start=water_cases_one_state.date[0],
                        end=water_cases_one_state.date[len(water_cases_one_state.date)-1],
                        freq = 'D')
  dt_all_py = [d.to_pydatetime() for d in dt_all]
  dt_obs_py = [d.to_pydatetime() for d in pd.to_datetime(water_cases_one_state['date'])]
  dt_breaks = [d for d in dt_all_py if d not in dt_obs_py]

  if len(dt_breaks) > 100:
    fig.update_xaxes(rangebreaks=[dict(values=dt_breaks)])

  fig.update_layout(title_text='%s' %state, title_x=0.3)
  fig.update_xaxes(title_text='Date')
  fig.update_yaxes(title_text='Copies / mL of sewage', secondary_y=False)
  fig.update_yaxes(title_text='Cases', secondary_y=True)

  fig.show()  
 

In [None]:
time_series_plot_by_state('MA')

### **Whole USA**

In [None]:
def time_series_plot_USA():
  water_by_county = water_by_county_import.drop([0]) # drop 2020-01-01
  water_by_county.rename(columns={'sampling_week':'date', 'effective_concentration_rolling_average':'concentration'}, inplace=True)
  water_USA = water_by_county[['date', 'concentration']].groupby(by=['date'], as_index=False).mean().sort_values(by='date')

  # turn weekly to daily data
  water_USA['date'] = pd.to_datetime(water_USA .date, format='%Y/%m/%d')  
  water_USA = water_USA.set_index('date').resample('D').ffill().reset_index()
  water_USA['date'] = (pd.to_datetime(water_USA['date']) - timedelta(3))
  water_USA['date'] = water_USA ['date'].astype('str')

  # cases
  cases_by_county = cases_by_county_import.loc[cases_by_county_import['fips'].notnull(),:]
  cases_USA = cases_by_county[['date', 'cases']].groupby(by=['date'], as_index=False).sum().sort_values(by='date')
  cases_USA['increased_cases'] = cases_USA.cases.diff()
  cases_USA['increased_cases'] = cases_USA['increased_cases'].fillna(0)

  # merge two data frames
  cases_and_water_USA = cases_USA.merge(water_USA, on='date')

  # time series plot
  fig = make_subplots(specs=[[{"secondary_y": True}]])
  fig.add_trace(
      go.Scatter(x = cases_and_water_USA['date'], y = cases_and_water_USA['concentration'],
                name='Concentration Rolling Average', marker_color='blue'), secondary_y=False    
  )

  fig.add_trace(
      go.Scatter(x = cases_and_water_USA['date'], y = cases_and_water_USA['increased_cases'], 
                mode='lines', name='Increased Cases') , secondary_y=True
  )

  fig.update_layout(title_text='USA', title_x=0.4)
  fig.update_xaxes(title_text='Date')
  fig.update_yaxes(title_text='Copies / mL of sewage', secondary_y=False)
  fig.update_yaxes(title_text='Cases', secondary_y=True)

  fig.show()

In [None]:
time_series_plot_USA()

### **Linear Model**

From the time series plot, an obvious time lag between daily increased cases and virus concentration rolling average can be observed. Instead of using the concentration rolling average as a variable to predict the time series of increased cases, we shifted the wastewater data N (N = 1, 2,...., 15) days backward and fitted N linear models with daily increased cases as the target. Lowest MSE occurred when N = 11, so we used this model to predict daily cases and plotted the results. Except for the period when a lot of people are infected, a simple linear model with shifted wastewater data as the only variable can roughly predict daily increased cases.

In [None]:
def find_optimal_lag():
  water_by_county = water_by_county_import.drop([0]) # drop 2020-01-01
  water_by_county.rename(columns={'sampling_week':'date', 'effective_concentration_rolling_average':'concentration'}, inplace=True)
  water_USA = water_by_county[['date', 'concentration']].groupby(by=['date'], as_index=False).mean().sort_values(by='date')

  # turn weekly to daily data
  water_USA ['date'] = pd.to_datetime(water_USA .date, format='%Y/%m/%d')  
  water_USA  = water_USA .set_index('date').resample('D').ffill().reset_index()
  water_USA['date'] = (pd.to_datetime(water_USA['date']) - timedelta(3))
  water_USA.sort_values(by='date')
  water_USA ['date'] = water_USA ['date'].astype('str')

  # cases
  cases_by_county = cases_by_county_import.loc[cases_by_county_import['fips'].notnull(),:]
  cases_USA = cases_by_county[['date', 'cases']].groupby(by=['date'], as_index=False).sum().sort_values(by='date')
  cases_USA['increased_cases'] = cases_USA.cases.diff()
  cases_USA['increased_cases'] = cases_USA['increased_cases'].fillna(0)

  lag_list = []
  mse_list = []
  model_list = []
  for i in range(5, 15):
    lag = i
    lag_list.append(lag)

    water_USA2 = water_USA.copy(deep=True)
    water_USA2['date'] = water_USA2['date'].astype('datetime64')
    water_USA2['date_adj'] = water_USA2['date'] + pd.Timedelta('%dD' %lag)
    water_USA2 = water_USA2[['date_adj','concentration']]
    water_USA2.rename(columns={'date_adj':'date'}, inplace=True)
    water_USA2['date'] = water_USA2['date'].astype('str')

    # merge two data frames
    cases_and_water_USA2 = cases_USA.merge(water_USA2, on='date')

    # LM
    After_may = cases_and_water_USA2[80:].copy(deep=True)
    X = After_may.concentration.to_numpy().reshape(-1, 1)
    y = After_may.increased_cases.to_numpy().reshape(-1, 1)
    reg = LinearRegression().fit(X, y)
    model_list.append(reg)

    predict_cases = reg.predict(After_may.concentration.to_numpy().reshape(-1, 1)).flatten()
    true_cases = After_may.increased_cases.to_numpy()

    mse = mean_squared_error(predict_cases, true_cases) / 1000000000
    mse_list.append(mse)

  optimize_lag = lag_list[mse_list.index(min(mse_list))]
  return optimize_lag

In [None]:
lag = find_optimal_lag()
lag

11

In [None]:
def time_series_plot_prediction(lag):
  water_by_county = water_by_county_import.drop([0]) # drop 2020-01-01
  water_by_county.rename(columns={'sampling_week':'date', 'effective_concentration_rolling_average':'concentration'}, inplace=True)
  water_USA = water_by_county[['date', 'concentration']].groupby(by=['date'], as_index=False).mean().sort_values(by='date')

  # turn weekly to daily data
  water_USA ['date'] = pd.to_datetime(water_USA .date, format='%Y/%m/%d')  
  water_USA  = water_USA .set_index('date').resample('D').ffill().reset_index()
  water_USA['date'] = (pd.to_datetime(water_USA['date']) - timedelta(3))
  water_USA.sort_values(by='date')

  water_USA['date'] = water_USA['date'].astype('datetime64')
  water_USA['date_adj'] = water_USA['date'] + pd.Timedelta('%dD' %lag)
  water_USA = water_USA[['date_adj','concentration']]
  water_USA.rename(columns={'date_adj':'date'}, inplace=True)
  water_USA['date'] = water_USA['date'].astype('str')


  # cases
  cases_by_county = cases_by_county_import.loc[cases_by_county_import['fips'].notnull(),:]
  cases_USA = cases_by_county[['date', 'cases']].groupby(by=['date'], as_index=False).sum().sort_values(by='date')
  cases_USA['increased_cases'] = cases_USA.cases.diff()
  cases_USA['increased_cases'] = cases_USA['increased_cases'].fillna(0)

  # merge two data frames
  cases_and_water_USA = cases_USA.merge(water_USA, on='date')

  # LM
  After_may = cases_and_water_USA[80:]
  X = After_may.concentration.to_numpy().reshape(-1, 1)
  y = After_may.increased_cases.to_numpy().reshape(-1, 1)
  reg = LinearRegression().fit(X, y)

  # prediction on whole dataset
  predict_cases = reg.predict(cases_and_water_USA.concentration.to_numpy().reshape(-1, 1)).flatten()
  true_cases = cases_and_water_USA.increased_cases.to_numpy()
  date = cases_and_water_USA.date.to_numpy()

  cases_predict_USA = pd.DataFrame({'date':date, 'Predicted Cases':predict_cases, "True Cases":true_cases})

  # time series plot
  fig = make_subplots()
  fig.add_trace(
      go.Scatter(x = cases_predict_USA['date'], y = cases_predict_USA['Predicted Cases'],
                name='Predicted Cases', marker_color='green')
                
  )
  fig.add_trace(
      go.Scatter(x = cases_predict_USA['date'], y = cases_predict_USA['True Cases'], 
                mode='lines', name='True Cases')
  )

  fig.update_layout(title_text='Predicted Cases and True Cases', title_x=0.4)
  fig.update_xaxes(title_text='Date')
  fig.update_yaxes(title_text='Cases')
  fig.show()


In [None]:
time_series_plot_prediction(11)