# T-GCN  model

We start with the baseline model (GCN and GRU) and try it to show how prediction works.


Main idea of T-GCN:
On one hand, the graph convolutional network is used to capture the topological structure of the urban road network to obtain the spatial dependence. On the other hand, the gated recurrent unit is used to capture the dynamic variation of traffic information on the roads to obtain the temporal dependence.

### Measures to estimate predictions:
R2 and Var calculate the correlation coefficient, which measures the ability of the prediction result to represent the actual data.


### Plotting the dataset tests 

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx

# Load the dataset
loader = ChickenpoxDatasetLoader()
dataset = loader.get_dataset()

# Extract target values (number of chickenpox cases) over time
targets = [snapshot.y.numpy() for snapshot in dataset]

# Plot the target values over time
plt.figure(figsize=(10, 6))
for i in range(targets[0].shape[0]):  # Iterate over each node
    plt.plot([t[i] for t in targets], label=f'Node {i}')
plt.title("Chickenpox Cases Over Time")
plt.xlabel("Time Step")
plt.ylabel("Number of Cases")
plt.legend()
plt.show()

# Visualize the graph structure for a specific snapshot (e.g., the first snapshot)
snapshot = dataset[0]

# Convert the PyTorch Geometric graph to a NetworkX graph
G = to_networkx(snapshot, to_undirected=True)

# Plot the graph
plt.figure(figsize=(8, 6))
pos = nx.spring_layout(G)  # Layout for visualization
nx.draw(G, pos, with_labels=True, node_color='lightblue', edge_color='gray', node_size=500, font_size=10)
plt.title("Graph Structure at Time Step 0")
plt.show()

In [None]:

from torch_geometric_temporal.nn.recurrent import TGCN  # Temporal Graph Convolutional Network
from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader  # Dataset loader for chickenpox data
from torch_geometric_temporal.signal import temporal_signal_split  # Utility to split temporal signals


# Load the Chickenpox dataset
loader = ChickenpoxDatasetLoader()
dataset = loader.get_dataset()  # Get the dataset

# Split the dataset into training and testing sets
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)

# Define a Recurrent Graph Convolutional Network (GCN) model
class RecurrentGCN(torch.nn.Module):

  '''
  first we apply GCN
  then we apply GRU
  '''
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        # TGCN layer: Temporal Graph Convolutional Network
        self.recurrent = TGCN(node_features, 32)  # Input features, 32 hidden units
        # Linear layer to map hidden state to output
        self.linear = torch.nn.Linear(32, 1)  # 32 hidden units to 1 output unit

    def forward(self, x, edge_index, edge_weight, prev_hidden_state):
        # Apply the TGCN layer
        h = self.recurrent(x, edge_index, edge_weight, prev_hidden_state)
        # Apply ReLU activation
        y = F.relu(h)
        # Apply the linear layer to get the final output
        y = self.linear(y)
        return y, h  # Return output and hidden state

# Initialize the model with 4 input node features
model = RecurrentGCN(node_features=4)

# Define the optimizer (Adam) with a learning rate of 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Set the model to training mode
model.train()

# Training loop
for epoch in tqdm(range(150)):  # 150 epochs
    cost = 0  # Initialize cost for the epoch
    hidden_state = None  # Initialize hidden state
    for time, snapshot in enumerate(train_dataset):  # Iterate over each snapshot in the training dataset
        # Forward pass: compute predictions and hidden state
        y_hat, hidden_state = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, hidden_state)
        # Compute Mean Squared Error (MSE) loss
        cost = cost + torch.mean((y_hat - snapshot.y) ** 2)
    # Average the cost over all snapshots
    cost = cost / (time + 1)
    # Backpropagation
    cost.backward()
    # Update model parameters
    optimizer.step()
    # Clear gradients
    optimizer.zero_grad()

# Set the model to evaluation mode
model.eval()

# Testing loop
cost = 0  # Initialize cost for testing
hidden_state = None  # Initialize hidden state
for time, snapshot in enumerate(test_dataset):  # Iterate over each snapshot in the testing dataset
    # Forward pass: compute predictions and hidden state
    y_hat, hidden_state = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, hidden_state)
    # Compute Mean Squared Error (MSE) loss
    cost = cost + torch.mean((y_hat - snapshot.y) ** 2)

