In [1]:
import pandas as pd
from tqdm.auto import tqdm
from datetime import datetime, timedelta
from itertools import product
from datetime import date
from tqdm import tqdm

In [2]:
# ok, so in the csv, the year / week is just ['Current MMWR Year','MMWR WEEK'], and the number of cases (new cases) is "CUrrent week", Label is the disease.
# and similarly the api response, states	year	week	label	m1 (m1 is the current week number of new cases)
# states or reporting area contains more than states. regions, total, territories, US res, non-us res
# thought is to, create an outlier detection thing like in time series stuff I did for Live action. I think it's possible.
# https://dev.socrata.com/foundry/data.cdc.gov/x9gk-5huc
# now that I've download from API the data, it's better formatted. for no data it's actually NaN
# so in short, i'm filling forward fillin in all dates (weeks of year) that aren't in the actual data but adding a flag to id that so if we later want to have a cumulative count
# we can ignore those rows

In [3]:
df = pd.read_pickle("../data/raw/NNDSS.pkl")

In [4]:
df = df[['states','year','week','label','m1','location1']]
df.head()

Unnamed: 0,states,year,week,label,m1,location1
0,US RESIDENTS,2022,1,Anthrax,,
1,NEW ENGLAND,2022,1,Anthrax,,
2,CONNECTICUT,2022,1,Anthrax,,CONNECTICUT
3,MAINE,2022,1,Anthrax,,MAINE
4,MASSACHUSETTS,2022,1,Anthrax,,MASSACHUSETTS


In [5]:
df = df[df['location1'].notna()] # removes regions and USA total
df = df.drop(columns='location1')
df.columns = ['state','year','week','label','new_cases']

In [6]:
df.head()

Unnamed: 0,state,year,week,label,new_cases
2,CONNECTICUT,2022,1,Anthrax,
3,MAINE,2022,1,Anthrax,
4,MASSACHUSETTS,2022,1,Anthrax,
5,NEW HAMPSHIRE,2022,1,Anthrax,
6,RHODE ISLAND,2022,1,Anthrax,


In [7]:
# Ensure 'date' column is in the correct format and is sorted
df['date'] = pd.to_datetime(df['year'].astype(str) + '-' + df['week'].astype(str) + '-1', format='%Y-%W-%w')
# make unique id with state and label:
df['item_id'] = df['state'] + '_' + df['label']
df.sort_values(['item_id', 'date'], inplace=True)

In [8]:
df['new_cases'] = df.groupby('item_id')['new_cases'].transform(lambda x: x.ffill().bfill().fillna(0))

In [9]:
df['new_cases'] = df.new_cases.astype(int)
df['week'] = df.week.astype(int)
df['year'] = df.year.astype(int)

In [10]:
df = df[~df.label.str.contains("Probable")] # remove 'probable' as I don't want to predict probable diseases, only confirmed

In [11]:
len(df)

672486

In [12]:
def get_weeks_in_year(year):
    """Determine the number of ISO weeks in a given year."""
    last_day_of_year = date(year, 12, 28)  # ISO-8601; the week containing 28th Dec is the last week of the year
    return last_day_of_year.isocalendar()[1]

def fill_weekly_gaps(df):
    # Determine the maximum year and week present in the data for later use
    max_year = df['year'].max()
    max_week_for_max_year = df[df['year'] == max_year]['week'].max()

    all_combinations = []
    for item_id in tqdm(df['item_id'].unique(), desc='Filling gaps'):
        item_df = df[df['item_id'] == item_id]
        first_year = item_df['year'].min()
        last_year = item_df['year'].max()
        first_week = item_df[item_df['year'] == first_year]['week'].min()
        last_week = item_df[item_df['year'] == last_year]['week'].max()
        
        for year in df['year'].unique():
            if year < first_year or year > last_year:
                continue  # Skip years before the item_id first appears and after it last appears
            
            week_start = first_week if year == first_year else 1
            week_end = last_week if year == last_year else get_weeks_in_year(year)
            for week in range(week_start, week_end + 1):
                all_combinations.append({'item_id': item_id, 'year': year, 'week': week})
                
    all_combinations_df = pd.DataFrame(all_combinations)

    # Merge the generated combinations with the original DataFrame
    df_merged = pd.merge(all_combinations_df, df, on=['item_id', 'year', 'week'], how='left', indicator=True)
    
    # Carry forward the last observed 'new_cases', but only within the bounds of existing data for each item_id
    df_merged['new_cases'] = df_merged.groupby('item_id')['new_cases'].ffill().bfill()

    # Mark filled values for new_cases
    df_merged['filled_value'] = df_merged['_merge'] == 'left_only'
    df_merged.drop(columns=['_merge'], inplace=True)

    return df_merged

