This notebook converts tabular data to graph data for pytorch geometric to ingest.

The final graph is a static homogeneous directed graph with temporal signals.


In [209]:
import math
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Data
from torch.nn import BatchNorm1d

import pandas as pd
import numpy as np

from tqdm import tqdm

In [210]:
flow_df = pd.read_csv("data/flow/_combined.csv")
dap_df = pd.read_csv("data/dap/_combined.csv")
flow_df.set_index("datetime", inplace=True)
dap_df.set_index("datetime", inplace=True)

datetime_intersect = flow_df.index.intersection(dap_df.index)
flow_df = flow_df.loc[datetime_intersect]
dap_df = dap_df.loc[datetime_intersect]

# Remove columns that contains "UK" or "IE"
flow_df = flow_df.loc[
    :, ~flow_df.columns.str.contains("UK") & ~flow_df.columns.str.contains("IE")
]
dap_df = dap_df.loc[
    :, ~dap_df.columns.str.contains("UK") & ~dap_df.columns.str.contains("IE")
]

# print(flow_df.isnull().sum())
# print(dap_df.isnull().sum())
assert not flow_df.isnull().values.any()
assert not dap_df.isnull().values.any()
dap_df = dap_df.reindex(sorted(dap_df.columns), axis=1)

In [211]:
# Normalize
# is this good enough?
flow_df = (flow_df - flow_df.mean()) / flow_df.std()
dap_df = (dap_df - dap_df.mean()) / dap_df.std()

# print(flow_df.head())
# print(dap_df.head())

                             BE->DE    BE->FR    BE->LU    BE->NL    DE->BE  \
datetime                                                                      
2023-04-01 00:00:00+00:00 -0.891149 -0.470883 -0.094801  2.616577 -0.541081   
2023-04-01 01:00:00+00:00 -0.944479 -0.470883  0.750963  3.598008  2.448207   
2023-04-01 02:00:00+00:00 -0.944479 -0.470883  1.302386  2.919569  1.663773   
2023-04-01 03:00:00+00:00 -0.944479 -0.470883  0.654091  1.892031  0.741433   
2023-04-01 04:00:00+00:00 -0.944479 -0.470883 -0.780355  0.269487  0.050287   

                             DE->DK    DE->AT    DE->CH    DE->CZ    DE->FR  \
datetime                                                                      
2023-04-01 00:00:00+00:00 -0.375745  0.765371  0.384475  0.277879 -0.400487   
2023-04-01 01:00:00+00:00 -0.375745  0.680750  0.178445  0.445536 -0.400487   
2023-04-01 02:00:00+00:00 -0.375745  0.769439  0.375343  0.440351 -0.400487   
2023-04-01 03:00:00+00:00 -0.370642  0.873589  0.41

In [212]:
# Static edges of shape (2, n_edges)
interconnectors = flow_df.columns[1:]
exporters = []
importers = []
for ic in interconnectors:
    exporter, importer = ic.split("->")
    exporters.append(exporter)
    importers.append(importer)

exporters = np.array(exporters)
importers = np.array(importers)

edges = np.vstack([exporters, importers])
print(edges.shape)
print(edges[:, :10])
n_edges = edges.shape[1]

(2, 45)
[['BE' 'BE' 'BE' 'DE' 'DE' 'DE' 'DE' 'DE' 'DE' 'DE']
 ['FR' 'LU' 'NL' 'BE' 'DK' 'AT' 'CH' 'CZ' 'FR' 'LU']]


In [213]:
# Map edge names to indices
edge_names = np.unique(edges)
edge_map = {edge: i for i, edge in enumerate(edge_names)}
edge_indices = np.array([edge_map[edge] for edge in edges.flatten()]).reshape(
    edges.shape
)
# Repeat edge indices for each datetime
edge_indices = np.repeat(
    edge_indices[np.newaxis, :, :],
    len(datetime_intersect),
    axis=0,
)
print(edge_indices.shape)

(5136, 2, 45)


In [214]:
# Edge labels (flow) of shape (n_datetimes, n_edges, 1)
edge_labels = np.array(flow_df[interconnectors])
edge_labels = np.reshape(edge_labels, (edge_labels.shape[0], edge_labels.shape[1], 1))
print(edge_labels.shape)

