In [None]:
import numpy as np
import MDAnalysis as mda
from MDAnalysis.analysis import distances
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from collections import defaultdict
from scipy import stats
import pandas as pd

import sys, os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../')))
from md_styler import MDStyler


In [None]:
# ---------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------
inner_msd_path   = "../figures/diffusion/inner_again_msd_data.txt"
inner_drift_path = "../figures/diffusion/inner_again_drift.txt"

# ---------------------------------------------------------------------
# Load data
# ---------------------------------------------------------------------
data = np.loadtxt(inner_msd_path)
time_ps   = data[:, 0]   # ps
msd_z_A2  = data[:, 3]   # Z column (physically your Y direction, per convention)


# ---------------------------------------------------------------------
# Unit conversion: ps -> μs, Å² -> nm²
# ---------------------------------------------------------------------
time_us      = time_ps / 1e6
msd_z_nm2    = msd_z_A2 * 0.01  # Å² -> nm²
N            = len(time_us)

# ---------------------------------------------------------------------
# Main fit window: fixed 0.2–0.8 (in fraction of trajectory)
# ---------------------------------------------------------------------
main_start_frac = 0.2
main_end_frac   = 0.8

main_start_idx = int(main_start_frac * N)
main_end_idx   = int(main_end_frac * N)

slope_main, intercept_main, r_value_main, p_value_main, std_err_main = stats.linregress(
    time_us[main_start_idx:main_end_idx],
    msd_z_nm2[main_start_idx:main_end_idx]
)

D_main = slope_main / 2.0  # 1D diffusion
r2_main = r_value_main**2

# Fitted line for the main window
fit_line_main = slope_main * time_us + (
    msd_z_nm2[main_start_idx] - slope_main * time_us[main_start_idx]
)

# ---------------------------------------------------------------------
# Window sampling around the main window to estimate sensitivity
# ---------------------------------------------------------------------
# Sample a grid of start/end fractions around [0.2, 0.8]
start_fracs = np.linspace(0.1, 0.5, 5)  # 0.15–0.25
end_fracs   = np.linspace(0.75, 1, 5)  # 0.75–0.85

D_samples = []
fit_lines = []  # for plotting faint sampled fits

min_points = 20  # require at least this many points per fit

for s_frac in start_fracs:
    for e_frac in end_fracs:
        if e_frac <= s_frac:
            continue

        i0 = int(s_frac * N)
        i1 = int(e_frac * N)

        if i1 - i0 < min_points:
            continue

        # Linear fit for this window
        slope, intercept, r_val, p_val, std_err = stats.linregress(
            time_us[i0:i1],
            msd_z_nm2[i0:i1]
        )
        D = slope / 2.0
        D_samples.append(D)

        # store fit line segment for plotting
        local_fit = slope * time_us + (msd_z_nm2[i0] - slope * time_us[i0])
        fit_lines.append((i0, i1, local_fit))

D_samples = np.array(D_samples)

if len(D_samples) == 0:
    print("⚠️ No valid sampled windows found; only main fit will be shown.")
    D_mean = D_main
    D_std  = 0.0
else:
    D_mean = np.mean(D_samples)
    D_std  = np.std(D_samples)

print(f"Main window  : [{main_start_frac:.2f}, {main_end_frac:.2f}]")
print(f"D_main       = {D_main:.3e} nm²/μs")
print(f"D_mean(win)  = {D_mean:.3e} nm²/μs")
print(f"σ_window(D)  = {D_std:.3e} nm²/μs")
print(f"R²(main fit) = {r2_main:.3f}")

# ---------------------------------------------------------------------
# Plot using MDStyler
# ---------------------------------------------------------------------
sty = MDStyler().apply()
# colors: data, main fit, sampled fits
c_data, c_main_fit, c_sample = sty.get_palette(3)

fig, ax = sty.fig_horizontal()

# 2) Plot MSD data
ax.plot(
    time_us,
    msd_z_nm2,
    color=c_data,
    **sty.as_aa(label="Inner MSD$_z$")
)

# 3) Plot all sampled fit lines faintly
for i0, i1, local_fit in fit_lines:
    ax.plot(
        time_us[i0:i1],
        local_fit[i0:i1],
        color=c_sample,
        linewidth=1.0,
        alpha=0.15,
        linestyle=sty.ls_cg  # faint dashed
    )

# 4) Plot the main fit line (bold)
ax.plot(
    time_us[main_start_idx:main_end_idx],
    fit_line_main[main_start_idx:main_end_idx],
    color=c_main_fit,
    **sty.as_cg(label="Linear fit")
)

# Labels, title
ax.set_xlabel(r"Time ($\mu$s)")
ax.set_ylabel(r"MSD$_z$ (nm$^2$)")

ax.set_title(
    r"Bundle D$_z$ = "
    + f"{D_main:.2e}"
    + r" nm$^2$/$\mu$s"
    + " "
    + r"$\pm$"
    + f" {D_std:.1e}"
)

ax.legend(frameon=False)