# Average the cost over all snapshots
cost = cost / (time + 1)

# Convert cost to a Python float
cost = cost.item()

# Print the Mean Squared Error (MSE) for the test dataset
print("MSE: {:.4f}".format(cost))

## Method DFD - GCN

Building the model

 Fourier Transform

    Time-Shift Problem: The article states that the Time-Shift problem complicates the modeling of spatial dependencies in traffic data. To address this issue, the use of the Fourier Transform is proposed, which converts traffic data into the frequency domain.

    Mathematical Justification:

        Let f(t)f(t) be the traffic data captured by sensors at a specific intersection. If the traffic is delayed by time t0t0​, the data f(t−t0)f(t−t0​) will be captured at the next intersection.

        According to the definition of the Fourier Transform:
$$F(\omega) = \int_{-\infty}^{\infty} f(t) e^{-j\omega t} dt$$

$$F_{t_0}(\omega) = \int_{-\infty}^{\infty} f(t-t_0) e^{-j\omega t} dt$$

## Experiments

 Datasets: Experiments are conducted on four real-world datasets with tens of thousands of time steps and hundreds of sensors. The dataset statistics are presented in Table 1.

## Basic Features and Metrics

  Baselines and Metrics: Classical methods such as HI, GWNet, DCRNN, AGCRN, STGCN, MTGNN, DGCRN are chosen as baselines. Evaluation metrics include Mean Absolute Error (MAE), Root Mean Square Error (RMSE), and Mean Absolute Percentage Error (MAPE).

  Experiment Settings: The datasets are divided into training, validation, and test sets in a 7:1:2 ratio. Traffic data is predicted for 12 time steps using historical data of length 12. The embedding sizes after Fourier Transform and identity are 10, and the time embeddings TtWTtW​ and TtDTtD​ are 12. The embedding size after 1D convolution is 30.

  Experiment Results: DFDGCN shows better results compared to the baselines on all datasets. Ablation analysis confirms the effectiveness of the frequency graph in modeling dynamic spatial dependencies.

In [None]:
import h5py
import pickle
import folium
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F



## Variation of the GCN model 

The variation of the model with embedding additional properties:
- structural
- temporal (time of the day, time of the week) 



# Load the data 


We load the data from main repository here 
https://github.com/GestaltCogTeam/BasicTS/tree/master/datasets

In [None]:
class convt(nn.Module):
    def __init__(self):
        super(convt, self).__init__()

    def forward(self, x, w):
        x = torch.einsum('bne, ek->bnk', (x, w))
        return x.contiguous()

class nconv(nn.Module):
    def __init__(self):
        super(nconv, self).__init__()

    def forward(self, x, A, dims):
        if dims == 2:
            x = torch.einsum('ncvl,vw->ncwl', (x, A))
        elif dims == 3:
            x = torch.einsum('ncvl,nvw->ncwl', (x, A))
        else:
            raise NotImplementedError('DFDGCN not implemented for A of dimension ' + str(dims))
        return x.contiguous()

class linear(nn.Module):
    """Linear layer."""

    def __init__(self, c_in, c_out):
        super(linear, self).__init__()
        self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(
            1, 1), padding=(0, 0), stride=(1, 1), bias=True)

    def forward(self, x):
        return self.mlp(x)

class gcn(nn.Module):
    """Graph convolution network."""

    def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
        super(gcn, self).__init__()
        self.nconv = nconv()

        self.c_in = c_in
        c_in = (order * (support_len + 1) + 1) * self.c_in
        self.mlp = linear(c_in, c_out)
        self.dropout = dropout
        self.order = order

    def forward(self, x, support):

        out = [x]
        for a in support:
            x1 = self.nconv(x, a.to(x.device), a.dim())
            out.append(x1)

            for k in range(2, self.order + 1):
                x2 = self.nconv(x1, a.to(x1.device), a.dim())
                out.append(x2)
                x1 = x2
        h = torch.cat(out, dim=1)
        h = self.mlp(h)
        h = F.dropout(h, self.dropout, training=self.training)
        return h

def dy_mask_graph(adj, k):
    M = []
    for i in range(adj.size(0)):
        adp = adj[i]
        mask = torch.zeros( adj.size(1),adj.size(2)).to(adj.device)
        mask = mask.fill_(float("0"))
        s1, t1 = (adp + torch.rand_like(adp) * 0.01).topk(k, 1)
        mask = mask.scatter_(1, t1, s1.fill_(1))
        M.append(mask)
    mask = torch.stack(M,dim=0)
    adj = adj * mask
    return adj

