In [None]:
!pip install torch-geometric


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader, Subset
from torch_geometric.nn import GATConv, TransformerConv, GCNConv
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt
from torch_geometric.data import Data, Batch
from sklearn.model_selection import KFold
import copy

In [None]:
torch.manual_seed(42)
np.random.seed(42)

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
def load_edges(edge_file, node_names):
    """
    Lê file de arestas. Suporta:
    - matriz de adjacência CSV (linhas e colunas com nomes de nós) -> retorna directed edges onde weight != 0
    - edge list com colunas (source, target, weight)
    Retorna edge_index (2 x E) e edge_weight (E)
    """
    df = pd.read_csv(edge_file, index_col=0)
    # Se df tem colunas iguais a node_names -> adjacency
    if set(df.columns) >= set(node_names):
        # adjacency matrix
        mat = df.loc[node_names, node_names].values.astype(float)
        src, tgt = np.where(~np.isclose(mat, 0.0))
        weights = mat[src, tgt]
        edge_index = torch.tensor(np.vstack((src, tgt)), dtype=torch.long)
        edge_weight = torch.tensor(weights, dtype=torch.float)
        return edge_index.to(DEVICE), edge_weight.to(DEVICE)
    else:
        # tentar edge list
        df2 = pd.read_csv(edge_file)
        # requer colunas source,target,weight
        if set(["source","target","weight"]) <= set(df2.columns):
            # map node names para índices
            name_to_idx = {n:i for i,n in enumerate(node_names)}
            src = [name_to_idx[s] for s in df2["source"].values]
            tgt = [name_to_idx[t] for t in df2["target"].values]
            weights = df2["weight"].values.astype(float)
            edge_index = torch.tensor([src, tgt], dtype=torch.long)
            edge_weight = torch.tensor(weights, dtype=torch.float)
            return edge_index.to(DEVICE), edge_weight.to(DEVICE)
        else:
            raise ValueError("Formato de edges não reconhecido. Forneça adjacency matrix (csv) ou edge list com colunas source,target,weight.")


In [None]:
class MultiPacienteTemporalDataset(Dataset):
    def __init__(self, data_dict, seq_len=5, horizon=1):
        self.samples = []
        for pid, seq in data_dict.items():
            for t in range(len(seq) - seq_len - horizon + 1):
                X = seq[t:t+seq_len]       # [seq_len, N_nodes]
                y = seq[t+seq_len:t+seq_len+horizon]  # [horizon, N_nodes]
                self.samples.append((X, y))
        self.seq_len = seq_len
        self.horizon = horizon

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        X, y = self.samples[idx]
        return X.float(), y.float()

In [None]:
def train_epoch(model, train_loader, criterion, optimizer):
    model.train()
    total_loss = 0
    for X, y in train_loader:
        X = X.unsqueeze(-1).to(DEVICE)  # [batch, seq_len, num_nodes, 1]
        y = y[:, -1, :].unsqueeze(-1).to(DEVICE)  # [batch, num_nodes, 1]

        optimizer.zero_grad()
        output = model(X)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

def validate_epoch(model, val_loader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for X, y in val_loader:
            X = X.unsqueeze(-1).to(DEVICE)
            y = y[:, -1, :].unsqueeze(-1).to(DEVICE)
            output = model(X)
            loss = criterion(output, y)
            total_loss += loss.item()

    return total_loss / len(val_loader)

def train_model(model, train_loader, val_loader, num_epochs=100, patience=10, printepochs=True):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, criterion, optimizer)
        val_loss = validate_epoch(model, val_loader, criterion)

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        if printepochs and epoch % 10 == 0:
            print(f'Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = copy.deepcopy(model.state_dict())
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch}')
            break

    if best_model_state:
        model.load_state_dict(best_model_state)

    return model, train_losses, val_losses