(5136, 45, 1)


In [215]:
# Edge attributes (capacity, etc.) of shape (n_datetimes, n_edges, n_attributes)
# copy the edge labels to the edge attributes
edge_attributes = np.copy(edge_labels)
# hard code the edge attributes to be ones for now
edge_attributes = np.ones(edge_labels.shape)
print(edge_attributes.shape)
# print(edge_attributes[:, :, 0])

(5136, 45, 1)


In [216]:
# Node features (dap)
node_features = np.array(node_features)
node_features = np.reshape(
    node_features, (node_features.shape[0], node_features.shape[1], 1)
)
print(node_features.shape)
n_nodes = node_features.shape[1]

(5136, 15, 1)


In [217]:
assert (
    len(datetime_intersect)
    == edge_indices.shape[0]
    == edge_attributes.shape[0]
    == edge_labels.shape[0]
    == node_features.shape[0]
)
# Print a snapshot of the shape of the graph data
i = 256
print("Edge indices:", edge_indices[i].shape)
print("Edge attributes:", edge_attributes[i].shape)
print("Edge labels:", edge_labels[i].shape)
print("Node features:", node_features[i].shape)

Edge indices: (2, 45)
Edge attributes: (45, 1)
Edge labels: (45, 1)
Node features: (15, 1)


In [218]:
n_snapshots = len(datetime_intersect)
snapshots = []
for i in range(n_snapshots):
    data = Data(
        x=torch.tensor(node_features[i], dtype=torch.float),
        edge_index=torch.tensor(edge_indices[i], dtype=torch.long),
        edge_attr=torch.tensor(edge_attributes[i], dtype=torch.float),
        y=torch.tensor(edge_labels[i], dtype=torch.float),
    )
    snapshots.append(data)
print(snapshots[0])

Data(x=[15, 1], edge_index=[2, 45], edge_attr=[45, 1], y=[45, 1])


In [219]:
# https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GATv2Conv.html
class GNNEncoder(nn.Module):
    def __init__(
        self, hidden_channels, num_heads_GAT, dropout_p_GAT, edge_dim_GAT, momentum_GAT
    ):
        super().__init__()
        self.gat = GATv2Conv(
            (-1, -1),
            hidden_channels,
            add_self_loops=False,
            heads=num_heads_GAT,
            edge_dim=edge_dim_GAT,
        )
        self.norm = BatchNorm1d(
            hidden_channels,
            momentum=momentum_GAT,
            affine=False,
            track_running_stats=False,
        )
        self.dropout = nn.Dropout(dropout_p_GAT)

    def forward(self, x, edge_indices, edge_attrs):
        x = self.dropout(x)
        x = self.norm(x)
        nodes_embedds = self.gat(x, edge_indices, edge_attrs)
        nodes_embedds = F.leaky_relu(nodes_embedds, negative_slope=0.1)
        return nodes_embedds

In [220]:
# Test GNNEncoder
hidden_channels = 8
num_heads_GAT = 2
dropout_p_GAT = 0.1
edge_dim_GAT = 1  # edge attributes
momentum_GAT = 0.1
encoder = GNNEncoder(
    hidden_channels, num_heads_GAT, dropout_p_GAT, edge_dim_GAT, momentum_GAT
)

# Test forward pass
i = 0
x = snapshots[i].x
edge_indices = snapshots[i].edge_index
edge_attrs = snapshots[i].edge_attr

nodes_embedds = encoder(x, edge_indices, edge_attrs)
print(nodes_embedds.shape)
# print(nodes_embedds)

torch.Size([15, 16])


In [221]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        x = x + self.pe[: x.size(0)]
        return self.dropout(x)