def cat(x1,x2):
    M = []
    for i in range(x1.size(0)):
        x = x1[i]
        new_x = torch.cat([x,x2],dim=1)
        M.append(new_x)
    result = torch.stack(M,dim=0)
    return result


class DFDGCN(nn.Module):

    def __init__(self, num_nodes, dropout=0.3, supports=None,
                    gcn_bool=True, addaptadj=True, aptinit=None,
                    in_dim=2, out_dim=12, residual_channels=32,
                    dilation_channels=32, skip_channels=256, end_channels=512,
                    kernel_size=2, blocks=4, layers=2, a=1, seq_len=12, affine=True, fft_emb=10, identity_emb=10, hidden_emb=30, subgraph=20):
        super(DFDGCN, self).__init__()
        self.dropout = dropout
        self.blocks = blocks
        self.layers = layers
        self.gcn_bool = gcn_bool
        self.addaptadj = addaptadj
        self.filter_convs = nn.ModuleList()
        self.gate_convs = nn.ModuleList()
        self.residual_convs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()
        self.bn = nn.ModuleList()
        self.gconv = nn.ModuleList()
        self.seq_len = seq_len
        self.a = a

        self.start_conv = nn.Conv2d(in_channels=in_dim,
                                    out_channels=residual_channels,
                                    kernel_size=(1, 1))

        self.supports = supports
        self.emb = fft_emb
        self.subgraph_size = subgraph
        self.identity_emb = identity_emb
        self.hidden_emb = hidden_emb
        self.fft_len = round(seq_len//2) + 1
        self.Ex1 = nn.Parameter(torch.randn(self.fft_len, self.emb), requires_grad=True)
        self.Wd = nn.Parameter(torch.randn(num_nodes,self.emb + self.identity_emb + self.seq_len * 2, self.hidden_emb), requires_grad=True)
        self.Wxabs = nn.Parameter(torch.randn(self.hidden_emb, self.hidden_emb), requires_grad=True)

        self.mlp = linear(residual_channels * 4,residual_channels)
        self.layersnorm = torch.nn.LayerNorm(normalized_shape=[num_nodes,self.hidden_emb], eps=1e-08,elementwise_affine=affine)
        self.convt = convt()

        self.node1 = nn.Parameter(
            torch.randn(num_nodes, self.identity_emb), requires_grad=True)
        self.drop = nn.Dropout(p=dropout)

        self.T_i_D_emb = nn.Parameter(
            torch.empty(288, self.seq_len))
        self.D_i_W_emb = nn.Parameter(
            torch.empty(7, self.seq_len))

        receptive_field = 1
        self.reset_parameter()
        self.supports_len = 0
        if not addaptadj:
            self.supports_len -= 1
        if supports is not None:
            self.supports_len += len(supports)
        if gcn_bool and addaptadj:
            if aptinit is None:
                if supports is None:
                    self.supports = []
                self.nodevec1 = nn.Parameter(
                    torch.randn(num_nodes, self.emb), requires_grad=True)
                self.nodevec2 = nn.Parameter(
                    torch.randn(self.emb, num_nodes), requires_grad=True)
                self.supports_len += 1
            else:
                if supports is None:
                    self.supports = []
                m, p, n = torch.svd(aptinit)
                initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
                initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
                self.nodevec1 = nn.Parameter(initemb1, requires_grad=True)
                self.nodevec2 = nn.Parameter(initemb2, requires_grad=True)
                self.supports_len += 1

        for b in range(blocks):
            additional_scope = kernel_size - 1
            new_dilation = 1
            for i in range(layers):
                # dilated convolutions
                self.filter_convs.append(nn.Conv2d(in_channels=residual_channels,
                                                   out_channels=dilation_channels,
                                                   kernel_size=(1, kernel_size), dilation=new_dilation))

                self.gate_convs.append(nn.Conv2d(in_channels=residual_channels,
                                                 out_channels=dilation_channels,
                                                 kernel_size=(1, kernel_size), dilation=new_dilation))

                # 1x1 convolution for residual connection
                self.residual_convs.append(nn.Conv2d(in_channels=dilation_channels,
                                                     out_channels=residual_channels,
                                                     kernel_size=(1, 1)))

                # 1x1 convolution for skip connection
                self.skip_convs.append(nn.Conv2d(in_channels=dilation_channels,
                                                 out_channels=skip_channels,
                                                 kernel_size=(1, 1)))
                self.bn.append(nn.BatchNorm2d(residual_channels))
                new_dilation *= 2
                receptive_field += additional_scope
                additional_scope *= 2
                if self.gcn_bool:
                    self.gconv.append(
                        gcn(dilation_channels, residual_channels, dropout, support_len=self.supports_len))
        self.end_conv_1 = nn.Conv2d(in_channels=skip_channels,
                                    out_channels=end_channels,
                                    kernel_size=(1, 1),
                                    bias=True)

        self.end_conv_2 = nn.Conv2d(in_channels=end_channels,
                                    out_channels=out_dim,
                                    kernel_size=(1, 1),
                                    bias=True)

        self.receptive_field = receptive_field

    def reset_parameter(self):
        nn.init.xavier_uniform_(self.T_i_D_emb)
        nn.init.xavier_uniform_(self.D_i_W_emb)


    def forward(self, history_data: torch.Tensor) -> torch.Tensor:
        """Feedforward function of DFDGCN; Based on Graph WaveNet

        Args:
            history_data (torch.Tensor): shape [B, L, N, C]

        Graphs:
            predefined graphs: two graphs; [2, N, N] : Pre-given graph structure, including in-degree and out-degree graphs

            self-adaptive graph: [N, N] : Self-Adaptively constructed graphs with two learnable parameters
                torch.mm(self.nodevec1, self.nodevec2)
                    nodevec: [N, Emb]

            dynamic frequency domain graph: [B, N, N] : Data-driven graphs constructed with frequency domain information from traffic data
                traffic_data : [B, N, L]
                frequency domain information : [B, N, L/2.round + 1] ------Embedding ------[B, N, Emb2]
                Identity embedding : learnable parameter [N, Emb3]
                Time embedding : Week and Day : [N, 7] [N, 24(hour) * 12 (60min / 5min due to sampling)] ------Embedding ------ [N, 2 * Emb4]
                Concat frequency domain information + Identity embedding + Time embedding ------Embedding , Activating, Normalization and Dropout
                Conv1d to get adjacency matrix

        Returns:
            torch.Tensor: [B, L, N, 1]
        """
        #num_feat = model_args["num_feat"]
        input = history_data.transpose(1, 3).contiguous()[:,:,:,:]

        data = history_data

        in_len = input.size(3)
        if in_len < self.receptive_field:
            x = nn.functional.pad(
                input, (self.receptive_field-in_len, 0, 0, 0))
        else:
            x = input
        x = self.start_conv(x)

        skip = 0
        if self.gcn_bool and self.addaptadj and self.supports is not None:


            gwadp = F.softmax(
                F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)

            new_supports = self.supports + [gwadp] # pretrained graph in DCRNN and self-adaptive graph in GWNet

            # Construction of dynamic frequency domain graph
            xn1 = input[:, 0, :, -self.seq_len:]

            # T_D = self.T_i_D_emb[(data[:, :, :, 1] * 288).type(torch.LongTensor)][:, -1, :, :]
            T_D = self.T_i_D_emb[(data[:, :, :, 1]).type(torch.LongTensor)][:, -1, :, :]
            D_W = self.D_i_W_emb[(data[:, :, :, 1 + 1]).type(torch.LongTensor)][:, -1, :, :]

            xn1 = torch.fft.rfft(xn1, dim=-1)
            xn1 = torch.abs(xn1)

            xn1 = torch.nn.functional.normalize(xn1, p=2.0, dim=1, eps=1e-12, out=None)
            xn1 = torch.nn.functional.normalize(xn1, p=2.0, dim=2, eps=1e-12, out=None) * self.a

            xn1 = torch.matmul(xn1, self.Ex1)
            xn1k = cat(xn1, self.node1)
            x_n1 = torch.cat([xn1k, T_D, D_W], dim=2)
            x1 = torch.bmm(x_n1.permute(1,0,2),self.Wd).permute(1,0,2)
            x1 = torch.relu(x1)
            x1k = self.layersnorm(x1)
            x1k = self.drop(x1k)
            adp = self.convt(x1k, self.Wxabs)
            adj = torch.bmm(adp, x1.permute(0, 2, 1))
            adp = torch.relu(adj)
            adp = dy_mask_graph(adp, self.subgraph_size)
            adp = F.softmax(adp, dim=2)
            new_supports = new_supports + [adp]

        # WaveNet layers
        for i in range(self.blocks * self.layers):

            # dilated convolution
            residual = x
            filter = self.filter_convs[i](residual)
            filter = torch.tanh(filter)
            gate = self.gate_convs[i](residual)
            gate = torch.sigmoid(gate)
            x = filter * gate

            # parametrized skip connection
            s = x
            s = self.skip_convs[i](s)
            try:
                skip = skip[:, :, :,  -s.size(3):]
            except:
                skip = 0
            skip = s + skip

            if self.gcn_bool and self.supports is not None:
                if self.addaptadj:
                    x = self.gconv[i](x, new_supports)

                else:
                    x = self.gconv[i](x, self.supports)
            else:
                x = self.residual_convs[i](x)
            x = x + residual[:, :, :, -x.size(3):]

            x = self.bn[i](x)

        x = F.relu(skip)
        x = F.relu(self.end_conv_1(x))
        x = self.end_conv_2(x)
        return x

# model = DFDGCN(num_nodes=pems_bay_adj.shape[0], supports=supports, in_dim=combined_data.shape[3])

# model(combined_data).shape

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim

from tqdm.notebook import tqdm

# Данные и параметры
combined_data = combined_data_full[:, 288:2016+288, :, :]
B, L, N, C = combined_data.shape  # [1, 52116, 325, 14]
batch_size = 16
train_ratio = 0.7
val_ratio = 0.2
test_ratio = 0.1
seq_len = 12  # Количество временных шагов на вход
pred_len = 1  # Количество временных шагов для предсказания

# Разделение данных на train, val и test
num_samples = combined_data.shape[1]
train_size = int(num_samples * train_ratio)
val_size = int(num_samples * val_ratio)
test_size = num_samples - train_size - val_size

train_data = combined_data[:, :train_size, :, :]
val_data = combined_data[:, train_size:train_size + val_size, :, :]
test_data = combined_data[:, train_size + val_size:, :, :]


# Создание кастомного Dataset
class TrafficDataset(Dataset):
    def __init__(self, data, seq_len, pred_len):
        super().__init__()
        self.data = data.squeeze(0)  # Убираем батч, форма [L, N, C]
        self.seq_len = seq_len
        self.pred_len = pred_len

    def __len__(self):
        return self.data.shape[0] - self.seq_len - self.pred_len + 1

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.seq_len, :, :]  # Последовательность входных данных
        y = self.data[idx + self.seq_len:idx + self.seq_len + self.pred_len, :, 0]  # Целевая скорость
        return x, y


