In [32]:
import os
from dataclasses import dataclass
import pandas as pd
import numpy as np
from sqlalchemy import create_engine
from config import (
    countries,
    dap_bidding_zones,
    interconnections,
    interconnections_edge_matrix,
)
from tqdm import tqdm
from dotenv import load_dotenv

import math
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Data
from torch.nn import BatchNorm1d

In [33]:
load_dotenv()
engine = create_engine(os.getenv("SQLALCHEMY_DATABASE_URI"))

In [34]:
flow_df = pd.read_sql_table("flow_32", engine)
flow_df = flow_df.set_index("DateTime")
flow_df.fillna(0, inplace=True)

In [35]:
dap_df = pd.DataFrame()
for country_id in countries.keys():
    dap_df[country_id] = pd.read_sql_table(f"{country_id}_dap", engine).set_index(
        "DateTime"
    )
dap_df.index = pd.to_datetime(dap_df.index)
dap_df.ffill(inplace=True)
dap_df.fillna(0, inplace=True)

In [36]:
load_df = pd.DataFrame()
for country_id in countries.keys():
    load_df[country_id] = pd.read_sql_table(f"{country_id}_load", engine).set_index(
        "DateTime"
    )
load_df.ffill(inplace=True)
# Fille NaN with mean of the column
load_df.fillna(load_df.mean(), inplace=True)

In [37]:
biomass_df = pd.DataFrame()
fossil_brown_coal_df = pd.DataFrame()
fossil_coal_derived_gas_df = pd.DataFrame()
fossil_gas_df = pd.DataFrame()
fossil_hard_coal_df = pd.DataFrame()
fossil_oil_df = pd.DataFrame()
hydro_pumped_storage_df = pd.DataFrame()
hydro_run_of_river_and_poundage_df = pd.DataFrame()
hydro_water_reservoir_df = pd.DataFrame()
nuclear_df = pd.DataFrame()
other_df = pd.DataFrame()
other_renewable_df = pd.DataFrame()
solar_df = pd.DataFrame()
waste_df = pd.DataFrame()
wind_offshore_df = pd.DataFrame()
wind_onshore_df = pd.DataFrame()
geothermal_df = pd.DataFrame()
fossil_peat_df = pd.DataFrame()

gen_types = [
    "Biomass",
    "Fossil Brown coal/Lignite",
    "Fossil Coal-derived gas",
    "Fossil Gas",
    "Fossil Hard coal",
    "Fossil Oil",
    "Hydro Pumped Storage",
    "Hydro Run-of-river and poundage",
    "Hydro Water Reservoir",
    "Nuclear",
    "Other",
    "Other renewable",
    "Solar",
    "Waste",
    "Wind Offshore",
    "Wind Onshore",
    "Geothermal",
    "Fossil Peat",
]

for country_id in countries.keys():
    this_cty_gen_df = pd.read_sql_table(f"{country_id}_gen", engine).set_index(
        "DateTime"
    )
    biomass_df[country_id] = this_cty_gen_df["Biomass"]
    fossil_brown_coal_df[country_id] = this_cty_gen_df["Fossil Brown coal/Lignite"]
    fossil_coal_derived_gas_df[country_id] = this_cty_gen_df["Fossil Coal-derived gas"]
    fossil_gas_df[country_id] = this_cty_gen_df["Fossil Gas"]
    fossil_hard_coal_df[country_id] = this_cty_gen_df["Fossil Hard coal"]
    fossil_oil_df[country_id] = this_cty_gen_df["Fossil Oil"]
    hydro_pumped_storage_df[country_id] = this_cty_gen_df["Hydro Pumped Storage"]
    hydro_run_of_river_and_poundage_df[country_id] = this_cty_gen_df[
        "Hydro Run-of-river and poundage"
    ]
    hydro_water_reservoir_df[country_id] = this_cty_gen_df["Hydro Water Reservoir"]
    nuclear_df[country_id] = this_cty_gen_df["Nuclear"]
    other_df[country_id] = this_cty_gen_df["Other"]
    other_renewable_df[country_id] = this_cty_gen_df["Other renewable"]
    solar_df[country_id] = this_cty_gen_df["Solar"]
    waste_df[country_id] = this_cty_gen_df["Waste"]
    wind_offshore_df[country_id] = this_cty_gen_df["Wind Offshore"]
    wind_onshore_df[country_id] = this_cty_gen_df["Wind Onshore"]
    geothermal_df[country_id] = this_cty_gen_df["Geothermal"]
    fossil_peat_df[country_id] = this_cty_gen_df["Fossil Peat"]

