In [1]:
# basic
import os
import warnings
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
# pre processing
from sklearn import preprocessing as pre
# NN
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import MSELoss
from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool
# val and plot
from sklearn.metrics import (
    mean_absolute_error,
    mean_squared_error,
    r2_score,
    median_absolute_error,
    explained_variance_score,
    mean_absolute_percentage_error
)
from loguru import logger as log
# plot
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from torch_geometric.data import DataLoader
from torch_geometric.data import Batch
from torch_geometric.data import Data
from torchviz import make_dot
from tqdm.notebook import trange  # opcional, pra barra de progresso

from torch_geometric.nn import SAGEConv, LEConv, GlobalAttention
from torch_geometric.data import Batch

warnings.filterwarnings("ignore")

In [2]:
# set seed
SEED = 1345
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(SEED)
plt.style.use('ggplot')
pd.set_option('display.float_format', '{:.16f}'.format)

In [3]:
def load_dataset(fpath):
    # Load the StaticGraphTemporalSignal object from the file
    with open(fpath, 'rb') as f:
        loaded_temporal_signal = pickle.load(f)
    return loaded_temporal_signal

In [6]:
c = 51

In [7]:
train_dataset = load_dataset(f'dataset_train_{c}_time_sts.pkl')
test_dataset = load_dataset(f'dataset_test_{c}_time_sts.pkl')

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, GlobalAttention, JumpingKnowledge
# opcional: GraphNorm funciona bem quando há batchs grandes
# from torch_geometric.nn import GraphNorm

class GraphSAGEForecast(nn.Module):
    def __init__(self, in_channels, hidden_channels, horizon=100, out_dims=3,
                 num_layers=3, dropout=0.2):
        super().__init__()
        self.horizon = horizon
        self.out_dims = out_dims
        self.dropout_p = dropout

        # ----- Encoder GNN -----
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()

        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.norms.append(nn.LayerNorm(hidden_channels))  # GraphNorm se preferir

        for _ in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
            self.norms.append(nn.LayerNorm(hidden_channels))

        # Jumping Knowledge para combinar representações de todas as camadas
        self.jk = JumpingKnowledge(mode='cat')
        self.proj_jk = nn.Linear(hidden_channels * num_layers, hidden_channels)

        # Gating mais expressivo para a atenção global
        self.pool = GlobalAttention(
            gate_nn=nn.Sequential(
                nn.Linear(hidden_channels, hidden_channels // 2),
                nn.GELU(),
                nn.Linear(hidden_channels // 2, 1)
            )
        )

        self.dropout = nn.Dropout(dropout)

        # ----- Decoder p/ horizonte -----
        self.decoder = nn.Sequential(
            nn.Linear(hidden_channels, 2 * hidden_channels),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(2 * hidden_channels, self.out_dims * self.horizon)
        )

        # init um pouco melhor p/ lineares
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x, edge_index, batch=None):
        if batch is None:
            # se vier um único grafo, crie um batch “fake”
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)

        xs = []
        h = x
        for conv, norm in zip(self.convs, self.norms):
            h_res = h
            h = conv(h, edge_index)
            h = norm(h)
            h = F.gelu(h)
            h = self.dropout(h)
            # pequeno residual ajuda estabilidade (não mudar dims!)
            h = h + h_res if h.shape == h_res.shape else h
            xs.append(h)

        h = self.jk(xs)          # [num_nodes_total, hidden * num_layers]
        h = F.gelu(self.proj_jk(h))
        g = self.pool(h, batch)   # [B, hidden]

        out = self.decoder(g)     # [B, 3*horizon]
        out = out.view(g.size(0), self.out_dims, self.horizon)  # [B, 3, H]
        return out

In [9]:
device = 'cuda:2'
device

'cuda:2'

In [10]:
model = GraphSAGEForecast(
    in_channels=100,
    hidden_channels=128,
    horizon=100,      # 100 passos à frente
    out_dims=1,       # 3 dimensões no espaço de fase
    num_layers=3,
    dropout=0.2
).to(device)

