In [None]:
import pandas as pd

ny_path = "/home/fuest/EnData/data/pecanstreet/15minute_data_newyork.csv"
austin_path = "/home/fuest/EnData/data/pecanstreet/15minute_data_austin.csv"
cali_path = "/home/fuest/EnData/data/pecanstreet/15minute_data_california.csv"

ny_data = pd.read_csv(ny_path)
austin_data = pd.read_csv(austin_path)
cali_data = pd.read_csv(cali_path)

ny_user_ids = ny_data.dataid.unique()
austin_user_ids = austin_data.dataid.unique()
cali_user_ids = cali_data.dataid.unique()
austin_user_ids

In [None]:
a = train_dataset.dataset.data
a[(a.month==7) & (a.weekday==0)]

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import itertools
from data_utils.dataset import PecanStreetDataset

def plot_grid_profile(df, month, weekday):
    # Filter the DataFrame based on the specified month and weekday
    filtered_df = df[(df['month'] == month) & (df['weekday'] == weekday)]
    
    # Check if there are any rows after filtering
    if filtered_df.empty:
        print(f"No data available for month {month} and weekday {weekday}.")
        return

    # Convert the 'grid' column to a list of arrays
    grid_values = filtered_df['grid'].apply(np.array).values

    # Calculate the average grid values for each 15-minute interval across all filtered rows
    averaged_grid = np.mean(np.vstack(grid_values), axis=0)

    # Generate 96 timestamps for a day at 15-minute intervals
    timestamps = pd.date_range(start='00:00', end='23:45', freq='15T').strftime('%H:%M')

    # Ensure that the lengths match
    if len(averaged_grid) != len(timestamps):
        raise ValueError(f"Length of averaged_grid ({len(averaged_grid)}) does not match length of timestamps ({len(timestamps)}).")

    # Plot the averaged grid profile
    plt.figure(figsize=(12, 6))
    plt.plot(timestamps, averaged_grid, marker='o')
    plt.title(f'Grid Profile for Month {month} and Weekday {weekday}')
    plt.xlabel('Time of Day')
    plt.ylabel('Grid Values')
    plt.xticks(rotation=45)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

df = PecanStreetDataset(normalize=False).data
plot_grid_profile(df, month=5, weekday=3)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from data_utils.dataset import PecanStreetDataset

def plot_grid_profile(df, month, weekday, dataid=None, ax=None):
    filtered_df = df[(df['month'] == month) & (df['weekday'] == weekday)]
    if dataid is not None:
        filtered_df = filtered_df[filtered_df['dataid'] == dataid]
    
    if filtered_df.empty:
        print(f"No data available for month {month} and weekday {weekday} with dataid {dataid}.")
        return

    grid_values = filtered_df['grid'].apply(np.array).values
    averaged_grid = np.mean(np.vstack(grid_values), axis=0)
    timestamps = pd.date_range(start='00:00', end='23:45', freq='15T').strftime('%H:%M')

    if len(averaged_grid) != len(timestamps):
        raise ValueError(f"Length of averaged_grid ({len(averaged_grid)}) does not match length of timestamps ({len(timestamps)}).")
    
    if ax is None:
        plt.figure(figsize=(12, 6))
        plt.plot(timestamps, averaged_grid, marker='o')
        title = f'Grid Profile for Month {month} and Weekday {weekday}'
        if dataid is not None:
            title += f' (DataID: {dataid})'
        plt.title(title)
        plt.xlabel('Time of Day')
        plt.ylabel('Grid Values')
        plt.xticks(rotation=45)
        plt.grid(True)
        plt.tight_layout()
        plt.show()
    else:
        ax.plot(timestamps, averaged_grid, marker='o', label=f'DataID: {dataid}')

def plot_all_users_grid_profile(df, month, weekday):
    unique_dataids = df['dataid'].unique()
    fig, ax = plt.subplots(figsize=(12, 6))
    for dataid in unique_dataids:
        plot_grid_profile(df, month, weekday, dataid, ax=ax)
    
    title = f'Grid Profile for Month {month} and Weekday {weekday} for All Users'
    plt.title(title)
    plt.xlabel('Time of Day')
    plt.ylabel('Grid Values')
    plt.xticks(rotation=45)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

df = PecanStreetDataset(normalize=True).data
plot_all_users_grid_profile(df, month=11, weekday=3)

