Generate results in various formats from one model

In [None]:
import utils
import pandas as pd
import numpy as np
import requests as rq
import datetime as dt
import torch
import json
import neptune

from matplotlib.ticker import MultipleLocator
from matplotlib.dates import DayLocator, AutoDateLocator, ConciseDateFormatter
%matplotlib inline

DATA_DIR = 'data'
EXPERIMENTS_DIR = 'experiments'
DEVICE = 'cpu'
NEPTUNE_PRJ = 'indiacovidseva/covid-net'

In [None]:
experiment_id = "0001_test"
checkpoint = "latest-e100.pt"

model, cp = utils.load_model(experiment_id, checkpoint)

In [None]:
cols = ['location', 'date', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths', 'population']
dates = ['date']
df = pd.read_csv(DATA_DIR + "/" + cp['config']['DS']['SRC'],
                 usecols=cols,
                 parse_dates=dates)
df = utils.fix_anomalies(df)
df.sample()

### Predict from OWID data

In [None]:
c = "India"
n_days_prediction = 200

# restrict predictions if outputs != inputs
if cp['config']['IP_FEATURES'] != cp['config']['OP_FEATURES']:
    op_len = cp['config']['DS']['OP_SEQ_LEN']
    print("WARNING: Input features and output features are different. Cannot predict more than", op_len, "days.")
    n_days_prediction = op_len

IP_SEQ_LEN = cp['config']['DS']['IP_SEQ_LEN']
OP_SEQ_LEN = cp['config']['DS']['OP_SEQ_LEN']
pop_fct = df.loc[df.location==c, 'population'].iloc[0] / 1000
test_data = np.array(df.loc[(df.location==c) & (df.total_cases>=100), cp['config']['DS']['FEATURES']].rolling(7, center=True, min_periods=1).mean() / pop_fct, dtype=np.float32)

in_data = test_data[-IP_SEQ_LEN:, cp['config']['IP_FEATURES']]
out_data = np.ndarray(shape=(0, len(cp['config']['OP_FEATURES'])), dtype=np.float32)
for i in range(int(n_days_prediction / OP_SEQ_LEN)):
    ip = torch.tensor(
        in_data,
        dtype=torch.float32
    )
    ip = ip.to(DEVICE)
    pred = model.predict(ip.view(1, IP_SEQ_LEN, len(cp['config']['IP_FEATURES']))).view(OP_SEQ_LEN, len(cp['config']['OP_FEATURES']))
    in_data = np.append(in_data[-IP_SEQ_LEN+OP_SEQ_LEN:, :], pred.cpu().numpy(), axis=0)
    out_data = np.append(out_data, pred.cpu().numpy(), axis=0)

for o in cp['config']['IP_FEATURES']:
    orig_df = pd.DataFrame({
        'actual': test_data[:,o] * pop_fct
    })
    fut_df = pd.DataFrame({
        'predicted': out_data[:,o] * pop_fct
    })
    # print(fut_df['predicted'].astype('int').to_csv(sep='|', index=False))
    orig_df = orig_df.append(fut_df, ignore_index=True, sort=False)
    orig_df['total'] = (orig_df['actual'].fillna(0) + orig_df['predicted'].fillna(0)).cumsum()

    start_date = df.loc[(df.location==c) & (df.total_cases>=100)]['date'].iloc[0]
    orig_df['Date'] = pd.Series([start_date + dt.timedelta(days=i) for i in range(len(orig_df))])
    ax = orig_df.plot(
        x='Date',
        y=['actual', 'predicted'],
        title=c + ' ' + cp['config']['DS']['FEATURES'][o],
        figsize=(10,6),
        grid=True
    )
    mn_l = DayLocator()
    ax.xaxis.set_minor_locator(mn_l)
    mj_l = AutoDateLocator()
    mj_f = ConciseDateFormatter(mj_l, show_offset=False)
    ax.xaxis.set_major_formatter(mj_f)
    # orig_df['total'] = orig_df['total'].astype('int')
    # orig_df['predicted'] = orig_df['predicted'].fillna(0).astype('int')
    # print(orig_df.tail(n_days_prediction))

    # arrow
    # peakx = 172
    # peak = orig_df.iloc[peakx]
    # peak_desc = peak['Date'].strftime("%d-%b") + "\n" + str(int(peak['predicted']))
    # _ = ax.annotate(
    #     peak_desc, 
    #     xy=(peak['Date'] - dt.timedelta(days=1), peak['predicted']),
    #     xytext=(peak['Date'] - dt.timedelta(days=45), peak['predicted'] * .9),
    #     arrowprops={},
    #     bbox={'facecolor':'white'}
    # )

    # _ = ax.axvline(x=peak['Date'], linewidth=1, color='r')

### Statewise predictions (covid19india)

In [None]:
r=rq.get('https://api.covid19india.org/v3/min/timeseries.min.json')
ts = r.json()

data = []
for state in ts:
    for date in ts[state]:
        ttl = ts[state][date]['total']
        data.append((state, date, ttl.get('confirmed', 0), ttl.get('deceased', 0), ttl.get('recovered', 0), ttl.get('tested', 0)))

states_df = pd.DataFrame(data, columns=['state', 'date', 'confirmed', 'deceased', 'recovered', 'tested'])
states_df['date'] = pd.to_datetime(states_df['date'])
first_case_date = states_df['date'].min()

In [None]:
# http://www.populationu.com/india-population
STT_INFO = {
    'AN' : {"name": "Andaman & Nicobar Islands", "popn": 450000},
    'AP' : {"name": "Andhra Pradesh", "popn": 54000000},
    'AR' : {"name": "Arunachal Pradesh", "popn": 30000000},
    'AS' : {"name": "Asaam", "popn": 35000000},
    'BR' : {"name": "Bihar", "popn": 123000000},
    'CH' : {"name": "Chandigarh", "popn": 1200000},
    'CT' : {"name": "Chhattisgarh", "popn": 29000000},
    'DL' : {"name": "Delhi", "popn": 19500000},
    'DN' : {"name": "Dadra & Nagar Haveli and Daman & Diu", "popn": 700000},
    'GA' : {"name": "Goa", "popn": 1580000},
    'GJ' : {"name": "Gujarat", "popn": 65000000},
    'HP' : {"name": "Himachal Pradesh", "popn": 7400000},
    'HR' : {"name": "Haryana", "popn": 28000000},
    'JH' : {"name": "Jharkhand", "popn": 38000000},
    'JK' : {"name": "Jammu & Kashmir", "popn": 13600000},
    'KA' : {"name": "Karnataka", "popn": 67000000},
    'KL' : {"name": "Kerala", "popn": 36000000},
    'LA' : {"name": "Ladakh", "popn": 325000},
    'MH' : {"name": "Maharashtra", "popn": 122000000},
    'ML' : {"name": "Meghalaya", "popn": 3400000},
    'MN' : {"name": "Manipur", "popn": 3000000},
    'MP' : {"name": "Madhya Pradesh", "popn": 84000000},
    'MZ' : {"name": "Mizoram", "popn": 1200000},
    'NL' : {"name": "Nagaland", "popn": 2200000},
    'OR' : {"name": "Odisha", "popn": 46000000},
    'PB' : {"name": "Punjab", "popn": 30000000},
    'PY' : {"name": "Puducherry", "popn": 1500000},
    'RJ' : {"name": "Rajasthan", "popn": 80000000},
    'TG' : {"name": "Telangana", "popn": 39000000},
    'TN' : {"name": "Tamil Nadu", "popn": 77000000},
    'TR' : {"name": "Tripura", "popn": 4100000},
    'UP' : {"name": "Uttar Pradesh", "popn": 235000000},
    'UT' : {"name": "Uttarakhand", "popn": 11000000},
    'WB' : {"name": "West Bengal", "popn": 98000000},
#     'SK' : {"name": "Sikkim", "popn": 681000},
#     'UN' : {"name": "Unassigned", "popn": 40000000}, #avg pop
#     'LD' : {"name": "Lakshadweep", "popn": 75000}
}

# uncomment for India
# STT_INFO = {
#     'TT' : {"name": "India", "popn": 1387155000}
# }

#### Dummy state data: fruit country

In [None]:
# dummy data for testing
# SET 1 - 10 states
# STT_INFO = {
#     'A': {"name": "Apple", "popn": 10000000},
#     'B': {"name": "Berry", "popn": 10000000},
#     'C': {"name": "Cherry", "popn": 10000000},
#     'D': {"name": "Dates", "popn": 10000000},
#     'E': {"name": "Elderberry", "popn": 10000000},
#     'F': {"name": "Fig", "popn": 10000000},
#     'G': {"name": "Grape", "popn": 10000000},
#     'H': {"name": "Honeysuckle", "popn": 10000000},
#     'I': {"name": "Icaco", "popn": 10000000},
#     'J': {"name": "Jujube", "popn": 10000000},
# }
# total = 100
# SET 2 - 1 agg state
STT_INFO = {
    'Z': {"name": "FruitCountry1000x", "popn": 10000000},
}
total = 1000


r = {
    'state': [],
    'date': [],
    'total': []
}

start_date = dt.datetime(day=1, month=3, year=2020)
end_date = dt.datetime.now()
while start_date <= end_date:
    for s in STT_INFO:
        r['state'].append(s)
        r['date'].append(start_date)
        r['total'].append(total)
    total *= 1.03
    start_date += dt.timedelta(days=1)
states_df = pd.DataFrame(r)
states_df['date'] = pd.to_datetime(states_df['date'])
states_df.tail()

#### Predict

In [None]:
def expand(df):
    '''Fill missing dates in an irregular timeline'''
    min_date = df['date'].min()
    max_date = df['date'].max()
    idx = pd.date_range(min_date, max_date)
    
    df.index = pd.DatetimeIndex(df.date)
    df = df.drop(columns=['date'])
    return df.reindex(idx, method='pad').reset_index().rename(columns={'index':'date'})

def prefill(df, min_date):
    '''Fill zeros from first_case_date to df.date.min()'''
    assert(len(df.state.unique()) == 1)
    s = df.state.unique().item()
    min_date = min_date
    max_date = df['date'].max()
    idx = pd.date_range(min_date, max_date)
    
    df.index = pd.DatetimeIndex(df.date)
    df = df.drop(columns=['date'])
    return df.reindex(idx).reset_index().rename(columns={'index':'date'}).fillna({'state':s, 'total':0})

In [None]:
IP_SEQ_LEN = cp['config']['DS']['IP_SEQ_LEN']
OP_SEQ_LEN = cp['config']['DS']['OP_SEQ_LEN']

plot_feature = 0 # 0:confirmed, 1:deaths
prediction_offset = 1 # how many days of data to skip
n_days_prediction = 200 # number of days for prediction
n_days_data = len(expand(states_df.loc[states_df['state']=='TT']))
assert(n_days_prediction%OP_SEQ_LEN == 0)

agg_days = n_days_data - prediction_offset + n_days_prediction # number of days for plotting agg curve i.e. prediction + actual data 
states_agg = np.zeros(agg_days)

ax = None
api = {}
for state in STT_INFO:
    pop_fct = STT_INFO[state]["popn"] / 1000
    
    state_df = states_df.loc[states_df['state']==state][:-prediction_offset] # skip todays data. covid19 returns incomplete.
    state_df = prefill(expand(state_df), first_case_date)
    state_df['new_cases'] = state_df['confirmed'] - state_df['confirmed'].shift(1).fillna(0)
    state_df['new_deaths'] = state_df['deceased'] - state_df['deceased'].shift(1).fillna(0)
    state_df['new_recovered'] = state_df['recovered'] - state_df['recovered'].shift(1).fillna(0)
    state_df['new_tests'] = state_df['tested'] - state_df['tested'].shift(1).fillna(0)
    test_data = np.array(state_df[cp['config']['DS']['FEATURES']].rolling(7, center=True, min_periods=1).mean() / pop_fct, dtype=np.float32)

    in_data = test_data[-IP_SEQ_LEN:, cp['config']['IP_FEATURES']]
    out_data = np.ndarray(shape=(0, len(cp['config']['OP_FEATURES'])), dtype=np.float32)
    for i in range(int(n_days_prediction / OP_SEQ_LEN)):
        ip = torch.tensor(
            in_data,
            dtype=torch.float32
        ).to(DEVICE)
        try:
            pred = model.predict(ip.view(-1, IP_SEQ_LEN, len(cp['config']['IP_FEATURES']))).view(OP_SEQ_LEN, len(cp['config']['OP_FEATURES']))
        except Exception as e:
            print(state, e)
        in_data = np.append(in_data[-IP_SEQ_LEN+OP_SEQ_LEN:, :], pred.cpu().numpy(), axis=0)
        out_data = np.append(out_data, pred.cpu().numpy(), axis=0)
    
    sn = STT_INFO[state]['name']
    orig_df = pd.DataFrame({
        'actual': np.array(test_data[:,plot_feature] * pop_fct, dtype=np.int)
    })
    fut_df = pd.DataFrame({
        'predicted': np.array(out_data[:,plot_feature] * pop_fct, dtype=np.int)
    })
    # print(fut_df.to_csv(sep='|'))
    orig_df = orig_df.append(fut_df, ignore_index=True, sort=False)
    orig_df[sn] = orig_df['actual'].fillna(0) + orig_df['predicted'].fillna(0)
    orig_df['total'] = orig_df[sn].cumsum()
    
    states_agg += np.array(orig_df[sn][-agg_days:].fillna(0))

    # generate date col for orig_df from state_df
    start_date = state_df['date'].iloc[0]
    orig_df['Date'] = pd.to_datetime([(start_date + dt.timedelta(days=i)).strftime("%Y-%m-%d") for i in range(len(orig_df))])
#     if orig_df[sn].max() < 10000: # or orig_df[sn].max() < 5000:
#         continue
    
    # print state, cumulative, peak
    peak = orig_df.loc[orig_df[sn].idxmax()]
    print(sn, "|", peak['Date'].strftime("%b %d"), "|", int(peak[sn]), "|", int(orig_df['total'].iloc[-1]))
    
    # export data for API
    orig_df['daily_deaths'] = orig_df[sn] * 0.028
    orig_df['daily_recovered'] = orig_df[sn].shift(14, fill_value=0) - orig_df['daily_deaths'].shift(7, fill_value=0)
    orig_df['daily_active'] = orig_df[sn] - orig_df['daily_recovered'] - orig_df['daily_deaths']
    
    api[state] = {}
    for idx, row in orig_df[-agg_days:].iterrows():
        row_date = row['Date'].strftime("%Y-%m-%d")
        api[state][row_date] = {
            "delta": {
                "confirmed": int(row[sn]),
                "deceased": int(row['daily_deaths']),
                "recovered": int(row['daily_recovered']),
                "active": int(row['daily_active'])
            }
        }
        
    # plot state chart
    ax = orig_df.plot(
        x='Date',
        y=[sn],
        title='Daily Cases',
        figsize=(15,10),
        grid=True,
        ax=ax,
        lw=3
    )
    mn_l = DayLocator()
    ax.xaxis.set_minor_locator(mn_l)
    mj_l = AutoDateLocator()
    mj_f = ConciseDateFormatter(mj_l, show_offset=False)
    ax.xaxis.set_major_formatter(mj_f)

# plot aggregate chart
cum_df = pd.DataFrame({
    'states_agg': states_agg 
})
last_date = orig_df['Date'].iloc[-1].to_pydatetime()
start_date = last_date - dt.timedelta(days=agg_days)
cum_df['Date'] = pd.to_datetime([(start_date + dt.timedelta(days=i)).strftime("%Y-%m-%d") for i in range(len(cum_df))])
ax = cum_df.plot(
    x='Date',
    y=['states_agg'],
    title='Aggregate daily cases',
    figsize=(15,10),
    grid=True,
    lw=3
)
mn_l = DayLocator()
ax.xaxis.set_minor_locator(mn_l)
mj_l = AutoDateLocator()
mj_f = ConciseDateFormatter(mj_l, show_offset=False)
ax.xaxis.set_major_formatter(mj_f)

# plot peak in agg
peakx = 178
peak = cum_df.iloc[peakx]
peak_desc = peak['Date'].strftime("%d-%b") + "\n" + str(int(peak['states_agg']))
_ = ax.annotate(
    peak_desc, 
    xy=(peak['Date'] + dt.timedelta(days=1), peak['states_agg']),
    xytext=(peak['Date'] + dt.timedelta(days=45), peak['states_agg'] * .9),
    arrowprops={},
    bbox={'facecolor':'white'}
)
_ = ax.axvline(x=peak['Date'], linewidth=1, color='r')

#### Export JSON for API

In [None]:
# aggregate predictions
api['TT'] = {}
for state in api:
    if state == 'TT':
        continue
    for date in api[state]:
        api['TT'][date] = api['TT'].get(date, {'delta':{}, 'total':{}})
        for k in ['delta']: #'total'
            api['TT'][date][k]['confirmed'] = api['TT'][date][k].get('confirmed', 0) + api[state][date][k]['confirmed']
            api['TT'][date][k]['deceased'] = api['TT'][date][k].get('deceased', 0) + api[state][date][k]['deceased']
            api['TT'][date][k]['recovered'] = api['TT'][date][k].get('recovered', 0) + api[state][date][k]['recovered']
            api['TT'][date][k]['active'] = api['TT'][date][k].get('active', 0) + api[state][date][k]['active']

# export
with open("predictions.json", "w") as f:
    f.write(json.dumps(api, sort_keys=True))

#### Export data for video player

In [None]:
# aggregate predictions
api['TT'] = {}
for state in api:
    if state == 'TT':
        continue
    for date in api[state]:
        api['TT'][date] = api['TT'].get(date, {})
        api['TT'][date]['c'] = api['TT'][date].get('c', 0) + api[state][date]['delta']['confirmed']
        api['TT'][date]['d'] = api['TT'][date].get('d', 0) + api[state][date]['delta']['deceased']
        api['TT'][date]['r'] = api['TT'][date].get('r', 0) + api[state][date]['delta']['recovered']
        api['TT'][date]['a'] = api['TT'][date].get('a', 0) + api[state][date]['delta']['active']

# read previous and export
k = (states_df.date.max().to_pydatetime() - dt.timedelta(days=prediction_offset)).strftime("%Y-%m-%d")
try:
    with open("vp.json", "r") as f:
        out = json.loads(f.read())
except Exception as e:
    out = {}

with open("vp.json", "w") as f:
    out[k] = {'TT': api['TT']}
    f.write(json.dumps(out, sort_keys=True))

#### CSV export video player output

In [None]:
df_csv = pd.DataFrame(out[k]['TT'])
df_csv = df_csv.transpose()
df_csv['c'].to_csv('vp_' + k + '.csv')

#### Upload model to Neptune

In [None]:
neptune_prj = neptune.init(NEPTUNE_PRJ)
neptune_exp = neptune_prj.get_experiments(id=cp['config']['NEPTUNE_ID'])[0]
neptune_exp.log_artifact(EXPERIMENTS_DIR + "/" + experiment_id + "/" + checkpoint)