biomass_df.fillna(0, inplace=True)
fossil_brown_coal_df.fillna(0, inplace=True)
fossil_coal_derived_gas_df.fillna(0, inplace=True)
fossil_gas_df.fillna(0, inplace=True)
fossil_hard_coal_df.fillna(0, inplace=True)
fossil_oil_df.fillna(0, inplace=True)
hydro_pumped_storage_df.fillna(0, inplace=True)
hydro_run_of_river_and_poundage_df.fillna(0, inplace=True)
hydro_water_reservoir_df.fillna(0, inplace=True)
nuclear_df.fillna(0, inplace=True)
other_df.fillna(0, inplace=True)
other_renewable_df.fillna(0, inplace=True)
solar_df.fillna(0, inplace=True)
waste_df.fillna(0, inplace=True)
wind_offshore_df.fillna(0, inplace=True)
wind_onshore_df.fillna(0, inplace=True)
geothermal_df.fillna(0, inplace=True)
fossil_peat_df.fillna(0, inplace=True)

In [38]:
datetime_intersect = (
    flow_df.index.intersection(dap_df.index)
    .intersection(load_df.index)
    .intersection(biomass_df.index)
    .intersection(fossil_brown_coal_df.index)
    .intersection(fossil_coal_derived_gas_df.index)
    .intersection(fossil_gas_df.index)
    .intersection(fossil_hard_coal_df.index)
    .intersection(fossil_oil_df.index)
    .intersection(hydro_pumped_storage_df.index)
    .intersection(hydro_run_of_river_and_poundage_df.index)
    .intersection(hydro_water_reservoir_df.index)
    .intersection(nuclear_df.index)
    .intersection(other_df.index)
    .intersection(other_renewable_df.index)
    .intersection(solar_df.index)
    .intersection(waste_df.index)
    .intersection(wind_offshore_df.index)
    .intersection(wind_onshore_df.index)
    .intersection(geothermal_df.index)
    .intersection(fossil_peat_df.index)
)
print(len(datetime_intersect))
print(min(datetime_intersect), max(datetime_intersect))
# Check if datetime_intersect is monotonically increasing
assert all(
    datetime_intersect[i] < datetime_intersect[i + 1]
    for i in range(len(datetime_intersect) - 1)
)

43729
2015-01-04 23:00:00 2019-12-31 23:00:00


In [39]:
# Create temporal features based on datetime_intersect
temporal_hour_df = pd.DataFrame(index=datetime_intersect)
temporal_dow_df = pd.DataFrame(index=datetime_intersect)
temporal_month_df = pd.DataFrame(index=datetime_intersect)
temporal_doy_df = pd.DataFrame(index=datetime_intersect)
for country_id in countries.keys():
    temporal_hour_df[country_id] = datetime_intersect.hour
    temporal_dow_df[country_id] = datetime_intersect.dayofweek
    temporal_month_df[country_id] = datetime_intersect.month
    temporal_doy_df[country_id] = datetime_intersect.dayofyear

In [40]:
flow_df = flow_df.loc[datetime_intersect]
dap_df = dap_df.loc[datetime_intersect]
load_df = load_df.loc[datetime_intersect]
biomass_df = biomass_df.loc[datetime_intersect]
fossil_brown_coal_df = fossil_brown_coal_df.loc[datetime_intersect]
fossil_coal_derived_gas_df = fossil_coal_derived_gas_df.loc[datetime_intersect]
fossil_gas_df = fossil_gas_df.loc[datetime_intersect]
fossil_hard_coal_df = fossil_hard_coal_df.loc[datetime_intersect]
fossil_oil_df = fossil_oil_df.loc[datetime_intersect]
hydro_pumped_storage_df = hydro_pumped_storage_df.loc[datetime_intersect]
hydro_run_of_river_and_poundage_df = hydro_run_of_river_and_poundage_df.loc[
    datetime_intersect
]
hydro_water_reservoir_df = hydro_water_reservoir_df.loc[datetime_intersect]
nuclear_df = nuclear_df.loc[datetime_intersect]
other_df = other_df.loc[datetime_intersect]
other_renewable_df = other_renewable_df.loc[datetime_intersect]
solar_df = solar_df.loc[datetime_intersect]
waste_df = waste_df.loc[datetime_intersect]
wind_offshore_df = wind_offshore_df.loc[datetime_intersect]
wind_onshore_df = wind_onshore_df.loc[datetime_intersect]
geothermal_df = geothermal_df.loc[datetime_intersect]
fossil_peat_df = fossil_peat_df.loc[datetime_intersect]