In [None]:
def plot_daily_usage_profiles(df, dataid, month, weekday):
    # Filter the DataFrame based on the specified user id, month, and weekday
    filtered_df = df[(df['dataid'] == dataid) & (df['month'] == month) & (df['weekday'] == weekday)]
    
    # Check if there are any rows after filtering
    if filtered_df.empty:
        print(f"No data available for user {dataid}, month {month}, and weekday {weekday}.")
        return

    grid_values = filtered_df['grid'].apply(np.array).values
    timestamps = pd.date_range(start='00:00', end='23:45', freq='15T').strftime('%H:%M')

    plt.figure(figsize=(12, 6))
    for i, daily_grid in enumerate(grid_values):
        plt.plot(timestamps, daily_grid, marker='o', label=f'Day {i+1}')

    title = f'Daily Usage Profiles for User {dataid}, Month {month}, Weekday {weekday}'
    plt.title(title)
    plt.xlabel('Time of Day')
    plt.ylabel('Grid Values')
    plt.xticks(rotation=45)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

df = PecanStreetDataset(normalize=False).data
plot_daily_usage_profiles(df, dataid=3687, month=5, weekday=0)

In [None]:
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt

from data_utils.dataset import PecanStreetDataset, prepare_dataloader, split_dataset
from generator.acgan import ACGAN  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data = PecanStreetDataset(normalize=True, user_id=27)
train_dataset, val_dataset = split_dataset(data)
model = ACGAN(
    input_dim=1,
    noise_dim=512,
    embedding_dim=512,
    output_dim=96,
    learning_rate=1e-4,
    weight_path="runs/",
)
model.train(train_dataset, val_dataset, batch_size=32, num_epoch=100)

def generate_and_plot_series(model, day_labels, month_labels, data, month, weekday):
    series1 = model.generate([day_labels, month_labels]).squeeze()
    series2 = model.generate([day_labels, month_labels]).squeeze()
    series3 = model.generate([day_labels, month_labels]).squeeze()

    timestamps = pd.date_range(start='00:00', periods=96, freq='15T').strftime('%H:%M')

    filtered_data = data[(data['month'] == month) & (data['weekday'] == weekday)]
    real_profiles = filtered_data.sample(3)['grid'].values
    real_profile1 = np.array(real_profiles[0])
    real_profile2 = np.array(real_profiles[1])
    real_profile3 = np.array(real_profiles[2])

    # Plot all series on the same plot
    plt.figure(figsize=(15, 6))
    plt.plot(timestamps, series1, label='Generated Profile 1')
    plt.plot(timestamps, series2, label='Generated Profile 2')
    plt.plot(timestamps, series3, label='Generated Profile 3')
    plt.plot(timestamps, real_profile1, label='Real Profile 1', linestyle='--')
    plt.plot(timestamps, real_profile2, label='Real Profile 2', linestyle='--')
    plt.plot(timestamps, real_profile3, label='Real Profile 3', linestyle='--')
    plt.xlabel('Time')
    plt.ylabel('Value')
    plt.title('Generated and Real Time Series Profiles')
    plt.xticks(rotation=45)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

day_labels = torch.tensor([6]).to(device)
month_labels = torch.tensor([5]).to(device)

generate_and_plot_series(model, day_labels, month_labels, data.data, 5, 6)

In [10]:
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt

from generator.acgan import ACGAN  
from data_utils.dataset import PecanStreetDataset, split_dataset

def plot_range_with_values(df, colname, values_to_compare, month, weekday):
    values_to_compare = values_to_compare.cpu().numpy()
    filtered_df = df[(df['month'] == month) & (df['weekday'] == weekday)]
    array_data = np.array(filtered_df[colname].to_list())
    min_values = np.min(array_data, axis=0)
    max_values = np.max(array_data, axis=0)
    timestamps = pd.date_range(start='00:00', end='23:45', freq='15T').strftime('%H:%M')
    
    plt.figure(figsize=(15, 7))
    plt.fill_between(timestamps, min_values, max_values, color='gray', alpha=0.5, label='Range of values')
    plt.plot(timestamps, values_to_compare, color='blue', marker='o', label='Values to Compare')
    
    plt.title('Range of Values and Comparison')
    plt.xlabel('Time of Day')
    plt.ylabel('Values')
    plt.xticks(rotation=45)
    plt.legend()
    plt.tight_layout()
    plt.show()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data = PecanStreetDataset(normalize=True, user_id=1642, include_generation=False)
dataset = data.data