# Создание DataLoader
train_dataset = TrafficDataset(train_data, seq_len, pred_len)
val_dataset = TrafficDataset(val_data, seq_len, pred_len)
test_dataset = TrafficDataset(test_data, seq_len, pred_len)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Определение устройства
device = torch.device("cuda") if torch.cuda.is_available() else 'cpu'  # Устанавливаем CPU вместо CUDA
print(device)

# Обновление модели, данных и вычислений
supports = [torch.tensor(pems_bay_adj.to_numpy(), dtype=torch.float32)]
model = DFDGCN(num_nodes=N, supports=supports, in_dim=C, out_dim=pred_len).to(device)  # Модель на CPU
criterion = nn.MSELoss()  # Функция потерь
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

cpu


In [None]:
# Тренировочный цикл
def train_model(model, train_loader, val_loader, epochs):
    best_val_loss = float('inf')

    for epoch in range(epochs):
        # Тренировка
        model.train()
        train_loss = 0.0
        train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} - Training", leave=False)

        for x, y in train_loader_tqdm:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(x).squeeze(-1)  # Предсказание модели
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x.size(0)

            # Отображение среднего лосса
            train_loader_tqdm.set_postfix({"Batch Loss": loss.item()})

        train_loss /= len(train_loader.dataset)

        # Валидация
        model.eval()
        val_loss = 0.0
        val_loader_tqdm = tqdm(val_loader, desc=f"Epoch {epoch + 1}/{epochs} - Validation", leave=False)

        with torch.no_grad():
            for x, y in val_loader_tqdm:
                x, y = x.to(device), y.to(device)
                output = model(x).squeeze(-1)
                loss = criterion(output, y)
                val_loss += loss.item() * x.size(0)

                # Отображение среднего лосса
                val_loader_tqdm.set_postfix({"Batch Loss": loss.item()})

        val_loss /= len(val_loader.dataset)

        # Сохранение лучшей модели
        # tqdm.write(f"Epoch {epoch + 1}/{epochs}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_model.pth")

