# Helpers

In [None]:
from collections import deque
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
class DisjointSet:
    def __init__(self, elements=None):
        self.parent = {}
        self.rank = {}
        if elements:
            for e in elements:
                self.parent[e] = e
                self.rank[e] = 0

    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        xroot = self.find(x)
        yroot = self.find(y)
        if xroot == yroot:
            return
        if self.rank[xroot] < self.rank[yroot]:
            self.parent[xroot] = yroot
        else:
            self.parent[yroot] = xroot
            if self.rank[xroot] == self.rank[yroot]:
                self.rank[xroot] += 1

    def is_connected(self, x, y):
        return self.find(x) == self.find(y)

In [None]:
def BFS_edges(G, source):
    visited = set([source])
    queue = deque([source])
    while queue:
        u = queue.popleft()
        for v in G.neighbors(u):
            if v not in visited:
                visited.add(v)
                queue.append(v)
                yield (u, v)

In [None]:
class Graph:
    def __init__(self):
        self.adj_list = {}
        self.edge_labels = {}  # edge weights/labels, keys are frozenset({u,v})

    def add_vertex(self, v):
        if v not in self.adj_list:
            self.adj_list[v] = set()

    def add_edge(self, u, v, label=None):
        self.add_vertex(u)
        self.add_vertex(v)
        self.adj_list[u].add(v)
        self.adj_list[v].add(u)
        self.edge_labels[frozenset((u, v))] = label

    def nodes(self):
      return list(self.adj_list.keys())



    def neighbors(self, v):
        return self.adj_list.get(v, set())

    def min_degree(self):
        if not self.adj_list:
            return 0
        return min(len(neigh) for neigh in self.adj_list.values())

    def empty_copy(self):
        new_g = Graph()
        for v in self.adj_list:
            new_g.add_vertex(v)
        return new_g

    def copy(self):
        new_g = self.empty_copy()
        for u in self.adj_list:
            for v in self.adj_list[u]:
                edge = frozenset((u, v))
                if edge not in new_g.edge_labels:
                    new_g.add_edge(u, v, self.edge_labels.get(edge))
        return new_g

    def edge_iterator(self, labels=True):
        seen = set()
        for u in self.adj_list:
            for v in self.adj_list[u]:
                edge = frozenset((u, v))
                if edge not in seen:
                    seen.add(edge)
                    if labels:
                        yield (u, v, self.edge_labels.get(edge))
                    else:
                        yield (u, v)

    def order(self):
        return len(self.adj_list)

    def size(self):
        return sum(len(n) for n in self.adj_list.values())//2

    def delete_edge(self, u, v):
        if v in self.adj_list.get(u, set()):
            self.adj_list[u].remove(v)
        if u in self.adj_list.get(v, set()):
            self.adj_list[v].remove(u)
        self.edge_labels.pop(frozenset((u, v)), None)

    def order_by_weight(self, labels=True):
        edges = list(self.edge_iterator(labels=True))
        edges.sort(key=lambda e: e[2] if e[2] is not None else 1)
        if labels:
            return edges
        else:
            return [(u, v) for u, v, _ in edges]

    def edges(self, sort=False, labels=False):
        edges = list(self.edge_iterator(labels=labels))
        if sort and labels:
            edges.sort(key=lambda e: e[2] if e[2] is not None else 1)
        elif sort and not labels:
            edges.sort(key=lambda e: (e[0], e[1]))
        return edges

    def set_edge_label(self, u, v, label):
        self.edge_labels[frozenset((u, v))] = label

    def edge_label(self, u, v):
        return self.edge_labels.get(frozenset((u, v)))

In [None]:
import matplotlib.colors as mcolors


def plotting_MSTs(G,trees):
    pos = nx.spring_layout(G)
    # Get a dictionary of all CSS4 named colors
    named_colors_dict = mcolors.CSS4_COLORS

    # Extract the color names into a list
    colors = list(named_colors_dict.keys())

    for i, tree in enumerate(trees, 1):
        nx.draw_networkx_edges(tree,pos,edge_color=colors[i-1],width = 0.5, alpha=1)
    nx.draw_networkx_nodes(G,pos, alpha=1, node_size = 6, node_color='m')
    plt.show()

In [None]:
def to_networkx_graph(g):
    nx_g = nx.Graph()
    # adiciona os nós
    nx_g.add_nodes_from(g.nodes())
    # adiciona as arestas com label/peso, se houver
    for u, v, label in g.edge_iterator(labels=True):
        if label is not None:
            nx_g.add_edge(u, v, weight=label)
        else:
            nx_g.add_edge(u, v)
    return nx_g

In [None]:
def from_networkx_graph(nx_graph):
    custom_g = Graph()
    for node in nx_graph.nodes():
        custom_g.add_vertex(node)
    for u, v, data in nx_graph.edges(data=True):
        label = data.get('weight') if 'weight' in data else None
        custom_g.add_edge(u, v, label)
    return custom_g

In [None]:
def run_dataset(dataset, algorithm, n_runs):
 res = []
 for j in range(1, n_runs+1):
  for i, G in enumerate(dataset):
      G_graph = from_networkx_graph(G)
      F = algorithm(G_graph, k=j, weights=True)
      res.append(F)
      # print(len(F))
      # for j, tree in enumerate(F, 1):
        # print(f"Spanning Tree {j} edges:", tree.edges(labels=True))
        # total_weight = sum(label for _, _, label in tree.edges(labels=True))
        # print(f"Total weight: {total_weight}")

      # trees = [to_networkx_graph(tree) for tree in F]

      # G_graph = to_networkx_graph(G_graph)

      # plotting_MSTs(G_graph,trees)
  return res

