In [3]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
from datetime import datetime as dt
from gretel_synthetics.timeseries_dgan.config import DGANConfig
from gretel_synthetics.timeseries_dgan.dgan import DGAN

##### Import load profiles and remove incomplete days

In [None]:
loadProfiles_df = pd.read_parquet(r'data/load_profiles.parquet.gzip')

In [None]:
def remove_incomplete_days(df):
    temp_df = df.groupby('date').count()
    incompleteDays = temp_df[(temp_df < 24).any(axis = 1)].index
    df = df.loc[~df['date'].isin(incompleteDays)]
    return df

In [None]:
loadProfiles_df = remove_incomplete_days(loadProfiles_df)

##### Select and visualize profile

In [None]:
profile = '16'

In [None]:
def plot_profile(df, profile, x = 'date', custom_title = None):
    fig = plt.figure(figsize = (12, 4), facecolor = 'w')
    plt.plot(df[x], df[profile])
    plt.title(f'Profile {profile}' if not custom_title else custom_title, fontsize = 16)
    plt.xlabel(x.capitalize(), fontsize = 14)
    plt.xticks(fontsize = 12.5)
    plt.ylabel('Consumed energy [Wh]', fontsize = 14)
    plt.yticks(fontsize = 12.5)
    plt.close()
    return fig


def plot_subsequence(df, profile, date):
    temp_df = df[df['date'] == dt.strptime(date, '%Y-%m-%d').date()]
    fig = plot_profile(temp_df, profile, 'hour of the day', f"Profile {profile}\n{temp_df['date'].unique()[0]}")
    return fig

In [None]:
plot_profile(loadProfiles_df, profile)

In [None]:
plot_subsequence(loadProfiles_df, profile, '2021-06-03')

##### Create and train model (can be skipped if model already exists)

In [None]:
config = DGANConfig(
    max_sequence_len = 24,
    sample_len = 1,
    feature_noise_dim = 32,
    feature_num_layers = 1,
    feature_num_units = 100,
    apply_feature_scaling = True,
    apply_example_scaling = False,
    generator_learning_rate = 2*1e-5,
    discriminator_learning_rate = 2*1e-5,
    attribute_discriminator_learning_rate = 2*1e-5,
    batch_size = 100,
    epochs = 10000,
    attribute_loss_coef = 10
)

In [None]:
model = DGAN(config)

In [None]:
dateIndex_dict = {item: idx for idx, item in enumerate(loadProfiles_df['date'].unique())}
loadProfiles_df['date index'] = loadProfiles_df['date'].map(dateIndex_dict)

In [None]:
profile = '16'
attributes = ['month of the year', 'day off']

model.train_dataframe(
   loadProfiles_df,
   attribute_columns = attributes,
   feature_columns = [profile],
   example_id_column = 'date index',
   time_column = 'timestamp',
   discrete_columns = attributes,
   df_style = 'long'
)

In [None]:
model.save('models/model.DGAN')

##### Import existing model

In [None]:
profile = '16'
attributes = ['month of the year', 'day off']

model = DGAN.load('models/model.DGAN')

##### Create and visualize synthetic data

In [None]:
syntheticProfiles_df = model.generate_dataframe(500)

In [None]:
from math import ceil
from matplotlib.lines import Line2D

ncols = 4
plotCount = len(syntheticProfiles_df.groupby(attributes).count())
nrows = ceil(plotCount/ncols)
fig, axes = plt.subplots(nrows = nrows, ncols = ncols, figsize = (5*ncols, 4*nrows), facecolor = 'w')
axes_list = axes.reshape(-1)
axesCount = 0
for month in syntheticProfiles_df['month of the year'].unique():
    for day_off in syntheticProfiles_df['day off'].unique():
        tempSynth_df = syntheticProfiles_df.query("`month of the year` == @month & `day off` == @day_off").copy()
        if len(tempSynth_df) > 0:
        
            tempSynth_df['hour of the day'] = tempSynth_df['timestamp'].dt.hour + 1
            tempSynth_df = tempSynth_df.pivot_table(values = profile, index = 'date index', columns = 'hour of the day')
            tempReal_df = loadProfiles_df.query("`month of the year` == @month & `day off` == @day_off").copy()
            tempReal_df = tempReal_df.pivot_table(values = profile, index = 'date', columns = 'hour of the day')
            title = f'month: {month} | day off: {day_off} | count: {len(tempSynth_df)}'
            tempSynthPlot = tempSynth_df.T.plot(color = 'red', title = title, alpha = 0.5, legend = False, ax = axes_list[axesCount])
            tempRealPlot = tempReal_df.T.plot(legend = False,  color = 'grey', alpha = 0.5, ax = axes_list[axesCount])
            tempRealPlot.set(xlabel = None, ylabel = None)
            axesCount += 1
        else:
            print(f'Missing: month: {month} | day off: {day_off}')
plt.tight_layout()
for idx in range(axesCount, nrows*ncols):
    axes_list[idx].axis('off')
fig.text(0.5, -0.01, 'Hour of the day', ha = 'center', fontsize = 18)
fig.text(-0.01, 0.5, 'Consumed energy [Wh]', va = 'center', rotation = 'vertical', fontsize = 18)
legendElements = [
    Line2D([0], [0], marker = 'o', color = 'w', label = 'Synthetic', markerfacecolor = 'red', markersize = 15),
    Line2D([0], [0], marker = 'o', color = 'w', label = 'Real', markerfacecolor = 'grey', markersize = 15)
]
axes_list[0].legend(handles = legendElements, loc = 2);

In [None]:
fig.savefig(f'results/results_profile_{profile}.png', bbox_inches = 'tight')