In [19]:

import warnings
# Suppress the warning
warnings.filterwarnings("ignore", category=UserWarning, message="Failed to guess the mass for the following atom types")
from collections import Counter
import glob
import os
import MDAnalysis as mda
from tqdm.notebook import tqdm
import itertools
from MDAnalysis.analysis.base import (AnalysisBase,
                                      AnalysisFromFunction,
                                      analysis_class)
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
import pickle as pkl
import re
import seaborn as sns
import shutil
from scipy import ndimage

# import sparse

# parallelising mda
import multiprocessing as mp
from multiprocessing import Pool
from functools import partial

# MDA
from MDAnalysis.analysis.leaflet import LeafletFinder
from MDAnalysis.analysis import distances

from MDAnalysis.analysis import align
from MDAnalysis.analysis.rms import rmsd
from MDAnalysis.analysis.rdf import InterRDF
import MDAnalysis.analysis.msd as msd
import os
import logging
from scipy.interpolate import Rbf
import matplotlib.pyplot as plt
from tqdm import tqdm
import plotly.graph_objects as go

from collections import defaultdict # Used in calculate_scd_profile
from scipy.stats import binned_statistic_2d # Used in calculate_order_heatmap

In [3]:
import MDAnalysis as mda
import numpy as np
from tqdm import tqdm
from typing import Tuple, Optional

def calculate_xz_density_map(
    universe: mda.Universe,
    selection: str = 'name PO*',
    grid_size: int = 100,
    start_frame: int = 0,
    end_frame: Optional[int] = None,
    stride: int = 1,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Computes a 2D number density map in the XZ plane, averaged over frames.

    This function bins atom positions onto a 2D grid and normalizes by the
    area and number of frames to calculate number density (atoms/Å²).

    Note:
        This function assumes the box dimensions do not significantly fluctuate
        during the analyzed trajectory. The bin edges and final area normalization
        are based on the dimensions of the final frame in the selection.

    Args:
        universe (mda.Universe):
            MDAnalysis Universe containing the system and trajectory.
        selection (str):
            Atom selection string for the groups to be included in the density map.
        grid_size (int):
            The number of bins to use for each axis (X and Z).
        start_frame (int):
            The first frame for analysis. Defaults to the beginning.
        end_frame (Optional[int]):
            The final frame for analysis. Defaults to the end of the trajectory.
        stride (int):
            The step size between frames to analyze.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray]:
        - density_map (np.ndarray):
            A 2D array (grid_size x grid_size) of the number density in units of atoms/Å².
        - x_edges (np.ndarray):
            The array of bin edges along the X-axis.
        - z_edges (np.ndarray):
            The array of bin edges along the Z-axis.
    """
    atom_group = universe.select_atoms(selection)
    if not atom_group:
        raise ValueError(f"Selection '{selection}' resulted in an empty AtomGroup.")

    # Prepare trajectory slice and total number of frames for averaging
    trajectory_slice = universe.trajectory[start_frame:end_frame:stride]
    n_frames_analyzed = len(trajectory_slice)

    if n_frames_analyzed == 0:
        raise ValueError("The specified frame range resulted in zero frames to analyze.")

    # Initialize a 2D array to accumulate atom counts
    total_counts = np.zeros((grid_size, grid_size), dtype=np.int64)
    
    # Define bin edges based on the final frame's dimensions.
    # This assumes the box size is relatively constant.
    final_dims = universe.trajectory[-1].dimensions
    Lx, Lz = final_dims[0], final_dims[2]
    x_edges = np.linspace(0, Lx, grid_size + 1)
    z_edges = np.linspace(0, Lz, grid_size + 1)

    for ts in tqdm(trajectory_slice, desc="Calculating density map"):
        # Extract XZ positions of the selected atoms for the current frame
        pos_xz = atom_group.positions[:, [0, 2]]
        
        # Use np.histogram2d for efficient, vectorized binning.
        # This is much faster than looping in Python.
        counts, _, _ = np.histogram2d(
            pos_xz[:, 0], pos_xz[:, 1], bins=[x_edges, z_edges]
        )
        total_counts += counts

    # Calculate the area of a single grid cell
    area_per_bin = (Lx / grid_size) * (Lz / grid_size)
    
    # Normalize the accumulated counts by the number of frames and bin area
    # Convert total_counts to float for division
    density_map = total_counts.astype(float) / (n_frames_analyzed * area_per_bin)

    return density_map, x_edges, z_edges

In [4]:
import MDAnalysis as mda
import numpy as np
from tqdm import tqdm
from typing import Tuple, Optional

def count_lipids_by_midplane(
    universe: mda.Universe,
    lipid_selection: str = "name PO*",
    start_frame: int = 0,
    end_frame: Optional[int] = None,
    stride: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Counts lipids in upper/lower leaflets based on a simple Z-coordinate midplane.

    This method is fast but may be less accurate for highly curved membranes
    compared to algorithms like LeafletFinder. It defines the midplane for each
    frame as the mean Z-coordinate of the selected headgroup atoms.

    Args:
        universe (mda.Universe):
            An MDAnalysis Universe object.
        lipid_selection (str):
            Selection string for lipid headgroup atoms (e.g., 'resname POPG and name PO4').
            Defaults to "name PO*".
        start_frame (int):
            Starting frame for analysis. Defaults to 0.
        end_frame (Optional[int]):
            Ending frame for analysis. If None, analyzes to the end. Defaults to None.
        stride (int):
            Stride for trajectory iteration. Defaults to 1.

    Returns:
        Tuple[np.ndarray, np.ndarray]:
        - upper_counts (np.ndarray):
            Number of headgroups in the upper leaflet per analyzed frame.
        - lower_counts (np.ndarray):
            Number of headgroups in the lower leaflet per analyzed frame.
    """
    atom_group = universe.select_atoms(lipid_selection)
    if not atom_group:
        raise ValueError(f"Selection '{lipid_selection}' resulted in an empty AtomGroup.")

    # Prepare trajectory slice and get the number of frames to analyze
    trajectory_slice = universe.trajectory[start_frame:end_frame:stride]
    n_frames_analyzed = len(trajectory_slice)

    # Pre-allocate NumPy arrays for results for better performance than list.append()
    upper_counts = np.zeros(n_frames_analyzed, dtype=int)
    lower_counts = np.zeros(n_frames_analyzed, dtype=int)

    for i, ts in enumerate(tqdm(trajectory_slice, desc="Counting leaflets")):
        # Get Z-positions of headgroups for the current frame
        z_positions = atom_group.positions[:, 2]

        # It's unlikely for z_positions to be empty if the initial check passed,
        # but this is a safe check.
        if z_positions.size == 0:
            continue # Counts for this frame will remain 0

        # Compute the midplane as the average z-coordinate
        midplane_z = np.mean(z_positions)

        # Count atoms above (upper) and below/on (lower) the midplane
        upper_counts[i] = np.sum(z_positions > midplane_z)
        lower_counts[i] = np.sum(z_positions <= midplane_z)

    return upper_counts, lower_counts

