In [1]:
import polars as pl
import numpy as np

# Load the data
df = pl.read_csv('./connectome_graph.csv')

# Extract arrays
source_nodes = df[df.columns[0]].to_numpy().astype(np.int64)
target_nodes = df[df.columns[1]].to_numpy().astype(np.int64)
edge_weights = df[df.columns[2]].to_numpy().astype(np.int64)


In [2]:
# Get unique node IDs and map to indices
unique_nodes = np.unique(np.concatenate((source_nodes, target_nodes)))
node_id_to_index = {node_id: idx for idx, node_id in enumerate(unique_nodes)}
index_to_node_id = {idx: node_id for node_id, idx in node_id_to_index.items()}

# Map node IDs to indices in edge lists
source_indices = np.array([node_id_to_index[node_id] for node_id in source_nodes])
target_indices = np.array([node_id_to_index[node_id] for node_id in target_nodes])

In [3]:
import jax.numpy as jnp
import jax
from jax import random

# Convert to JAX arrays
source_indices = jnp.array(source_indices)
target_indices = jnp.array(target_indices)
edge_weights = jnp.array(edge_weights)

# Compute maximum edge weight
total_edge_weight = jnp.sum(edge_weights)

# Normalize edge weights
edge_weights = edge_weights / total_edge_weight


In [4]:
num_nodes = len(unique_nodes)
key = random.PRNGKey(0)
positions = random.uniform(key, shape=(num_nodes,))

In [5]:

sorted_indices = jnp.argsort(positions)

# Create a mapping from node index to order in the sequence
node_order = jnp.zeros(num_nodes, dtype=int)
node_order = node_order.at[sorted_indices].set(jnp.arange(num_nodes))

edge_directions = node_order[target_indices] - node_order[source_indices]

forward_edges = edge_directions > 0

total_forward_weight_initial = jnp.sum(edge_weights * forward_edges)

total_edge_weight = jnp.sum(edge_weights)
original_total_edge_weights = total_edge_weight
# Compute the percentage of forward edge weight
percentage_forward_initial = 100 * float(total_forward_weight_initial) / (total_edge_weight)

# Print the results
print(f"Total Forward Edge Weight (Initial): {total_forward_weight_initial}")
print(f"Percentage of Forward Edge Weight (Initial): {percentage_forward_initial:.2f}%")

Total Forward Edge Weight (Initial): 0.5002518892288208
Percentage of Forward Edge Weight (Initial): 50.03%


In [6]:
def normalize_positions(positions):
    # Normalize positions to have zero mean and unit variance
    mean = jnp.mean(positions)
    std = jnp.std(positions) + 1e-8  # Add epsilon to avoid division by zero
    positions = (positions - mean) / std
    return positions

@jax.jit
def objective_function(positions, source_indices, target_indices, edge_weights, epoch):
    # Get positions of source and target nodes
    pos_source = positions[source_indices]
    pos_target = positions[target_indices]
    
    delta = pos_target - pos_source
    
    beta = 2.0  
    sigmoid = safe_sigmoid(beta * delta)
    # Compute the weighted sum
    total_forward_weight = jnp.sum(edge_weights * sigmoid)
    # if epoch % 10 == 0:
    #     jax.debug.print("Sig{y} | {z} | {q}", y=jnp.sum(sigmoid), z=jnp.sum(pos_source), q=jnp.sum(pos_target))

    return -total_forward_weight

def safe_sigmoid(x):

    return jnp.where(
        x >= 0,
        1 / (1 + jnp.exp(-x)),
        jnp.exp(x) / (1 + jnp.exp(x))
    )


In [7]:
import jax
import optax

# Create the gradient function
objective_grad = jax.grad(objective_function)

# Define the optimizer
optimizer = optax.adam(learning_rate=0.005)

# Initialize optimizer state
opt_state = optimizer.init(positions)

In [8]:
def calculate_metric(positions):
    # Get final positions
    final_positions = positions

    # Sort node indices based on positions
    sorted_indices = jnp.argsort(final_positions)

    # Map back to node IDs
    # ordered_node_ids = [index_to_node_id[int(idx)] for idx in sorted_indices]

    # Create a mapping from node index to order in the final sequence
    node_order = jnp.zeros(num_nodes)
    node_order = node_order.at[sorted_indices].set(jnp.arange(num_nodes))

    # Compute the direction of each edge in the final ordering
    edge_directions = node_order[target_indices] - node_order[source_indices]

    # Edges pointing forward have positive edge_directions
    forward_edges = edge_directions > 0

    # Compute the total forward edge weight
    total_forward_weight = jnp.sum(edge_weights * forward_edges)
    total_edge_weight = jnp.sum(edge_weights)

    print(f"Total Forward Edge Weight: {total_forward_weight}")
    print(f"Percentage of Forward Edge Weight: {100 * float(total_forward_weight) / float(total_edge_weight):.2f}%")
    return 100 * float(total_forward_weight) / float(total_edge_weight)

In [9]:
from tqdm import tqdm
num_epochs = 1000

for epoch in range(num_epochs):
    # Compute gradients
    loss, grads = jax.value_and_grad(objective_function)(positions, source_indices, target_indices, edge_weights, epoch)
    
    # Update positions
    updates, opt_state = optimizer.update(grads, opt_state)
    positions = optax.apply_updates(positions, updates)
    positions = normalize_positions(positions)
    # Optional: Print progress every 100 epochs
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {-loss}")
        calculate_metric(positions)