plt.show()

In [15]:

# Configuration
CUTOFF = 5.0  # Angstroms
FRAME_RATE = 20  # ps per frame
O_PER_MICELLE = 440  # 6160*2 O atoms per micelle

# Load trajectory
u = mda.Universe('../SOAP/cores.gro', '../SOAP/cores.xtc')
print(f"Loaded trajectory with {len(u.trajectory)} frames")
print(f"Total time: {len(u.trajectory) * FRAME_RATE / 1000:.2f} ns")

# Select atom groups
ca_atoms = u.select_atoms('name Ca')
o_atoms = u.select_atoms('name O1 O2')
n_ca = len(ca_atoms)
n_o = len(o_atoms)

print(f"\nSystem composition:")
print(f"  Ca atoms: {n_ca}")
print(f"  O atoms: {n_o}")
print(f"  Micelles: {n_o // O_PER_MICELLE}")

def get_micelle_id(o_index):
    """Get micelle ID for an oxygen atom based on its index."""
    return o_index // O_PER_MICELLE

def compute_coordination(ca_pos, o_pos, box):
    """Compute which O atoms are within cutoff of each Ca atom."""
    # Returns list of lists: for each Ca, which O indices are bound
    dist_array = distances.distance_array(ca_pos, o_pos, box=box)
    coordination = []
    for i in range(len(ca_pos)):
        bound_o = np.where(dist_array[i] <= CUTOFF)[0]
        coordination.append(bound_o)
    return coordination

def analyze_bridges(coordination_list):
    """Identify Ca atoms that bridge different micelles."""
    bridges = []
    for ca_idx, bound_o in enumerate(coordination_list):
        if len(bound_o) >= 2:
            micelles = set(get_micelle_id(o_idx) for o_idx in bound_o)
            if len(micelles) >= 2:
                bridges.append(ca_idx)
    return bridges

Loaded trajectory with 5000 frames
Total time: 100.00 ns

System composition:
  Ca atoms: 3300
  O atoms: 6600
  Micelles: 15


# Analyse trajectory

In [16]:
print("\nAnalyzing trajectory...")
coordination_history = []  # For each frame: list of Ca coordinations
coordination_counts = []   # For each frame: [n_0bond, n_1bond, n_2bond, n_3plus]
bridge_counts = []         # Number of bridges per frame

for ts in u.trajectory:
    coord = compute_coordination(ca_atoms.positions, o_atoms.positions, ts.dimensions)
    coordination_history.append(coord)
    
    # Count coordination numbers
    counts = [0, 0, 0, 0]  # 0, 1, 2, 3+ bonds
    for bound_o in coord:
        n_bonds = len(bound_o)
        if n_bonds == 0:
            counts[0] += 1
        elif n_bonds == 1:
            counts[1] += 1
        elif n_bonds == 2:
            counts[2] += 1
        else:
            counts[3] += 1
    coordination_counts.append(counts)
    
    # Count bridges
    bridges = analyze_bridges(coord)
    bridge_counts.append(len(bridges))

coordination_counts = np.array(coordination_counts)
bridge_counts = np.array(bridge_counts)
times = np.arange(len(u.trajectory)) * FRAME_RATE / 1000  # in ns

print(f"Analysis complete for {len(coordination_history)} frames")



Analyzing trajectory...
Analysis complete for 5000 frames


In [None]:

# ---------------------------------------------------------------
# 1. Apply styler
# ---------------------------------------------------------------
sty = MDStyler().apply()

# ---------------------------------------------------------------
# 2. For each frame: count bridging vs non-bridging Ca in
#    the 2-bond and 3+ bond categories
# ---------------------------------------------------------------
n_frames = len(coordination_history)

bridge_2 = np.zeros(n_frames, dtype=int)   # Ca with exactly 2 bonds AND bridging
bridge_3p = np.zeros(n_frames, dtype=int)  # Ca with 3+ bonds AND bridging

for frame_idx, coord in enumerate(coordination_history):
    # coord[ca_idx] is an array of bound O indices for that Ca at this frame
    for ca_idx, bound_o in enumerate(coord):
        n_bonds = len(bound_o)
        if n_bonds >= 2:
            micelles = {get_micelle_id(o_idx) for o_idx in bound_o}
            is_bridge = (len(micelles) >= 2)
            if is_bridge:
                if n_bonds == 2:
                    bridge_2[frame_idx] += 1
                else:  # n_bonds >= 3
                    bridge_3p[frame_idx] += 1

# ---------------------------------------------------------------
# 3. Sanity check: our detailed bridging count should match your
#    original bridge_counts array frame-by-frame
# ---------------------------------------------------------------
total_bridges_from_split = bridge_2 + bridge_3p
if not np.array_equal(total_bridges_from_split, bridge_counts):
    print("⚠️ WARNING: total bridges from 2/3+ split "
          "do NOT match `bridge_counts`.")
    print("   max abs diff:", np.abs(total_bridges_from_split - bridge_counts).max())
