In [73]:
#
# Emre Alca
# University of Pennsylvania
# Created on Sat Nov 22 2025
#

In [None]:
import numpy as np
import trimesh
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 (needed for 3D)
# from tqdm import tqdm
import sys
import time
from IPython.display import HTML

%matplotlib widget

from src import spindle_state as ss
import pickle
import os


Need to implement:
basic functionality for short timescales:
- [x] spindle state
- [x] pulling forces of a given spindle state
- [x] pushing forces of a given spindle state
- [x] equations of motion
- [x] time evolution

Turnover regimes
- [ ] Markovian catastrophe and nucleation 
- [ ] gradient descent
    - [x] cost function
    - [x] sampling spatial nucleation distribution
    - [x] sampling spatial catastrophe distribution
    - [x] sampling length nucleation distribution
    - [x] sampling length catastrophe distribution
    - [ ] spindle update
    - [ ] stochastic gradient descent loop



In [75]:
test_spindle_lattice = np.array([
    [1, 0, 0],
    [-1, 0, 0],
    [0, 1, 0],
    [0, -1, 0],
    [0, 0, 1],
    [0, 0, -1],
])

expected_mt_vecs = np.array([
       [ 0.5,  0. ,  0. ],
       [-1.5,  0. ,  0. ],
       [-0.5,  1. ,  0. ],
       [-0.5, -1. ,  0. ],
       [-0.5,  0. ,  1. ],
       [-0.5,  0. , -1. ]])

test_spindle_state = np.array([1, 1, 3, 3, 1, 1])

test_spindle = ss.Spindle(np.array([0, 0, 0]), test_spindle_state, test_spindle_lattice, timestep_size=0.01)

In [83]:
-1e-4

-0.0001

In [77]:

# testing update spindle

# given unstable position and some cost

test_spindle.add_microtubules([1,0])
test_spindle.set_mtoc_pos(np.array([0, 0.5, 0]))

for i in range(100):

    old_spindle_state = np.copy(test_spindle.spindle_state)
    old_mtoc_pos = np.copy(test_spindle.mtoc_pos)
    old_cost = test_spindle.calc_cost()

    # call spindle update

    attempts = test_spindle.gradient_descent_spindle_update()

    new_spindle_state = np.copy(test_spindle.spindle_state)
    new_mtoc_pos = np.copy(test_spindle.mtoc_pos)
    new_cost = test_spindle.calc_cost()

    assert not (new_spindle_state == old_spindle_state).all()
    assert not (new_mtoc_pos == old_mtoc_pos).all()
    # assert (new_cost <= old_cost)
    assert np.round(new_cost, 6) <= np.round(old_cost, 6)

I now have a working spindle update. I can take a spindle state and modify it such that the cost function is minimized.

The issue for me right now is that no turnover is forced. The system can find a stable position and stay there. How can I force turnover? By setting a threshold for modifying the spindle to be a cost delta of -0.0001 rather than 0?

The other question is how should I save the data of a simulation?

Data I need to save:
- time
- spindle state
- mtoc position

I also want to keep track of when spindle updates occur and how 'difficult' they are by tracking the number of attempts needed. 
I suppose 

In [78]:
spindle_dict = test_spindle.as_dict()    

# -- parameters --
assert (spindle_dict['mtoc_pos'] == test_spindle.mtoc_pos).all()
assert (spindle_dict['spindle_state'] == test_spindle.spindle_state).all()
assert (spindle_dict['lattice_sites'] == test_spindle.lattice_sites).all()

# # -- hyperparameters --
assert spindle_dict['f_pull_0'] == test_spindle.f_pull_0
assert spindle_dict['rigidity'] == test_spindle.rigidity
assert spindle_dict['friction_coefficient'] == test_spindle.friction_coefficient
assert spindle_dict['growth_rate'] == test_spindle.growth_rate
assert spindle_dict['stall_force'] == test_spindle.stall_force
assert spindle_dict['drag_factor'] == test_spindle.drag_factor
assert spindle_dict['boundary_radius'] == test_spindle.boundary_radius
assert spindle_dict['timestep_size'] == test_spindle.timestep_size
assert spindle_dict['max_total_mt_length'] == test_spindle.max_total_mt_length
assert spindle_dict['mt_len_cost_punishment_degree'] == test_spindle.mt_len_cost_punishment_degree
assert spindle_dict['cytoplasmic_catastrophe_rate'] == test_spindle.cytoplasmic_catastrophe_rate

