In [None]:
import numpy as np
import pandas as pd
import pathlib
import random
import time
import math
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import init
from dateutil import parser
from pathlib import Path
import json 
import shutil
import logging
logger = logging.getLogger(str(os.getpid()))

from wattile.data_reading import read_dataset_from_file
from wattile.buildings_processing import correct_predictor_columns, correct_timestamps, resample_or_rolling_stats, timelag_predictors, timelag_predictors_target, roll_predictors_target, input_data_split
from wattile.time_processing import add_processed_time_columns
from wattile.models import ModelFactory
from wattile.entry_point import init_logging, create_input_dataframe, run_model
PROJECT_DIRECTORY = Path().resolve().parent.parent

# reading configs

In [None]:
"""
For this example, we will be using the default configs.
Check out the docs for an explaination of each config.
"""
##################################################################################
# choose the configs file to use as an input
##################################################################################
# main configs file
with open(PROJECT_DIRECTORY / "wattile" / "configs" / "configs.json", "r") as f:
    configs = json.load(f)
##################################################################################
# code testing configs file
# with open(PROJECT_DIRECTORY / "tests" / "fixtures" / "test_configs.json", "r") as f:
#     configs = json.load(f)
##################################################################################

exp_dir = PROJECT_DIRECTORY / "notebooks" / "exp_dir"
if exp_dir.exists():
    shutil.rmtree(exp_dir)
exp_dir.mkdir()

configs["data_input"]["exp_dir"] = str(PROJECT_DIRECTORY / exp_dir)
configs["data_input"]["data_dir"] = str(PROJECT_DIRECTORY / "tests" / "data" / "Synthetic Site")

configs

In [None]:
configs["data_processing"]["feat_stats"]["active"] = True
configs["data_processing"]["feat_stats"]["window_width"] = "5min"
configs["data_processing"]["resample"]["bin_interval"] = "1min"
configs["data_processing"]["resample"]["bin_closed"] = "right"
configs["data_processing"]["resample"]["bin_label"] = "right"

### read data

In [None]:
filepath = str(PROJECT_DIRECTORY / "tests/fixtures/rolling_stats_input_w_target.csv")

data = pd.read_csv(filepath, index_col=0)
data.index = pd.to_datetime(data.index)

data_raw = data.copy()
data_raw = data_raw.add_suffix("_raw")

data_raw

### process data

In [None]:
# # assert we have the correct columns and order them
# data = correct_predictor_columns(configs, data)

# # sort and trim data specified time period
# data = correct_timestamps(configs, data)

# # Add time-based features
# data = add_processed_time_columns(data, configs)

# Add statistics features
data = resample_or_rolling_stats(data, configs)

data

### plot data

In [None]:
import plotly.graph_objects as go
from plotly.colors import n_colors
from plotly.validators.scatter.marker import SymbolValidator

In [None]:
##########################################################################################
list_color = n_colors("rgb(237,198,17)", "rgb(105,122,130)", data_raw.shape[1], colortype="rgb")

dict_color = {}
for count, label in enumerate(data_raw.columns):
    label_revised = label.split("_raw")[0]
    dict_color[label_revised] = list_color[count]

##########################################################################################
raw_symbols = SymbolValidator().values
symbols = []
for i in range(0,len(raw_symbols),3):
    name = raw_symbols[i+2]
    symbols.append(name)

##########################################################################################
fig = go.Figure()

##########################################################################################
i_col = 0
dict_symbol = {}
for col in data_raw.columns:
    
    fig.add_trace(go.Scatter(
        x=data_raw.index,
        y=data_raw[col],
        name=col,
        mode="markers",
        marker=dict(
            symbol=symbols[i_col*4+2],
            size=12,
            line=dict(
                width=2,
            ),
            color=list_color[i_col],
        ),
    ))
    
    label_revised = col.split("_raw")[0]
    dict_symbol[label_revised] = symbols[i_col*4+2]
    
    i_col+=1

##########################################################################################
for col in data.columns:
    
    for label in dict_color.keys():
        if label in col:
            color_line = dict_color[label]
                        
    for label in dict_symbol.keys():
        if label in col:
            symbol_line = dict_symbol[label]
            break
        else:
            symbol_line = "circle"
    
    fig.add_trace(go.Scatter(
        x=data.index,
        y=data[col],
        name=col,
        mode="markers+lines",
        line=dict(
            width=1,
            color=color_line,
        ),
        marker=dict(
            symbol=symbol_line,
            size=8,
#             line=dict(
#                 width=2,
#             ),
        ),
    ))

##########################################################################################
fig.update_layout(
    width=800,
    height=500,
    margin=dict(
        l=0,
        r=0,
        t=0,
        b=150,
    ),
    legend=dict(
        orientation="h",
        yanchor="top",
        y=-0.15,
        xanchor="center",
        x=0.5,
        font=dict(
            size=10,
            color="black",
        ),
    )
)

##########################################################################################
fig.update_xaxes(
    dtick=1000*60,
    showgrid=True,
    gridwidth=2, 
)

##########################################################################################
fig.show()