In [11]:
def trajectory_loss(y_pred, y_true, lam_smooth=0.1):
    # garante batch dimension quando vier 2D
    if y_pred.dim() == 2:
        y_pred = y_pred.unsqueeze(0)   # [1, 3, H]
    if y_true.dim() == 2:
        y_true = y_true.unsqueeze(0)   # [1, 3, H]

    # MSE básico
    base = F.mse_loss(y_pred, y_true)

    # termo de suavidade temporal (diferenças finitas)
    dp = y_pred[:, :, 1:] - y_pred[:, :, :-1]
    dt = y_true[:, :, 1:] - y_true[:, :, :-1]
    smooth = F.mse_loss(dp, dt)

    return base + lam_smooth * smooth

In [12]:
# model = GraphTemporalModel(
#     node_features=100,  # conforme seu x=[12, 100]
#     horizon=100,
#     num_targets=3
# ).to(device)

# optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
max_grad_norm = 1.0

In [14]:
train_dataset.targets = [
    t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t
    for t in train_dataset.targets
]

In [15]:
training_loss = []
# for epoch in tqdm(range(100)):
#     model.train()
#     cost = 0
#     h, c = None, None
#     for time, snapshot in enumerate(train_dataset): # faz o treino em cada bath temporal
#         snapshot.to(device)
#         y_hat = model(snapshot.x, snapshot.edge_index)
#         #print(f"y_hat: {y_hat.shape} snapshot: {snapshot.y.shape}")
#         cost = cost + torch.mean((y_hat-snapshot.y)**2)
        
nb_epocas = 120
for epoch in range(nb_epocas):
    model.train()
    train_loss = 0
    for snapshot in train_dataset:
        snapshot = snapshot.to(device)
        optimizer.zero_grad()
        y_hat = model(snapshot.x, snapshot.edge_index)
        #loss = torch.mean((y_hat - snapshot.y)**2)
        loss = trajectory_loss(y_hat, snapshot.y, lam_smooth=0.1)
        #loss = F.mse_loss(y_hat, snapshot.y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        training_loss.append(train_loss)
    if (epoch+1) % 30 == 0:
        print(f"Epoch {epoch}, Train Loss: {train_loss/train_dataset.snapshot_count:.4f}")

Epoch 29, Train Loss: 0.7500
Epoch 59, Train Loss: 0.7418
Epoch 89, Train Loss: 0.7354
Epoch 119, Train Loss: 0.7334


In [16]:
def test(model, test_dataset, device):
    model.eval()  # Modo de avaliação (desativa dropout, batchnorm, etc.)
    test_loss = 0
    predictions = []
    ground_truths = []

    with torch.no_grad():  # Desativa cálculo de gradientes (economiza memória)
        for time, snapshot in enumerate(test_dataset):
            snapshot = snapshot.to(device)
            y_hat = model(snapshot.x, snapshot.edge_index)  # Forward pass
            loss = torch.mean((y_hat - snapshot.y)**2)  # MSE
            test_loss += loss.item()

            # Guarda previsões e valores reais para métricas adicionais
            predictions.append(y_hat.cpu().numpy())
            ground_truths.append(snapshot.y.cpu().numpy())

    # Calcula a média do erro sobre todos os snapshots
    test_loss /= test_dataset.snapshot_count
    print(f"\nTest Loss (MSE): {test_loss:.4f}")

    # Converte listas para arrays numpy (opcional, útil para análise)
    predictions = np.stack(predictions)
    ground_truths = np.stack(ground_truths)

    return test_loss, predictions, ground_truths

In [18]:
test_dataset.targets = [
    t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t
    for t in test_dataset.targets
]

In [19]:
test_loss, y_pred, y_true = test(model, test_dataset, device)


Test Loss (MSE): 0.3396


In [20]:
y_pred_src = y_pred.copy()
y_true_src = y_true.copy()

In [23]:
y_pred_src.shape, y_true_src.shape

((300, 1, 1, 100), (300, 51, 100))

In [21]:
idx = 0

ytrues = []
ypreds = []

for i in range(3):
    print(f"shape idx: {idx}")
    y_true = y_true_src[idx,2,:].tolist()
    y_pred = y_pred_src[idx,2,:].tolist()
    
    ytrues.extend(y_true)
    ypreds.extend(y_pred)
    
    idx += 100
    if idx == 100:
        idx -= 1
    

shape idx: 0


IndexError: index 2 is out of bounds for axis 1 with size 1