In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import copy
import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from typing import List
from pprint import pprint
from collections import defaultdict

# Sugiyama implementation

## Data

In [3]:
# node_link_list = [
#     (0, 6),
#     (6, 1),
#     (1, 7),
#     (7, 2),
#     (7, 3),
#     (3, 8),
#     (8, 4),
#     (3, 9),
#     (9, 5),
# ]
# rank_node_dict = {0: [0], 1: [6], 2: [1], 3: [7], 4: [2, 3], 5: [8, 9], 6: [4, 5]}

node_link_list = [
    (0, 1),
    (1, 2),
    (3, 4),
    (4, 5),
    (6, 0),
    (0, 5),
    (3, 0),
    (3, 2),
    (0, 7),
    (7, 8),
    (8, 0),
    (1, 8),
    (5, 9),
    (9, 6),
    (10, 5),
]
rank_node_dict = {0: [0], 1: [1, 6, 5, 3, 7, 8], 2: [2, 9, 4, 10]}


In [4]:
node_connections_dict = defaultdict(list)
for node, connection in node_link_list:
    node_connections_dict[node].append(connection)
    node_connections_dict[connection].append(node)

# Transform lists to numpy arrays
node_connections_dict = {
    node: np.array(connections) for node, connections in node_connections_dict.items()
}
rank_node_dict = {rank: np.array(nodes) for rank, nodes in rank_node_dict.items()}

print("Node connections dict:", node_connections_dict)
print("Rank node dict:", rank_node_dict)


Node connections dict: {0: array([1, 6, 5, 3, 7, 8]), 1: array([0, 2, 8]), 2: array([1, 3]), 3: array([4, 0, 2]), 4: array([3, 5]), 5: array([ 4,  0,  9, 10]), 6: array([0, 9]), 7: array([0, 8]), 8: array([7, 0, 1]), 9: array([5, 6]), 10: array([5])}
Rank node dict: {0: array([0]), 1: array([1, 6, 5, 3, 7, 8]), 2: array([ 2,  9,  4, 10])}


In [5]:
max_iterations = 20

In [6]:
def median_pos(rank_node_dict: dict, rank_neighbors:int, neighbors: np.ndarray):
    nodes_at_rank: np.ndarray = rank_node_dict[rank_neighbors]
    # print(nodes_at_rank)
    # print(neighbors)

    positions_of_neighbors = []
    for n in neighbors:
        pos_n = np.where(nodes_at_rank == n)[0]
        if pos_n.size > 0:
            positions_of_neighbors.append(pos_n[0])
    positions_of_neighbors = np.array(positions_of_neighbors)
    # print(positions_of_neighbors)

    # If no neighbors are found, return -1
    if positions_of_neighbors.size == 0:
        return -1

    median_position_neighbors = np.median(positions_of_neighbors).astype(int)
    # print(median_position_neighbors)

    return median_position_neighbors

# med = median_pos(rank_node_dict, 2, node_connections_dict[3])
# print(med)

In [7]:
def get_new_positions(
    nodes: np.ndarray,
    rank_node_dict: dict,
    node_connections_dict: dict,
    rank: int,
    forward: bool,
):
    new_positions = []

    for cur_pos, node in enumerate(nodes):
        neighbors = node_connections_dict[node]
        if forward:
            neighbor_rank = rank + 1
        else:
            neighbor_rank = rank - 1
        # print("Node:", node)
        # print("Neighbors:", neighbors)
        # print("Neighbor rank:", neighbor_rank)
        if neighbor_rank not in rank_node_dict:
            new_positions.append(cur_pos)
            continue

        new_position = median_pos(rank_node_dict, neighbor_rank, neighbors)
        # print("New position:", new_position)
        if new_position == -1:
            new_positions.append(cur_pos)
        else:
            new_positions.append(new_position)

    new_positions = np.array(new_positions)
    # print("New positions:", new_positions)

    return new_positions


def layout(rank_node_dict: dict, node_connections_dict: dict, max_iterations: int):
    for it in tqdm.tqdm(range(max_iterations)):
        # Forward pass
        for rank in sorted(rank_node_dict.keys()):
            nodes = rank_node_dict[rank]
            # print("Rank:", rank)
            # print("Nodes:", nodes)
            new_positions = get_new_positions(
                nodes, rank_node_dict, node_connections_dict, rank, forward=True
            )

            # Sort the new positions
            sorted_new_positions = np.argsort(new_positions)
            # print("Sorted new positions:", sorted_new_positions)
            rank_node_dict[rank] = nodes[sorted_new_positions]

        # Backward pass
        for rank in sorted(rank_node_dict.keys(), reverse=True):
            nodes = rank_node_dict[rank]
            # print("Rank:", rank)
            # print("Nodes:", nodes)
            new_positions = get_new_positions(
                nodes, rank_node_dict, node_connections_dict, rank, forward=False
            )

            # Sort the new positions
            sorted_new_positions = np.argsort(new_positions)
            # print("Sorted new positions:", sorted_new_positions)
            rank_node_dict[rank] = nodes[sorted_new_positions]

    return rank_node_dict


In [8]:
new_rank_node_dict = layout(
    rank_node_dict=copy.deepcopy(rank_node_dict),
    node_connections_dict=copy.deepcopy(node_connections_dict),
    max_iterations=max_iterations,
)
print("New rank node dict:", new_rank_node_dict)

100%|██████████| 20/20 [00:00<00:00, 3046.41it/s]

New rank node dict: {0: array([0]), 1: array([1, 6, 3, 5, 7, 8]), 2: array([ 2,  9,  4, 10])}