class Transformer(nn.Module):
    """
    Transformer-based module for creating temporal node embeddings.

    Args:
        dim_model (int): The dimension of the model's hidden states.
        num_heads_TR (int): The number of attention heads.
        num_encoder_layers_TR (int): The number of encoder layers.
        num_decoder_layers_TR (int): The number of decoder layers.
        dropout_p_TR (float): Dropout probability.
    """
    def __init__(
        self,
        dim_model,
        num_heads_TR,
        num_encoder_layers_TR,
        num_decoder_layers_TR,
        dropout_p_TR,
    ):
        super().__init__()
        self.pos_encoder = PositionalEncoding(dim_model)
        self.transformer = nn.Transformer(
            d_model=dim_model,
            nhead=num_heads_TR,
            num_decoder_layers=num_encoder_layers_TR,
            num_encoder_layers=num_decoder_layers_TR,
            dropout=dropout_p_TR,
        )

    def forward(self, src, trg):
        src = self.pos_encoder(src)
        trg = self.pos_encoder(trg)
        temporal_node_embeddings = self.transformer(src, trg)
        return temporal_node_embeddings

In [222]:
# Test Transformer
num_heads_TR = 2
num_encoder_layers_TR = 2
num_decoder_layers_TR = 2
dropout_p_TR = 0.1
transformer = Transformer(
    hidden_channels * num_heads_GAT,
    num_heads_TR,
    num_encoder_layers_TR,
    num_decoder_layers_TR,
    dropout_p_TR,
)

# Test forward pass
seq_len = 24
src_embedds = []
for i in range(seq_len):
    snapshot = snapshots[i]
    src_embedds.append(encoder(snapshot.x, snapshot.edge_index, snapshot.edge_attr))
src_embedds = torch.stack(src_embedds)
trg_embedds = src_embedds[-1].unsqueeze(0)
print(src_embedds.shape)
print(trg_embedds.shape)

temporal_node_embedds = transformer(src_embedds, trg_embedds)
temporal_node_embedds = temporal_node_embedds.squeeze(0)
print(temporal_node_embedds.shape)

torch.Size([24, 15, 16])
torch.Size([1, 15, 16])
torch.Size([15, 16])




In [223]:
class EdgeDecoder(nn.Module):
    def __init__(self, hidden_channels, num_heads_GAT, num_edges, num_nodes):
        super().__init__()
        self.lin1 = nn.Linear(
            num_nodes * hidden_channels * num_heads_GAT, hidden_channels
        )
        self.lin2 = nn.Linear(hidden_channels, num_edges)

    def forward(self, x):
        # Flatten the tensor
        x = torch.flatten(x)
        x = self.lin1(x)
        x = F.leaky_relu(x, negative_slope=0.1)
        x = self.lin2(x)
        return x.view(-1)

In [224]:
# Test EdgeDecoder
print(hidden_channels, num_heads_GAT, n_edges)
decoder = EdgeDecoder(hidden_channels, num_heads_GAT, n_edges, n_nodes)

# Test forward pass
edge_predictions = decoder(temporal_node_embedds)
print(edge_predictions.shape)

8 2 45
torch.Size([45])


In [225]:
class Model(nn.Module):
    def __init__(
        self,
        hidden_channels,
        num_heads_GAT,
        dropout_p_GAT,
        edge_dim_GAT,
        momentum_GAT,
        dim_model,
        num_heads_TR,
        num_encoder_layers_TR,
        num_decoder_layers_TR,
        dropout_p_TR,
        num_edges,
    ):
        super().__init__()
        self.encoder = GNNEncoder(
            hidden_channels, num_heads_GAT, dropout_p_GAT, edge_dim_GAT, momentum_GAT
        )
        self.transformer = Transformer(
            dim_model,
            num_heads_TR,
            num_encoder_layers_TR,
            num_decoder_layers_TR,
            dropout_p_TR,
        )
        self.decoder = EdgeDecoder(hidden_channels, num_heads_GAT, num_edges, n_nodes)

    def forward(self, x, edge_indices, edge_attrs):
        src_embedds = []
        for i in range(x.shape[0]):
            src_embedds.append(self.encoder(x[i], edge_indices[i], edge_attrs[i]))
        src_embedds = torch.stack(src_embedds)
        trg_embedds = src_embedds[-1].unsqueeze(0)
        temporal_node_embedds = self.transformer(src_embedds, trg_embedds)
        temporal_node_embedds = temporal_node_embedds.squeeze(0)
        edge_predictions = self.decoder(temporal_node_embedds)
        return edge_predictions