In [13]:
df = fill_weekly_gaps(df)

Filling gaps: 100%|██████████| 6840/6840 [10:17<00:00, 11.08it/s]


In [14]:
df.item_id.nunique()

6840

In [15]:
def year_week_to_date(year, week):
    """
    Convert a year and week number into the date of the Monday of that week.
    """
    # Calculate the first day of the year
    first_of_year = datetime(year, 1, 1)
    # ISO-8601 calculation for the first week of the year
    if first_of_year.weekday() > 3:  # If the first day is Friday or later
        # Move to the next Monday
        first_of_year += timedelta(days=7-first_of_year.weekday())
    else:
        # Move to the Monday of the current week
        first_of_year -= timedelta(days=first_of_year.weekday())
    
    # Calculate the Monday of the given week number
    week_start_date = first_of_year + timedelta(weeks=week-1)
    
    return week_start_date

# Assuming df is your DataFrame and it already contains 'year' and 'week' columns
# Update the 'date' column with the calculated Monday dates
df['date'] = df.apply(lambda row: year_week_to_date(int(row['year']), int(row['week'])), axis=1)


In [16]:
df['new_cases'] = df.new_cases.astype(int)

In [17]:
# fill in the state and label for the inserted rows if needed later
df['state'] = df.groupby('item_id')['state'].ffill().bfill()
df['label'] = df.groupby('item_id')['label'].ffill().bfill()

In [18]:
df_mod = df[['item_id','date','label','new_cases']]

In [19]:
df_mod.head()

Unnamed: 0,item_id,date,label,new_cases
0,ALABAMA_Anthrax,2022-01-03,Anthrax,0
1,ALABAMA_Anthrax,2022-01-10,Anthrax,0
2,ALABAMA_Anthrax,2022-01-17,Anthrax,0
3,ALABAMA_Anthrax,2022-01-24,Anthrax,0
4,ALABAMA_Anthrax,2022-01-31,Anthrax,0


In [21]:
# Setting the cut-off date for the train-test split
# For example, if you want the last 4 weeks as your test set:
cut_off_date = df_mod['date'].max() - pd.Timedelta(weeks=10)
print(cut_off_date)
# Splitting the DataFrame into training and testing sets
train = df_mod[df_mod['date'] <= cut_off_date]
test = df_mod[df_mod['date'] > cut_off_date]


2023-12-11 00:00:00


In [22]:
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from gluonts.torch import DeepAREstimator
from gluonts.dataset.common import ListDataset
from gluonts.torch.distributions.negative_binomial import NegativeBinomialOutput



In [23]:
train_ds = PandasDataset.from_long_dataframe(train, target='new_cases', item_id='item_id', 
                                             timestamp='date', freq='W')
test_ds = PandasDataset.from_long_dataframe(test, target='new_cases', item_id='item_id', 
                                            timestamp='date', freq='W')
                                            
# Train the model and make predictions
estimator = DeepAREstimator(
    prediction_length=10,
    freq="W",
    distr_output=NegativeBinomialOutput(),
    trainer_kwargs={"max_epochs": 100}
)

model = estimator.train(train_ds)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/usr/local/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.

  | Name  | Type        | Params | In sizes                                                      | Out sizes   
--------------------------------------

Training: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 50: 'train_loss' reached 0.54609 (best 0.54609), saving model to '/workspaces/NNDSS/notebooks/lightning_logs/version_20/checkpoints/epoch=0-step=50.ckpt' as top 1
Epoch 1, global step 100: 'train_loss' reached 0.48108 (best 0.48108), saving model to '/workspaces/NNDSS/notebooks/lightning_logs/version_20/checkpoints/epoch=1-step=100.ckpt' as top 1
Epoch 2, global step 150: 'train_loss' reached 0.45975 (best 0.45975), saving model to '/workspaces/NNDSS/notebooks/lightning_logs/version_20/checkpoints/epoch=2-step=150.ckpt' as top 1
Epoch 3, global step 200: 'train_loss' reached 0.44517 (best 0.44517), saving model to '/workspaces/NNDSS/notebooks/lightning_logs/version_20/checkpoints/epoch=3-step=200.ckpt' as top 1
Epoch 4, global step 250: 'train_loss' was not in top 1
Epoch 5, global step 300: 'train_loss' was not in top 1
Epoch 6, global step 350: 'train_loss' was not in top 1
Epoch 7, global step 400: 'train_loss' reached 0.39256 (best 0.39256), saving model to '/w

