In [1]:
import pandas as pd
import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import tqdm
import matplotlib.pyplot as plt

In [2]:
# Merge individual customer CSV files into a single CSV
dfs = []
# Read all CSV files in the specified directory
for folder in ["2010-2011", "2011-2012", "2012-2013"]:
    for file in glob.glob(f"./processed_data/{folder}/*.csv"):
        df = pd.read_csv(file)
        dfs.append(df)

df_agg = pd.concat(dfs, ignore_index=True)
df_agg = df_agg.sort_values(by=["Customer", "datetime"])
df_agg.to_csv("./model_input_data/2010-2013_combined.csv", index=False)

# Keep relevant columns
df_agg = df_agg[["Customer", "datetime", "GG", "NL"]]

df_agg.to_csv("./model_input_data/2010-2013_processed.csv", index=False)

In [3]:
df = df_agg.copy()

In [4]:
# Split data into 4 communities of 75 customers
communities = {
    "community_1": range(1, 76),
    "community_2": range(76, 151),
    "community_3": range(151, 226),
    "community_4": range(226, 301),
}

for community, cid in communities.items():
    # Select customers in the current community
    print(f"Processing {community} with customer IDs: {cid}")

    df_community = df[df["Customer"].isin(cid)]
    df_community.to_csv(f"./model_input_data/{community}.csv", index=False)

    df_community = df_community.groupby("datetime")[["Customer", "GG", "NL"]].sum().reset_index()
    df_community.to_csv(f"./model_input_data/{community}_agg.csv", index=False)
    df_community = df_community.drop(columns=["Customer"])

    df_community["datetime"] = pd.to_datetime(df_community["datetime"])  # Ensure datetime is in correct format

    # Add feature columns
    df_community["NL_t-24"] = df_community["NL"].shift(48)  # Shift by 48 half-hours (1 day)
    df_community["NL_t-24*7"] = df_community["NL"].shift(48*7)  # Shift by 336 half-hours (7 days)
    df_community["NL_t-24*30"] = df_community["NL"].shift(48*30)  # Shift by 1440 half-hours (30 days)

    df_community["Month"] = df_community["datetime"].dt.month  # Extract month from datetime
    df_community["Season"] = (df_community["Month"] % 12) // 3 + 1  # Convert month to season (1-4)
    df_community["is_weekend"] = (df_community["datetime"].dt.weekday >= 5).astype(int)  # 1 if Saturday or Sunday

    # Restrict time from 5am to 8pm
    df_community = df_community[(df_community["datetime"].dt.hour >= 5) & (df_community["datetime"].dt.hour <= 20)]

    # Backfill missing values
    df_community = df_community.bfill()

    categorical_columns = ["Month", "Season"]

    # Apply one-hot encoding for categorical columns
    encoder = OneHotEncoder(sparse_output=False)
    encoded_array = encoder.fit_transform(df_community[categorical_columns])
    df_encoded = pd.DataFrame(encoded_array, columns=encoder.get_feature_names_out(categorical_columns))
    
    # Reset index before concat
    df_encoded = df_encoded.reset_index(drop=True)
    df_community = df_community.reset_index(drop=True)

    df_community = pd.concat([df_community, df_encoded], axis=1)
    df_community.drop(columns=categorical_columns, inplace=True)

    df_community.to_csv(f"./flmodel_data/{community}.csv", index=False)
    
    #df_train_community = df_train[df_train["Customer"].isin(cid)]
    #df_test_community = df_test[df_test["Customer"].isin(cid)]
    
    #df_train_community.to_csv(f"./community_data/{community}_train.csv", index=False)
    #df_test_community.to_csv(f"./community_data/{community}_test.csv", index=False)

Processing community_1 with customer IDs: range(1, 76)
Processing community_2 with customer IDs: range(76, 151)
Processing community_3 with customer IDs: range(151, 226)
Processing community_4 with customer IDs: range(226, 301)


In [5]:
for i in range(1, 13):
    season_number = i % 12 // 3 + 1
    season_name = "summer" if i in [12, 1, 2] else "autumn" if i in [3, 4, 5] else "winter" if i in [6, 7, 8] else "spring"
    print(f"{i} {season_name} {season_number}")

1 summer 1
2 summer 1
3 autumn 2
4 autumn 2
5 autumn 2
6 winter 3
7 winter 3
8 winter 3
9 spring 4
10 spring 4
11 spring 4
12 summer 1


In [6]:
df_verify = pd.read_csv("./flmodel_data/community_4.csv")
# Filter to show rows by Month
df_verify["datetime"] = pd.to_datetime(df_verify["datetime"])  # Ensure datetime is in correct format
df_test = df_verify[df_verify["datetime"].dt.month == 7]
df_test.head()

Unnamed: 0,datetime,GG,NL,NL_t-24,NL_t-24*7,NL_t-24*30,is_weekend,Month_1,Month_2,Month_3,...,Month_7,Month_8,Month_9,Month_10,Month_11,Month_12,Season_1,Season_2,Season_3,Season_4
0,2010-07-01 05:00:00,0.006,28.416,28.416,28.416,28.416,0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
1,2010-07-01 05:30:00,0.0,32.65,28.416,28.416,28.416,0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
2,2010-07-01 06:00:00,0.006,40.277,28.416,28.416,28.416,0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
3,2010-07-01 06:30:00,0.013,44.912,28.416,28.416,28.416,0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
4,2010-07-01 07:00:00,0.006,48.57,28.416,28.416,28.416,0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0


In [7]:
df_test = df_verify[df_verify["datetime"].dt.month == 10]
df_test.head()

Unnamed: 0,datetime,GG,NL,NL_t-24,NL_t-24*7,NL_t-24*30,is_weekend,Month_1,Month_2,Month_3,...,Month_7,Month_8,Month_9,Month_10,Month_11,Month_12,Season_1,Season_2,Season_3,Season_4
2944,2010-10-01 05:00:00,0.006,18.593,18.156,16.521,18.809,0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0
2945,2010-10-01 05:30:00,0.001,19.985,19.403,21.422,21.165,0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0
2946,2010-10-01 06:00:00,0.245,22.401,24.892,23.789,23.657,0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0
2947,2010-10-01 06:30:00,1.205,27.744,24.874,29.103,35.841,0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0
2948,2010-10-01 07:00:00,2.235,22.844,19.381,32.034,39.889,0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0