def train_with_cross_validation(model_class, dataset, node_names, scalers=None,
                                n_splits=5, test_size=0.2,
                                num_epochs=100, patience=10, batch_size=32):
    """
    Executa validação cruzada mantendo conjunto de teste fixo
    """
    # Primeiro, separar conjunto de teste (fixo)
    n_total = len(dataset)
    n_test = max(1, int(test_size * n_total))  # Garantir pelo menos 1 amostra de teste
    n_train_val = n_total - n_test

    # Criar índice para split fixo
    indices = list(range(n_total))
    np.random.seed(42)
    np.random.shuffle(indices)

    test_idx = indices[:n_test]
    train_val_idx = indices[n_test:]

    test_subset = Subset(dataset, test_idx)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

    # Ajustar n_splits se necessário
    n_splits = min(n_splits, n_train_val)  # Não pode ter mais splits que amostras

    if n_splits < 2:
        print("Aviso: dados insuficientes para validação cruzada. Usando holdout simples.")
        # Treinar com todos os dados de treino/validação
        train_subset = Subset(dataset, train_val_idx)
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)

        # Usar o mesmo conjunto para validação (não ideal, mas funciona)
        val_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=False)

        # Inicializar e treinar modelo
        model = copy.deepcopy(model_class).to(DEVICE)
        trained_model, train_losses, val_losses = train_model(
            model, train_loader, val_loader, num_epochs, patience
        )

        fold_results = [{
            'fold': 0,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'best_val_loss': min(val_losses) if val_losses else np.inf,
            'final_val_loss': val_losses[-1] if val_losses else np.inf
        }]

        best_model = trained_model
    else:
        # K-Fold apenas nos dados de treino/validação
        kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)

        fold_results = []
        cv_models = []

        for fold, (train_idx, val_idx) in enumerate(kfold.split(train_val_idx)):
            print(f"\nFold {fold + 1}/{n_splits}")
            print("-" * 40)

            # Mapear índices de volta para o dataset original
            train_indices_fold = [train_val_idx[i] for i in train_idx]
            val_indices_fold = [train_val_idx[i] for i in val_idx]

            # Criar subsets
            train_subset = Subset(dataset, train_indices_fold)
            val_subset = Subset(dataset, val_indices_fold)

            # Criar loaders
            train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
            val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

            # Inicializar modelo
            model = copy.deepcopy(model_class).to(DEVICE)

            # Treinar
            trained_model, train_losses, val_losses = train_model(
                model, train_loader, val_loader, num_epochs, patience
            )

            # Avaliar no fold de validação
            trained_model.eval()
            val_loss = 0
            criterion = nn.MSELoss()

            with torch.no_grad():
                for X, y in val_loader:
                    X = X.unsqueeze(-1).to(DEVICE)
                    y = y[:, -1, :].unsqueeze(-1).to(DEVICE)
                    val_loss += criterion(trained_model(X), y).item()

            val_loss /= len(val_loader)

            fold_results.append({
                'fold': fold,
                'train_losses': train_losses,
                'val_losses': val_losses,
                'best_val_loss': min(val_losses) if val_losses else np.inf,
                'final_val_loss': val_loss
            })

            cv_models.append(trained_model)

        # Encontrar o melhor modelo
        best_fold = np.argmin([res['best_val_loss'] for res in fold_results])
        best_model = cv_models[best_fold]

    print(f"\n{'='*50}")
    print(f"Tamanho do dataset: {n_total} amostras")
    print(f"Amostras de treino/validação: {n_train_val}")
    print(f"Amostras de teste: {n_test}")
    print(f"Número de folds: {n_splits}")
    print(f"{'='*50}")

    # Avaliar no conjunto de teste
    evaluate_model(best_model, test_loader, node_names, scalers)

    return fold_results, best_model, test_loader

def evaluate_model(model, test_loader, node_names, scalers=None):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for X, y in test_loader:
            X = X.unsqueeze(-1).to(DEVICE)
            y = y[:, -1, :].cpu().numpy()
            pred = model(X).cpu().numpy().squeeze(-1)
            y_true.append(y)
            y_pred.append(pred)

    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    # --- se quiser reverter normalização ---
    if scalers is not None:
        for i, s in enumerate(node_names):
            scaler = scalers[s]
            y_true[:, i] = scaler.inverse_transform(y_true[:, i].reshape(-1, 1)).squeeze()
            y_pred[:, i] = scaler.inverse_transform(y_pred[:, i].reshape(-1, 1)).squeeze()

    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    r2 = r2_score(y_true, y_pred)

    print(f"\nTEST - MAE: {mae:.4f} - RMSE: {rmse:.4f} - R²: {r2:.4f}")
    for i, name in enumerate(node_names):
        mae_i = mean_absolute_error(y_true[:, i], y_pred[:, i])
        r2_i = r2_score(y_true[:, i], y_pred[:, i])
        print(f"{name:15s} MAE: {mae_i:.4f}   R²: {r2_i:.4f}")



