In [1]:
from graph_utils import generate_graph, graph_to_jraph
from gcn import GCN, TrainState
import optax
from hamiltonian_cycle import train, post_process
import flax.linen as nn
from plot_utils import draw_cycle
import networkx as nx

In [2]:
from typing import Any


from jax._src.basearray import Array
from networkx.classes.graph import Graph


from _src.gcn import TrainState


def run_model(
    side_len: int,
    n_epochs: int,
    lr: float,
    hidden_size: int,
    num_convolutions: int = 2,
    num_layers: int = 1,
    patience: int = 250,
    tol: float = 1e-3,
    use_TSP: bool = False,
    graph_type: str = "chess",
) -> tuple[TrainState, Graph, dict[int, Array | Any]]:
    nx_graph, pos = generate_graph(side_len * side_len, graph_type=graph_type)
    jraph_graph = graph_to_jraph(nx_graph, pos)

    n = nx_graph.number_of_nodes()

    optimizer = optax.adam(learning_rate=lr)

    net = GCN(
        hidden_size,
        n,
        nn.leaky_relu,
        output_activation=nn.softmax,
        num_layers=num_layers,
        dropout_rate=0.0,
        num_convolutions=num_convolutions,
    )

    state, _ = train(
        jraph_graph,
        net,
        optimizer,
        n_epochs,
        tol=tol,
        patience=patience,
        use_TSP=use_TSP,
    )

    return state, nx_graph, pos


def evaluate_and_draw(state: TrainState, nx_graph: nx.Graph, pos: dict, side_len: int = 4, graph_type: str = "chess", use_TSP: bool = False, **kwargs) -> None:
    jraph_graph = graph_to_jraph(nx_graph, pos)
    pred_graph = state.apply_fn(state.params, jraph_graph, training=False)
    pred_cycle = post_process(pred_graph.nodes)
    save_name = f"tsp_{side_len}_{graph_type}_{"tsp" if use_TSP else "ham"}"
    draw_cycle(
        nx_graph,
        pos,
        pred_cycle,
        rounding=True,
        title="Predicted Cycle",
        save_name=save_name,
    )


def vary_grid_and_TSP(kwargs: dict) -> None:
    for use_TSP in [False, True]:
        for graph_type in ["chess", "grid"]:
            kwargs["use_TSP"] = use_TSP
            kwargs["graph_type"] = graph_type

            state, nx_graph, pos = run_model(**kwargs)
            evaluate_and_draw(state, nx_graph, pos, **kwargs)

In [3]:
kwargs = {
    "side_len": 4,
    "n_epochs": 10000,
    "lr": 0.01,
    "hidden_size": 32,
    "num_convolutions": 1,
    "num_layers": 1,
    "use_TSP": False,
    "graph_type": "chess",
}

vary_grid_and_TSP(kwargs)

Training:  27%|██▋       | 2729/10000 [00:03<00:08, 860.86epoch/s, loss=0.0008, patience=100.0%]


Early stopping at epoch 2751
Final loss: 0.0008
Best loss: 0.0008


In [None]:
kwargs["side_len"] = 8
vary_grid_and_TSP(kwargs)