In [24]:
preds = list(model.predict(train_ds))   

In [25]:
from pandas.tseries.offsets import Week

start_date = test.date.min()
num_periods = 10
prediction_dates = pd.date_range(start=start_date, periods=num_periods, freq='7D')
prediction_dates

DatetimeIndex(['2023-12-18', '2023-12-25', '2024-01-01', '2024-01-08',
               '2024-01-15', '2024-01-22', '2024-01-29', '2024-02-05',
               '2024-02-12', '2024-02-19'],
              dtype='datetime64[ns]', freq='7D')

In [57]:
from pandas import Timestamp

all_preds = []
for i, forecast in enumerate(preds):
    item_id = forecast.item_id
    pred_df = pd.DataFrame({
        'date': prediction_dates,
        'item_id': item_id,
        'pred_mean': forecast.mean,
        'pred_lower': forecast.quantile(0.01),
        'pred_upper': forecast.quantile(0.99)
    })
    all_preds.append(pred_df)

all_preds_df = pd.concat(all_preds, ignore_index=True)


In [58]:
from sklearn.metrics import mean_absolute_error, mean_squared_error
import numpy as np

In [59]:
test = test.reset_index(drop=True)

In [60]:
df_evaluation = pd.merge(all_preds_df, test, on=['item_id', 'date'])
df_evaluation.head()

Unnamed: 0,date,item_id,pred_mean,pred_lower,pred_upper,label,new_cases
0,2023-12-18,ALABAMA_Anthrax,0.0,0.0,0.0,Anthrax,0
1,2023-12-25,ALABAMA_Anthrax,0.0,0.0,0.0,Anthrax,0
2,2024-01-01,ALABAMA_Anthrax,0.0,0.0,0.0,Anthrax,0
3,2024-01-08,ALABAMA_Anthrax,0.0,0.0,0.0,Anthrax,0
4,2024-01-15,ALABAMA_Anthrax,0.0,0.0,0.0,Anthrax,0


In [61]:
# Calculate errors
mae = mean_absolute_error(df_evaluation['new_cases'], df_evaluation['pred_mean'])
rmse = np.sqrt(mean_squared_error(df_evaluation['new_cases'], df_evaluation['pred_mean']))

# Since MAPE can have division by zero issues, we'll handle it carefully
mape = np.mean(np.abs((df_evaluation['new_cases'] - df_evaluation['pred_mean']) / df_evaluation['new_cases'].replace(0, np.nan))) * 100

# Printing the errors
print(f"Mean Absolute Error (MAE): {mae}")
print(f"Root Mean Squared Error (RMSE): {rmse}")
print(f"Mean Absolute Percentage Error (MAPE): {mape}%")

Mean Absolute Error (MAE): 1.1776180634702538
Root Mean Squared Error (RMSE): 22.08968304621623
Mean Absolute Percentage Error (MAPE): 44.806015800156835%


In [62]:
mode_new_cases = df_evaluation.groupby('item_id')['new_cases'].apply(lambda x: x.mode()[0])
item_ids_to_keep = mode_new_cases[mode_new_cases != 0].index
filtered_df = df_evaluation[df_evaluation['item_id'].isin(item_ids_to_keep)]
mae = mean_absolute_error(filtered_df['new_cases'], filtered_df['pred_mean'])
rmse = np.sqrt(mean_squared_error(filtered_df['new_cases'], filtered_df['pred_mean']))

# Since MAPE can have division by zero issues, we'll handle it carefully
mape = np.mean(np.abs((filtered_df['new_cases'] - filtered_df['pred_mean']) / filtered_df['new_cases'].replace(0, np.nan))) * 100

# Printing the errors
print(f"Mean Absolute Error (MAE): {mae}")
print(f"Root Mean Squared Error (RMSE): {rmse}")
print(f"Mean Absolute Percentage Error (MAPE): {mape}%")

Mean Absolute Error (MAE): 4.377667148987247
Root Mean Squared Error (RMSE): 42.59009647106484
Mean Absolute Percentage Error (MAPE): 44.80601580015683%


In [63]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