In [41]:
edges = np.array(interconnections_edge_matrix)
print(edges.shape)
# Map edge names to indices
edge_names = np.unique(edges)
edge_map = {edge: i for i, edge in enumerate(edge_names)}
edge_indices = np.array([edge_map[edge] for edge in edges.flatten()]).reshape(
    edges.shape
)
# Repeat edge indices for each datetime
edge_indices = np.repeat(
    edge_indices[np.newaxis, :, :],
    len(datetime_intersect),
    axis=0,
)
print(edge_indices.shape)
n_edges = edges.shape[1]

(2, 32)
(43729, 2, 32)


In [42]:
# Edge labels (flow) of shape (n_datetime, n_edges, 1)
edge_labels = np.array(flow_df)
# print(edge_labels.shape)
edge_labels = np.reshape(
    edge_labels, (len(datetime_intersect), edge_labels.shape[1], 1)
)
print(edge_labels.shape)
edge_attributes = np.copy(edge_labels)
print(edge_attributes.shape)

(43729, 32, 1)
(43729, 32, 1)


In [43]:
# Node features (dap, load) of shape (n_datetime, n_nodes, n_node_features)
node_features = np.stack(
    [
        dap_df.to_numpy(),
        load_df.to_numpy(),
        biomass_df.to_numpy(),
        fossil_brown_coal_df.to_numpy(),
        fossil_coal_derived_gas_df.to_numpy(),
        fossil_gas_df.to_numpy(),
        fossil_hard_coal_df.to_numpy(),
        fossil_oil_df.to_numpy(),
        hydro_pumped_storage_df.to_numpy(),
        hydro_run_of_river_and_poundage_df.to_numpy(),
        hydro_water_reservoir_df.to_numpy(),
        nuclear_df.to_numpy(),
        other_df.to_numpy(),
        other_renewable_df.to_numpy(),
        solar_df.to_numpy(),
        waste_df.to_numpy(),
        wind_offshore_df.to_numpy(),
        wind_onshore_df.to_numpy(),
        geothermal_df.to_numpy(),
        fossil_peat_df.to_numpy(),
        temporal_hour_df.to_numpy(),
        temporal_dow_df.to_numpy(),
        temporal_month_df.to_numpy(),
        temporal_doy_df.to_numpy(),
    ],
    axis=-1,
)
print(node_features.shape)
print(node_features[0, 0, :])
# print(node_features)
n_nodes = node_features.shape[1]

(43729, 10, 24)
[3.656000e+01 1.003953e+04 2.311300e+02 0.000000e+00 0.000000e+00
 1.791710e+03 4.368500e+02 0.000000e+00 0.000000e+00 3.596000e+01
 0.000000e+00 3.904350e+03 6.016400e+02 0.000000e+00 0.000000e+00
 2.690100e+02 4.937000e+01 2.415900e+02 0.000000e+00 0.000000e+00
 2.300000e+01 6.000000e+00 1.000000e+00 4.000000e+00]


In [44]:
assert (
    len(datetime_intersect)
    == edge_indices.shape[0]
    == edge_labels.shape[0]
    == edge_attributes.shape[0]
    == node_features.shape[0]
)

In [45]:
# Print a snapshot of the graph data
idx = 0
print(datetime_intersect[idx])
print(edge_indices[idx])
print(edge_labels[idx])
print(edge_attributes[idx])
print(node_features[idx])

2015-01-04 23:00:00
[[0 0 0 0 3 3 3 4 4 4 4 2 2 2 2 2 5 6 6 7 7 7 7 7 8 8 1 1 9 9 9 9]
 [4 6 7 9 2 7 8 0 2 1 9 3 4 6 7 1 9 0 2 0 3 2 8 9 3 7 4 2 0 4 5 7]]