In [226]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model(
    hidden_channels,
    num_heads_GAT,
    dropout_p_GAT,
    edge_dim_GAT,
    momentum_GAT,
    hidden_channels * num_heads_GAT,
    num_heads_TR,
    num_encoder_layers_TR,
    num_decoder_layers_TR,
    dropout_p_TR,
    n_edges,
)
model = model.to(device)

In [227]:
n_epochs = 10
for epoch in range(n_epochs):
    window = 24 * 7
    for m in tqdm(range(window, len(snapshots))):
        history = snapshots[m - window : m]
        y = snapshots[m].y.view(-1)

        optimizer = torch.optim.Adam(model.parameters(), lr=0.00005)
        optimizer.zero_grad()

        x = [data.x for data in history]
        edge_indices = [data.edge_index for data in history]
        edge_attrs = [data.edge_attr for data in history]

        x = torch.stack(x)
        edge_indices = torch.stack(edge_indices)
        edge_attrs = torch.stack(edge_attrs)
        y = y.to(device)

        x = x.to(device)
        edge_indices = edge_indices.to(device)
        edge_attrs = edge_attrs.to(device)

        edge_predictions = model(x, edge_indices, edge_attrs)

        loss = F.mse_loss(edge_predictions, y)
        loss.backward()
        optimizer.step()

        if (m - window) % 100 == 0:
            print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

  0%|          | 1/4968 [00:00<19:03,  4.35it/s]

Epoch 1, Loss: 1.2342640161514282


  2%|▏         | 102/4968 [00:21<16:20,  4.96it/s]

Epoch 1, Loss: 0.8182895183563232


  4%|▍         | 201/4968 [00:41<15:57,  4.98it/s]

Epoch 1, Loss: 0.8691268563270569


  6%|▌         | 302/4968 [01:01<15:42,  4.95it/s]

Epoch 1, Loss: 0.8302033543586731


  8%|▊         | 402/4968 [01:22<15:19,  4.96it/s]

Epoch 1, Loss: 0.5905413627624512


 10%|█         | 502/4968 [01:42<15:07,  4.92it/s]

Epoch 1, Loss: 0.5877428650856018


 12%|█▏        | 602/4968 [02:03<14:44,  4.94it/s]

Epoch 1, Loss: 0.5930600166320801


 14%|█▍        | 702/4968 [02:23<14:20,  4.96it/s]

Epoch 1, Loss: 0.7041400074958801


 16%|█▌        | 802/4968 [02:44<14:04,  4.93it/s]

Epoch 1, Loss: 1.4931166172027588


 18%|█▊        | 902/4968 [03:04<13:45,  4.93it/s]

Epoch 1, Loss: 0.6004807353019714


 20%|██        | 1002/4968 [03:24<13:20,  4.96it/s]

Epoch 1, Loss: 0.3341328501701355


 22%|██▏       | 1102/4968 [03:45<13:16,  4.86it/s]

Epoch 1, Loss: 0.5179334282875061


 24%|██▍       | 1202/4968 [04:06<12:50,  4.89it/s]

Epoch 1, Loss: 0.6790197491645813


 26%|██▌       | 1302/4968 [04:26<12:48,  4.77it/s]

Epoch 1, Loss: 0.5504294037818909


 28%|██▊       | 1402/4968 [04:47<12:12,  4.87it/s]

Epoch 1, Loss: 0.7634257078170776


 30%|███       | 1502/4968 [05:08<11:52,  4.87it/s]

Epoch 1, Loss: 1.6626137495040894


 32%|███▏      | 1602/4968 [05:28<11:36,  4.83it/s]

Epoch 1, Loss: 0.5621646046638489


 34%|███▍      | 1701/4968 [05:49<11:03,  4.92it/s]

Epoch 1, Loss: 1.1738470792770386


 36%|███▋      | 1802/4968 [06:10<10:53,  4.84it/s]

Epoch 1, Loss: 0.5344696044921875


 38%|███▊      | 1901/4968 [06:31<10:31,  4.85it/s]

Epoch 1, Loss: 0.5312365889549255


 40%|████      | 2001/4968 [06:52<10:07,  4.88it/s]

