# Zona de import

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# Graph transformer layer

In [None]:
class GraphTransformerLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=n_heads,
            dropout=dropout,
            batch_first=False  # MultiheadAttention usa (L, B, D)
        )
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, adj_mask=None): # ojo que adj_mask es un tensor (N,N)
        #x: tensor (B, N, d_model)
        #adj_mask: tensor (N, N) con 0 en aristas válidas y -inf donde NO hay arista
        #(misma para todo el batch, grafo estructural fijo)

        B, N, D = x.shape

        # MultiheadAttention espera (L, B, D)
        x_t = x.permute(1, 0, 2)  # (N, B, D)

        # Self-attention
        attn_output, attn_weights = self.self_attn(
            x_t, x_t, x_t, attn_mask=adj_mask
        )
        x_t = x_t + self.dropout(attn_output)
        x_t = self.norm1(x_t)

        # feed-forward
        ff = self.linear2(F.relu(self.linear1(x_t)))
        x_t = x_t + self.dropout(ff)
        x_t = self.norm2(x_t)

        # devolvemos el (B, N, D)
        out = x_t.permute(1, 0, 2)
        return out


# Encoder de los nodos

In [None]:
class NodeEncoder(nn.Module):
    def __init__(self, d_dyn, d_static, d_img, d_model):
        super().__init__()
        self.lin_x = nn.Linear(d_dyn, d_model)
        self.lin_e = nn.Linear(d_static, d_model)
        self.lin_I = nn.Linear(d_img, d_model)

    def forward(self, X_t, E, I_t):
        #X_t: (B, N, d_dyn) features dinámicas en tiempo t
        #E:   (B, N, d_static) features estáticas del nodo (anatómicas, tipo)
        #I_t: (B, d_img) embedding del frame en t

        B, N, _ = X_t.shape

        h_x = self.lin_x(X_t)# (B, N, d_model) ambos h
        h_e = self.lin_e(E)

        # broadcast de I_t a todos los nodos
        I_expanded = I_t.unsqueeze(1).expand(B, N, -1)  #(B, N, d_img)
        h_I = self.lin_I(I_expanded)  # (B, N, d_model)

        H0 = h_x + h_e + h_I
        return H0


# graph recurrent

In [None]:
class RecurrentGraphTransformer(nn.Module):
    def __init__(
        self,
        d_dyn,        # dim features dinámicas X
        d_static,     # dim features estáticas E
        d_img,        # dim embeddings de frame I
        d_model=256,
        n_heads=4,
        n_layers=3,
        d_ff=512,
        dropout=0.1,
        d_out=1 # esto es para que la salida se active
    ):
        super().__init__()

        self.d_model = d_model
        self.encoder = NodeEncoder(d_dyn, d_static, d_img, d_model) # Encoder nodos (X + E + I)

        # Stack de capas Graph Transformer
        self.layers = nn.ModuleList([
            GraphTransformerLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])

        # Decoder para predecir
        self.decoder = nn.Sequential(nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, d_out))

    def forward(self, X, E, I, adj=None):
        #X:  (B, T, N, d_dyn) es trayectoria de features dinámicos
        #E:  (B, N, d_static) es las features estáticas de nodos
        #I:  (B, T, d_img) embeddings de frame por tiempo
        #adj: (N, N) o None matriz de adyacencia (0/1)

        #Returna preds: (B, T-1, N, d_out) con predicciones de X_{t+1}
        B, T, N, d_dyn = X.shape

        if adj is not None:
            # adj: 1 arista, 0 no arista.
            adj = adj.to(X.device)
            attn_mask = (adj == 0).float() * -1e9  # (N, N)
        else:
            attn_mask = None
        H_prev = torch.zeros(B, N, self.d_model, device=X.device)
        preds = []

        # Recorremos tiempos t = 0..T-2 para predecir X_{t+1}
        for t in range(T - 1):
            X_t = X[:, t]       # (B, N, d_dyn)
            I_t = I[:, t]       # (B, d_img)

            H0_t = self.encoder(X_t, E, I_t)  # (B, N, d_model)
            H_t = H0_t + H_prev # hacemos la suma

            #pasamos por L capas del graph transformer
            for layer in self.layers:
                H_t = layer(H_t, adj_mask=attn_mask)

            #Prediccion siguiente
            Y_t = self.decoder(H_t)  # (B, N, d_out)
            preds.append(Y_t.unsqueeze(1))  # (B, 1, N, d_out)
            H_prev = H_t #actualizamos
        preds = torch.cat(preds, dim=1)  # (B, T-1, N, d_out)
        return preds


# Función de entrenamiento

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