[[   0.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [1315.79]
 [  52.  ]
 [ 617.  ]
 [ 279.  ]
 [1433.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [3205.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [2106.86]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [ 964.  ]
 [   0.  ]
 [ 704.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [ 169.19]
 [   0.  ]]
[[   0.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [1315.79]
 [  52.  ]
 [ 617.  ]
 [ 279.  ]
 [1433.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [3205.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [2106.86]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [ 964.  ]
 [   0.  ]
 [ 704.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [   0.  ]
 [ 169.19]
 [   0.  ]]
[[3.65600000e+01 1.00395300e+04 2.31130000e+02 0.00000000e+00
  0.00000000e+00 1.79171000e+03 4.36850000e+02 0.00000000e+00
  0.00000000e+0

In [46]:
# https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GATv2Conv.html
class GNNEncoder(nn.Module):
    def __init__(
        self, hidden_channels, num_heads_GAT, dropout_p_GAT, edge_dim_GAT, momentum_GAT
    ):
        super().__init__()
        self.gat = GATv2Conv(
            (-1, -1),
            hidden_channels,
            add_self_loops=False,
            heads=num_heads_GAT,
            edge_dim=edge_dim_GAT,
        )
        self.norm = BatchNorm1d(
            hidden_channels,
            momentum=momentum_GAT,
            affine=False,
            track_running_stats=False,
        )
        self.dropout = nn.Dropout(dropout_p_GAT)

    def forward(self, x, edge_indices, edge_attrs):
        x = self.dropout(x)
        x = self.norm(x)
        nodes_embedds = self.gat(x, edge_indices, edge_attrs)
        nodes_embedds = F.leaky_relu(nodes_embedds, negative_slope=0.1)
        return nodes_embedds

In [47]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        x = x + self.pe[: x.size(0)]
        return self.dropout(x)

In [48]:
class Transformer(nn.Module):
    def __init__(
        self,
        dim_model,
        num_heads_TR,
        num_encoder_layers_TR,
        num_decoder_layers_TR,
        dropout_p_TR,
    ):
        super().__init__()
        self.pos_encoder = PositionalEncoding(dim_model)
        self.transformer = nn.Transformer(
            d_model=dim_model,
            nhead=num_heads_TR,
            num_decoder_layers=num_encoder_layers_TR,
            num_encoder_layers=num_decoder_layers_TR,
            dropout=dropout_p_TR,
        )

    def forward(self, src, trg):
        src = self.pos_encoder(src)
        trg = self.pos_encoder(trg)
        temporal_node_embeddings = self.transformer(src, trg)
        return temporal_node_embeddings

In [49]:
class EdgeDecoder(nn.Module):
    def __init__(self, hidden_channels, num_heads_GAT, num_edges, num_nodes):
        super().__init__()
        self.lin1 = nn.Linear(
            num_nodes * hidden_channels * num_heads_GAT, hidden_channels
        )
        self.lin2 = nn.Linear(hidden_channels, num_edges)

    def forward(self, x):
        # Flatten the tensor
        x = torch.flatten(x)
        x = self.lin1(x)
        x = F.leaky_relu(x, negative_slope=0.1)
        x = self.lin2(x)
        return x.view(-1)

In [50]:
class Model(nn.Module):
    def __init__(
        self,
        hidden_channels,
        num_heads_GAT,
        dropout_p_GAT,
        edge_dim_GAT,
        momentum_GAT,
        dim_model,
        num_heads_TR,
        num_encoder_layers_TR,
        num_decoder_layers_TR,
        dropout_p_TR,
        n_edges,
        n_nodes,
    ):
        super().__init__()
        self.encoder = GNNEncoder(
            hidden_channels, num_heads_GAT, dropout_p_GAT, edge_dim_GAT, momentum_GAT
        )  # node embedding with GAT
        self.transformer = Transformer(
            dim_model,
            num_heads_TR,
            num_encoder_layers_TR,
            num_decoder_layers_TR,
            dropout_p_TR,
        )
        self.decoder = EdgeDecoder(hidden_channels, num_heads_GAT, n_edges, n_nodes)

    def forward(self, x, edge_indices, edge_attrs):
        src_embedds = []
        for i in range(x.shape[0]):
            src_embedds.append(self.encoder(x[i], edge_indices[i], edge_attrs[i]))
        src_embedds = torch.stack(src_embedds)
        trg_embedds = src_embedds[-1].unsqueeze(0)
        temporal_node_embedds = self.transformer(src_embedds, trg_embedds)
        temporal_node_embedds = temporal_node_embedds.squeeze(0)
        edge_predictions = self.decoder(temporal_node_embedds)
        return edge_predictions

In [51]:
def train(model, data, window_size, num_epochs, lr):
    model = model.to(device)
    data = [d.to(device) for d in data]
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    for epoch in range(num_epochs):
        loss_sum = None
        for m in tqdm(range(len(data) - window_size)):
            # for m in range(len(data) - window_size):
            optimizer.zero_grad()
            x = torch.stack([data[m + i].x for i in range(window_size)])
            edge_indices = torch.stack(
                [data[m + i].edge_index for i in range(window_size)]
            )
            edge_attrs = torch.stack(
                [data[m + i].edge_attr for i in range(window_size)]
            )
            y = data[m + window_size].y
            y_pred = model(x, edge_indices, edge_attrs)
            y_pred = y_pred.view((-1, 1))
            loss = criterion(y_pred, y)
            # print(f"Epoch {epoch}, Loss {loss.item()}")
            if loss_sum is None:
                loss_sum = loss
            else:
                loss_sum += loss
            # if m % 24 * 7 == 0 or m == len(data) - window_size - 1:
            if m == len(data) - window_size - 1 or m % 24 == 0:
                diff = y.squeeze() - y_pred
                diff = diff.detach().cpu().numpy()
                print(f"Epoch {epoch}, m={m}", diff.mean())
                loss_sum.backward()
                optimizer.step()
                optimizer.zero_grad()
                loss_sum = None

    return model

In [52]:
print(node_features.shape)
print(edge_indices.shape)
print(edge_attributes.shape)
print(edge_labels.shape)
snapshots = []
for i in range(len(datetime_intersect)):
    x = torch.tensor(node_features[i], dtype=torch.float)
    edge_index = torch.tensor(edge_indices[i], dtype=torch.long)
    edge_attr = torch.tensor(edge_attributes[i], dtype=torch.float)
    y = torch.tensor(edge_labels[i], dtype=torch.float)
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
    snapshots.append(data)
print(len(snapshots))

(43729, 10, 24)
(43729, 2, 32)
(43729, 32, 1)
(43729, 32, 1)
43729


In [53]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model(
    hidden_channels=64,
    num_heads_GAT=4,
    dropout_p_GAT=0.1,
    edge_dim_GAT=1,  # edge attributes
    momentum_GAT=0.1,
    dim_model=64 * 4,  # hidden_channels * num_heads_GAT
    num_heads_TR=4,
    num_encoder_layers_TR=6,
    num_decoder_layers_TR=6,
    dropout_p_TR=0.1,
    n_edges=n_edges,
    n_nodes=n_nodes,
)
train(model, snapshots[:8760], window_size=24, num_epochs=10000, lr=0.001)

 79%|███████▉  | 6887/8736 [06:14<01:11, 26.01it/s]

Epoch 2, m=6888 100.416885


 79%|███████▉  | 6912/8736 [06:15<01:07, 27.13it/s]

Epoch 2, m=6912 100.10389


 79%|███████▉  | 6933/8736 [06:16<01:11, 25.32it/s]

Epoch 2, m=6936 120.12009


 80%|███████▉  | 6959/8736 [06:18<01:05, 26.96it/s]

Epoch 2, m=6960 123.95025


 80%|███████▉  | 6984/8736 [06:19<01:04, 26.99it/s]

Epoch 2, m=6984 71.968834


 80%|████████  | 7006/8736 [06:20<01:05, 26.43it/s]

Epoch 2, m=7008 32.50405


 80%|████████  | 7028/8736 [06:21<01:04, 26.29it/s]

Epoch 2, m=7032 29.336586


 81%|████████  | 7054/8736 [06:23<01:00, 27.70it/s]

Epoch 2, m=7056 86.65198


 81%|████████  | 7079/8736 [06:24<00:57, 28.83it/s]

Epoch 2, m=7080 79.34653


 81%|████████▏ | 7102/8736 [06:25<00:58, 27.73it/s]

Epoch 2, m=7104 63.893677


 82%|████████▏ | 7128/8736 [06:26<00:58, 27.72it/s]

Epoch 2, m=7128 44.49556


 82%|████████▏ | 7150/8736 [06:28<01:00, 26.35it/s]

Epoch 2, m=7152 -14.090698


 82%|████████▏ | 7175/8736 [06:29<00:54, 28.85it/s]

Epoch 2, m=7176 -36.775593


 82%|████████▏ | 7200/8736 [06:30<00:56, 27.15it/s]

Epoch 2, m=7200 11.46801


 83%|████████▎ | 7223/8736 [06:32<00:55, 27.13it/s]

Epoch 2, m=7224 13.108147


 83%|████████▎ | 7245/8736 [06:33<00:56, 26.39it/s]

Epoch 2, m=7248 -44.84626


 83%|████████▎ | 7272/8736 [06:34<00:48, 29.89it/s]

Epoch 2, m=7272 8.252632


 83%|████████▎ | 7292/8736 [06:35<00:56, 25.72it/s]

Epoch 2, m=7296 10.76194


 84%|████████▍ | 7319/8736 [06:37<00:48, 29.50it/s]

Epoch 2, m=7320 13.940407


 84%|████████▍ | 7341/8736 [06:38<00:50, 27.42it/s]

Epoch 2, m=7344 47.720024


 84%|████████▍ | 7366/8736 [06:39<00:47, 29.13it/s]

Epoch 2, m=7368 77.78702


 85%|████████▍ | 7388/8736 [06:40<00:50, 26.81it/s]

Epoch 2, m=7392 46.038704


 85%|████████▍ | 7414/8736 [06:42<00:47, 28.08it/s]

Epoch 2, m=7416 67.27581


 85%|████████▌ | 7436/8736 [06:43<00:50, 25.99it/s]

Epoch 2, m=7440 50.644737


 85%|████████▌ | 7464/8736 [06:44<00:42, 30.07it/s]

Epoch 2, m=7464 88.66801


 86%|████████▌ | 7487/8736 [06:45<00:45, 27.27it/s]

Epoch 2, m=7488 68.92307


 86%|████████▌ | 7508/8736 [06:47<00:47, 25.73it/s]

Epoch 2, m=7512 -14.243683


 86%|████████▋ | 7535/8736 [06:48<00:41, 28.75it/s]

Epoch 2, m=7536 11.956139


 87%|████████▋ | 7559/8736 [06:49<00:41, 28.33it/s]

Epoch 2, m=7560 31.199799


 87%|████████▋ | 7584/8736 [06:51<00:40, 28.12it/s]

Epoch 2, m=7584 0.18450928


 87%|████████▋ | 7606/8736 [06:52<00:43, 26.19it/s]

Epoch 2, m=7608 63.447243


 87%|████████▋ | 7631/8736 [06:53<00:38, 28.80it/s]

Epoch 2, m=7632 52.120003


 88%|████████▊ | 7655/8736 [06:54<00:40, 26.75it/s]

Epoch 2, m=7656 32.447327


 88%|████████▊ | 7678/8736 [06:56<00:51, 20.72it/s]

Epoch 2, m=7680 72.0876


 88%|████████▊ | 7701/8736 [06:57<00:40, 25.33it/s]

Epoch 2, m=7704 -0.2614212


 88%|████████▊ | 7728/8736 [06:59<00:36, 27.96it/s]

Epoch 2, m=7728 22.244537


 89%|████████▊ | 7752/8736 [07:00<00:36, 26.73it/s]

Epoch 2, m=7752 53.522507


 89%|████████▉ | 7776/8736 [07:01<00:35, 27.38it/s]

Epoch 2, m=7776 77.80987


 89%|████████▉ | 7800/8736 [07:03<00:35, 26.35it/s]

Epoch 2, m=7800 -12.732674


 90%|████████▉ | 7824/8736 [07:04<00:33, 27.25it/s]

Epoch 2, m=7824 -6.175812


 90%|████████▉ | 7848/8736 [07:05<00:32, 27.65it/s]

Epoch 2, m=7848 3.5869064


 90%|█████████ | 7868/8736 [07:07<00:34, 24.92it/s]

Epoch 2, m=7872 -28.469704


 90%|█████████ | 7894/8736 [07:08<00:29, 28.83it/s]

Epoch 2, m=7896 29.592178


 91%|█████████ | 7918/8736 [07:09<00:30, 27.18it/s]

Epoch 2, m=7920 29.84816


 91%|█████████ | 7942/8736 [07:11<00:28, 27.88it/s]

Epoch 2, m=7944 55.257668


 91%|█████████ | 7967/8736 [07:12<00:26, 28.52it/s]

Epoch 2, m=7968 6.3739624


 91%|█████████▏| 7989/8736 [07:13<00:29, 25.37it/s]

Epoch 2, m=7992 -44.752308


 92%|█████████▏| 8014/8736 [07:15<00:25, 28.52it/s]

Epoch 2, m=8016 -19.169884


 92%|█████████▏| 8036/8736 [07:16<00:25, 27.02it/s]

Epoch 2, m=8040 -7.886627


 92%|█████████▏| 8063/8736 [07:17<00:22, 29.47it/s]

Epoch 2, m=8064 10.520851


 93%|█████████▎| 8087/8736 [07:18<00:23, 27.86it/s]

Epoch 2, m=8088 -11.483971


 93%|█████████▎| 8110/8736 [07:20<00:23, 26.90it/s]

Epoch 2, m=8112 61.64615


 93%|█████████▎| 8136/8736 [07:21<00:20, 29.26it/s]

Epoch 2, m=8136 4.6289062


 93%|█████████▎| 8160/8736 [07:22<00:20, 27.52it/s]

Epoch 2, m=8160 58.203056


 94%|█████████▎| 8184/8736 [07:24<00:20, 27.01it/s]

Epoch 2, m=8184 91.41126


 94%|█████████▍| 8208/8736 [07:25<00:20, 25.69it/s]

Epoch 2, m=8208 36.698547


 94%|█████████▍| 8229/8736 [07:26<00:20, 24.93it/s]

Epoch 2, m=8232 45.60649


 94%|█████████▍| 8253/8736 [07:28<00:17, 27.06it/s]

Epoch 2, m=8256 42.39205


 95%|█████████▍| 8277/8736 [07:29<00:16, 27.75it/s]

Epoch 2, m=8280 57.558952


 95%|█████████▌| 8302/8736 [07:30<00:15, 28.06it/s]

Epoch 2, m=8304 60.375443


 95%|█████████▌| 8325/8736 [07:32<00:15, 27.09it/s]

Epoch 2, m=8328 24.52745


 96%|█████████▌| 8350/8736 [07:33<00:13, 28.26it/s]

Epoch 2, m=8352 -10.269745


 96%|█████████▌| 8374/8736 [07:34<00:12, 28.01it/s]

Epoch 2, m=8376 12.146034


 96%|█████████▌| 8398/8736 [07:35<00:12, 27.64it/s]

Epoch 2, m=8400 89.200294


 96%|█████████▋| 8422/8736 [07:37<00:11, 28.12it/s]

Epoch 2, m=8424 4.5527496


 97%|█████████▋| 8444/8736 [07:38<00:10, 26.73it/s]

Epoch 2, m=8448 -60.74202


 97%|█████████▋| 8471/8736 [07:39<00:08, 29.48it/s]

Epoch 2, m=8472 -39.377525


 97%|█████████▋| 8493/8736 [07:41<00:09, 26.43it/s]

Epoch 2, m=8496 -58.253902


 97%|█████████▋| 8517/8736 [07:42<00:07, 27.80it/s]

Epoch 2, m=8520 -54.771973


 98%|█████████▊| 8542/8736 [07:43<00:07, 26.63it/s]

Epoch 2, m=8544 -86.54287


 98%|█████████▊| 8567/8736 [07:44<00:05, 28.45it/s]

Epoch 2, m=8568 -30.285995


 98%|█████████▊| 8589/8736 [07:46<00:05, 26.46it/s]

Epoch 2, m=8592 -70.02002


 99%|█████████▊| 8615/8736 [07:47<00:04, 28.73it/s]

Epoch 2, m=8616 -13.80336


 99%|█████████▉| 8639/8736 [07:48<00:03, 26.98it/s]

Epoch 2, m=8640 -33.540504


 99%|█████████▉| 8660/8736 [07:50<00:02, 25.78it/s]

Epoch 2, m=8664 -44.267174


 99%|█████████▉| 8688/8736 [07:51<00:01, 27.18it/s]

Epoch 2, m=8688 -24.544235


100%|█████████▉| 8712/8736 [07:52<00:00, 25.97it/s]

Epoch 2, m=8712 22.151596


100%|█████████▉| 8734/8736 [07:54<00:00, 26.82it/s]

Epoch 2, m=8735 38.80275


100%|██████████| 8736/8736 [07:54<00:00, 18.40it/s]
  0%|          | 4/8736 [00:00<04:38, 31.38it/s]

Epoch 3, m=0 -164.19379


  0%|          | 24/8736 [00:00<03:33, 40.78it/s]

Epoch 3, m=24 -168.0294


  1%|          | 46/8736 [00:01<05:17, 27.37it/s]

Epoch 3, m=48 -141.3866


  1%|          | 69/8736 [00:03<05:17, 27.33it/s]

Epoch 3, m=72 -35.355392


  1%|          | 93/8736 [00:04<05:07, 28.07it/s]

Epoch 3, m=96 -37.03753


  1%|▏         | 117/8736 [00:05<05:08, 27.95it/s]

Epoch 3, m=120 -67.276184


  2%|▏         | 142/8736 [00:06<05:12, 27.50it/s]

Epoch 3, m=144 -109.88258


  2%|▏         | 168/8736 [00:08<04:51, 29.40it/s]

Epoch 3, m=168 -42.748024


  2%|▏         | 189/8736 [00:09<05:36, 25.43it/s]

Epoch 3, m=192 -9.396194


  2%|▏         | 215/8736 [00:10<04:58, 28.56it/s]

Epoch 3, m=216 -28.413307


  3%|▎         | 238/8736 [00:12<05:17, 26.80it/s]

Epoch 3, m=240 -42.446457


  3%|▎         | 262/8736 [00:13<05:03, 27.96it/s]

Epoch 3, m=264 -31.014633


  3%|▎         | 288/8736 [00:14<04:48, 29.32it/s]

Epoch 3, m=288 -24.979256


  4%|▎         | 310/8736 [00:15<05:11, 27.03it/s]

Epoch 3, m=312 -20.396805


  4%|▍         | 332/8736 [00:17<05:10, 27.08it/s]

Epoch 3, m=336 -19.259705


  4%|▍         | 357/8736 [00:18<05:12, 26.80it/s]

Epoch 3, m=360 -89.04425


  4%|▍         | 381/8736 [00:19<05:39, 24.58it/s]

Epoch 3, m=384 -19.188042


  5%|▍         | 406/8736 [00:21<05:05, 27.27it/s]

Epoch 3, m=408 1.0183563


  5%|▍         | 430/8736 [00:22<04:56, 28.06it/s]

Epoch 3, m=432 -16.532204


  5%|▌         | 454/8736 [00:23<05:00, 27.55it/s]

Epoch 3, m=456 74.15493


  5%|▌         | 479/8736 [00:25<05:59, 22.96it/s]

Epoch 3, m=480 -9.040466


  6%|▌         | 503/8736 [00:26<05:03, 27.14it/s]

Epoch 3, m=504 72.9059


  6%|▌         | 527/8736 [00:27<05:02, 27.16it/s]

Epoch 3, m=528 14.229912


  6%|▋         | 548/8736 [00:29<05:25, 25.16it/s]

Epoch 3, m=552 11.357437


  7%|▋         | 573/8736 [00:30<05:18, 25.62it/s]

Epoch 3, m=576 57.212635


  7%|▋         | 597/8736 [00:31<04:58, 27.26it/s]

Epoch 3, m=600 29.048042


  7%|▋         | 622/8736 [00:33<04:51, 27.81it/s]

Epoch 3, m=624 98.651855


  7%|▋         | 646/8736 [00:34<04:47, 28.14it/s]

Epoch 3, m=648 152.85284


  8%|▊         | 670/8736 [00:35<04:53, 27.46it/s]

Epoch 3, m=672 176.97675


  8%|▊         | 695/8736 [00:37<04:44, 28.30it/s]

Epoch 3, m=696 179.02748


  8%|▊         | 717/8736 [00:38<05:01, 26.57it/s]

Epoch 3, m=720 176.6073


  9%|▊         | 743/8736 [00:39<04:31, 29.40it/s]

Epoch 3, m=744 162.82294


  9%|▉         | 765/8736 [00:40<04:56, 26.89it/s]

Epoch 3, m=768 68.06915


  9%|▉         | 791/8736 [00:42<04:39, 28.44it/s]

Epoch 3, m=792 128.07588


  9%|▉         | 816/8736 [00:43<04:38, 28.39it/s]

Epoch 3, m=816 121.73525


 10%|▉         | 837/8736 [00:44<05:18, 24.83it/s]

Epoch 3, m=840 129.14432


 10%|▉         | 863/8736 [00:45<04:28, 29.27it/s]

Epoch 3, m=864 146.98383


 10%|█         | 885/8736 [00:47<04:52, 26.83it/s]

Epoch 3, m=888 157.43636


 10%|█         | 909/8736 [00:48<04:51, 26.90it/s]

Epoch 3, m=912 124.30865


 11%|█         | 935/8736 [00:49<04:31, 28.70it/s]

Epoch 3, m=936 88.65631


 11%|█         | 959/8736 [00:51<05:01, 25.79it/s]

Epoch 3, m=960 101.14924


 11%|█▏        | 983/8736 [00:52<04:36, 27.99it/s]

Epoch 3, m=984 65.774345
