In [1]:
import networkx as nx
import numpy as np
from dataclasses import dataclass, field
from scipy.spatial import cKDTree
from scipy.spatial import distance_matrix
from tqdm.notebook import trange
from glob import glob
import pandas as pd

In [2]:
@dataclass
class ResultSet:
    """
    Helper class to keep all results from a single experiment together
    """
    
    all_trails: np.array
    seed: np.array
    nuclei: np.array
    stations: np.array
    map_with_stations: np.array
    start_pos: np.array
    all_points: np.array = field(init=False)

    def __post_init__(self):
        points = self.nuclei.reshape(-1, 2)
        points = points[~np.all(points == 0, axis=1)]  # Remove zero points
        self.all_points = np.vstack([points, self.stations])

    def __str__(self):
        return f"Experiment with start pos {self.start_pos}, seed {self.seed}"

In [12]:
def create_nuclei_graph(results: ResultSet) -> nx.Graph:
    """
    Creates a graph from all slime nuclei, where every node is connected to
    its 10 closest neighbours
    """
    
    # 2. Build KDTree for fast neighbor search
    tree = cKDTree(results.all_points)
    
    # 3. For each point, find its pm=10 nearest neighbors (excluding itself)
    pm = 10
    dists, idxs = tree.query(results.all_points, k=pm+1)  # +1 because first neighbor is itself
    
    # 4. Build the proximity graph
    G = nx.Graph()
    
    # idxs: shape (N, pm+1), where idxs[i, 0] == i (self), idxs[i, 1:] are neighbors
    src = np.repeat(np.arange(idxs.shape[0]), idxs.shape[1] - 1)
    dst = idxs[:, 1:].reshape(-1) 
    edges = np.stack([src, dst], axis=1) # shape (N*(pm-1), 2)
    
    # Compute edge weights (Euclidean distances)
    diffs = results.all_points[edges[:, 0]] - results.all_points[edges[:, 1]]
    weights = np.linalg.norm(diffs, axis=1)
    
    # Add all edges at once to the graph
    G.add_weighted_edges_from([(int(i), int(j), float(w)) for (i, j), w in zip(edges, weights)])

    return G

In [13]:
def bresenham_line(x0, y0, x1, y1):
    """Yield integer coordinates on the line from (x0, y0) to (x1, y1) using Bresenham's algorithm."""
    x0, y0, x1, y1 = int(round(x0)), int(round(y0)), int(round(x1)), int(round(y1))
    dx = abs(x1 - x0)
    dy = abs(y1 - y0)
    x, y = x0, y0
    sx = 1 if x0 < x1 else -1
    sy = 1 if y0 < y1 else -1
    if dx > dy:
        err = dx / 2.0
        while x != x1:
            yield x, y
            err -= dy
            if err < 0:
                y += sy
                err += dx
            x += sx
        yield x, y
    else:
        err = dy / 2.0
        while y != y1:
            yield x, y
            err -= dx
            if err < 0:
                x += sx
                err += dy
            y += sy
        yield x, y

def prune_edges_by_map(G, result_set, max_water_crossings=2):
    """
    Prunes illegal edges by checking the number of invalid pixels on a line 
    """
    pruned_graph = nx.Graph()
    for i, j in G.edges():
        x0, y0 = result_set.all_points[i]
        x1, y1 = result_set.all_points[j]
        # Sample the line between the two points
        line_pixels = list(bresenham_line(x0, y0, x1, y1))
        # Count how many pixels cross water (0)
        water_crossings = sum(
            result_set.map_with_stations[int(x), int(y)] == 0
            for x, y in line_pixels
            if 0 <= int(x) < result_set.map_with_stations.shape[0] and 0 <= int(y) < result_set.map_with_stations.shape[1]
        )
        if water_crossings <= max_water_crossings:
            pruned_graph.add_edge(i, j, weight=G[i][j]['weight'])
    return pruned_graph