# Kruskal

In [None]:
def kruskal(edges, n):
  res = []
  st = DisjointSet(elements=range(n))
  i = 0
  l = len(edges)
  while i < l:
    u, v, weight = edges[i]
    x = st.find(u)
    y = st.find(v)

    if x != y:
        res.append((u, v, weight))
        st.union(x, y)
        edges.remove(edges[i])
        l -= 1
    else:
        i += 1

  tree = Graph()
  for edge in res:
    tree.add_edge(*edge)

  return tree, edges

def modified_kruskal(G,  k, weights=True):
   F = []
   i, n = 0, len(G.adj_list)
   edges = G.order_by_weight()

   while i < k:
    res_i, edges = kruskal(edges, n)
    if len(res_i.edges()) != n-1:
      break

    F.append(res_i)
    i += 1


   res = [F[j] for j in range(i)]
   return res


# Roskind-Tarjan

In [None]:
def edge_disjoint_spanning_trees(G, k, weights=False):
    # print(f"Starting edge_disjoint_spanning_trees with k={k}")
    if k > 1 + G.min_degree() // 2:
        raise ValueError("this graph does not contain the required number of trees/arborescences")

    # Inicializa DisjointSets para manter as partições dos grafos.
    partition = [DisjointSet(G.nodes()) for _ in range(k + 1)]

    # Mapeamento de cada aresta para a floresta na qual ela está contida.
    edge_index = {frozenset(e): 0 for e in G.edge_iterator(labels=False)}

    # Cópia vazia do grafo para construir as florestas.
    G_spanning = G.empty_copy()
    F = [G_spanning.copy() for _ in range(k + 1)]

    if weights:
        edges_ordered = list(G.order_by_weight(labels=True))
        # print("Edges ordered (with weights):")
        # print(edges_ordered)
        edge_weights = {frozenset((u, v)): w for u, v, w in edges_ordered}
    else:
        edges_ordered = list(G.order_by_weight(labels=False))
        edge_weights = {frozenset((u, v)): None for u, v in edges_ordered}

    for idx, edge in enumerate(edges_ordered, 1):
        if weights:
            x, y, w = edge
        else:
            x, y = edge
            w = None

        # Se os dois vértices já estão unidos em partition[0] (o clump global), ignora a aresta.
        if partition[0].find(x) == partition[0].find(y):
            continue

        edge_label = {}
        queue = [(x, y)]
        queue_begin = 0
        queue_end = 1

        # p[i] armazenará os predecessores (raiz x) na floresta F[i] obtida por BFS.
        p = [{x: x} for _ in range(k + 1)]
        for i in range(1, k + 1):
            for u, v in BFS_edges(F[i], x):
                p[i][v] = u

        augmenting_sequence_found = False

        while queue_begin < queue_end:
            e = queue[queue_begin]
            queue_begin += 1
            fe = frozenset(e)
            # Escolhe a floresta i com base em edge_index (alternando de 1 a k)
            i = (edge_index[fe] % k) + 1
            v, w_ = e

            if partition[i].find(v) != partition[i].find(w_):
                augmenting_sequence_found = True
                break
            else:
                if v == x or (v in p[i] and frozenset((v, p[i][v])) in edge_label):
                    u = w_
                else:
                    u = v
                edges_to_label = []
                while u != x and (u in p[i] and frozenset((u, p[i][u])) not in edge_label):
                    edges_to_label.append((u, p[i][u]))
                    u = p[i][u]

                while edges_to_label:
                    ep = edges_to_label.pop()
                    edge_label[frozenset(ep)] = fe
                    queue.append(ep)
                    queue_end += 1

        if augmenting_sequence_found:
            # Atualiza a partição da floresta i incrementando com a aresta atual
            partition[i].union(v, w_)

            # Loop de augmentação: troca arestas entre florestas de forma incremental
            while fe in edge_label:
                old_forest = edge_index[fe]
                u, vv = tuple(fe)
                # Remove aresta da floresta antiga
                F[old_forest].delete_edge(u, vv)
                # Adiciona aresta na floresta i com o peso apropriado
                F[i].add_edge(u, vv, edge_weights[frozenset((u, vv))])
                # Atualização incremental: une u e vv na partição da floresta i
                partition[i].union(u, vv)
                # Rotaciona as variáveis para a próxima iteração da sequência aumentante
                e, edge_index[fe], i = edge_label[fe], i, edge_index[fe]
                fe = frozenset(e)

            u, vv = tuple(e)
            F[i].add_edge(u, vv, edge_weights[frozenset(e)])
            edge_index[frozenset(e)] = i
            partition[i].union(u, vv)
        else:
            partition[0].union(x, y)

    res = [F[i] for i in range(1, k + 1) if F[i].size() == G.order() - 1]

    if len(res) != k:
        raise ValueError("There is no solution!")

    for f in res:
        for u, v, label in f.edges(labels=True):
            f.set_edge_label(u, v, edge_weights[frozenset((u, v))])

    return res

# Test Code

In [None]:
df = pd.read_pickle("dataset-grafos-trab-final.pkl")
dataset_total = df['graph']

In [None]:
%timeit -r 100 res = run_dataset(dataset_total, edge_disjoint_spanning_trees, 4)

# Results:
* Tempo total de execução (100 repetições):
  * Kruskal: 734 ms ± 161 ms per loop (mean ± std. dev. of 100 runs)