train_dataset, val_dataset = split_dataset(data)
model = ACGAN(
    input_dim=1,
    noise_dim=512,
    embedding_dim=512,
    window_length=96,
    learning_rate=1e-4,
    weight_path="runs/",
)
model.train(train_dataset, val_dataset, batch_size=32, num_epoch=100)

Epoch 1: 100%|██████████| 9/9 [00:00<00:00, 15.39it/s]


Epoch [1/100], Mean MMD Loss: [0.01061761]


Epoch 2: 100%|██████████| 9/9 [00:00<00:00, 24.66it/s]


Epoch [2/100], Mean MMD Loss: [0.00929548]


Epoch 3: 100%|██████████| 9/9 [00:00<00:00, 22.59it/s]


Epoch [3/100], Mean MMD Loss: [0.00981539]


Epoch 4: 100%|██████████| 9/9 [00:00<00:00, 23.34it/s]


Epoch [4/100], Mean MMD Loss: [0.00733485]


Epoch 5: 100%|██████████| 9/9 [00:00<00:00, 30.98it/s]


Epoch [5/100], Mean MMD Loss: [0.00623002]


Epoch 6: 100%|██████████| 9/9 [00:00<00:00, 29.67it/s]


Epoch [6/100], Mean MMD Loss: [0.00410437]


Epoch 7: 100%|██████████| 9/9 [00:00<00:00, 26.08it/s]


Epoch [7/100], Mean MMD Loss: [0.00435637]


Epoch 8: 100%|██████████| 9/9 [00:00<00:00, 24.05it/s]


Epoch [8/100], Mean MMD Loss: [0.00383525]


Epoch 9: 100%|██████████| 9/9 [00:00<00:00, 25.41it/s]


Epoch [9/100], Mean MMD Loss: [0.00331037]


Epoch 10: 100%|██████████| 9/9 [00:00<00:00, 24.43it/s]


Epoch [10/100], Mean MMD Loss: [0.00375408]


Epoch 11: 100%|██████████| 9/9 [00:00<00:00, 24.13it/s]


Epoch [11/100], Mean MMD Loss: [0.00424258]


Epoch 12: 100%|██████████| 9/9 [00:00<00:00, 25.11it/s]


Epoch [12/100], Mean MMD Loss: [0.00517236]


Epoch 13: 100%|██████████| 9/9 [00:00<00:00, 22.28it/s]


Epoch [13/100], Mean MMD Loss: [0.00586487]


Epoch 14: 100%|██████████| 9/9 [00:00<00:00, 26.74it/s]


Epoch [14/100], Mean MMD Loss: [0.00451508]


Epoch 15: 100%|██████████| 9/9 [00:00<00:00, 22.66it/s]


Epoch [15/100], Mean MMD Loss: [0.00511303]


Epoch 16: 100%|██████████| 9/9 [00:00<00:00, 25.56it/s]


Epoch [16/100], Mean MMD Loss: [0.00498966]


Epoch 17: 100%|██████████| 9/9 [00:00<00:00, 22.45it/s]


Epoch [17/100], Mean MMD Loss: [0.00547681]


Epoch 18: 100%|██████████| 9/9 [00:00<00:00, 23.10it/s]


Epoch [18/100], Mean MMD Loss: [0.00538026]


Epoch 19: 100%|██████████| 9/9 [00:00<00:00, 24.01it/s]


Epoch [19/100], Mean MMD Loss: [0.0042599]


Epoch 20: 100%|██████████| 9/9 [00:00<00:00, 25.10it/s]


Epoch [20/100], Mean MMD Loss: [0.00451676]


Epoch 21: 100%|██████████| 9/9 [00:00<00:00, 26.12it/s]


Epoch [21/100], Mean MMD Loss: [0.00568049]


Epoch 22: 100%|██████████| 9/9 [00:00<00:00, 26.09it/s]


Epoch [22/100], Mean MMD Loss: [0.00544778]


Epoch 23: 100%|██████████| 9/9 [00:00<00:00, 22.92it/s]


Epoch [23/100], Mean MMD Loss: [0.00559735]


Epoch 24: 100%|██████████| 9/9 [00:00<00:00, 21.64it/s]


Epoch [24/100], Mean MMD Loss: [0.00610596]


Epoch 25: 100%|██████████| 9/9 [00:00<00:00, 26.99it/s]


Epoch [25/100], Mean MMD Loss: [0.00776769]


Epoch 26: 100%|██████████| 9/9 [00:00<00:00, 21.22it/s]


Epoch [26/100], Mean MMD Loss: [0.0065436]


