In [25]:
from data_utils import *

In [26]:
def quantile_score(q, y, alpha):
    return 2 * (int(y < q) - alpha) * (q - y)

# Compute squared error, absolute error or quantile score based on "type"
def score(prediction, observation, score_type, quantile):
    if score_type == "mean":
        return (prediction - observation) ** 2
    elif score_type == "median":
        return abs(prediction - observation)
    elif score_type == "quantile":
        return quantile_score(prediction, observation, quantile)
    else:
        raise ValueError("Invalid type specified")

def compute_wis(df):
    # Filter rows where 'quantile' is 0.5, rename 'value' to 'med', and drop unnecessary columns
    df_median =  df[(df['type'] == 'quantile') & (df['quantile'] == 0.5)].copy()
    df_median = df_median.rename(columns={'value': 'med'}).drop(columns=['quantile'], errors='ignore')

    # Filter rows where 'type' is 'quantile' and merge with df_median
    df_quantile = df[df['type'] == 'quantile'].copy()
    df = df_quantile.merge(df_median, how='left')

    # Compute scores and decomposition row-wise
    df['wis'] = df.apply(lambda row: score(row['value'], row['target'], row['type'], row['quantile']), axis=1)
    df['spread'] = df.apply(lambda row: score(row['value'], row['med'], row['type'], row['quantile']), axis=1)
    df['overprediction'] = df.apply(lambda row: row['wis'] - row['spread'] if row['med'] > row['target'] else 0, axis=1)
    df['underprediction'] = df.apply(lambda row: row['wis'] - row['spread'] if row['med'] < row['target'] else 0, axis=1)

    # Group by 'model' and compute the mean for each metric
    result_df = df.groupby('model').agg({
        'spread': 'mean',
        'overprediction': 'mean',
        'underprediction': 'mean',
        'wis': 'mean'
    }).reset_index()

    return result_df

In [30]:
df = load_submissions()

In [31]:
df_wis = df.groupby(['source', 'disease', 'level', 'location', 'age_group', 'horizon'])[df.columns].apply(compute_wis).reset_index().drop(columns='level_6')

In [32]:
df_wis

Unnamed: 0,source,disease,level,location,age_group,horizon,model,spread,overprediction,underprediction,wis
0,agi,are,age,DE,00-04,-3,KIT-simple_nowcast,8.544097,0.000000,214079.868525,214088.412621
1,agi,are,age,DE,00-04,-2,KIT-simple_nowcast,13.870924,0.000000,222597.781206,222611.652130
2,agi,are,age,DE,00-04,-1,KIT-simple_nowcast,24.168074,0.000000,226179.241533,226203.409607
3,agi,are,age,DE,00-04,0,KIT-simple_nowcast,49.832295,0.000000,228129.065022,228178.897317
4,agi,are,age,DE,00-04,1,KIT-LightGBM,4566.593938,1414.000616,2431.589062,8412.183615
...,...,...,...,...,...,...,...,...,...,...,...
808,survstat,rsv,states,DE-TH,00+,-3,KIT-simple_nowcast,0.000496,0.000000,0.000000,0.000496
809,survstat,rsv,states,DE-TH,00+,-2,KIT-simple_nowcast,0.006824,0.000000,0.034739,0.041563
810,survstat,rsv,states,DE-TH,00+,-1,KIT-simple_nowcast,0.030273,0.000000,0.057072,0.087345
811,survstat,rsv,states,DE-TH,00+,0,KIT-simple_nowcast,0.247643,0.034739,0.032258,0.314640


In [24]:
# df_wis.to_csv('../data/wis.csv', index=False)

In [44]:
df_wis

Unnamed: 0,source,disease,level,location,age_group,horizon,model,spread,overprediction,underprediction,wis
0,agi,are,age,DE,00-04,-3,KIT-simple_nowcast,8.544097,0.000000,214079.868525,214088.412621
1,agi,are,age,DE,00-04,-2,KIT-simple_nowcast,13.870924,0.000000,222597.781206,222611.652130
2,agi,are,age,DE,00-04,-1,KIT-simple_nowcast,24.168074,0.000000,226179.241533,226203.409607
3,agi,are,age,DE,00-04,0,KIT-simple_nowcast,49.832295,0.000000,228129.065022,228178.897317
4,agi,are,age,DE,00-04,1,KIT-LightGBM,4566.593938,1414.000616,2431.589062,8412.183615
...,...,...,...,...,...,...,...,...,...,...,...
808,survstat,rsv,states,DE-TH,00+,-3,KIT-simple_nowcast,0.000496,0.000000,0.000000,0.000496
809,survstat,rsv,states,DE-TH,00+,-2,KIT-simple_nowcast,0.006824,0.000000,0.034739,0.041563
810,survstat,rsv,states,DE-TH,00+,-1,KIT-simple_nowcast,0.030273,0.000000,0.057072,0.087345
811,survstat,rsv,states,DE-TH,00+,0,KIT-simple_nowcast,0.247643,0.034739,0.032258,0.314640


In [42]:
def compute_coverage(df):

    df_wide = df[df.type == 'quantile'].pivot(index=['source', 'disease', 'level', 'location', 'age_group', 'forecast_date', 'target_end_date', 'horizon',
           'type', 'model', 'date', 'year', 'week', 'target'], columns='quantile', values='value')

    df_wide.columns = [f'quantile_{col}' for col in df_wide.columns]

    df_wide = df_wide.reset_index()

    df_wide['c50'] = (df_wide['target'] >= df_wide['quantile_0.25']) & (df_wide['target'] <= df_wide['quantile_0.75'])
    df_wide['c95'] = (df_wide['target'] >= df_wide['quantile_0.025']) & (df_wide['target'] <= df_wide['quantile_0.975'])

    coverage_df = df_wide.groupby(['source', 'disease', 'model', 'level', 'location', 'age_group', 'horizon']).agg(
        c50=('c50', 'mean'),
        c95=('c95', 'mean')
    ).reset_index()
    
    return coverage_df

In [45]:
df_coverage = compute_coverage(df)

In [48]:
df_coverage

Unnamed: 0,source,disease,model,level,location,age_group,horizon,c50,c95
0,agi,are,KIT-LightGBM,age,DE,00-04,1,0.590909,0.954545
1,agi,are,KIT-LightGBM,age,DE,00-04,2,0.500000,0.909091
2,agi,are,KIT-LightGBM,age,DE,00-04,3,0.318182,0.909091
3,agi,are,KIT-LightGBM,age,DE,00-04,4,0.363636,0.954545
4,agi,are,KIT-LightGBM,age,DE,05-14,1,0.409091,0.863636
...,...,...,...,...,...,...,...,...,...
808,survstat,rsv,KIT-simple_nowcast,states,DE-TH,00+,-3,1.000000,1.000000
809,survstat,rsv,KIT-simple_nowcast,states,DE-TH,00+,-2,0.952381,0.952381
810,survstat,rsv,KIT-simple_nowcast,states,DE-TH,00+,-1,0.952381,1.000000
811,survstat,rsv,KIT-simple_nowcast,states,DE-TH,00+,0,0.952381,0.952381


In [51]:
df = df_wis.merge(df_coverage, on=['source', 'disease', 'model', 'level', 'location', 'age_group', 'horizon'])

In [52]:
df.to_csv('../data/scores.csv', index=False)