In [1]:
import pandas as pd
import numpy as np
import os
from src.data.process_dataset import load_csv_dataset, create_grid, create_grid_ids, correlation_adjacency_matrix, features_targets_and_externals, Dataset
from src.data.encode_externals import encode_times
from src.models.models import ExternalLSTM, GraphModel, CustomTemporalSignal, Encoder, Decoder, STGNNModel
from torch_geometric_temporal.signal import StaticGraphTemporalSignal, temporal_signal_split
import dill
from torch_geometric_temporal.nn.recurrent import DCRNN
import torch
import torch.nn.functional as F

In [2]:
open_file = open("/Users/theisferre/Documents/SPECIALE/Thesis/data/processed/202106-citibike-tripdata.pkl", "rb")
data = dill.load(open_file)
train_dataset, test_dataset = Dataset.train_test_split(data)

In [32]:
data.X.shape

(179, 69, 1)

In [3]:
graph_model = GraphModel(node_in_features=1, num_nodes=69, node_out_features=8)
weather_model = ExternalLSTM(data.weather_information.shape[-1], num_nodes=69)
time_model = ExternalLSTM(data.time_encoding.shape[-1], num_nodes=69)

In [4]:
data_graph = []
data_weather = []
data_time = []
for i in range(len(train_dataset)):
    dat, weather, time_enc = train_dataset[i]
    data_graph.append(dat)
    data_weather.append(weather)
    data_time.append(time_enc)
    if i > 5:
        break


In [5]:
cell_state_graph, hidden_state_graph = graph_model(data_graph)
cell_state_weather, hidden_state_weather = weather_model(data_weather)
cell_state_time, hidden_state_weather = time_model(data_time)

In [6]:
print(f"graph model shapes: {cell_state_graph.shape}, {hidden_state_graph.shape}")
print(f"weather model shapes: {cell_state_weather.shape}, {hidden_state_weather.shape}")
print(f"time model shapes: {cell_state_time.shape}, {hidden_state_weather.shape}")

graph model shapes: torch.Size([1, 552]), torch.Size([1, 552])
weather model shapes: torch.Size([1, 552]), torch.Size([1, 552])
time model shapes: torch.Size([1, 552]), torch.Size([1, 552])


In [7]:
encoder = Encoder(
    node_in_features=1,
    num_nodes=69,
    node_out_features=8,
    time_features=data.time_encoding.shape[-1],
    weather_features=data.weather_information.shape[-1],
    hidden_size=64
)

decoder = Decoder(
    node_out_features=8,
    num_nodes=69
)

model = STGNNModel(encoder, decoder)

In [8]:
cell_state_fused, hidden_state_fused = encoder(data_graph, data_weather, data_time)

In [23]:
num_history = 4
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
for EPOCH in range(50):
    epoch_loss = 0
    
    for i in range(num_history, len(train_dataset)):
        count = num_history
        data_graph = []
        data_weather = []
        data_time = []
        while count > 0:
            
            graph_data, weather, time_enc = train_dataset[i - count]

            data_graph.append(graph_data)
            data_weather.append(weather)
            data_time.append(time_enc)

            count -= 1
        optimizer.zero_grad()
        out, (hidden, cell) = model(data_graph, data_weather, data_time)
        loss = criterion(train_dataset[i][0].x.reshape(1, 69), out)

        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
    print(epoch_loss)


38902347.695373535
36283846.99133301
34044505.380371094
32113777.197265625
30434422.378173828
28929526.017578125
27564996.71484375
26336232.771972656
25229010.177734375
24217284.365234375
23295184.458984375
22448109.55078125
21662522.75390625
20944823.512695312
20315854.52734375
19799935.904296875
19381208.607910156
18903001.575683594
18412164.77783203
18112621.92626953
19119536.341796875
20108726.66845703
20182546.497070312
20061253.306152344
19764143.69580078
19386045.11376953
19041985.525878906
18799719.36376953
18444769.14013672
18053726.045898438
17684221.60986328
17381299.950195312
17109179.846679688
16814313.466796875
16538619.699707031
16285750.161132812
16118160.215332031
15928487.703125
15712604.963378906
15504022.148925781
15306247.554199219
15135506.352539062
14985343.233886719
14820935.336914062
14666248.395019531
14521105.134765625
14384724.334960938
14257646.829589844
14138338.965332031
14025142.470214844