Epoch 27: 100%|██████████| 9/9 [00:00<00:00, 23.61it/s]


Epoch [27/100], Mean MMD Loss: [0.0073522]


Epoch 28: 100%|██████████| 9/9 [00:00<00:00, 24.20it/s]


Epoch [28/100], Mean MMD Loss: [0.00704362]


Epoch 29: 100%|██████████| 9/9 [00:00<00:00, 27.30it/s]


Epoch [29/100], Mean MMD Loss: [0.00694742]


Epoch 30: 100%|██████████| 9/9 [00:00<00:00, 23.22it/s]


Epoch [30/100], Mean MMD Loss: [0.00905716]


Epoch 31: 100%|██████████| 9/9 [00:00<00:00, 27.93it/s]


Epoch [31/100], Mean MMD Loss: [0.00777249]


Epoch 32: 100%|██████████| 9/9 [00:00<00:00, 26.12it/s]


Epoch [32/100], Mean MMD Loss: [0.00762611]


Epoch 33: 100%|██████████| 9/9 [00:00<00:00, 24.40it/s]


Epoch [33/100], Mean MMD Loss: [0.00865978]


Epoch 34: 100%|██████████| 9/9 [00:00<00:00, 26.04it/s]


Epoch [34/100], Mean MMD Loss: [0.00839454]


Epoch 35: 100%|██████████| 9/9 [00:00<00:00, 25.28it/s]


Epoch [35/100], Mean MMD Loss: [0.01007142]


Epoch 36: 100%|██████████| 9/9 [00:00<00:00, 27.04it/s]


Epoch [36/100], Mean MMD Loss: [0.00600806]


Epoch 37: 100%|██████████| 9/9 [00:00<00:00, 27.93it/s]


Epoch [37/100], Mean MMD Loss: [0.00618953]


Epoch 38: 100%|██████████| 9/9 [00:00<00:00, 23.10it/s]


Epoch [38/100], Mean MMD Loss: [0.00629919]


Epoch 39: 100%|██████████| 9/9 [00:00<00:00, 23.17it/s]


Epoch [39/100], Mean MMD Loss: [0.00536442]


Epoch 40: 100%|██████████| 9/9 [00:00<00:00, 27.84it/s]


Epoch [40/100], Mean MMD Loss: [0.00603999]


Epoch 41: 100%|██████████| 9/9 [00:00<00:00, 22.99it/s]


Epoch [41/100], Mean MMD Loss: [0.00658552]


Epoch 42: 100%|██████████| 9/9 [00:00<00:00, 26.94it/s]


Epoch [42/100], Mean MMD Loss: [0.00672677]


Epoch 43: 100%|██████████| 9/9 [00:00<00:00, 25.61it/s]


Epoch [43/100], Mean MMD Loss: [0.00566329]


Epoch 44: 100%|██████████| 9/9 [00:00<00:00, 24.10it/s]


Epoch [44/100], Mean MMD Loss: [0.00540354]


Epoch 45: 100%|██████████| 9/9 [00:00<00:00, 21.97it/s]


Epoch [45/100], Mean MMD Loss: [0.0044028]


Epoch 46: 100%|██████████| 9/9 [00:00<00:00, 26.63it/s]


Epoch [46/100], Mean MMD Loss: [0.00534307]


Epoch 47: 100%|██████████| 9/9 [00:00<00:00, 24.13it/s]


Epoch [47/100], Mean MMD Loss: [0.0051079]


Epoch 48: 100%|██████████| 9/9 [00:00<00:00, 26.17it/s]


Epoch [48/100], Mean MMD Loss: [0.00489045]


Epoch 49: 100%|██████████| 9/9 [00:00<00:00, 26.68it/s]


Epoch [49/100], Mean MMD Loss: [0.00531018]


Epoch 50: 100%|██████████| 9/9 [00:00<00:00, 26.46it/s]


Epoch [50/100], Mean MMD Loss: [0.00489121]


Epoch 51: 100%|██████████| 9/9 [00:00<00:00, 27.56it/s]


Epoch [51/100], Mean MMD Loss: [0.00532022]


Epoch 52: 100%|██████████| 9/9 [00:00<00:00, 23.47it/s]


Epoch [52/100], Mean MMD Loss: [0.00687033]


Epoch 53: 100%|██████████| 9/9 [00:00<00:00, 28.12it/s]


Epoch [53/100], Mean MMD Loss: [0.00526939]


Epoch 54: 100%|██████████| 9/9 [00:00<00:00, 27.86it/s]