def plot_forecasts_plotly(df, all_preds_df, num_charts=4):
    item_ids = np.random.choice(df['item_id'].unique(), size=num_charts, replace=False)
    
    num_rows = num_charts // 2 + num_charts % 2
    fig = make_subplots(rows=num_rows, cols=2, subplot_titles=item_ids)
    
    for i, item_id in enumerate(item_ids, start=1):
        original_filtered = df[df['item_id'] == item_id]
        predictions_filtered = all_preds_df[all_preds_df['item_id'] == item_id]

        row = (i-1) // 2 + 1
        col = i % 2 if i % 2 != 0 else 2

        # Plot the actual values with lines
        fig.add_trace(go.Scatter(x=original_filtered['date'], y=original_filtered['new_cases'],
                                 mode='lines+markers', name=f'Actual {item_id}',
                                 legendgroup=f"group{i}", showlegend=False,
                                 line=dict(color='blue'),  # Set the line color to blue
                                 marker=dict(color=original_filtered['filled_value'].map({True: 'green', False: 'blue'}),  # Color dots based on filled_value
                                             size=2)),  # Set marker size
                      row=row, col=col)

        # Plot the predicted mean
        fig.add_trace(go.Scatter(x=predictions_filtered['date'], y=predictions_filtered['pred_mean'],
                                 mode='lines', name=f'Predicted {item_id}',
                                 legendgroup=f"group{i}", showlegend=False, line=dict(color='red')),
                      row=row, col=col)

        # Prediction intervals
        fig.add_trace(go.Scatter(x=predictions_filtered['date'], y=predictions_filtered['pred_lower'],
                                 mode='lines', name=f'Lower {item_id}', 
                                 legendgroup=f"group{i}", showlegend=False, line=dict(width=0)),
                      row=row, col=col)

        fig.add_trace(go.Scatter(x=predictions_filtered['date'], y=predictions_filtered['pred_upper'],
                                 mode='lines', name=f'Upper {item_id}', fill='tonexty',
                                 legendgroup=f"group{i}", showlegend=False, line=dict(width=0), fillcolor='rgba(255, 0, 0, 0.3)'),
                      row=row, col=col)
        
    fig.update_layout(template="plotly_dark", height=300*num_rows, title_text="Forecasts of New Cases", showlegend=False)
    fig.show()


In [64]:
plot_forecasts_plotly(df, all_preds_df, num_charts=8)

In [65]:
df_avg = df.groupby(['item_id','state','label'],as_index=False)['new_cases'].mean()
df_avg[df_avg.new_cases>50].item_id.unique()