# Тестирование
def test_model(model, test_loader):
    model.load_state_dict(torch.load("best_model.pth"))
    model.eval()
    test_loss = 0.0
    test_loader_tqdm = tqdm(test_loader, desc="Testing")

    with torch.no_grad():
        for x, y in test_loader_tqdm:
            x, y = x.to(device), y.to(device)
            output = model(x).squeeze(-1)
            loss = criterion(output, y)
            test_loss += loss.item() * x.size(0)

            # Отображение среднего лосса
            test_loader_tqdm.set_postfix({"Batch Loss": loss.item()})

    test_loss /= len(test_loader.dataset)
    tqdm.write(f"Test Loss: {test_loss:.4f}")

train_model(model, train_loader, val_loader, epochs=1)
test_model(model, test_loader)

Epoch 1/1 - Training:   0%|          | 0/88 [00:00<?, ?it/s]

Epoch 1/1 - Validation:   0%|          | 0/25 [00:00<?, ?it/s]

  model.load_state_dict(torch.load("best_model.pth"))


Testing:   0%|          | 0/12 [00:00<?, ?it/s]

Test Loss: 7.5325


In [None]:
import torch
from torch.utils.data import DataLoader, Dataset

import torch.nn as nn
import torch.optim as optim

from tqdm.notebook import tqdm

