In [1]:
import networkx as nx
import numpy as np
import random
import os
import glob
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from pyadigraph import Adigraph

In [2]:
class RandomWeightedGraph(nx.Graph):
    def __init__(self, nodes:int):
        super().__init__(nx.bipartite.random_graph(nodes, nodes, 0.5).edges())
        for (u, v) in self.edges():
            self.edges[u,v]['weight'] = np.random.randint(10, 100)
            
    def weight(self, edge):
        return self.edges[edge]['weight']

In [16]:
class MinimumSpanningTree(dict):
    def __init__(self, G: nx.Graph, root: int, describe: bool = False):
        self._original_G, self._G, self._root = G, G, root
        files = glob.glob('../images/prim/*')
        self.counter = 0
        for f in files:
            os.remove(f)
        if describe:
            self._setup_describe()

    def _setup_describe(self):
        #self._pos = nx.kamada_kawai_layout(self._G)
        self._pos = nx.spring_layout(self._G, iterations=10000)
        gray, self._dark_gray = '#eeeeee', '#555555'
        self._node_colors = [gray for n in self._G]
        self._edge_colors = [gray for e in self._G.edges]
        positions = np.array(list(self._pos.values()))
        self._min_x, self._min_y = np.min(positions, axis=0) - 0.1
        self._max_x, self._max_y = np.max(positions, axis=0) + 0.1

    def _plot_graphs(self, graphs: list, title: str):
        a = Adigraph(
            layout=self._pos,
            weights=nx.get_edge_attributes(self._original_G, 'weight'),
            style="-",
            row_size=4,
            caption=title,
            directed=False,
            edges_color_fallback="gray!80",
            vertices_color_fallback="gray")

        for i, (G, colors, title) in enumerate(graphs):
            if isinstance(colors, str):
                a.add_graph(
                    self._original_G,
                    caption=title,
                    vertices_color_fallback=colors,
                    edges_color={e: "black!90"
                                 for e in G.edges},
                    edges_width={e: 1
                                 for e in G.edges},
                    vertices_width={v: 1
                                    for v in G.nodes})
            else:
                a.add_graph(
                    self._original_G,
                    caption=title,
                    vertices_color=colors,
                    edges_color={e: "black!90"
                                 for e in G.edges},
                    edges_width={e: 1
                                 for e in G.edges},
                    vertices_width={v: 1
                                    for v in G.nodes})

        a.save("../chapters/prim/{c}.tex".format(c=self.counter))
        self.counter += 1

    def _describe_start(self, root: int):
        tree = nx.Graph()
        tree.add_node(root)
        colors = ['r' for n in self._original_G]
        self._plot_graphs([(self._original_G, colors, "Initial Graph"),
                           (tree, colors, "Initial Tree")],
                          "Initial Conditions")

In [17]:
class PrimTree(MinimumSpanningTree):
    def __init__(self, G: nx.Graph, root: int, describe: bool = False):
        super().__init__(G, root, describe=describe)

    def run(self):
        self._costs = dict((v, np.inf) if v != root else (v, 0) for v in G)
        self._pred = {root: root}
        while self._costs and not self._is_spanning():
            self.iteration()

    def iteration(self):
        node = min(self._costs, key=self._costs.get)
        self[node] = self._pred[node]
        self._costs.pop(node, None)
        self._neighbors = []
        for v in self._G.neighbors(node):
            if v in self._costs and self._G.weight((node, v)) < self._costs[v]:
                self._pred[v] = node
                self._neighbors.append(v)
                self._costs[v] = self._G.weight((node, v))
        return node

    def _is_spanning(self):
        """Determine if tree is currently spanning the graph G."""
        return len({v for g in zip(*self.items()) for v in g}) == len(self._G)

In [18]:
class GraphicalPrimTree(PrimTree):
    def __init__(self, G: nx.Graph, root: int):
        super().__init__(G, root, describe=True)

    def run(self):
        self._i = 0
        super().run()

    def iteration(self):
        node = super().iteration()
        self._plot_graphs([
            (*self._create_iteration_graph(self._costs, self._neighbors, node),
             " "),
            (nx.Graph(list(self.items())), {n:'red!90' for n in self}, " ")
        ], "Iteration {i}: adding node {node}.".format(i=self._i, node=node))
        self._i += 1

    def _create_iteration_graph(self, costs: dict, neighbors: list,
                                node: int) -> tuple:
        iteration = nx.Graph()
        [iteration.add_node(v) for v in [*costs, node, *neighbors]]
        colors = {
            v:'green!90' if v in neighbors else 'magenta!90' if v == node else 'cyan!90'
            for v in iteration
        }
        return iteration, colors

In [19]:
seed = 66
random.seed(seed)
np.random.seed(seed)
G = RandomWeightedGraph(4)
root = 4
end = 3

In [20]:
GraphicalPrimTree(G, root).run()