In [None]:
# General imports
import pandas as pd
import numpy as np
import matplotlib
from matplotlib import pyplot as plt

# GluonTS imports
from gluonts.dataset.common import ListDataset
from gluonts.torch.model.simple_feedforward import SimpleFeedForwardEstimator
from gluonts.dataset.split import split

# SimbaML imports
from simba_ml.simulation import distributions, generators
from simba_ml.simulation import kinetic_parameters as kinetic_parameters_module
from simba_ml.simulation import noisers
from simba_ml.simulation import species, system_model


In [None]:
start_date = pd.to_datetime('2020-02-20')
offset = 22

prediction_length = 7
context_length = 7

In [None]:
name = "SIR - Covid-19 - Data Augmentation"
# Population obtained form:
# https://www-genesis.destatis.de/genesis/online?operation=abruftabelleBearbeiten&levelindex=1&levelid=1676991208921&auswahloperation=abruftabelleAuspraegungAuswaehlen&auswahlverzeichnis=ordnungsstruktur&auswahlziel=werteabruf&code=12411-0001&auswahltext=&werteabruf=Value+retrieval#abreadcrumb
specieses = [
    species.Species("Suspectible", distributions.Constant(83166711-100), contained_in_output=False, min_value=0), #83166711
    species.Species("Infected", distributions.Constant(100), contained_in_output=False, min_value=0),
    species.Species("Recovered", distributions.Constant(0), contained_in_output=False, min_value=0),
    species.Species("Cumulative Infected", distributions.Constant(100), contained_in_output=True, min_value=0),
]

kinetic_parameters: dict[str, kinetic_parameters_module.KineticParameter] = {
    "beta": kinetic_parameters_module.ConstantKineticParameter(distributions.ContinuousUniformDistribution(0.32, 0.35)),
    "gamma": kinetic_parameters_module.ConstantKineticParameter(distributions.ContinuousUniformDistribution(0.123, 0.125)),
}

def deriv(_t: float, y: list[float], arguments: dict[str, float]) -> tuple[float, float, float]:
    """Defines the derivative of the function at the point _.

    Args:
        y: Current y vector.
        arguments: Dictionary of arguments configuring the problem.

    Returns:
        Tuple[float, float, float]
    """
    S, I, R, _ = y
    N = S + I + R
    

    dS_dt = -arguments["beta"] * S * I / N
    dI_dt = arguments["beta"] * S * I / N - (arguments["gamma"]) * I
    dR_dt = arguments["gamma"] * I
    dC_dt = arguments["beta"] * S * I / N
    return dS_dt, dI_dt, dR_dt, dC_dt



noiser = noisers.AdditiveNoiser(distributions.NormalDistribution(0, 42*10**3))

sm = system_model.SystemModel(
            name,
            specieses,
            kinetic_parameters,
            deriv=deriv,
            noiser=noiser,
            timestamps=distributions.Constant(100)
        )
    


In [None]:
simulations = generators.TimeSeriesGenerator(sm).generate_signals(n=100)
simulations_new_cases = [simulation.assign(new_cases = simulation["Cumulative Infected"].diff()) for simulation in simulations]
sim_targets = [{"target": simulation["new_cases"].iloc[20:100].to_numpy(), "start": start_date} for simulation in simulations_new_cases]

In [None]:
real_data =  pd.read_csv('data/rki_case_numbers_germany.csv')
real_data = real_data.loc[50:150].reset_index(drop=True)
real_target = [{"target": real_data["new_cases_7d_average"].to_numpy(), "start": start_date}]


In [None]:

target = [{"target": real_target[0]["target"][:offset], "start": start_date}] + sim_targets 

In [None]:
dataset = ListDataset(target, freq='d')
real_dataset = ListDataset(real_target, freq='d')

train_real, test_gen = split(real_dataset, offset=offset)


In [None]:
# Training with augmented dataset
model = SimpleFeedForwardEstimator(
    prediction_length=prediction_length, context_length=context_length,  trainer_kwargs={"max_epochs": 30}
);
predictor = model.train(dataset);