In [5]:
import MDAnalysis as mda
import numpy as np
from tqdm import tqdm
from sklearn.cluster import DBSCAN
from typing import Tuple, Optional, List

def _find_leaflets_in_frame(
    atom_group: mda.AtomGroup,
    eps: float,
    min_samples: int,
    min_cluster_size_fraction: float,
) -> Tuple[float, float]:
    """Helper function to run DBSCAN on a single frame's coordinates."""
    
    # Need enough atoms for DBSCAN to be meaningful
    if atom_group.n_atoms <= min_samples:
        return np.nan, np.nan

    # Run DBSCAN on the 3D coordinates
    positions = atom_group.positions
    db = DBSCAN(eps=eps, min_samples=min_samples).fit(positions)
    labels = db.labels_  # Cluster labels for each point; -1 is noise

    # --- Identify and filter clusters ---
    # Count members of each cluster label (excluding noise label -1)
    unique_labels, counts = np.unique(labels[labels != -1], return_counts=True)
    
    # Minimum number of atoms for a cluster to be considered a leaflet
    min_size_threshold = atom_group.n_atoms * min_cluster_size_fraction

    # Filter for significant clusters and store them with their size
    significant_clusters = []
    for label, size in zip(unique_labels, counts):
        if size >= min_size_threshold:
            cluster_mask = (labels == label)
            significant_clusters.append(
                {'atoms': atom_group[cluster_mask], 'size': size}
            )

    # If we don't find at least two significant clusters, the frame is ambiguous
    if len(significant_clusters) < 2:
        return np.nan, np.nan

    # --- Process the two largest clusters ---
    # Sort clusters by size in descending order and take the two largest
    significant_clusters.sort(key=lambda c: c['size'], reverse=True)
    leaflet1 = significant_clusters[0]['atoms']
    leaflet2 = significant_clusters[1]['atoms']

    # Assign to upper/lower based on mean Z-coordinate
    mean_z1 = leaflet1.positions[:, 2].mean()
    mean_z2 = leaflet2.positions[:, 2].mean()

    upper_leaflet = leaflet1 if mean_z1 > mean_z2 else leaflet2
    lower_leaflet = leaflet2 if mean_z1 > mean_z2 else leaflet1

    # Count the number of unique residues in each identified leaflet
    upper_count = np.unique(upper_leaflet.resids).size
    lower_count = np.unique(lower_leaflet.resids).size
    
    return float(upper_count), float(lower_count)