array(['ALABAMA_Chlamydia trachomatis infection', 'ALABAMA_Gonorrhea',
       'ALABAMA_Hepatitis C, chronic, Confirmed',
       'ARIZONA_Chlamydia trachomatis infection',
       'ARIZONA_Coccidioidomycosis',
       'ARIZONA_Coccidioidomycosis, Confirmed',
       'ARIZONA_Coccidioidomycosis, total', 'ARIZONA_Gonorrhea',
       'ARKANSAS_Chlamydia trachomatis infection', 'ARKANSAS_Gonorrhea',
       'CALIFORNIA_Campylobacteriosis',
       'CALIFORNIA_Chlamydia trachomatis infection',
       'CALIFORNIA_Gonorrhea', 'COLORADO_Chlamydia trachomatis infection',
       'COLORADO_Gonorrhea', 'DELAWARE_Chlamydia trachomatis infection',
       'DISTRICT OF COLUMBIA_Chlamydia trachomatis infection',
       'FLORIDA_Chlamydia trachomatis infection', 'FLORIDA_Gonorrhea',
       'GEORGIA_Chlamydia trachomatis infection', 'GEORGIA_Gonorrhea',
       'IDAHO_Chlamydia trachomatis infection',
       'ILLINOIS_Chlamydia trachomatis infection',
       'INDIANA_Chlamydia trachomatis infection', 'INDIANA_Go

In [66]:
def plot_forecasts_plotly_by_item_id(df, all_preds_df, item_ids, num_charts=None):
    # If num_charts is not specified, plot for all given item_ids
    if num_charts is None:
        num_charts = len(item_ids)
    else:
        num_charts = min(num_charts, len(item_ids))
    
    item_ids_to_plot = np.random.choice(item_ids, size=num_charts, replace=False) if len(item_ids) > num_charts else item_ids
    
    num_rows = num_charts // 2 + num_charts % 2
    fig = make_subplots(rows=num_rows, cols=2, subplot_titles=item_ids_to_plot)
    
    for i, item_id in enumerate(item_ids_to_plot, start=1):
        original_filtered = df[df['item_id'] == item_id]
        predictions_filtered = all_preds_df[all_preds_df['item_id'] == item_id]

        row = (i-1) // 2 + 1
        col = i % 2 if i % 2 != 0 else 2

        # Plot the actual values with lines
        fig.add_trace(go.Scatter(x=original_filtered['date'], y=original_filtered['new_cases'],
                                 mode='lines+markers', name=f'Actual {item_id}',
                                 legendgroup=f"group{i}", showlegend=False,
                                 line=dict(color='blue'),  # Set the line color to blue
                                 marker=dict(color=original_filtered['filled_value'].map({True: 'green', False: 'blue'}),  # Color dots based on filled_value
                                             size=4)),  # Set marker size
                      row=row, col=col)

        # Plot the predicted mean
        fig.add_trace(go.Scatter(x=predictions_filtered['date'], y=predictions_filtered['pred_mean'],
                                 mode='lines', name=f'Predicted {item_id}',
                                 legendgroup=f"group{i}", showlegend=False, line=dict(color='red')),
                      row=row, col=col)

        # Prediction intervals
        fig.add_trace(go.Scatter(x=predictions_filtered['date'], y=predictions_filtered['pred_lower'],
                                 mode='lines', name=f'Lower {item_id}', 
                                 legendgroup=f"group{i}", showlegend=False, line=dict(width=0)),
                      row=row, col=col)

        fig.add_trace(go.Scatter(x=predictions_filtered['date'], y=predictions_filtered['pred_upper'],
                                 mode='lines', name=f'Upper {item_id}', fill='tonexty',
                                 legendgroup=f"group{i}", showlegend=False, line=dict(width=0), fillcolor='rgba(255, 0, 0, 0.3)'),
                      row=row, col=col)
        
    fig.update_layout(template="plotly_dark", height=300*num_rows, title_text="Forecasts of New Cases", showlegend=True)
    fig.show()

In [67]:
item_ids = ['ARKANSAS_Chlamydia trachomatis infection', 'ARKANSAS_Gonorrhea',
       'CALIFORNIA_Campylobacteriosis',
       'CALIFORNIA_Chlamydia trachomatis infection',
       'CALIFORNIA_Gonorrhea', 'COLORADO_Chlamydia trachomatis infection',
       'COLORADO_Gonorrhea', 'DELAWARE_Chlamydia trachomatis infection',
       'FLORIDA_Chlamydia trachomatis infection', 'FLORIDA_Gonorrhea',
       'GEORGIA_Chlamydia trachomatis infection', 'GEORGIA_Gonorrhea',
       'IDAHO_Chlamydia trachomatis infection',
       'ILLINOIS_Chlamydia trachomatis infection',]

In [68]:
plot_forecasts_plotly_by_item_id(df, all_preds_df, item_ids, num_charts=None)

In [139]:
plot_forecasts_plotly_by_item_id(df, all_preds_df, item_ids, num_charts=None)

In [69]:
df_evaluation[df_evaluation.new_cases > df_evaluation.pred_upper]

Unnamed: 0,date,item_id,pred_mean,pred_lower,pred_upper,label,new_cases
3409,2024-02-05,ARIZONA_Giardiasis,2.750000,0.0,9.0,Giardiasis,10
3412,2023-12-18,ARIZONA_Gonorrhea,32.990002,11.0,58.0,Gonorrhea,104
3413,2023-12-25,ARIZONA_Gonorrhea,39.540001,14.0,79.0,Gonorrhea,104
3419,2024-02-05,ARIZONA_Gonorrhea,33.849998,6.0,71.0,Gonorrhea,91
3584,2024-01-15,"ARIZONA_Invasive pneumococcal disease, all age...",17.020000,4.0,36.0,"Invasive pneumococcal disease, all ages, Confi...",42
...,...,...,...,...,...,...,...
54419,2023-12-25,"VIRGINIA_Hepatitis, A, acute",0.930000,0.0,3.0,"Hepatitis, A, acute",5
55871,2023-12-25,WASHINGTON_Tuberculosis,1.150000,0.0,4.0,Tuberculosis,6
55872,2024-01-01,WASHINGTON_Tuberculosis,1.130000,0.0,4.0,Tuberculosis,6
55874,2024-01-15,WASHINGTON_Tuberculosis,1.040000,0.0,4.0,Tuberculosis,6