test_data = test_gen.generate_instances(prediction_length=prediction_length, windows=1);
forecasts_mix = list(predictor.predict(test_data.input));

In [None]:
# Training with only real-world dataset
del model
model = SimpleFeedForwardEstimator(
    prediction_length=prediction_length, context_length=context_length,  trainer_kwargs={"max_epochs": 30}, weight_decay=0.01
);
predictor = model.train(train_real);

test_data = test_gen.generate_instances(prediction_length=prediction_length, windows=1);
forecasts_obs_only = list(predictor.predict(test_data.input));

In [None]:
# Requires pdflatex
matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': False,
})

In [None]:
forecast_date = pd.to_datetime(real_data.loc[offset, "day_idx"])
forecast_date = forecast_date.to_period(freq='D')

test_data = test_gen.generate_instances(prediction_length=prediction_length, windows=1)

forecasts_obs_only = list(predictor.predict(test_data.input))

medium = 11
large = 12

plt.rc('font', size=large)         
plt.rc('axes', titlesize=large)     
plt.rc('axes', labelsize=large)    
plt.rc('xtick', labelsize=medium)   
plt.rc('ytick', labelsize=medium)    
plt.rc('legend', fontsize=large)   
plt.rc('figure', titlesize=large)  

fig = plt.gcf()
fig.set_size_inches(6.8, 4.8)

# Plot ground truth time series
real_data["day_idx"] = pd.to_datetime(real_data["day_idx"])
plt.plot(real_data["day_idx"][:offset+prediction_length+1], real_data["new_cases_7d_average"][:offset+prediction_length+1],
         label="Ground Truth", color="#332288", linewidth=2.5)

# Plot forecast of model trained with augmented dataset
fcoo = forecasts_mix[0].to_sample_forecast(num_samples=10000)
fcoo.start_date = forecast_date

fcoo.samples = np.array([[real_data['new_cases_7d_average'][offset]] + list(a) for a in fcoo.samples])
fcoo.plot(intervals=(0.50, 0.85), color="#44AA99")
augmented_label = "Forecast: Synthetically Augmented Data"

# Plot forecast of model trained with only the real-world dataset
fcoo = forecasts_obs_only[0].to_sample_forecast(num_samples=10000)
fcoo.start_date = forecast_date

fcoo.samples = np.array([[real_data['new_cases_7d_average'][offset]] + list(a) for a in fcoo.samples])
fcoo.plot(intervals=(0.50, 0.85), color="#AA4499")
real_label = "Forecast: Only Real Data"

# Adding the legend
handles, labels = plt.gca().get_legend_handles_labels()
handles.append(plt.Line2D([0], [0], color="#44AA99", linewidth=2.5))
handles.append(plt.Line2D([0], [0], color="#AA4499", linewidth=2.5))
labels.append(augmented_label)
labels.append(real_label)
plt.legend(handles, labels, loc="upper left", fontsize="medium", ncols=1)

# Set ticks
tick_dates = pd.date_range(start="2020-02-22", periods=offset+prediction_length, freq="D")[::4]
plt.xticks(tick_dates, [date.strftime('%d\n%b') for date in tick_dates])

# Set axis labels
plt.xlabel("Date (2020)", fontsize="large")
plt.ylabel("New Cases (7-day average)", fontsize="large")

# Add train cutoff visualization
plt.axvline(x=forecast_date.to_timestamp(), ymin=0, ymax=14000, color="black", linestyle="dashed", linewidth=1)
plt.text(pd.to_datetime('2020-03-12 21:30'), 4500, "Training cutoff", fontsize=12, color="black", rotation=90, horizontalalignment='right')

# Shade the background of the training data
start_date = pd.to_datetime('2020-02-01')
plt.axvspan(start_date, forecast_date.to_timestamp(), facecolor='grey', alpha=0.14)

# Set axis limits
plt.xlim(pd.to_datetime('2020-02-20'), pd.to_datetime('2020-03-20'))
plt.ylim(0, 12000)

plt.savefig('figure2.pdf', bbox_inches='tight')
plt.close() 