else:
    print("✅ Bridge split matches original `bridge_counts` exactly.")

# ---------------------------------------------------------------
# 4. Averages and std devs
# ---------------------------------------------------------------
avg_counts = coordination_counts.mean(axis=0)   # [0, 1, 2, 3+]
std_counts = coordination_counts.std(axis=0)

avg_bridge_2 = bridge_2.mean()
avg_bridge_3p = bridge_3p.mean()

# Non-bridging parts of 2 and 3+ categories
avg_nonbridge_2 = avg_counts[2] - avg_bridge_2
avg_nonbridge_3p = avg_counts[3] - avg_bridge_3p

# Clip tiny negative values from numerical noise
avg_nonbridge_2 = max(0.0, avg_nonbridge_2)
avg_nonbridge_3p = max(0.0, avg_nonbridge_3p)

In [None]:
from matplotlib.patches import Patch
# ---------------------------------------------------------------
# 5. Build *normalised* stacked bar heights (fractions of all Ca)
# ---------------------------------------------------------------
categories = ['0 bonds\n(Free)', '1 bond', '2 bonds', '3+ bonds']
x = np.arange(len(categories))

# Normalise by total Ca
avg_counts_frac = avg_counts / n_ca * 100.0
std_counts_frac = std_counts / n_ca * 100.0

avg_bridge_2_frac = avg_bridge_2 / n_ca * 100.0
avg_bridge_3p_frac = avg_bridge_3p / n_ca * 100.0

avg_nonbridge_2_frac = avg_counts_frac[2] - avg_bridge_2_frac
avg_nonbridge_3p_frac = avg_counts_frac[3] - avg_bridge_3p_frac

# Clip tiny negatives from float noise
avg_nonbridge_2_frac = max(0.0, avg_nonbridge_2_frac)
avg_nonbridge_3p_frac = max(0.0, avg_nonbridge_3p_frac)

heights_main = np.array([
    avg_counts_frac[0],       # 0 bonds
    avg_counts_frac[1],       # 1 bond
    avg_nonbridge_2_frac,     # 2 bonds non-bridge
    avg_nonbridge_3p_frac,    # 3+ bonds non-bridge
])

heights_bridge = np.array([
    0.0,
    0.0,
    avg_bridge_2_frac,        # 2 bonds bridging
    avg_bridge_3p_frac,       # 3+ bonds bridging
])
# ---------------------------------------------------------------
# 6. Colors from MDStyler (matte VMD-like)
# ---------------------------------------------------------------
c_free = sty.get_color("black")
c_1 = sty.get_color("gray")
c_2_non = sty.get_color("cyan")
c_3p_non = sty.get_color("box")
c_bridge = sty.get_color("green")

colors_main = [c_free, c_1, c_2_non, c_3p_non]
colors_bridge = [ "none", "none", c_bridge, c_bridge ]

# ---------------------------------------------------------------
# 7. Create figure
# ---------------------------------------------------------------
fig, ax = sty.fig_horizontal()

bars_main = ax.bar(
    x,
    heights_main,
    yerr=std_counts_frac,
    capsize=3,
    color=colors_main,
    alpha=0.9,
    edgecolor="none",
    zorder=2,
)

bars_bridge = ax.bar(
    x,
    heights_bridge,
    bottom=heights_main,
    color=colors_bridge,
    alpha=0.9,
    edgecolor="none",
    zorder=3,
)

# ---------------------------------------------------------------
# 8. Labels
# ---------------------------------------------------------------
ax.set_xticks(x)
ax.set_xticklabels(categories)
ax.set_ylabel('Fraction of Ca ions (%)')
ax.set_title('Average Ca Coordination Distribution')
ax.set_ylim(0, 100)  # full 0–100% scale

ax.grid(True, axis='y', alpha=0.3, zorder=0)

# ---------------------------------------------------------------
# 9. Percentage labels
# ---------------------------------------------------------------
for i, (bar, total_frac) in enumerate(zip(bars_main, avg_counts_frac)):
    height = heights_main[i] + heights_bridge[i]
    ax.text(
        bar.get_x() + bar.get_width()/2.,
        height,
        f'{total_frac:.1f}%',
        ha='center',
        va='bottom',
        fontsize=sty.base_fontsize * 0.9,
    )

# Optional: label bridge fraction inside 2- and 3+-bond bars
for i, (bridge_val, total_frac) in enumerate(zip(heights_bridge, avg_counts_frac)):
    if bridge_val > 0 and total_frac > 0:
        frac = bridge_val / total_frac * 100.0 if total_frac > 0 else 0
        ax.text(
            x[i],
            heights_main[i] + bridge_val * 0.5,
            f'{frac:.0f}%',
            ha='center',
            va='center',
            fontsize=sty.base_fontsize * 0.75,
            color="white",
        )

# Legend
legend_elements = [
    Patch(facecolor=c_bridge, label='Fibre Bridges'),
]
ax.legend(handles=legend_elements, frameon=False, loc='upper left')
plt.show()