Epoch [54/100], Mean MMD Loss: [0.00495866]


Epoch 55: 100%|██████████| 9/9 [00:00<00:00, 28.63it/s]


Epoch [55/100], Mean MMD Loss: [0.00883999]


Epoch 56: 100%|██████████| 9/9 [00:00<00:00, 22.60it/s]


Epoch [56/100], Mean MMD Loss: [0.00707196]


Epoch 57: 100%|██████████| 9/9 [00:00<00:00, 25.22it/s]


Epoch [57/100], Mean MMD Loss: [0.00714725]


Epoch 58: 100%|██████████| 9/9 [00:00<00:00, 26.67it/s]


Epoch [58/100], Mean MMD Loss: [0.00592751]


Epoch 59: 100%|██████████| 9/9 [00:00<00:00, 28.12it/s]


Epoch [59/100], Mean MMD Loss: [0.00719201]


Epoch 60: 100%|██████████| 9/9 [00:00<00:00, 28.82it/s]


Epoch [60/100], Mean MMD Loss: [0.00741924]


Epoch 61: 100%|██████████| 9/9 [00:00<00:00, 29.16it/s]


Epoch [61/100], Mean MMD Loss: [0.00620422]


Epoch 62: 100%|██████████| 9/9 [00:00<00:00, 27.07it/s]


Epoch [62/100], Mean MMD Loss: [0.0067472]


Epoch 63: 100%|██████████| 9/9 [00:00<00:00, 28.16it/s]


Epoch [63/100], Mean MMD Loss: [0.00663745]


Epoch 64: 100%|██████████| 9/9 [00:00<00:00, 29.33it/s]


Epoch [64/100], Mean MMD Loss: [0.00874144]


Epoch 65: 100%|██████████| 9/9 [00:00<00:00, 22.63it/s]


Epoch [65/100], Mean MMD Loss: [0.00645038]


Epoch 66: 100%|██████████| 9/9 [00:00<00:00, 24.16it/s]


Epoch [66/100], Mean MMD Loss: [0.00591833]


Epoch 67: 100%|██████████| 9/9 [00:00<00:00, 27.31it/s]


Epoch [67/100], Mean MMD Loss: [0.00626528]


Epoch 68: 100%|██████████| 9/9 [00:00<00:00, 25.77it/s]


Epoch [68/100], Mean MMD Loss: [0.00641864]


Epoch 69: 100%|██████████| 9/9 [00:00<00:00, 24.92it/s]


Epoch [69/100], Mean MMD Loss: [0.00662393]


Epoch 70: 100%|██████████| 9/9 [00:00<00:00, 22.40it/s]


Epoch [70/100], Mean MMD Loss: [0.00566709]


Epoch 71: 100%|██████████| 9/9 [00:00<00:00, 24.74it/s]


Epoch [71/100], Mean MMD Loss: [0.00732072]


Epoch 72: 100%|██████████| 9/9 [00:00<00:00, 21.64it/s]


Epoch [72/100], Mean MMD Loss: [0.0068618]


Epoch 73: 100%|██████████| 9/9 [00:00<00:00, 27.05it/s]


Epoch [73/100], Mean MMD Loss: [0.00660202]


Epoch 74: 100%|██████████| 9/9 [00:00<00:00, 26.44it/s]


Epoch [74/100], Mean MMD Loss: [0.00675334]


Epoch 75: 100%|██████████| 9/9 [00:00<00:00, 23.26it/s]


Epoch [75/100], Mean MMD Loss: [0.00626514]


Epoch 76: 100%|██████████| 9/9 [00:00<00:00, 24.93it/s]


Epoch [76/100], Mean MMD Loss: [0.00989704]


Epoch 77: 100%|██████████| 9/9 [00:00<00:00, 24.60it/s]


Epoch [77/100], Mean MMD Loss: [0.00581392]


Epoch 78: 100%|██████████| 9/9 [00:00<00:00, 26.59it/s]


Epoch [78/100], Mean MMD Loss: [0.0071997]


Epoch 79: 100%|██████████| 9/9 [00:00<00:00, 27.01it/s]


Epoch [79/100], Mean MMD Loss: [0.00747968]


Epoch 80: 100%|██████████| 9/9 [00:00<00:00, 28.82it/s]


Epoch [80/100], Mean MMD Loss: [0.00572364]


Epoch 81: 100%|██████████| 9/9 [00:00<00:00, 26.07it/s]


Epoch [81/100], Mean MMD Loss: [0.00720813]