Epoch 1, Loss: 0.6128056049346924


 42%|████▏     | 2102/4968 [07:13<10:03,  4.75it/s]

Epoch 1, Loss: 3.5386993885040283


 44%|████▍     | 2201/4968 [07:34<10:13,  4.51it/s]

Epoch 1, Loss: 0.5528126955032349


 46%|████▋     | 2301/4968 [07:55<09:00,  4.94it/s]

Epoch 1, Loss: 0.535499095916748


 48%|████▊     | 2402/4968 [08:16<08:54,  4.80it/s]

Epoch 1, Loss: 0.6852515339851379


 50%|█████     | 2502/4968 [08:37<08:30,  4.83it/s]

Epoch 1, Loss: 0.883474588394165


 52%|█████▏    | 2602/4968 [08:58<08:02,  4.91it/s]

Epoch 1, Loss: 0.4097172021865845


 54%|█████▍    | 2702/4968 [09:19<07:54,  4.78it/s]

Epoch 1, Loss: 1.7587591409683228


 56%|█████▋    | 2801/4968 [09:40<07:23,  4.89it/s]

Epoch 1, Loss: 0.46981650590896606


 58%|█████▊    | 2902/4968 [10:01<07:05,  4.85it/s]

Epoch 1, Loss: 2.3054120540618896


 60%|██████    | 3002/4968 [10:21<06:49,  4.80it/s]

Epoch 1, Loss: 0.8392885327339172


 62%|██████▏   | 3102/4968 [10:42<06:28,  4.80it/s]

Epoch 1, Loss: 1.013670802116394


 64%|██████▍   | 3202/4968 [11:03<06:05,  4.84it/s]

Epoch 1, Loss: 0.45886194705963135


 66%|██████▋   | 3302/4968 [11:24<05:56,  4.67it/s]

Epoch 1, Loss: 0.9480093717575073


 68%|██████▊   | 3401/4968 [11:45<05:26,  4.80it/s]

Epoch 1, Loss: 0.5693511366844177


 70%|███████   | 3501/4968 [12:06<05:08,  4.76it/s]

Epoch 1, Loss: 0.6645352840423584


 73%|███████▎  | 3602/4968 [12:27<04:48,  4.73it/s]

Epoch 1, Loss: 0.5557405948638916


 75%|███████▍  | 3702/4968 [12:48<04:16,  4.93it/s]

Epoch 1, Loss: 0.653459370136261


 77%|███████▋  | 3802/4968 [13:09<03:56,  4.94it/s]

Epoch 1, Loss: 3.704559326171875


 79%|███████▊  | 3902/4968 [13:30<03:39,  4.85it/s]

Epoch 1, Loss: 0.807232677936554


 81%|████████  | 4001/4968 [13:50<03:16,  4.93it/s]

Epoch 1, Loss: 2.6219663619995117


 83%|████████▎ | 4102/4968 [14:12<02:55,  4.93it/s]

Epoch 1, Loss: 0.6866225004196167


 85%|████████▍ | 4202/4968 [14:32<02:35,  4.94it/s]

Epoch 1, Loss: 0.9220677018165588


 87%|████████▋ | 4302/4968 [14:53<02:16,  4.87it/s]

Epoch 1, Loss: 0.4691568613052368


 89%|████████▊ | 4401/4968 [15:14<02:00,  4.72it/s]

Epoch 1, Loss: 2.0498337745666504


 91%|█████████ | 4501/4968 [15:35<01:36,  4.84it/s]

Epoch 1, Loss: 0.7209762930870056


 93%|█████████▎| 4602/4968 [15:57<01:16,  4.77it/s]

Epoch 1, Loss: 0.6259484887123108


 95%|█████████▍| 4701/4968 [16:18<00:56,  4.75it/s]

Epoch 1, Loss: 0.6757475137710571


 97%|█████████▋| 4801/4968 [16:38<00:36,  4.59it/s]

Epoch 1, Loss: 0.6350441575050354


 99%|█████████▊| 4901/4968 [16:59<00:13,  5.00it/s]

Epoch 1, Loss: 4.083783149719238


100%|██████████| 4968/4968 [17:13<00:00,  4.81it/s]
  0%|          | 2/4968 [00:00<16:22,  5.05it/s]

