In [1]:
import polars as pl
import numpy as np

# Load the data
df = pl.read_csv('./connectome_graph.csv')

# Extract arrays
source_nodes = df[df.columns[0]].to_numpy().astype(np.int64)
target_nodes = df[df.columns[1]].to_numpy().astype(np.int64)
edge_weights = df[df.columns[2]].to_numpy().astype(np.int64)


In [2]:
# Get unique node IDs and map to indices
unique_nodes = np.unique(np.concatenate((source_nodes, target_nodes)))
node_id_to_index = {node_id: idx for idx, node_id in enumerate(unique_nodes)}
index_to_node_id = {idx: node_id for node_id, idx in node_id_to_index.items()}

# Map node IDs to indices in edge lists
source_indices = np.array([node_id_to_index[node_id] for node_id in source_nodes])
target_indices = np.array([node_id_to_index[node_id] for node_id in target_nodes])

In [3]:
import jax.numpy as jnp
import jax
from jax import random

# Convert to JAX arrays
source_indices = jnp.array(source_indices)
target_indices = jnp.array(target_indices)
edge_weights = jnp.array(edge_weights)

# Compute maximum edge weight
total_edge_weight = jnp.sum(edge_weights)

# Normalize edge weights
edge_weights = edge_weights / total_edge_weight


In [4]:
num_nodes = len(unique_nodes)
key = random.PRNGKey(0)
positions = random.uniform(key, shape=(num_nodes,))

In [5]:

sorted_indices = jnp.argsort(positions)

# Create a mapping from node index to order in the sequence
node_order = jnp.zeros(num_nodes, dtype=int)
node_order = node_order.at[sorted_indices].set(jnp.arange(num_nodes))

edge_directions = node_order[target_indices] - node_order[source_indices]

forward_edges = edge_directions > 0

total_forward_weight_initial = jnp.sum(edge_weights * forward_edges)

total_edge_weight = jnp.sum(edge_weights)
original_total_edge_weights = total_edge_weight
# Compute the percentage of forward edge weight
percentage_forward_initial = 100 * float(total_forward_weight_initial) / (total_edge_weight)

# Print the results
print(f"Total Forward Edge Weight (Initial): {total_forward_weight_initial}")
print(f"Percentage of Forward Edge Weight (Initial): {percentage_forward_initial:.2f}%")

Total Forward Edge Weight (Initial): 0.5002518892288208
Percentage of Forward Edge Weight (Initial): 50.03%


In [6]:
def normalize_positions(positions):
    # Normalize positions to have zero mean and unit variance
    mean = jnp.mean(positions)
    std = jnp.std(positions) + 1e-8  # Add epsilon to avoid division by zero
    positions = (positions - mean) / std
    return positions

@jax.jit
def objective_function(positions, source_indices, target_indices, edge_weights, epoch):
    # Get positions of source and target nodes
    pos_source = positions[source_indices]
    pos_target = positions[target_indices]
    
    delta = pos_target - pos_source
    
    beta = 10.0  
    sigmoid = safe_sigmoid(beta * delta)
    # Compute the weighted sum
    total_forward_weight = jnp.sum(edge_weights * sigmoid)
    # if epoch % 10 == 0:
    #     jax.debug.print("Sig{y} | {z} | {q}", y=jnp.sum(sigmoid), z=jnp.sum(pos_source), q=jnp.sum(pos_target))

    return -total_forward_weight

def safe_sigmoid(x):

    return jnp.where(
        x >= 0,
        1 / (1 + jnp.exp(-x)),
        jnp.exp(x) / (1 + jnp.exp(x))
    )


In [7]:
import jax
import optax

# Create the gradient function
objective_grad = jax.grad(objective_function)

# Define the optimizer
optimizer = optax.adam(learning_rate=0.01)

# Initialize optimizer state
opt_state = optimizer.init(positions)

In [8]:
def calculate_metric(positions):
    # Get final positions
    final_positions = positions

    # Sort node indices based on positions
    sorted_indices = jnp.argsort(final_positions)

    # Map back to node IDs
    # ordered_node_ids = [index_to_node_id[int(idx)] for idx in sorted_indices]

    # Create a mapping from node index to order in the final sequence
    node_order = jnp.zeros(num_nodes)
    node_order = node_order.at[sorted_indices].set(jnp.arange(num_nodes))

    # Compute the direction of each edge in the final ordering
    edge_directions = node_order[target_indices] - node_order[source_indices]

    # Edges pointing forward have positive edge_directions
    forward_edges = edge_directions > 0

    # Compute the total forward edge weight
    total_forward_weight = jnp.sum(edge_weights * forward_edges)
    total_edge_weight = jnp.sum(edge_weights)

    print(f"Total Forward Edge Weight: {total_forward_weight}")
    print(f"Percentage of Forward Edge Weight: {100 * float(total_forward_weight) / float(total_edge_weight):.2f}%")

In [9]:
from tqdm import tqdm
num_epochs = 10000

for epoch in tqdm(range(num_epochs)):
    # Compute gradients
    loss, grads = jax.value_and_grad(objective_function)(positions, source_indices, target_indices, edge_weights, epoch)
    
    # Update positions
    updates, opt_state = optimizer.update(grads, opt_state)
    positions = optax.apply_updates(positions, updates)
    positions = normalize_positions(positions)
    # Optional: Print progress every 100 epochs
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {-loss}")
        calculate_metric(positions)


  0%|          | 0/10000 [00:00<?, ?it/s]


TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function objective_function at /var/folders/lh/qs4dckx57_x56y2s1_8blvm80000gq/T/ipykernel_85127/1686088588.py:8 for jit. This concrete value was not available in Python because it depends on the value of the argument epoch.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError