In [64]:
import networkx as nx
import numpy as np
from dataclasses import dataclass, field
from scipy.spatial import cKDTree
from scipy.spatial import distance_matrix
from tqdm.notebook import trange, tqdm
from glob import glob
import pandas as pd
from collections import defaultdict
import os
import json

In [3]:
@dataclass
class ResultSet:
    """
    Helper class to keep all results from a single experiment together
    """
    path: str
    all_trails: np.array
    seed: np.array
    nuclei: np.array
    stations: np.array
    map_with_stations: np.array
    start_pos: np.array
    all_points: np.array = field(init=False)
    station_indices: range = field(init=False)

    def __post_init__(self):
        points = self.nuclei.reshape(-1, 2)
        points = points[~np.all(points == 0, axis=1)]  # Remove zero points
        self.all_points = np.vstack([points, self.stations])
        
        self.station_indices = range(self.all_points.shape[0] - len(self.stations), self.all_points.shape[0])

    def __str__(self):
        return f"Experiment with start pos {self.start_pos}, seed {self.seed}"

In [4]:
def create_nuclei_graph(results: ResultSet) -> nx.Graph:
    """
    Creates a graph from all slime nuclei, where every node is connected to
    its 10 closest neighbours
    """
    
    # Build KDTree for fast neighbor search
    tree = cKDTree(results.all_points)
    
    # For each point, find its pm=10 nearest neighbors (excluding itself)
    pm = 10
    dists, idxs = tree.query(results.all_points, k=pm+1)  # +1 because first neighbor is itself
    
    # Build the proximity graph
    G = nx.Graph()
    
    # idxs: shape (N, pm+1), where idxs[i, 0] == i (self), idxs[i, 1:] are neighbors
    src = np.repeat(np.arange(idxs.shape[0]), idxs.shape[1] - 1)
    dst = idxs[:, 1:].reshape(-1) 
    edges = np.stack([src, dst], axis=1) # shape (N*(pm-1), 2)
    
    # Compute edge weights (Euclidean distances)
    diffs = results.all_points[edges[:, 0]] - results.all_points[edges[:, 1]]
    weights = np.linalg.norm(diffs, axis=1)
    
    # Add all edges at once to the graph
    G.add_weighted_edges_from([(int(i), int(j), float(w)) for (i, j), w in zip(edges, weights)])

    return G

In [5]:
def bresenham_line(x0, y0, x1, y1):
    """Yield integer coordinates on the line from (x0, y0) to (x1, y1) using Bresenham's algorithm."""
    x0, y0, x1, y1 = int(round(x0)), int(round(y0)), int(round(x1)), int(round(y1))
    dx = abs(x1 - x0)
    dy = abs(y1 - y0)
    x, y = x0, y0
    sx = 1 if x0 < x1 else -1
    sy = 1 if y0 < y1 else -1
    if dx > dy:
        err = dx / 2.0
        while x != x1:
            yield x, y
            err -= dy
            if err < 0:
                y += sy
                err += dx
            x += sx
        yield x, y
    else:
        err = dy / 2.0
        while y != y1:
            yield x, y
            err -= dx
            if err < 0:
                x += sx
                err += dy
            y += sy
        yield x, y

def prune_edges_by_map(G, result_set, max_water_crossings=2):
    """
    Prunes illegal edges by checking the number of invalid pixels on a line 
    """
    # Create new pruned graph objecct
    pruned_graph = nx.Graph()
    for i, j in G.edges(): # Loop over edges
        x0, y0 = result_set.all_points[i] # Source x, y
        x1, y1 = result_set.all_points[j] # Target x, y
        # Sample the line between the two points
        line_pixels = list(bresenham_line(x0, y0, x1, y1))
        # Count how many pixels cross water (0)
        water_crossings = sum(
            result_set.map_with_stations[int(x), int(y)] == 0
            for x, y in line_pixels
            if 0 <= int(x) < result_set.map_with_stations.shape[0] and 0 <= int(y) < result_set.map_with_stations.shape[1]
        )
        # Only add if below threshold
        if water_crossings <= max_water_crossings:
            pruned_graph.add_edge(i, j, weight=G[i][j]['weight'])
    return pruned_graph

In [306]:
def find_shortest_paths(G, stations):
    """
    Finds shortest paths between all stations on G.
    """
    
    paths = defaultdict(list) # list of shortest paths per source
    distances = defaultdict(dict) # (source, target)-dictionary with distances
        
    for i in range(len(stations)): # Loop over all source stations
        source = stations[i]
        for j in range(len(stations)): # Loop over all target stations
            if i == j: continue
            try:
                target = stations[j]
                length, path = nx.single_source_dijkstra(G, source=str(source), target=str(target), weight='weight')
                paths[source].append((length, path, target))
                distances[source][target] = length
            except nx.NetworkXNoPath:
                continue

    return paths, distances

In [7]:
def build_refined_station_network(G, results, proximities):
    """
    Build a refined network connecting each station to its p nearest station neighbors,
    using only mesh edges from the original graph G.
    """
    graphs = []
    for p in proximities:
        graphs.append(nx.Graph())

    all_paths, distances = find_shortest_paths(G, results.station_indices)

    for source, paths in all_paths.items():
        paths = sorted(paths, key=lambda p: p[0])
        
        for p in range(len(proximities)):
            p_graph = graphs[p]
            proximity = proximities[p]
            for length, path, target in paths[:proximity]:
                p_graph.add_weighted_edges_from((str(path[k]), str(path[k+1]), G[path[k]][path[k+1]]['weight']) for k in range(len(path)-1))

    for p in range(len(proximities)):
        nx.write_weighted_edgelist(graphs[p], results.path + f'.p{proximities[p]}.weighted.edgelist')
    
    with open(results.path + '.distances.json', 'w') as f: 
        json.dump(distances, f)
                
    return graphs, distances