def calculate_leaflet_counts_dbscan(
    universe: mda.Universe,
    lipid_selection: str = "name PO*",
    eps: float = 10.0,
    min_samples: int = 5,
    start_frame: int = 0,
    end_frame: Optional[int] = None,
    stride: int = 1,
    show_progress: bool = True,
    min_cluster_size_fraction: float = 0.1,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Calculates lipid counts in leaflets using 3D DBSCAN clustering.

    This method identifies the two largest dense clusters of headgroups,
    assigning them to leaflets. It is robust for curved or disrupted membranes.
    Frames where two clear leaflets cannot be found are marked with np.nan.

    Args:
        universe (mda.Universe): An MDAnalysis Universe object.
        lipid_selection (str): Selection for lipid headgroup atoms (e.g., 'name PO4').
        eps (float): Max distance (Å) between samples for DBSCAN. Tune for your system.
        min_samples (int): Min samples in a neighborhood for DBSCAN. Tune for your system.
        start_frame (int): Starting frame for analysis.
        end_frame (Optional[int]): Ending frame for analysis. Defaults to the end.
        stride (int): Stride for trajectory iteration.
        show_progress (bool): If True, shows a tqdm progress bar.
        min_cluster_size_fraction (float): A cluster must contain this fraction of
                                           total atoms to be considered a leaflet (0.0 to 1.0).

    Returns:
        Tuple[np.ndarray, np.ndarray]:
        - upper_counts (np.ndarray): Count of unique residues in the upper leaflet per frame.
        - lower_counts (np.ndarray): Count of unique residues in the lower leaflet per frame.
                                     Values are np.nan for ambiguous frames.
    """
    atom_group = universe.select_atoms(lipid_selection)
    if not atom_group:
        raise ValueError(f"Selection '{lipid_selection}' resulted in an empty AtomGroup.")

    # Prepare trajectory slice
    trajectory_slice = universe.trajectory[start_frame:end_frame:stride]
    n_frames = len(trajectory_slice)

    # Pre-allocate result arrays with np.nan as the default
    upper_counts = np.full(n_frames, np.nan)
    lower_counts = np.full(n_frames, np.nan)

    # Set up progress bar iterator
    iterator = tqdm(
        enumerate(trajectory_slice),
        total=n_frames,
        disable=not show_progress,
        desc="DBSCAN Leaflet Analysis"
    )

    for i, ts in iterator:
        # The core logic is now in a clean helper function
        upper_counts[i], lower_counts[i] = _find_leaflets_in_frame(
            atom_group, eps, min_samples, min_cluster_size_fraction
        )
        
    return upper_counts, lower_counts

In [11]:
import MDAnalysis as mda
import numpy as np
from tqdm import tqdm
from typing import Tuple, Optional

def calculate_lipid_tilt_angle(
    universe: mda.Universe,
    resid: int,
    atom1_name: str,
    atom2_name: str,
    resname: Optional[str] = None,
    membrane_normal: np.ndarray = np.array([0, 0, 1]),
    start_frame: int = 0,
    end_frame: Optional[int] = None,
    stride: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Calculates the tilt angle of a specific lipid over time.

    The tilt angle is defined as the angle between a vector defined by two
    atoms within the lipid and the normal vector of the membrane.

    Args:
        universe (mda.Universe): The MDAnalysis Universe to be analyzed.
        resid (int): The residue ID of the specific lipid to track.
        atom1_name (str): The name of the atom defining the start of the lipid vector.
        atom2_name (str): The name of the atom defining the end of the lipid vector.
        resname (Optional[str]): The residue name of the lipid. Including this makes the
                                 selection more specific and robust. Defaults to None.
        membrane_normal (np.ndarray): The vector normal to the membrane plane.
                                      Defaults to the Z-axis [0, 0, 1].
        start_frame (int): The first frame for analysis.
        end_frame (Optional[int]): The final frame for analysis. Defaults to the end.
        stride (int): The step size between frames to analyze.

    Returns:
        Tuple[np.ndarray, np.ndarray]:
        - times (np.ndarray): An array of time points in picoseconds.
        - tilt_angles (np.ndarray): The corresponding tilt angles in degrees.
    """
    # 1. Build a robust selection string for the specific lipid
    sel_parts = [f"resid {resid}"]
    if resname:
        sel_parts.append(f"resname {resname}")
    selection_string = " and ".join(sel_parts)
    
    lipid_ag = universe.select_atoms(selection_string)
    if not lipid_ag:
        raise ValueError(f"Lipid not found with selection: '{selection_string}'")

    # Select the two atoms from within the lipid AtomGroup
    atom1 = lipid_ag.select_atoms(f"name {atom1_name}")
    atom2 = lipid_ag.select_atoms(f"name {atom2_name}")
    
    if not atom1 or not atom2:
        found_names = list(np.unique(lipid_ag.names))
        raise ValueError(
            f"Could not find atoms '{atom1_name}' or '{atom2_name}' in the selected lipid. "
            f"Available atom names: {found_names}"
        )
    
    # 2. Prepare for trajectory iteration
    trajectory_slice = universe.trajectory[start_frame:end_frame:stride]
    n_frames = len(trajectory_slice)
    
    # Pre-allocate arrays for results
    times = np.zeros(n_frames)
    tilt_angles = np.zeros(n_frames)
    
    # Normalize the membrane normal once outside the loop
    norm_membrane_normal = np.linalg.norm(membrane_normal)
    if norm_membrane_normal == 0:
        raise ValueError("Membrane normal vector cannot be a zero vector.")
    
    # 3. Iterate through the trajectory and perform the calculation
    for i, ts in enumerate(tqdm(trajectory_slice, desc=f"Tracking tilt for resid {resid}")):
        times[i] = ts.time
        
        lipid_vector = atom2.positions[0] - atom1.positions[0] # Use .positions to get updated coords
        norm_lipid_vector = np.linalg.norm(lipid_vector)

        if norm_lipid_vector == 0:
            # This case is unlikely but safe to handle. Angle is undefined, 0 is a neutral choice.
            angle_rad = 0.0
        else:
            dot_product = np.dot(lipid_vector, membrane_normal)
            cos_angle = dot_product / (norm_lipid_vector * norm_membrane_normal)
            
            # Use np.clip for floating point safety, as in your original script
            angle_rad = np.arccos(np.clip(cos_angle, -1.0, 1.0))
        
        tilt_angles[i] = np.rad2deg(angle_rad)
        
    return times, tilt_angles


# --- 1. Configuration ---
# All your settings are now clear arguments for the function.
#FLIPPING_LIPID_RESID = 946
#RESIDUE_NAME = "POPG"
#ATOM_1 = "C3A"  # Start of vector (tail)
#ATOM_2 = "PO4"  # End of vector (head)

# Assume your Universe 'para_38_rep1' is already loaded
# u = para_38_rep1 

# --- 2. Run Analysis ---
# Call the function with your parameters. It handles all the setup,
# looping, and error checking internally.
#try:
#    times, tilt_angles = calculate_lipid_tilt_angle(
#        universe=u,
#        resid=FLIPPING_LIPID_RESID,
#        resname=RESIDUE_NAME,
#        atom1_name=ATOM_1,
#        atom2_name=ATOM_2,
#    )
    
#    print("Analysis finished successfully.")
#    # You can now proceed to plot `times` vs. `tilt_angles`
    
#except ValueError as e:
#    print(f"An error occurred: {e}")

In [None]:
import numpy as np
import MDAnalysis as mda
from MDAnalysis.analysis import contacts
import multiprocessing
from tqdm.notebook import tqdm

def _worker_contact_calc(frames, top_path, traj_path, sel_a_str, sel_b_str, radius):
    """
    Private worker function for multiprocessing. Each worker process executes this function
    on a subset of the trajectory frames.
    """
    u = mda.Universe(top_path, traj_path)
    group_a = u.select_atoms(sel_a_str)
    group_b = u.select_atoms(sel_b_str)
    
    timeseries_chunk = []
    for frame_idx in frames:
        u.trajectory[frame_idx]
        dist = contacts.distance_array(group_a.positions, group_b.positions)
        n_contacts = contacts.contact_matrix(dist, radius).sum()
        timeseries_chunk.append([u.trajectory.time, n_contacts])
    return timeseries_chunk

def parallel_contacts_within_cutoff(u, sel_a_str, sel_b_str, radius=5.0, step=1, n_cores=None):
    """
    Calculates the number of contacts between two atom groups over time in parallel.

    This function splits the trajectory analysis across multiple CPU cores to
    accelerate the calculation. It is particularly useful for long trajectories
    where serial processing would be time-consuming.

    Args:
        u (MDAnalysis.Universe): The Universe object containing the trajectory.
        sel_a_str (str): The selection string for the first group of atoms (e.g., 'resname DEN').
        sel_b_str (str): The selection string for the second group of atoms (e.g., 'resid 1-642 and name P*').
        radius (float, optional): The cutoff distance in Angstroms to define a contact. Defaults to 5.0.
        step (int, optional): The step size for trajectory iteration. 1 iterates over all
                              frames, N iterates over every Nth frame. Defaults to 1.
        n_cores (int, optional): The number of CPU cores to use. If None, it defaults to
                                 all available cores minus one. Defaults to None.

    Returns:
        numpy.ndarray: A 2D array of shape (n_frames, 2), where the columns are
                       [Time (ps), Number of Contacts].
    """
    if n_cores is None:
        n_cores = multiprocessing.cpu_count() - 1
        if n_cores < 1: n_cores = 1
    
    # Split the list of frames into chunks for each worker
    all_frame_indices = np.arange(0, u.trajectory.n_frames, step)
    frame_chunks = np.array_split(all_frame_indices, n_cores)

    # Prepare arguments for each worker, passing file paths and selection strings
    worker_args = [(chunk, u.filename, u.trajectory.filename, sel_a_str, sel_b_str, radius) for chunk in frame_chunks]

    # Create the multiprocessing pool and run the analysis
    print(f"Starting parallel contact calculation on {n_cores} cores...")
    with multiprocessing.Pool(processes=n_cores) as pool:
        results_list = list(tqdm(pool.starmap(_worker_contact_calc, worker_args), total=len(frame_chunks)))
    
    # Collect, flatten, and sort the results
    timeseries = [item for sublist in results_list for item in sublist]
    timeseries_arr = np.array(timeseries)
    return timeseries_arr[timeseries_arr[:, 0].argsort()]

# Example usage:
# Make sure to replace 'system.tpr' and 'system.xtc' with your actual
# topology and trajectory files.
# Note: The Universe object 'u' should be created with your specific files.
# Uncomment the following lines to run the analysis:
#Load your universe
#u = mda.Universe("system.tpr", "system.xtc")

# Define selection strings
#dendrimer_sel = "resname DEN"
#outer_leaflet_sel = "resid 2-642 and name P* NH*"

# Run the analysis using 4 cores
#contact_data = parallel_contacts_within_cutoff(u, dendrimer_sel, outer_leaflet_sel, radius=5.0, n_cores=4)

#print(contact_data)

# Suggested addition to the docstring's Returns section:
# "The final array is sorted by the time column."

In [20]:
import MDAnalysis as mda
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
from scipy.ndimage import gaussian_filter
from mpl_toolkits.axes_grid1 import make_axes_locatable
from tqdm.notebook import tqdm # Use tqdm.notebook for a nice progress bar in Jupyter
import MDAnalysis as mda
import numpy as np
from scipy.interpolate import griddata
from scipy.ndimage import gaussian_filter
from tqdm.notebook import tqdm
from typing import Tuple

def calculate_mean_curvature(
    universe: mda.Universe,
    selection_string: str,
    traj_slice: slice = slice(-101, None),
    grid_size: float = 1.0,
    rbf_function: str = 'cubic'
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Calculates the time-averaged membrane curvature from atom positions.

    This function fits a surface to the selected atoms for each frame,
    calculates the mean curvature (H) across a 2D grid, and then
    averages this grid over the specified trajectory slice. It robustly
    handles periodic boundary conditions by tiling the system coordinates.

    Args:
        universe (mda.Universe): The MDAnalysis Universe object.
        selection_string (str): Selection for atoms to define the surface (e.g., 'name PO4').
        traj_slice (slice): Slice object for the trajectory frames to analyze.
        grid_size (float): The resolution of the 2D grid for interpolation in Angstroms.
        rbf_function (str): The radial basis function to use for interpolation.
                             Options include 'cubic', 'linear', 'thin_plate_spline'.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray]:
        - mean_curvature (np.ndarray): The 2D array of the mean curvature.
        - grid_x (np.ndarray): The X-coordinates of the grid.
        - grid_y (np.ndarray): The Y-coordinates of the grid.
    """
    selected_atoms = universe.select_atoms(selection_string)
    if not selected_atoms:
        raise ValueError(f"No atoms found with selection: '{selection_string}'")

    if universe.dimensions is None:
        raise ValueError("Universe dimensions are not defined.")
    Lx, Ly, _ = universe.dimensions

    grid_x, grid_y = np.mgrid[0:Lx:grid_size, 0:Ly:grid_size]
    curvature_accumulator = np.zeros_like(grid_x)
    valid_frame_count = 0

    traj_to_analyze = universe.trajectory[traj_slice]
    if len(traj_to_analyze) == 0:
        raise ValueError("No frames found in the specified trajectory slice.")

    for ts in tqdm(traj_to_analyze, desc="Calculating curvature"):
        positions = selected_atoms.positions
        x, y, z = positions[:, 0], positions[:, 1], positions[:, 2]

        # Create a 3x3 tiled set of coordinates to handle PBCs
        x_padded, y_padded, z_padded = [], [], []
        for shift_x in [-Lx, 0, Lx]:
            for shift_y in [-Ly, 0, Ly]:
                x_padded.extend(x + shift_x)
                y_padded.extend(y + shift_y)
                z_padded.extend(z)

        # Interpolate the surface using the padded coordinates
        grid_z = griddata(
            (np.array(x_padded), np.array(y_padded)),
            np.array(z_padded),
            (grid_x, grid_y),
            method=rbf_function
        )

        # Skip frames where interpolation fails
        if np.all(np.isnan(grid_z)):
            continue

        # Smooth surface and calculate derivatives
        grid_z_filtered = gaussian_filter(grid_z, sigma=1.0, mode='wrap')
        dzdx, dzdy = np.gradient(grid_z_filtered)
        d2zdx2, d2zdy2 = np.gradient(dzdx, axis=0), np.gradient(dzdy, axis=1)
        d2zdxdy = np.gradient(dzdx, axis=1)

        # Calculate mean curvature (H)
        numerator = (1 + dzdy**2) * d2zdx2 - 2 * dzdx * dzdy * d2zdxdy + (1 + dzdx**2) * d2zdy2
        denominator = 2 * (1 + dzdx**2 + dzdy**2)**1.5
        
        H = np.full_like(grid_z_filtered, np.nan)
        valid_denom = np.abs(denominator) > 1e-9 # Avoid division by zero
        H[valid_denom] = numerator[valid_denom] / denominator[valid_denom]
        
        # Accumulate valid curvature values
        curvature_accumulator = np.nansum(np.dstack([curvature_accumulator, H]), axis=2)
        valid_frame_count += 1
    
    if valid_frame_count == 0:
        raise ValueError("Curvature could not be calculated for any frames.")

    mean_curvature = curvature_accumulator / valid_frame_count
    return mean_curvature, grid_x, grid_y


import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

def plot_curvature_map(
    curvature_grid: np.ndarray,
    grid_x: np.ndarray,
    grid_y: np.ndarray,
    vmin: float = -0.03,
    vmax: float = 0.03,
    cmap: str = 'RdBu_r'
):
    """
    Plots a 2D membrane curvature map.

    Args:
        curvature_grid (np.ndarray): 2D array of the mean curvature values.
        grid_x (np.ndarray): The X-coordinates of the grid.
        grid_y (np.ndarray): The Y-coordinates of the grid.
        vmin (float): Minimum value for the color scale.
        vmax (float): Maximum value for the color scale.
        cmap (str): The matplotlib colormap to use.

    Returns:
        tuple: A tuple containing the matplotlib Figure and Axes objects (fig, ax).
    """
    fig, ax = plt.subplots(figsize=(7, 6))
    
    extent = [grid_x.min(), grid_x.max(), grid_y.min(), grid_y.max()]
    im = ax.imshow(
        curvature_grid.T,
        extent=extent,
        origin='lower',
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        interpolation='bicubic'
    )
    
    ax.set_xlabel('X-axis (Å)')
    ax.set_ylabel('Y-axis (Å)')
    ax.set_aspect('equal', adjustable='box')
    
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1)
    cbar = fig.colorbar(im, cax=cax)
    cbar.set_label('Mean Curvature (Å⁻¹)')
    
    fig.tight_layout()
    return fig, ax


# --- Step 1: Load your Universe ---
# u = mda.Universe("path/to/topology.tpr", "path/to/trajectory.xtc")

# --- Step 2: Calculate the mean curvature ---
# This step performs the analysis and returns the raw data.
#try:
#    mean_curv_data, x_grid, y_grid = calculate_mean_curvature(
#        universe=u,
#        selection_string='resid 2-641 and name PO4', # e.g., upper leaflet phosphates
#        traj_slice=slice(-101, None) # Last 101 frames
#    )
#    
#    # --- Step 3: Plot the results ---
#    # Pass the data to the plotting function.
#    fig, ax = plot_curvature_map(
#        mean_curv_data,
#        x_grid,
#        y_grid,
#        vmin=-0.04, # Custom color range
#        vmax=0.04
#    )
#    ax.set_title("Upper Leaflet Mean Curvature")
#    plt.show()

#except ValueError as e:
#    print(f"An error occurred: {e}")

In [14]:
import MDAnalysis as mda
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
# A library mapping Martini lipid names to a complete list of their bonds.
MARTINI_LIPID_BONDS = {
    # Base components
    '_PC_HEAD': [('NC3', 'PO4'), ('PO4', 'GL1'), ('GL1', 'GL2')],
    '_PE_HEAD': [('NH3', 'PO4'), ('PO4', 'GL1'), ('GL1', 'GL2')],
    '_PG_HEAD': [('GL0', 'PO4'), ('PO4', 'GL1'), ('GL1', 'GL2')],
    '_CDL_CORE': [('GL0', 'PO41'), ('GL0', 'PO42'), ('PO41', 'GL11'), ('PO42', 'GL12'), ('GL11', 'GL21'), ('GL12', 'GL22')],
    '_LPS_CHAINS': [('GL2', 'C1A'), ('C1A', 'C2A'), ('C2A', 'C3A'), ('GL1', 'C1B'), ('C1B', 'C2B'), ('C1B', 'C3B'), ('GL3', 'C1D'), ('C1D', 'C2D'), ('C2D', 'C3D'), ('GL4', 'C1C'), ('C1C', 'C2C'), ('C2C', 'C3C'), ('GL6', 'C1E'), ('C1E', 'C2E'), ('GL8', 'C1F'), ('C1F', 'C2F')],
    # Tail definitions
    '_TAIL_DAPC': [('GL1', 'D1A'), ('D1A', 'D2A'), ('D2A', 'D3A'), ('D3A', 'D4A'), ('D4A', 'C5A'), ('GL2', 'D1B'), ('D1B', 'D2B'), ('D2B', 'D3B'), ('D3B', 'D4B'), ('D4B', 'C5B')],
    '_TAIL_DL':   [('GL1', 'C1A'), ('C1A', 'C2A'), ('C2A', 'C3A'), ('GL2', 'C1B'), ('C1B', 'C2B'), ('C2B', 'C3B')],
    '_TAIL_DO':   [('GL1', 'C1A'), ('C1A', 'D2A'), ('D2A', 'C3A'), ('C3A', 'C4A'), ('GL2', 'C1B'), ('C1B', 'D2B'), ('D2B', 'C3B'), ('C3B', 'C4B')],
    '_TAIL_DP':   [('GL1', 'C1A'), ('C1A', 'C2A'), ('C2A', 'C3A'), ('C3A', 'C4A'), ('GL2', 'C1B'), ('C1B', 'C2B'), ('C2B', 'C3B'), ('C3B', 'C4B')],
    '_TAIL_PO':   [('GL1', 'C1A'), ('C1A', 'D2A'), ('D2A', 'C3A'), ('C3A', 'C4A'), ('GL2', 'C1B'), ('C1B', 'C2B'), ('C2B', 'C3B'), ('C3B', 'C4B')],
    '_TAIL_CDL2': [('GL11', 'C1A1'), ('C1A1', 'C2A1'), ('C2A1', 'D3A1'), ('D3A1', 'C4A1'), ('C4A1', 'C5A1'), ('GL21', 'C1B1'), ('C1B1', 'C2B1'), ('C2B1', 'D3B1'), ('D3B1', 'C4B1'), ('C4B1', 'C5B1'), ('GL12', 'C1A2'), ('C1A2', 'C2A2'), ('C2A2', 'D3A2'), ('D3A2', 'C4A2'), ('C4A2', 'C5A2'), ('GL22', 'C1B2'), ('C1B2', 'C2B2'), ('C2B2', 'D3B2'), ('D3B2', 'C4B2'), ('C4B2', 'C5B2')],
}
# Combine base components to form full lipid definitions
MARTINI_LIPID_BONDS.update({
    'DAPC': MARTINI_LIPID_BONDS['_PC_HEAD'] + MARTINI_LIPID_BONDS['_TAIL_DAPC'], 'DLPC': MARTINI_LIPID_BONDS['_PC_HEAD'] + MARTINI_LIPID_BONDS['_TAIL_DL'],
    'DOPC': MARTINI_LIPID_BONDS['_PC_HEAD'] + MARTINI_LIPID_BONDS['_TAIL_DO'], 'DPPC': MARTINI_LIPID_BONDS['_PC_HEAD'] + MARTINI_LIPID_BONDS['_TAIL_DP'],
    'POPC': MARTINI_LIPID_BONDS['_PC_HEAD'] + MARTINI_LIPID_BONDS['_TAIL_PO'],
    'DAPE': MARTINI_LIPID_BONDS['_PE_HEAD'] + MARTINI_LIPID_BONDS['_TAIL_DAPC'], 'DLPE': MARTINI_LIPID_BONDS['_PE_HEAD'] + MARTINI_LIPID_BONDS['_TAIL_DL'],
    'DOPE': MARTINI_LIPID_BONDS['_PE_HEAD'] + MARTINI_LIPID_BONDS['_TAIL_DO'], 'DPPE': MARTINI_LIPID_BONDS['_PE_HEAD'] + MARTINI_LIPID_BONDS['_TAIL_DP'],
    'POPE': MARTINI_LIPID_BONDS['_PE_HEAD'] + MARTINI_LIPID_BONDS['_TAIL_PO'],
    'POPG': MARTINI_LIPID_BONDS['_PG_HEAD'] + MARTINI_LIPID_BONDS['_TAIL_PO'], 'DOPG': MARTINI_LIPID_BONDS['_PG_HEAD'] + MARTINI_LIPID_BONDS['_TAIL_DO'],
    'DLPG': MARTINI_LIPID_BONDS['_PG_HEAD'] + MARTINI_LIPID_BONDS['_TAIL_DL'],
    'CDL2': MARTINI_LIPID_BONDS['_CDL_CORE'] + MARTINI_LIPID_BONDS['_TAIL_CDL2'],
    'RAMP': MARTINI_LIPID_BONDS['_LPS_CHAINS'], 'REMP': MARTINI_LIPID_BONDS['_LPS_CHAINS'],
})


def calculate_scd_profile(u: mda.Universe, lipid_types: list, traj_slice: slice = slice(-101, None)):
    """
    Calculates the bond order parameter (S_CD) for each bond in specified lipid types.
    """
    results = defaultdict(list)
    bilayer_normal = np.array([0, 0, 1])

    selection_str = " or ".join([f"resname {name}" for name in lipid_types])
    lipid_residues = u.select_atoms(selection_str).residues

    print(f"Analyzing {len(u.trajectory[traj_slice])} frames for {len(lipid_residues)} lipids...")
    for ts in tqdm(u.trajectory[traj_slice], desc="Processing Frames"):
        for res in lipid_residues:
            bond_list = MARTINI_LIPID_BONDS.get(res.resname, [])
            for atom1_name, atom2_name in bond_list:
                try:
                    atom1 = res.atoms.select_atoms(f'name {atom1_name}')[0]
                    atom2 = res.atoms.select_atoms(f'name {atom2_name}')[0]
                    
                    vector = atom2.position - atom1.position
                    norm = np.linalg.norm(vector)
                    if norm == 0: continue
                    
                    cos_theta = np.dot(vector / norm, bilayer_normal)
                    scd = 0.5 * (3 * cos_theta**2 - 1)
                    results[f"{atom1_name}-{atom2_name}"].append(scd)
                except IndexError:
                    # This atom pair doesn't exist in this specific residue, skip
                    continue
    
    # Average the results
    avg_scd = {bond: np.mean(values) for bond, values in results.items()}
    profile_df = pd.DataFrame(list(avg_scd.items()), columns=['bond', 'S_CD'])
    
    return profile_df

# Assuming 'u' is your loaded Universe
# my_lipids_for_profile = ['POPE', 'POPG'] 
# scd_profile = calculate_scd_profile(u, lipid_types=my_lipids_for_profile)
# print(scd_profile)

In [21]:
import MDAnalysis as mda
import numpy as np
import pandas as pd
from scipy.stats import binned_statistic_2d
from tqdm.notebook import tqdm
from typing import List, Dict

# This data library is required by the function
MARTINI_WHOLE_TAIL_VECTORS = {
    # Phosphatidylcholines (PCs)
    'DAPC': [('GL1', 'C5A'), ('GL2', 'C5B')],
    'DLPC': [('GL1', 'C3A'), ('GL2', 'C3B')],
    'DOPC': [('GL1', 'C4A'), ('GL2', 'C4B')],
    'DPPC': [('GL1', 'C4A'), ('GL2', 'C4B')],
    'POPC': [('GL1', 'C4A'), ('GL2', 'C4B')],
    # Phosphatidylethanolamines (PEs)
    'DAPE': [('GL1', 'C5A'), ('GL2', 'C5B')],
    'DLPE': [('GL1', 'C3A'), ('GL2', 'C3B')],
    'DOPE': [('GL1', 'C4A'), ('GL2', 'C4B')],
    'DPPE': [('GL1', 'C4A'), ('GL2', 'C4B')],
    'POPE': [('GL1', 'C4A'), ('GL2', 'C4B')],
    # Phosphatidylglycerols (PGs)
    'POPG': [('GL1', 'C4A'), ('GL2', 'C4B')],
    'DOPG': [('GL1', 'C4A'), ('GL2', 'C4B')],
    'DLPG': [('GL1', 'C3A'), ('GL2', 'C3B')],
    # Cardiolipin
    'CDL2': [('GL11', 'C5A1'), ('GL21', 'C5B1'), ('GL12', 'C5A2'), ('GL22', 'C5B2')],
}

def calculate_order_heatmap(
    u: mda.Universe,
    lipid_types: List[str],
    grid_size: int = 50,
    traj_slice: slice = slice(-101, None)
) -> Dict:
    """
    Calculates a whole-tail lipid order parameter (S) and its 2D spatial map.

    Args:
        u (mda.Universe): The MDAnalysis Universe object.
        lipid_types (List[str]): List of lipid residue names (e.g., ['POPE', 'CDL2']).
        grid_size (int): Number of bins for the 2D spatial grid.
        traj_slice (slice): The slice of trajectory frames to analyze.

    Returns:
        Dict: A dictionary containing the results:
              - 'grid' (np.ndarray): 2D array of the spatially averaged order parameter.
              - 'x_edges' (np.ndarray): Bin edges for the x-axis.
              - 'y_edges' (np.ndarray): Bin edges for the y-axis.
              - 'averages' (dict): Mean order parameter for each unique vector type.
    """
    vector_definitions = {
        lipid: MARTINI_WHOLE_TAIL_VECTORS.get(lipid)
        for lipid in lipid_types if lipid in MARTINI_WHOLE_TAIL_VECTORS
    }
    if not any(vector_definitions.values()):
        raise ValueError("None of the lipid types were found in the library.")

    all_results = []
    bilayer_normal = np.array([0, 0, 1])

    # Build a single selection for all relevant lipids
    selection_str = " or ".join([f"resname {res}" for res in vector_definitions])
    lipids = u.select_atoms(selection_str)

    traj_to_analyze = u.trajectory[traj_slice]
    print(f"Analyzing {len(traj_to_analyze)} frames...")
    
    for ts in tqdm(traj_to_analyze, desc="Processing for Heatmap"):
        box_dims = ts.dimensions
        # Robust looping: iterate through each lipid residue individually
        for res in lipids.residues:
            vectors_to_calc = vector_definitions.get(res.resname)
            if not vectors_to_calc:
                continue

            for atom1_name, atom2_name in vectors_to_calc:
                try:
                    # Select atoms from within the current residue
                    atom1 = res.atoms.select_atoms(f'name {atom1_name}')[0]
                    atom2 = res.atoms.select_atoms(f'name {atom2_name}')[0]
                except IndexError:
                    continue # Skip if this specific atom pair isn't in this residue

                vector = atom2.position - atom1.position
                norm = np.linalg.norm(vector)
                if norm == 0:
                    continue
                
                cos_theta = np.dot(vector / norm, bilayer_normal)
                order_param = 0.5 * (3 * cos_theta**2 - 1)
                
                # Use COM of the residue for spatial positioning
                pos = res.atoms.center_of_mass(pbc=True)
                all_results.append({
                    'X': pos[0], 'Y': pos[1], 'S': order_param,
                    'vector': f"{res.resname}:{atom1_name}-{atom2_name}"
                })

    if not all_results:
        print("Warning: No valid data was generated. Check selections and frame range.")
        return None

    df = pd.DataFrame(all_results)
    x_bins = np.linspace(0, box_dims[0], grid_size + 1)
    y_bins = np.linspace(0, box_dims[1], grid_size + 1)
    
    # Create the 2D grid for the heatmap
    statistic, x_edge, y_edge, _ = binned_statistic_2d(
        df['X'], df['Y'], df['S'], statistic='mean', bins=[x_bins, y_bins]
    )
    
    # Calculate the overall average S for each vector type
    average_s_by_vector = df.groupby('vector')['S'].mean().to_dict()

    print("Analysis complete.")
    return {
        'grid': np.nan_to_num(statistic.T), # Transpose for plotting and fill NaNs
        'x_edges': x_edge,
        'y_edges': y_edge,
        'averages': average_s_by_vector
    }
    
    
    # --- Step 1: Load your simulation ---
# u = mda.Universe("path/to/topology.tpr", "path/to/trajectory.xtc")

# --- Step 2: Define lipids to analyze ---
# my_lipids = ['POPE', 'POPG', 'CDL2'] 

# --- Step 3: Run the analysis ---
# results = calculate_order_heatmap(u, lipid_types=my_lipids)

# --- Step 4: Inspect and plot the results ---
# if results:
#     print("\nAverage Order Parameters:")
#     for vector, avg_s in results['averages'].items():
#         print(f"  {vector}: {avg_s:.3f}")

#     # (Insert plotting code here as needed)

In [17]:
import MDAnalysis as mda
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import Rbf
from mpl_toolkits.axes_grid1 import make_axes_locatable
from tqdm.notebook import tqdm

def calculate_membrane_thickness(
    u: mda.Universe,
    phos_sel: str = "name PO4*",
    traj_slice: slice = slice(-101, None),
    grid_points: int = 100
):
    """
    Calculates the time-averaged membrane thickness using RBF interpolation.

    For each frame in the trajectory slice, this function:
    1. Selects phosphate atoms and splits them into upper/lower leaflets.
    2. Handles periodic boundaries by creating a 3x3 tiled system of coordinates.
    3. Fits a smooth surface to each leaflet using Radial Basis Functions (RBF).
    4. Calculates the thickness as the difference between the two surfaces on a grid.
    
    The final result is the average thickness map over all analyzed frames.

    Args:
        u (mda.Universe): The MDAnalysis Universe object.
        phos_sel (str): Atom selection string for the phosphate headgroups.
        traj_slice (slice): The slice of trajectory frames to analyze.
        grid_points (int): Number of grid points along each dimension for the map.

    Returns:
        dict: A dictionary containing the results:
              - 'grid' (np.ndarray): The 2D array of the time-averaged thickness.
              - 'x_edges' (np.ndarray): Bin edges for the x-axis.
              - 'y_edges' (np.ndarray): Bin edges for the y-axis.
              - 'average_thickness' (float): The overall mean thickness value.
    """
    phos_atoms = u.select_atoms(phos_sel)
    if not phos_atoms:
        raise ValueError(f"No phosphate atoms found with selection: '{phos_sel}'")

    # Initialize accumulator for time-averaging
    traj_to_analyze = u.trajectory[traj_slice]
    num_frames = len(traj_to_analyze)
    box_dims = u.dimensions
    Lx, Ly = box_dims[0], box_dims[1]
    thickness_accumulator = np.zeros((grid_points, grid_points))
    
    print(f"Analyzing {num_frames} frames...")
    for ts in tqdm(traj_to_analyze, desc="Calculating thickness"):
        # Split into leaflets for the current frame
        pos = phos_atoms.positions
        median_z = np.median(pos[:, 2])
        lower_leaflet_pos = pos[pos[:, 2] < median_z]
        upper_leaflet_pos = pos[pos[:, 2] >= median_z]

        if len(lower_leaflet_pos) < 4 or len(upper_leaflet_pos) < 4:
            print(f"Skipping frame {ts.frame}: insufficient atoms in a leaflet.")
            continue

        # --- Fit surfaces with PBC handling ---
        def fit_pbc_surface(positions):
            x, y, z = positions[:, 0], positions[:, 1], positions[:, 2]
            # Create 3x3 tiled grid of points for PBC
            x_padded, y_padded, z_padded = [], [], []
            for shift_x in [-Lx, 0, Lx]:
                for shift_y in [-Ly, 0, Ly]:
                    x_padded.extend(x + shift_x)
                    y_padded.extend(y + shift_y)
                    z_padded.extend(z)
            # Fit RBF on the tiled data
            return Rbf(x_padded, y_padded, z_padded, function='thin_plate', epsilon=1.0)

        upper_surface = fit_pbc_surface(upper_leaflet_pos)
        lower_surface = fit_pbc_surface(lower_leaflet_pos)
        
        # Create grid based on box dimensions
        x_range = np.linspace(0, Lx, grid_points)
        y_range = np.linspace(0, Ly, grid_points)
        x_grid, y_grid = np.meshgrid(x_range, y_range)

        # Evaluate surfaces and calculate thickness for this frame
        z_upper = upper_surface(x_grid, y_grid)
        z_lower = lower_surface(x_grid, y_grid)
        thickness_grid = z_upper - z_lower
        
        thickness_accumulator += thickness_grid

    # Finalize the average
    mean_thickness_grid = thickness_accumulator / num_frames
    overall_avg_thickness = np.mean(mean_thickness_grid)
    
    print("Analysis complete.")
    return {
        'grid': mean_thickness_grid,
        'x_edges': x_range,
        'y_edges': y_range,
        'average_thickness': overall_avg_thickness
    }
    
    
# # --- Step 1: Load your simulation data ---
# # Replace the file paths with your own.
# u = mda.Universe("path/to/topology.tpr", "path/to/trajectory.xtc")

# # --- Step 2: Run the analysis ---
# # Analyze the last 101 frames using the default phosphate selection.
# thickness_results = calculate_membrane_thickness(u)

# # --- Step 3: Inspect and plot the results ---
# if thickness_results:
#     # Print the overall average thickness
#     avg_thick = thickness_results['average_thickness']
#     print(f"\nOverall average membrane thickness: {avg_thick:.2f} Å")

#     # Create the heatmap plot
#     fig, ax = plt.subplots(figsize=(7, 6))
#     im = ax.imshow(
#         thickness_results['grid'].T, # Transpose for correct orientation
#         origin='lower',
#         cmap='magma',
#         extent=[
#             thickness_results['x_edges'][0], thickness_results['x_edges'][-1],
#             thickness_results['y_edges'][0], thickness_results['y_edges'][-1]
#         ]
#     )
#     ax.set_xlabel("X (Å)")
#     ax.set_ylabel("Y (Å)")
#     ax.set_title("Time-Averaged Membrane Thickness")
#     ax.set_aspect('equal', adjustable='box')
    
#     # Add a colorbar
#     divider = make_axes_locatable(ax)
#     cax = divider.append_axes("right", size="5%", pad=0.1)
#     cbar = fig.colorbar(im, cax=cax)
#     cbar.set_label("Thickness (Å)")
    
#     plt.tight_layout()
#     plt.show()