In [28]:
train_dataset[i][0].x.reshape(-1, 69)

tensor([[3.1470e+03, 3.7510e+03, 3.7200e+02, 3.1000e+02, 4.4900e+02, 5.8200e+02,
         2.2180e+03, 4.3140e+03, 4.5500e+02, 5.3000e+02, 1.2800e+02, 7.1600e+02,
         1.5800e+03, 9.7000e+01, 3.1660e+03, 7.9800e+02, 9.2600e+02, 7.5500e+02,
         1.6200e+02, 2.0400e+02, 1.2580e+03, 2.0800e+02, 8.6300e+02, 6.2000e+01,
         1.0470e+03, 1.9400e+02, 1.2650e+03, 5.3000e+01, 4.1500e+02, 1.5690e+03,
         2.6500e+03, 3.2700e+02, 8.0400e+02, 7.1000e+01, 5.6000e+02, 2.3400e+02,
         1.3500e+02, 8.2500e+02, 5.4000e+01, 4.2400e+02, 3.3000e+01, 3.6000e+01,
         1.0200e+02, 8.4300e+02, 2.1900e+02, 6.9000e+01, 4.3800e+02, 2.2000e+01,
         2.0000e+01, 1.1000e+02, 1.0500e+02, 1.2900e+02, 6.3000e+01, 2.2000e+01,
         3.1600e+02, 2.2200e+02, 6.8100e+02, 1.2200e+02, 1.1600e+02, 7.7000e+01,
         4.9000e+01, 1.1800e+02, 2.9000e+01, 6.6000e+01, 2.7400e+02, 1.4000e+01,
         1.4000e+01, 2.0000e+00, 3.0000e+00]])

In [24]:
out

tensor([[9.3937e+02, 1.1260e+03, 1.2480e+02, 1.5172e+02, 2.0407e+02, 2.6598e+02,
         7.6112e+02, 1.1229e+03, 1.6712e+02, 2.2028e+02, 6.0180e+01, 3.4024e+02,
         7.0483e+02, 5.0670e+01, 8.8920e+02, 3.7179e+02, 4.1900e+02, 3.3010e+02,
         7.0750e+01, 8.5250e+01, 5.1337e+02, 1.1423e+02, 4.0493e+02, 2.7798e+01,
         5.4362e+02, 8.6623e+01, 5.1305e+02, 3.1851e+01, 1.8452e+02, 6.3607e+02,
         9.2732e+02, 1.6285e+02, 3.7495e+02, 2.7667e+01, 2.3447e+02, 1.0955e+02,
         4.5538e+01, 2.8532e+02, 2.6618e+01, 2.0892e+02, 1.8396e+01, 1.6028e+01,
         4.1424e+01, 3.4911e+02, 1.1317e+02, 4.2365e+01, 1.7412e+02, 1.2909e+01,
         7.3460e+00, 5.8989e+01, 5.6845e+01, 5.3196e+01, 2.1942e+01, 1.3973e+01,
         1.4238e+02, 8.6082e+01, 2.4314e+02, 5.1823e+01, 3.9697e+01, 4.4273e+01,
         2.4008e+01, 4.6208e+01, 1.4531e+01, 3.3986e+01, 1.1814e+02, 4.5745e+00,
         8.7140e+00, 1.8612e+00, 5.4499e-01]], grad_fn=<AddmmBackward>)

In [10]:
print(cell_state_fused.shape)

torch.Size([1, 552])


In [15]:
out, (hidden_state, cell_state) = decoder(data_graph[-1].x.reshape(1, 69), hidden_state_fused.unsqueeze(0), cell_state_fused.unsqueeze(0))