In [302]:
def simplify_graph(G, results):
    """
    Simplifies a graph G by iteratively removing edges wiht degree 2
    that are not stations and whose neighbours don't already have a connection.
    """
    G_contracted = G.copy()
    non_removable = []
    
    # This evaluates to false when the set becomes empty
    while nodes_with_degree_2 := set([n for n, d in G_contracted.degree() if d == 2 
                                      and int(n) not in results.station_indices 
                                      and n not in non_removable]):
        node = nodes_with_degree_2.pop()
        edges = list(G_contracted.edges(node))
        left = edges[0][1]
        right = edges[1][1]
        if G_contracted.has_edge(left, right): # Edge already exists, cannot overwrite it
            non_removable.append(node)
            continue
        combined_weight = G_contracted[node][left]['weight'] + G_contracted[node][right]['weight']
        G_contracted.add_edge(left, right, weight = combined_weight)
        G_contracted.remove_node(node)

    return G_contracted

In [378]:
def calculate_network_cost(G):
    """
    Calculates the cost of the total network
    """
    return G.size('weight')

def calculate_mean_travel_time(G, stations):
    """
    Calculates the mean travel time on the network
    """
    _, distances = find_shortest_paths(G, stations)
    full_cost = calculate_network_cost(G)

    means = []
    for source in stations:
        target_distances = []
        for target in stations:
            if source == target: continue
            try:
                target_distance = distances[source][target]
            except KeyError:
                target_distance = full_cost
            target_distances.append(target_distance)

        means.append(np.mean(target_distances))
    return np.mean(means)
    
def calculate_network_vulnerability(G, ref_travel_time, stations):
    """
    Calculates the mean vulnerability on the network.
    If the graph becomes disconnected after an edge is removed,
    then the total weight of the graph is added to the vulnerability
    of this edge.
    """
    vulnerabilities = []
    full_cost = calculate_network_cost(G)

    for e in tqdm(G.edges()):
        # Create a copy of the graph to avoid modifying the original
        G_copy = G.copy()
        G_copy.remove_edge(*e)
        
        mean_time = calculate_mean_travel_time(G_copy, stations)
        vuln = np.abs(ref_travel_time - mean_time)
        
        if not nx.is_connected(G_copy):
            vuln += full_cost
            
        vulnerabilities.append(vuln)
    return np.mean(vulnerabilities)

In [384]:
def process_file(path: str) -> ResultSet:
    """
    Process a single result file.
    """
    with np.load(path) as data:
        return ResultSet(path=path,
                         all_trails = data['all_trails'],
                         seed = data['seed'],
                         nuclei = data['nuclei'],
                         stations = data['stations'],
                         map_with_stations = data['map_with_stations'],
                         start_pos = data['start_pos'])

def process_folder(dir_path: str, proximities):
    """
    Process a folder containing experiment result files.
    """
    full_path = dir_path + "*.npz"
    files = glob(full_path)
    
    results = []
    for i in trange(len(files)):
        path = files[i]
        result_set = process_file(path)
        full_graph = create_nuclei_graph(result_set)
        pruned_graph = prune_edges_by_map(full_graph, result_set)

        if os.path.exists(path + '.distances.json'):
            with open(path + '.distances.json', 'r') as f: 
                distances = json.load(f)
            graphs = []
            for p in proximities:
                graphs.append(nx.read_weighted_edgelist(path + f'.p{p}.weighted.edgelist'))
        else:
            graphs, distances = build_refined_station_network(pruned_graph, result_set, proximities)
            
        for p in range(len(proximities)):
            graph = simplify_graph(graphs[p], result_set)
            proximity = proximities[p]
            total_cost = calculate_network_cost(graph)
            mean_travel_time = calculate_mean_travel_time(graph, result_set.station_indices)
            network_vulnerability = calculate_network_vulnerability(graph, mean_travel_time, result_set.station_indices)
            is_connected = nx.is_connected(graph)

            print(total_cost, mean_travel_time, network_vulnerability)
    
            results.append((proximity, result_set.start_pos[0], result_set.start_pos[1], result_set.seed, 
                            total_cost, mean_travel_time, network_vulnerability, is_connected))
    
    df = pd.DataFrame(results, columns=["proximity", "start_pos_x", "start_pos_y", "seed", 
                                        "total_cost", "mean_travel_time", "vulnerability", "is_connected"])
    df_path = dir_path + f"/results.csv"
    df.to_csv(df_path)

In [385]:
proximities = range(1, 6) # Define the proximities we want to use: [1, 6), i.e., {1,2,3,4,5}

# Process experiments
process_folder("../experiment_outputs_different_starts/", proximities)
process_folder("../experiment_outputs_same_starts/", proximities)

  0%|          | 0/26 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

2063.4483423387715 1887.6422112139878 2119.272318576484


  0%|          | 0/71 [00:00<?, ?it/s]

4460.449312732917 599.6157267425207 1494.3068709621775


  0%|          | 0/137 [00:00<?, ?it/s]

6518.087761402611 456.2612466010431 534.7827257915848


  0%|          | 0/178 [00:00<?, ?it/s]

KeyboardInterrupt: 