In [79]:
# file_path = '/Users/emrealca/Documents/Penn/flatiron-microtubules/simulations/data/self_2026-01-15_15-16-05.pkl'
# file_path = '/Users/emrealca/Documents/Penn/flatiron-microtubules/simulations/data/self_2026-01-16_08-13-21.pkl'
file_path = '/Users/emrealca/Documents/Penn/flatiron-microtubules/simulations/data/self_2026-01-16_08-31-04.pkl'

with open(file_path, "rb") as f:
    data = pickle.load(f)

trajectory = data['trajectory']

ts = list(trajectory.keys())

xs = []
ys = []
zs = []

for key in trajectory.keys():

    xs.append(trajectory[key]['mtoc_pos'][0])
    ys.append(trajectory[key]['mtoc_pos'][1])
    zs.append(trajectory[key]['mtoc_pos'][2])

reformatted_trajectory = np.column_stack([ts, xs, ys, zs])

traj_to_plot = reformatted_trajectory[::10]

In [80]:
trajectory[list(trajectory.keys())[0]]

{'spindle_state': array([1.000, 1.000, 3.000, 3.000, 1.000, 1.000]),
 'mtoc_pos': array([0.500, 0.500, 0.500]),
 'boundary_violated': False,
 'cost': np.float64(0.7500178201697102),
 'num_update_attempts': 6}

In [None]:
# ---- Extract columns ----
t = traj_to_plot[:, 0]
x = traj_to_plot[:, 1]
y = traj_to_plot[:, 2]
z = traj_to_plot[:, 3]

# ---- Set up figure ----
fig = plt.figure(figsize=(6, 5))
ax = fig.add_subplot(111, projection='3d')

# Precompute limits so they don't rescale during animation
pad = 0.05
xmin, xmax = x.min(), x.max()
ymin, ymax = y.min(), y.max()
zmin, zmax = z.min(), z.max()
xr = xmax - xmin; yr = ymax - ymin; zr = zmax - zmin
# ax.set_xlim(xmin - pad*xr, xmax + pad*xr)
# ax.set_ylim(ymin - pad*yr, ymax + pad*yr)
# ax.set_zlim(zmin - pad*zr, zmax + pad*zr)
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(-1, 1)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
ax.set_title('MTOC Trajectory')

# static components
target = ax.scatter([0], [0], [0], s=40, color='green', label='origin')

# Line for the trail and a scatter for the head
(line,) = ax.plot([], [], [], lw=2, color='tab:blue')
head = ax.scatter([], [], [], s=40, color='crimson', label='mtoc')

# text in corner
spindle_state_text = ax.text2D(0.02, 0.95, '', transform=ax.transAxes)
time_text = ax.text2D(0.02, 0.85, '', transform=ax.transAxes)
cost_text = ax.text2D(0.02, 0.90, '', transform=ax.transAxes)

ax.legend()

def init():
    line.set_data([], [])
    line.set_3d_properties([])
    head._offsets3d = ([], [], [])
    time_text.set_text('')
    cost_text.set_text('')
    return line, head, time_text

def update(i):
    # Draw up to frame i
    line.set_data(x[:i+1], y[:i+1])
    line.set_3d_properties(z[:i+1])
    head._offsets3d = (np.array([x[i]]), np.array([y[i]]), np.array([z[i]]))

    time_text.set_text(f't = {t[i]:.2f}')

    cost = np.round(trajectory[t[i]]['cost'], 3)
    cost_text.set_text(f'cost = {cost}')
    spindle_state = trajectory[t[i]]['spindle_state'].astype(int)
    spindle_state_text.set_text(f'spindle state: {spindle_state}')
    return line, head, time_text

# Interval controls speed (ms between frames). 
# If your t is in seconds, you can compute an interval that matches real time:
#   intervals = np.diff(t); avg = np.mean(intervals); interval_ms = avg*1000
# For simplicity, use a fixed interval:
ani = FuncAnimation(fig, update, frames=len(t), init_func=init,
                    interval=0.001, blit=False, repeat=True)

plt.tight_layout()
plt.show()