In [None]:


def normalize_patients(df, sintomas):
    """
    Normaliza cada sintoma entre [0,1] globalmente (todos os pacientes).
    Retorna df_normalizado e scalers para inversão posterior.
    """
    scalers = {}
    df_norm = df.copy()
    for s in sintomas:
        scaler = MinMaxScaler()
        vals = df[[s]].values
        df_norm[s] = scaler.fit_transform(vals)
        scalers[s] = scaler
    return df_norm, scalers

In [None]:
def impute_timeseries(df_symptoms):
    """
    df_symptoms: DataFrame com colunas de sintomas e time.
    Faz interpolação linear por coluna e preenche bordas com média.
    Retorna matriz (T x n_symptoms) com floats.
    """
    X = df_symptoms.copy()
    time_col = [c for c in X.columns if "Time" in c or "mid" in c]
    # assume última coluna é tempo; remover antes de processar
    if len(time_col)>0:
        tcol = time_col[0]
        Xvals = X.drop(columns=[tcol])
    else:
        Xvals = X

    # substituir strings "NA" etc para NaN
    Xvals = Xvals.replace(["NA","NaN","nan",""], np.nan)
    Xvals = Xvals.astype(float)

    # interpolação temporal
    Xvals = Xvals.interpolate(method='linear', limit_direction='both', axis=0)

    # resto NaNs -> preencher com média da coluna
    Xvals = Xvals.fillna(Xvals.mean())

    return Xvals.values  # T x n_symptoms

In [None]:
df = pd.read_csv('DataS1.TXT', sep=',')

SINTOMAS = ['cheerful', 'pleasant_event', 'worry', 'fearful', 'sad', 'relaxed']

column_map = {
        'subjno': 'subject',
        'dayno': 'day',
        'beepno': 'beep',
        'informat04': 'therapy',  # 0=control, 1=therapy
        'st_period': 'period',    # 0=baseline, 1=post-baseline
        'opgewkt_': 'cheerful',
        'onplplez': 'pleasant_event',
        'pieker': 'worry',
        'angstig_': 'fearful',
        'somber__': 'sad',
        'ontspann': 'relaxed',
        'neur': 'neuroticism'
}

df = df.rename(columns=column_map)

df[SINTOMAS] = df[SINTOMAS].interpolate(method="linear")

df[SINTOMAS] = df[SINTOMAS].fillna(df.mean())

In [None]:
df = pd.read_csv('DataS1.TXT', sep=',')

SINTOMAS = ['cheerful', 'pleasant_event', 'worry', 'fearful', 'sad', 'relaxed']

column_map = {
        'subjno': 'subject',
        'dayno': 'day',
        'beepno': 'beep',
        'informat04': 'therapy',  # 0=control, 1=therapy
        'st_period': 'period',    # 0=baseline, 1=post-baseline
        'opgewkt_': 'cheerful',
        'onplplez': 'pleasant_event',
        'pieker': 'worry',
        'angstig_': 'fearful',
        'somber__': 'sad',
        'ontspann': 'relaxed',
        'neur': 'neuroticism'
}

df = df.rename(columns=column_map)

df.loc[df['period'] == 1, 'subject'] = df.loc[df['period'] == 1, 'subject'] * 1000

df[SINTOMAS] = df[SINTOMAS].interpolate(method="linear")

df_norm, scalers = normalize_patients(df, SINTOMAS)

# preparar dados
data_dict = {}
for pid, df_sub in df_norm.groupby("subject"):
#     # Use impute_timeseries to handle NaNs consistently
    seq = torch.tensor(impute_timeseries(df_sub[SINTOMAS]))
    data_dict[pid] = seq

edge_index, edge_weight = load_edges('coef_matrix.csv', SINTOMAS)

# Define lookback window
lookback_window = 5
horizon = 1

dataset = MultiPacienteTemporalDataset(data_dict, seq_len=lookback_window, horizon=horizon)

In [None]:
class SpatialTemporalCGNModel(nn.Module):
    def __init__(self, num_nodes, hidden_channels, num_gcn_layers, edge_index,
                 edge_weight=None, device='cpu', dropout=0.1):
        super().__init__()
        self.num_nodes = num_nodes
        self.edge_index = edge_index
        self.edge_weight = edge_weight
        self.device = device
        self.num_gcn_layers = num_gcn_layers

        # Múltiplas camadas GCN
        self.gcn_layers = nn.ModuleList()

        # Primeira camada: 1 -> hidden_channels
        self.gcn_layers.append(GCNConv(1, hidden_channels))

        # Camadas intermediárias: hidden_channels -> hidden_channels
        for _ in range(num_gcn_layers - 1):
            self.gcn_layers.append(GCNConv(hidden_channels, hidden_channels))

        self.dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_channels, hidden_channels, batch_first=True)
        self.fc = nn.Linear(hidden_channels, 1)

    def forward(self, X):
        batch_size, seq_len, num_nodes, _ = X.shape
        gcn_outs = []

        for t in range(seq_len):
            x_t = X[:, t, :, :].reshape(batch_size * num_nodes, 1)

            data_list = []
            for b in range(batch_size):
                x_nodes = x_t[b * num_nodes:(b+1) * num_nodes]
                data_list.append(Data(x=x_nodes, edge_index=self.edge_index,
                                     edge_attr=self.edge_weight))

            batched_graph = Batch.from_data_list(data_list)

            # Forward através de múltiplas camadas GCN
            h = batched_graph.x
            for i, gcn_layer in enumerate(self.gcn_layers):
                h = gcn_layer(h, batched_graph.edge_index, batched_graph.edge_attr)
                if i < len(self.gcn_layers) - 1:  # Não aplica dropout na última
                    h = torch.relu(h)
                    h = self.dropout(h)
                else:
                    h = torch.relu(h)  # Última camada só relu

            h = h.view(batch_size, num_nodes, -1)
            gcn_outs.append(h)

        # Resto permanece igual...
        gcn_seq = torch.stack(gcn_outs, dim=1)
        gru_input = gcn_seq.permute(0,2,1,3).reshape(batch_size * num_nodes, seq_len, -1)
        gru_out, _ = self.gru(gru_input)
        last = gru_out[:, -1, :]
        out = self.fc(last)
        return out.view(batch_size, num_nodes, 1)

In [None]:
model = SpatialTemporalCGNModel(
    num_nodes=len(SINTOMAS), # Pass num_nodes
    hidden_channels=8, # Example hidden size
    edge_index=edge_index,
    edge_weight=edge_weight,
    num_gcn_layers=8,
    device= DEVICE
).to(DEVICE)