In [14]:
def build_refined_station_network(G, results, proximities):
    """
    Build a refined network connecting each station to its p nearest station neighbors,
    using only mesh edges from the original graph G.
    """
    # Indices of station sources in all_points (they are last in the array)
    station_indices = np.arange(results.all_points.shape[0] - len(results.stations), results.all_points.shape[0])

    output = []
    for p in proximities: output.append((p, nx.Graph(), set()))

    for i in trange(len(station_indices)):
        source = station_indices[i]
        paths = []
        for j in range(len(station_indices)):
            if i == j: continue
            try:
                target = station_indices[j]
                length, path = nx.single_source_dijkstra(G, source=source, target=target, weight='weight')
                paths.append((length, path, target))
            except nx.NetworkXNoPath:
                continue

        paths = sorted(paths, key=lambda p: p[0])
        
        for p in range(len(proximities)):
            p_graph = output[p][1]
            for length, path, target in paths[:proximities[p]]:
                if p_graph.has_edge(source, target):
                    continue
                p_graph.add_edge(source, target, weight=length)
                output[p][2].update((min(path[k], path[k+1]), max(path[k], path[k+1])) for k in range(len(path)-1))

    return output

In [6]:
def calculate_network_cost(G, path_edges):
    """
    Calculates the cost of the total network
    """
    total_cost = 0.0
    for edge in path_edges:
        if G.has_edge(edge[0], edge[1]):
            total_cost += G[edge[0]][edge[1]]['weight']
    return total_cost

def calculate_mean_travel_time(G):
    """
    Calculates the mean travel time on the network
    """
    return G.size(weight='weight') / G.size()

def calculate_network_vulnerability(G, ref_travel_time):
    """
    Calculates the mean vulnerability on the network.
    If the graph becomes disconnected after an edge is removed,
    the vulnerability is then the weight of that edge.
    """
    vulnerabilities = []
    for e in G.edges():
        # Create a copy of the graph to avoid modifying the original
        G_copy = G.copy()
        G_copy.remove_edge(*e)
        mean_time = calculate_mean_travel_time(G_copy)
        if nx.is_connected(G_copy):
            vuln = np.abs(ref_travel_time - mean_time)
        else:
            vuln = G[e[0]][e[1]]['weight']
        vulnerabilities.append(vuln)
    return np.mean(vulnerabilities)

In [15]:
def process_file(path: str) -> ResultSet:
    """
    Process a single result file
    """
    with np.load(path) as data:
        return ResultSet(all_trails = data['all_trails'],
                         seed = data['seed'],
                         nuclei = data['nuclei'],
                         stations = data['stations'],
                         map_with_stations = data['map_with_stations'],
                         start_pos = data['start_pos'])

def process_folder(dir_path: str, proximities):
    """
    Process a folder containing experiment result files.
    """
    full_path = dir_path + "/*.npz"
    files = glob(full_path)
    
    results = []
    for i in range(len(files)):
        path = files[i]
        result_set = process_file(path)
        full_graph = create_nuclei_graph(result_set)
        pruned_graph = prune_edges_by_map(full_graph, result_set)
        
        for proximity, refined_graph, paths_set in build_refined_station_network(pruned_graph, result_set, proximities):
            total_cost = calculate_network_cost(pruned_graph, paths_set)
            mean_travel_time = calculate_mean_travel_time(refined_graph)
            network_vulnerability = calculate_network_vulnerability(refined_graph, mean_travel_time)
            is_connected = nx.is_connected(refined_graph)
    
            results.append((proximity, result_set.start_pos[0], result_set.start_pos[1], result_set.seed, 
                            total_cost, mean_travel_time, network_vulnerability, is_connected))
    
    df = pd.DataFrame(results, columns=["proximity", "start_pos_x", "start_pos_y", "seed", 
                                        "total_cost", "mean_travel_time", "vulnerability", "is_connected"])
    df_path = dir_path + f"/results.csv"
    df.to_csv(df_path)

In [11]:
proximities = range(1, 6) # Define the proximities we want to use: [1, 6), i.e., {1,2,3,4,5}

# Process experiments
process_folder("../experiment_outputs_different_starts", proximities)
process_folder("../experiment_outputs_same_starts", proximities)