Epoch 82: 100%|██████████| 9/9 [00:00<00:00, 26.81it/s]


Epoch [82/100], Mean MMD Loss: [0.0066032]


Epoch 83: 100%|██████████| 9/9 [00:00<00:00, 24.73it/s]


Epoch [83/100], Mean MMD Loss: [0.00619401]


Epoch 84: 100%|██████████| 9/9 [00:00<00:00, 23.96it/s]


Epoch [84/100], Mean MMD Loss: [0.00573358]


Epoch 85: 100%|██████████| 9/9 [00:00<00:00, 25.83it/s]


Epoch [85/100], Mean MMD Loss: [0.0070096]


Epoch 86: 100%|██████████| 9/9 [00:00<00:00, 26.77it/s]


Epoch [86/100], Mean MMD Loss: [0.0057658]


Epoch 87: 100%|██████████| 9/9 [00:00<00:00, 27.20it/s]


Epoch [87/100], Mean MMD Loss: [0.0066122]


Epoch 88: 100%|██████████| 9/9 [00:00<00:00, 26.19it/s]


Epoch [88/100], Mean MMD Loss: [0.00782433]


Epoch 89: 100%|██████████| 9/9 [00:00<00:00, 26.15it/s]


Epoch [89/100], Mean MMD Loss: [0.00606645]


Epoch 90: 100%|██████████| 9/9 [00:00<00:00, 26.96it/s]


Epoch [90/100], Mean MMD Loss: [0.00756442]


Epoch 91: 100%|██████████| 9/9 [00:00<00:00, 26.45it/s]


Epoch [91/100], Mean MMD Loss: [0.00694006]


Epoch 92: 100%|██████████| 9/9 [00:00<00:00, 29.44it/s]


Epoch [92/100], Mean MMD Loss: [0.00651608]


Epoch 93: 100%|██████████| 9/9 [00:00<00:00, 23.28it/s]


Epoch [93/100], Mean MMD Loss: [0.00575438]


Epoch 94: 100%|██████████| 9/9 [00:00<00:00, 24.88it/s]


Epoch [94/100], Mean MMD Loss: [0.00574916]


Epoch 95: 100%|██████████| 9/9 [00:00<00:00, 26.17it/s]


Epoch [95/100], Mean MMD Loss: [0.00678719]


Epoch 96: 100%|██████████| 9/9 [00:00<00:00, 25.43it/s]


Epoch [96/100], Mean MMD Loss: [0.00607477]


Epoch 97: 100%|██████████| 9/9 [00:00<00:00, 26.84it/s]


Epoch [97/100], Mean MMD Loss: [0.00669148]


Epoch 98: 100%|██████████| 9/9 [00:00<00:00, 25.27it/s]


Epoch [98/100], Mean MMD Loss: [0.00744984]


Epoch 99: 100%|██████████| 9/9 [00:00<00:00, 25.43it/s]


Epoch [99/100], Mean MMD Loss: [0.00731015]


Epoch 100: 100%|██████████| 9/9 [00:00<00:00, 24.16it/s]

Epoch [100/100], Mean MMD Loss: [0.00510622]





In [56]:
import numpy as np

july_data = dataset[dataset.month == 7].copy()
real_ts = np.concatenate(july_data["grid"].to_list())

In [61]:
import torch

weekdays = [2, 3, 4, 5, 6, 0, 1]
weeks = np.arange(0, 5)
month_label = torch.tensor([6]).to(device)

syn_ts = []

for week in weeks:
    for weekday in weekdays:
        day_label = torch.tensor([weekday]).to(device)
        gen_ts = model.generate([day_label, month_label]).squeeze().cpu().numpy()
        syn_ts.append(gen_ts)

day_29 = model.generate([torch.tensor([2]).to(device), torch.tensor([6]).to(device)]).squeeze().cpu().numpy()
day_30 = model.generate([torch.tensor([3]).to(device), torch.tensor([6]).to(device)]).squeeze().cpu().numpy()
day_31 = model.generate([torch.tensor([4]).to(device), torch.tensor([6]).to(device)]).squeeze().cpu().numpy()
syn_ts += day_29
syn_ts += day_30
syn_ts += day_31
syn_ts = np.concatenate(syn_ts)
syn_ts

array([1.5093107, 1.9871105, 1.9769833, ..., 1.9281092, 2.0487525,
       2.7130651], dtype=float32)

In [62]:
len(syn_ts)

3360

In [63]:
len(real_ts)

2976