n_samples = len(dataset)
n_splits = min(5, n_samples // 2)  # Ajuste automático baseado no tamanho do dataset

fold_results, best_model, test_loader = train_with_cross_validation(
    model_class=model,
    dataset=dataset,
    node_names=SINTOMAS,
    scalers=scalers,
    n_splits=n_splits,  # Ajustado automaticamente
    test_size=0.2,
    num_epochs=100,
    patience=10,
    batch_size=32
)

In [None]:
class TemporalAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # x: [B, T, N, F]
        B, T, N, F = x.shape

        # Transpor para que cada nó veja sua sequência temporal
        x_temp = x.transpose(1, 2)  # -> [B, N, T, F]

        Q = self.query(x_temp)
        K = self.key(x_temp)
        V = self.value(x_temp)

        # Atenção temporal (T x T)
        attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(F)
        attn_weights = self.softmax(attn_scores)
        temporal_out = torch.matmul(attn_weights, V)

        return temporal_out.mean(dim=2)  # [B, N, F]


In [None]:
import torch.nn.functional as Fnn


class AttentiveTemporalGATGNN(nn.Module):
    def __init__(self, num_nodes, in_channels=1, hidden_channels=32, out_channels=1,
                 num_gat_layers=2,  # ✅ NOVO: número de camadas GAT
                 edge_index=None, edge_attr=None, dropout=0.1):
        super().__init__()
        self.edge_index = edge_index
        self.edge_attr = edge_attr
        self.num_nodes = num_nodes
        self.hidden_channels = hidden_channels
        self.num_gat_layers = num_gat_layers

        edge_dim = edge_attr.size(1) if edge_attr is not None else None

        # ✅ Múltiplas camadas GAT
        self.gat_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()

        # Primeira camada
        self.gat_layers.append(
            GATConv(
                in_channels=in_channels,
                out_channels=hidden_channels,
                heads=2,
                concat=False,
                edge_dim=edge_dim,
                dropout=dropout
            )
        )
        self.norm_layers.append(nn.LayerNorm(hidden_channels))

        # Camadas intermediárias
        for _ in range(num_gat_layers - 1):
            self.gat_layers.append(
                GATConv(
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    heads=2,
                    concat=False,
                    edge_dim=edge_dim,
                    dropout=dropout
                )
            )
            self.norm_layers.append(nn.LayerNorm(hidden_channels))

        self.dropout = nn.Dropout(dropout)
        self.temporal_attn = TemporalAttention(hidden_channels)
        self.gru = nn.GRU(num_nodes * hidden_channels, hidden_channels, batch_first=True)
        self.linear = nn.Linear(hidden_channels, out_channels)

    def forward(self, X_seq):
        B, T, N, F = X_seq.size()
        gat_outs = []

        for t in range(T):
            x_t = X_seq[:, t, :, :].reshape(B * N, F)

            # Repete o grafo para cada batch
            edge_indices, edge_attrs = [], []
            for i in range(B):
                edge_indices.append(self.edge_index + i * self.num_nodes)
                if self.edge_attr is not None:
                    edge_attrs.append(self.edge_attr)

            batched_edge_index = torch.cat(edge_indices, dim=1)
            batched_edge_attr = torch.cat(edge_attrs, dim=0) if edge_attrs else None

            if batched_edge_attr is not None and batched_edge_attr.dim() == 1:
                batched_edge_attr = batched_edge_attr.unsqueeze(-1)

            # ✅ Forward através de múltiplas camadas GAT
            for i, (gat_layer, norm_layer) in enumerate(zip(self.gat_layers, self.norm_layers)):
                x_t = gat_layer(x_t, batched_edge_index, edge_attr=batched_edge_attr)

                if i < len(self.gat_layers) - 1:  # Não na última camada
                    x_t = norm_layer(x_t)
                    x_t = Fnn.elu(x_t)  # ELU comum em GATs
                    x_t = self.dropout(x_t)
                else:
                    x_t = Fnn.elu(x_t)  # Última camada só ELU

            x_t = x_t.view(B, N, -1)
            gat_outs.append(x_t)

        H = torch.stack(gat_outs, dim=1)
        attn_out = self.temporal_attn(H)

        H_seq = H.view(B, T, self.num_nodes * self.hidden_channels)
        _, h_n = self.gru(H_seq)
        h_n = h_n.squeeze(0)

        h_n_repeated = h_n.unsqueeze(1).repeat(1, N, 1)
        combined = attn_out + h_n_repeated
        out = self.linear(combined)
        return out

In [None]:
edge_attr = edge_weight.unsqueeze(-1)

model = AttentiveTemporalGATGNN(
    num_nodes=len(SINTOMAS), # Pass num_nodes
    in_channels=1, # Feature dimension is 1
    hidden_channels=8, # Example hidden size
    out_channels=1, # Predicting one value per node
    num_gat_layers=4,
    edge_index=edge_index,
    edge_attr=edge_attr
).to(DEVICE)

n_samples = len(dataset)
n_splits = min(5, n_samples // 2)  # Ajuste automático baseado no tamanho do dataset

fold_results, best_model, test_loader = train_with_cross_validation(
    model_class=model,
    dataset=dataset,
    node_names=SINTOMAS,
    scalers=scalers,
    n_splits=n_splits,  # Ajustado automaticamente
    test_size=0.2,
    num_epochs=100,
    patience=10,
    batch_size=32
)

In [None]:
import torch.nn.functional as Fnn

class AttentiveTemporalTransformerGNN(nn.Module):
    def __init__(self, num_nodes, in_channels=1, hidden_channels=32, out_channels=1,
                 num_gnn_layers=2,  # ✅ NOVO: número de camadas GNN
                 edge_index=None, edge_attr=None, dropout=0.1):
        super().__init__()
        self.edge_index = edge_index
        self.edge_attr = edge_attr
        self.num_nodes = num_nodes
        self.hidden_channels = hidden_channels
        self.num_gnn_layers = num_gnn_layers

        # ✅ Múltiplas camadas TransformerConv
        self.gnn_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()  # LayerNorm para estabilidade

        # Primeira camada
        self.gnn_layers.append(
            TransformerConv(
                in_channels=in_channels,
                out_channels=hidden_channels,
                heads=2,
                concat=False,
                dropout=dropout,
                edge_dim=1,
                beta=True
            )
        )
        self.norm_layers.append(nn.LayerNorm(hidden_channels))

        # Camadas intermediárias
        for _ in range(num_gnn_layers - 1):
            self.gnn_layers.append(
                TransformerConv(
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    heads=2,
                    concat=False,
                    dropout=dropout,
                    edge_dim=1,
                    beta=True
                )
            )
            self.norm_layers.append(nn.LayerNorm(hidden_channels))

        self.dropout = nn.Dropout(dropout)
        self.temporal_attn = TemporalAttention(hidden_channels)
        self.gru = nn.GRU(num_nodes * hidden_channels, hidden_channels, batch_first=True)
        self.linear = nn.Linear(hidden_channels, out_channels)

    def forward(self, X_seq):
        B, T, N, F = X_seq.size()
        gnn_outs = []

        for t in range(T):
            x_t = X_seq[:, t, :, :].reshape(B * N, F)

            # Criar grafo batched
            batched_edge_index, batched_edge_attr = self._create_batched_graph(B)

            # ✅ Forward através de múltiplas camadas
            h = x_t
            for i, (gnn_layer, norm_layer) in enumerate(zip(self.gnn_layers, self.norm_layers)):
                h = gnn_layer(h, batched_edge_index, edge_attr=batched_edge_attr)

                if i < len(self.gnn_layers) - 1:  # Não na última camada
                    h = norm_layer(h)
                    h = Fnn.elu(h)  # ELU comum em Transformers
                    h = self.dropout(h)
                else:
                    h = Fnn.elu(h)  # Última camada só ativação

            h = h.view(B, N, -1)
            gnn_outs.append(h)

        H = torch.stack(gnn_outs, dim=1)
        attn_out = self.temporal_attn(H)

        H_seq = H.view(B, T, self.num_nodes * self.hidden_channels)
        _, h_n = self.gru(H_seq)
        h_n = h_n.squeeze(0)

        h_n_repeated = h_n.unsqueeze(1).repeat(1, N, 1)
        combined = attn_out + h_n_repeated
        out = self.linear(combined)
        return out

    def _create_batched_graph(self, batch_size):
        """Helper para criar grafo batched"""
        edge_indices = []
        edge_attrs = []

        for i in range(batch_size):
            edge_indices.append(self.edge_index + i * self.num_nodes)
            if self.edge_attr is not None:
                edge_attrs.append(self.edge_attr)

        batched_edge_index = torch.cat(edge_indices, dim=1)
        batched_edge_attr = torch.cat(edge_attrs, dim=0) if edge_attrs else None

        if batched_edge_attr is not None and batched_edge_attr.dim() == 1:
            batched_edge_attr = batched_edge_attr.unsqueeze(-1)

        return batched_edge_index, batched_edge_attr

In [None]:
edge_attr = edge_weight.unsqueeze(-1)

model = AttentiveTemporalTransformerGNN(
    num_nodes=len(SINTOMAS), # Pass num_nodes
    in_channels=1, # Feature dimension is 1
    hidden_channels=8, # Example hidden size
    out_channels=1, # Predicting one value per node
    edge_index=edge_index,
    edge_attr=edge_attr,
    num_gnn_layers=4,
).to(DEVICE)

n_samples = len(dataset)
n_splits = min(5, n_samples // 2)  # Ajuste automático baseado no tamanho do dataset

fold_results, best_model, test_loader = train_with_cross_validation(
    model_class=model,
    dataset=dataset,
    node_names=SINTOMAS,
    scalers=scalers,
    n_splits=n_splits,  # Ajustado automaticamente
    test_size=0.2,
    num_epochs=100,
    patience=10,
    batch_size=32
)