# **Entrenamiento del Modelo GNN**

Pipeline de entrenamiento para el modelo GNN de ordenamiento de nodos.

In [None]:
%run 00_setup.ipynb

In [None]:
%run 03_gnn.ipynb

## **1. FUNCIÓN DE ENTRENAMIENTO**

In [None]:
def dsatur_ordering(G):
    coloring = {}
    order = []
    uncolored = set(G.nodes())

    while uncolored:
        saturation = {}
        for v in uncolored:
            neighbor_colors = {coloring[u] for u in G.neighbors(v) if u in coloring}
            saturation[v] = len(neighbor_colors)

        next_node = max(uncolored, key=lambda v: (saturation[v], G.degree(v)))
        order.append(next_node)

        neighbor_colors = {coloring[u] for u in G.neighbors(next_node) if u in coloring}
        color = 0
        while color in neighbor_colors:
            color += 1

        coloring[next_node] = color
        uncolored.remove(next_node)

    return order


def _graph_to_pyg_data(G):
    n = G.number_of_nodes()
    if set(G.nodes()) != set(range(n)):
        G = nx.convert_node_labels_to_integers(G, ordering='sorted')

    deg = np.array([G.degree(i) for i in range(n)], dtype=np.float32)
    deg = (deg - deg.mean()) / (deg.std() + 1e-8)
    x = torch.from_numpy(deg).view(-1, 1).float()

    edges = np.array(list(G.edges()), dtype=np.int64)
    if edges.size == 0:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
    else:
        edge_index = torch.from_numpy(edges.T).long()
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)

    data = Data(x=x, edge_index=edge_index, num_nodes=n).to(device)
    return data, G


def _dsatur_target_scores(G_int):
    order = dsatur_ordering(G_int)
    n = G_int.number_of_nodes()
    rank = torch.empty(n, dtype=torch.float32, device=device)
    for pos, node in enumerate(order):
        rank[int(node)] = float(pos)
    denom = float(n - 1) if n > 1 else 1.0
    target = 1.0 - (rank / denom)
    return target


def train_gnn_dsatur_imitation(model, graphs, optimizer, epochs=50):
    loss_fn = torch.nn.MSELoss()
    model.train()

    for epoch in range(epochs):
        total_loss = 0.0
        total_colors = 0.0

        for G in graphs:
            data_train, G_int = _graph_to_pyg_data(G)
            target = _dsatur_target_scores(G_int)

            optimizer.zero_grad()
            scores = model(data_train.x, data_train.edge_index)
            loss = loss_fn(scores, target)
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                ordering = torch.argsort(scores, descending=True).tolist()
                num_colors = greedy_coloring_gnn(data_train.edge_index, data_train.num_nodes, ordering)

            total_loss += float(loss.item())
            total_colors += float(num_colors)

        if epoch % 10 == 0:
            avg_loss = total_loss / len(graphs)
            avg_colors = total_colors / len(graphs)
            print(f"Epoch {epoch:03d} | loss={avg_loss:.4f} | colores_greedy(orden_GNN)={avg_colors:.2f}")

    return total_loss / len(graphs)

## **2. PREPARACIÓN DE DATOS**

In [None]:
graphs = []
for _ in range(10):
    n = 120
    p = 0.06
    G_train = nx.erdos_renyi_graph(n=n, p=p, seed=random.randint(0, 10_000))
    if set(G_train.nodes()) != set(range(G_train.number_of_nodes())):
        G_train = nx.convert_node_labels_to_integers(G_train, ordering='sorted')
    graphs.append(G_train)

print(f"Grafos preparados: {len(graphs)}")

## **3. CONFIGURACIÓN DEL OPTIMIZADOR**

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
print("Optimizador configurado: Adam (lr=0.01)")

## **4. ENTRENAMIENTO**

In [None]:
print("Iniciando entrenamiento (imitación de DSATUR)...")
print("="*60)
avg_loss = train_gnn_dsatur_imitation(model, graphs, optimizer, epochs=60)
print("="*60)
print(f"Entrenamiento completado. Loss final promedio: {avg_loss:.4f}")