def train_recurrent_graph_transformer(model, X, E, I, adj, num_epochs=50, lr=1e-3,
    batch_size=None, loss_fn=None, target_feature_idx=0, verbose=True):
    """
    Entrena un RecurrentGraphTransformer para predecir X_{t+1} a partir de X_t, E, I_t.

    Parámetros
    ----------
    model : nn.Module
        Instancia de RecurrentGraphTransformer.
    X : torch.Tensor
        Tensor de shape (B, T, N, d_dyn) con la historia de features dinámicos.
    E : torch.Tensor
        Tensor de shape (B, N, d_static) con features estáticas de cada nodo.
    I : torch.Tensor
        Tensor de shape (B, T, d_img) con embeddings de frame por tiempo.
    adj : torch.Tensor
        Tensor de shape (N, N) con la matriz de adyacencia (0/1).
    num_epochs : int
        Número de épocas de entrenamiento.
    lr : float
        Learning rate del optimizador Adam.
    batch_size : int or None
        Tamaño de batch. Si es None, usa todo el batch de una vez.
    loss_fn : callable or None
        Función de pérdida. Si es None, usa BCEWithLogitsLoss.
    target_feature_idx : int
        Índice de la feature de X que usaremos como target para X_{t+1}.
        Por ejemplo, si X tiene d_dyn>1 y la activación binaria está en el canal 0.
    verbose : bool
        Si True, imprime la pérdida por época.

    Returns
    -------
    losses : list[float]
        Lista con la pérdida promedio de cada época.
    """
    device = next(model.parameters()).device
    X = X.to(device)
    E = E.to(device)
    I = I.to(device)
    adj = adj.to(device)

    B, T, N, d_dyn = X.shape

    if batch_size is None:
        batch_size = B

    if loss_fn is None:
        loss_fn = nn.BCEWithLogitsLoss()

    optimizer = optim.Adam(model.parameters(), lr=lr)

    losses = []

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        num_batches = 0

        perm = torch.randperm(B, device=device)

        for start in range(0, B, batch_size):
            end = min(start + batch_size, B)
            idx = perm[start:end]

            X_b = X[idx]      # (b, T, N, d_dyn)
            E_b = E[idx]      # (b, N, d_static)
            I_b = I[idx]      # (b, T, d_img)

            optimizer.zero_grad()

            # Predicciones del siguiente elemento
            preds = model(X_b, E_b, I_b, adj=adj)  # (b, T-1, N, d_out)
            target = X_b[:, 1:, :, target_feature_idx:target_feature_idx+1]  # (b, T-1, N, 1)
            loss = loss_fn(preds, target)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            num_batches += 1
        epoch_loss /= max(1, num_batches)
        losses.append(epoch_loss)
        if verbose:
            print(f"Epoch {epoch+1}/{num_epochs} - loss: {epoch_loss:.4f}")
    return losses


In [None]:
import torch
import torch.nn as nn

def evaluate_recurrent_graph_transformer(model, X, E, I, adj, batch_size=None, loss_fn=None,
    target_feature_idx=0, threshold=0.5, verbose=True):
    """
    Evalúa un RecurrentGraphTransformer en un conjunto de validación/test.
    """
    device = next(model.parameters()).device

    X = X.to(device)
    E = E.to(device)
    I = I.to(device)
    adj = adj.to(device)

    B, T, N, d_dyn = X.shape
    if batch_size is None:
        batch_size = B
    if loss_fn is None:
        loss_fn = nn.BCEWithLogitsLoss()

    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_elements = 0
    num_batches = 0
    with torch.no_grad():
        for start in range(0, B, batch_size):
            end = min(start + batch_size, B)
            idx = slice(start, end)

            X_b = X[idx]
            E_b = E[idx]
            I_b = I[idx]
            preds = model(X_b, E_b, I_b, adj=adj)
            target = X_b[:, 1:, :, target_feature_idx:target_feature_idx+1]
            loss = loss_fn(preds, target)
            total_loss += loss.item()
            logits = preds
            probs = torch.sigmoid(logits)
            pred_bin = (probs >= threshold).float()
            correct = (pred_bin == target).sum().item()
            elements = target.numel()
            total_correct += correct
            total_elements += elements
            num_batches += 1

    avg_loss = total_loss / max(1, num_batches)
    accuracy = total_correct / max(1, total_elements)

    if verbose:
        print(f"[Eval] loss: {avg_loss:.4f}, accuracy: {accuracy:.4f}")

    return {"loss": avg_loss, "accuracy": accuracy}


# como debería usarse

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = RecurrentGraphTransformer(
    d_dyn=X_train.shape[-1],
    d_static=E_train.shape[-1],
    d_img=I_train.shape[-1],
    d_model=256,
    n_heads=4,
    n_layers=3,
    d_ff=512,
    dropout=0.1,
    d_out=1
).to(device)

# entrenamiento
train_losses = train_recurrent_graph_transformer(
    model,
    X_train,
    E_train,
    I_train,
    adj,
    num_epochs=50,
    lr=1e-3,
    batch_size=8,
    target_feature_idx=0,
)

# evaluación
metrics_val = evaluate_recurrent_graph_transformer(
    model,
    X_val,
    E_val,
    I_val,
    adj,
    batch_size=8,
    target_feature_idx=0,
    threshold=0.5,
)

print("Métricas validación:", metrics_val)