Epoch 0, Loss: 0.5001700520515442
Total Forward Edge Weight: 0.5031319856643677
Percentage of Forward Edge Weight: 50.31%
Epoch 100, Loss: 0.5667829513549805
Total Forward Edge Weight: 0.5832952857017517
Percentage of Forward Edge Weight: 58.33%
Epoch 200, Loss: 0.6246671080589294
Total Forward Edge Weight: 0.6540396809577942
Percentage of Forward Edge Weight: 65.40%
Epoch 300, Loss: 0.6644023656845093
Total Forward Edge Weight: 0.708562970161438
Percentage of Forward Edge Weight: 70.86%
Epoch 400, Loss: 0.6842442750930786
Total Forward Edge Weight: 0.7384061813354492
Percentage of Forward Edge Weight: 73.84%
Epoch 500, Loss: 0.6930047273635864
Total Forward Edge Weight: 0.7542293667793274
Percentage of Forward Edge Weight: 75.42%
Epoch 600, Loss: 0.6966937780380249
Total Forward Edge Weight: 0.7612147331237793
Percentage of Forward Edge Weight: 76.12%
Epoch 700, Loss: 0.69835364818573
Total Forward Edge Weight: 0.7648435235023499
Percentage of Forward Edge Weight: 76.48%
Epoch 800, Lo

# Simulated Annealing

In [10]:
import jax
import jax.numpy as jnp
from jax import random, jit

def compute_total_forward_weight(positions, source_indices, target_indices, edge_weights_normalized):
    node_order = jnp.argsort(positions)
    node_ranks = jnp.zeros_like(node_order)
    node_ranks = node_ranks.at[node_order].set(jnp.arange(len(node_order)))
    edge_directions = node_ranks[target_indices] - node_ranks[source_indices]
    forward_edges = edge_directions > 0
    total_forward_weight = jnp.sum(edge_weights_normalized * forward_edges)
    return total_forward_weight

# Function to swap two nodes (JAX-compatible)
def swap_positions(positions, i, j):
    positions = positions.at[i].set(positions[j])
    positions = positions.at[j].set(positions[i])
    return positions

# Simulated annealing step function (JAX-compatible)
@jit
def simulated_annealing_step(i, state):
    key, temp, current_positions, current_weight, best_positions, best_weight, source_indices, target_indices, edge_weights_normalized = state

    # Generate two random indices to swap
    key, subkey = random.split(key)
    num_nodes = len(current_positions)
    i, j = random.choice(subkey, num_nodes, shape=(2,), replace=False)  # Corrected: use `num_nodes` instead of `current_positions.shape`

    new_positions = swap_positions(current_positions, i, j)

    # Compute the new forward weight
    new_weight = compute_total_forward_weight(new_positions, source_indices, target_indices, edge_weights_normalized)
    delta_weight = new_weight - current_weight

    # Compute the acceptance probability
    accept = jnp.exp(delta_weight / temp)
    should_accept = (delta_weight > 0) | (random.uniform(subkey) < accept)

    current_positions = jnp.where(should_accept, new_positions, current_positions)
    current_weight = jnp.where(should_accept, new_weight, current_weight)

    # Update the best solution if the new solution is better
    best_positions = jnp.where(new_weight > best_weight, new_positions, best_positions)
    best_weight = jnp.where(new_weight > best_weight, new_weight, best_weight)

    temp = temp * 0.995

    return key, temp, current_positions, current_weight, best_positions, best_weight, source_indices, target_indices, edge_weights_normalized

# Simulated annealing loop (JAX-compatible)
@jit
def simulated_annealing(positions, source_indices, target_indices, edge_weights_normalized, initial_temp=1.0, final_temp=0.001, max_iter=10000):
    key = random.PRNGKey(0)

    # Initial conditions
    current_positions = positions.copy()
    current_weight = compute_total_forward_weight(current_positions, source_indices, target_indices, edge_weights_normalized)
    best_positions = current_positions.copy()
    best_weight = current_weight
    temp = initial_temp

    # Simulated annealing loop
    def body_fn(i, state):
        return simulated_annealing_step(i, state)

    final_state = jax.lax.fori_loop(0, max_iter, body_fn, (key, temp, current_positions, current_weight, best_positions, best_weight, source_indices, target_indices, edge_weights_normalized))

    _, _, _, _, best_positions, best_weight, _, _, _ = final_state
    return best_positions, best_weight


In [13]:
initial_positions = random.uniform(random.PRNGKey(42), shape=(len(positions),))  # Random initial positions
best_positions, best_weight = simulated_annealing(positions, source_indices, target_indices, edge_weights, max_iter=100)

print("Best forward edge weight:", best_weight, "Best positions:", calculate_metric(best_positions))


Total Forward Edge Weight: 0.7682592272758484
Percentage of Forward Edge Weight: 76.83%
Best forward edge weight: 0.7682592 Best positions: None


In [14]:
best_positions

Array([ 0.24743298, -1.042535  ,  0.22388445, ..., -0.55253047,
        1.9726529 , -1.7673079 ], dtype=float32)