Epoch 2, Loss: 1.0201770067214966


  2%|▏         | 102/4968 [00:21<16:15,  4.99it/s]

Epoch 2, Loss: 0.79212886095047


  4%|▍         | 202/4968 [00:41<16:24,  4.84it/s]

Epoch 2, Loss: 0.8328093886375427


  6%|▌         | 302/4968 [01:02<15:51,  4.90it/s]

Epoch 2, Loss: 0.7953106760978699


  8%|▊         | 401/4968 [01:22<15:33,  4.89it/s]

Epoch 2, Loss: 0.5379353165626526


 10%|█         | 501/4968 [01:43<15:08,  4.92it/s]

Epoch 2, Loss: 0.5466597080230713


 12%|█▏        | 601/4968 [02:04<14:53,  4.89it/s]

Epoch 2, Loss: 0.5423016548156738


 14%|█▍        | 702/4968 [02:25<14:45,  4.82it/s]

Epoch 2, Loss: 0.5826029777526855


 16%|█▌        | 801/4968 [02:45<14:13,  4.88it/s]

Epoch 2, Loss: 1.495030403137207


 18%|█▊        | 902/4968 [03:06<14:29,  4.67it/s]

Epoch 2, Loss: 0.6069523692131042


 20%|██        | 1001/4968 [03:27<14:23,  4.60it/s]

Epoch 2, Loss: 0.383100301027298


 22%|██▏       | 1102/4968 [03:48<13:08,  4.90it/s]

Epoch 2, Loss: 0.5267592072486877


 24%|██▍       | 1201/4968 [04:08<13:20,  4.71it/s]

Epoch 2, Loss: 0.672534704208374


 26%|██▌       | 1301/4968 [04:30<13:53,  4.40it/s]

Epoch 2, Loss: 0.5034573078155518


 28%|██▊       | 1401/4968 [04:51<12:07,  4.90it/s]

Epoch 2, Loss: 0.7131563425064087


 30%|███       | 1502/4968 [05:11<11:21,  5.09it/s]

Epoch 2, Loss: 1.8295085430145264


 32%|███▏      | 1601/4968 [05:31<11:13,  5.00it/s]

Epoch 2, Loss: 0.5404986143112183


 34%|███▍      | 1702/4968 [05:51<10:58,  4.96it/s]

Epoch 2, Loss: 1.1308797597885132


 36%|███▋      | 1802/4968 [06:11<10:33,  5.00it/s]

Epoch 2, Loss: 0.4906623065471649


 38%|███▊      | 1902/4968 [06:31<11:19,  4.51it/s]

Epoch 2, Loss: 0.4814990162849426


 40%|████      | 2002/4968 [06:52<09:50,  5.02it/s]

Epoch 2, Loss: 0.6353176832199097


 42%|████▏     | 2102/4968 [07:12<09:47,  4.88it/s]

Epoch 2, Loss: 3.388422727584839


 44%|████▍     | 2201/4968 [07:33<09:54,  4.66it/s]

Epoch 2, Loss: 0.549014151096344


 46%|████▋     | 2302/4968 [07:54<08:59,  4.94it/s]

Epoch 2, Loss: 0.44392502307891846


 48%|████▊     | 2401/4968 [08:15<09:22,  4.56it/s]

Epoch 2, Loss: 0.5943266153335571


 50%|█████     | 2502/4968 [08:36<08:21,  4.92it/s]

Epoch 2, Loss: 0.6889333128929138


 52%|█████▏    | 2602/4968 [08:57<07:39,  5.15it/s]

Epoch 2, Loss: 0.3956809341907501


 54%|█████▍    | 2701/4968 [09:17<07:48,  4.84it/s]

Epoch 2, Loss: 1.866665005683899


 56%|█████▋    | 2801/4968 [09:39<07:21,  4.91it/s]

Epoch 2, Loss: 0.4642089605331421


 58%|█████▊    | 2902/4968 [10:00<07:24,  4.65it/s]

Epoch 2, Loss: 2.2739977836608887


 59%|█████▉    | 2941/4968 [10:08<06:59,  4.83it/s]


KeyboardInterrupt: 