# Данные и параметры
combined_data = combined_data_full[:, 288:2016+288, :, :3]
B, L, N, C = combined_data.shape  # [1, 52116, 325, 3]
batch_size = 32
train_ratio = 0.7
val_ratio = 0.2
test_ratio = 0.1
seq_len = 12  # Количество временных шагов на вход
pred_len = 1  # Количество временных шагов для предсказания

# Разделение данных на train, val и test
num_samples = combined_data.shape[1]
train_size = int(num_samples * train_ratio)
val_size = int(num_samples * val_ratio)
test_size = num_samples - train_size - val_size

train_data = combined_data[:, :train_size, :, :]
val_data = combined_data[:, train_size:train_size + val_size, :, :]
test_data = combined_data[:, train_size + val_size:, :, :]


# Создание кастомного Dataset
class TrafficDataset(Dataset):
    def __init__(self, data, seq_len, pred_len):
        super().__init__()
        self.data = data.squeeze(0)  # Убираем батч, форма [L, N, C]
        self.seq_len = seq_len
        self.pred_len = pred_len

    def __len__(self):
        return self.data.shape[0] - self.seq_len - self.pred_len + 1

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.seq_len, :, :]  # Последовательность входных данных
        y = self.data[idx + self.seq_len:idx + self.seq_len + self.pred_len, :, 0]  # Целевая скорость
        return x, y


# Создание DataLoader
train_dataset = TrafficDataset(train_data, seq_len, pred_len)
val_dataset = TrafficDataset(val_data, seq_len, pred_len)
test_dataset = TrafficDataset(test_data, seq_len, pred_len)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Определение устройства
device = torch.device("cuda")

# Обновление модели, данных и вычислений
supports = [torch.tensor(pems_bay_adj.to_numpy(), dtype=torch.float32)]
model = DFDGCN(num_nodes=N, supports=supports, in_dim=C, out_dim=pred_len).to(device)
criterion = nn.MSELoss()  # Функция потерь
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

In [None]:
train_model(model, train_loader, val_loader, epochs=1)
test_model(model, test_loader)

Epoch 1/1 - Training:   0%|          | 0/176 [00:00<?, ?it/s]

Epoch 1/1 - Validation:   0%|          | 0/50 [00:00<?, ?it/s]

  model.load_state_dict(torch.load("best_model.pth"))


Testing:   0%|          | 0/25 [00:00<?, ?it/s]

Test Loss: 5.0635
