In [2]:
import numpy as np
import MDAnalysis as mda
from gridData import Grid

from MDAnalysis.analysis.distances import distance_array

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


import seaborn as sns
import glob
import scipy
import pandas as pd

from sklearn.cross_decomposition import PLSRegression
from OPLS_MD import OPLS, OPLS_PLS, PLS
from sklearn.model_selection import train_test_split

import nglview as nv

from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
from sklearn.decomposition import PCA
from sklearn.cluster import DBSCAN
import itertools
import imageio


from scipy.linalg import lstsq




# Density map calculation functions

In [27]:

def check_centering(universe):
    """checks protein centering in trajectory"""
    protein = universe.select_atoms("protein")
    com_positions = []
    
    for ts in universe.trajectory[0:-1:10]:
        com = protein.center_of_mass()
        com_positions.append(com)
    distances = scipy.spatial.distance.pdist(com_positions)
    biggest = distances.max()
    if biggest > 30:
        print(f"Protein center of mass moved over 30 Å during simulation ({np.round(biggest, 2)} Å). Are you sure the protein is centered?")
    

In [28]:

def calculate_box_size(universe):
    """Calculates the dimensions the density grid should have"""
    atoms = universe.select_atoms("all")
    x_min, x_max = atoms.positions[:,0].min(), atoms.positions[:,0].max()
    y_min, y_max = atoms.positions[:,1].min(), atoms.positions[:,1].max()
    z_min, z_max = atoms.positions[:,2].min(), atoms.positions[:,2].max()
    
    for ts in universe.trajectory:
        
        frame_x_min, frame_x_max = atoms.positions[:,0].min(), atoms.positions[:,0].max()
        frame_y_min, frame_y_max = atoms.positions[:,1].min(), atoms.positions[:,1].max()
        frame_z_min, frame_z_max = atoms.positions[:,2].min(), atoms.positions[:,2].max()
        
        if frame_x_min < x_min: x_min = frame_x_min
        if frame_x_max > x_max: x_max = frame_x_max
        if frame_y_min < y_min: y_min = frame_y_min
        if frame_y_max > y_max: y_max = frame_y_max
        if frame_z_min < z_min: z_min = frame_z_min
        if frame_z_max > z_max: z_max = frame_z_max

    x_dim = x_max - x_min
    y_dim = y_max - y_min
    z_dim = z_max - z_min
    
    return max([x_dim, y_dim, z_dim])



In [29]:

def make_grid(dimension, n_grids):
    """Calculates grid points for density grid"""
    x_grids = np.linspace(-dimension*.5, dimension*.5, n_grids)
    y_grids = np.linspace(-dimension*.5, dimension*.5, n_grids)
    z_grids = np.linspace(-dimension*.5, dimension*.5, n_grids)
    
    return x_grids.reshape(-1,1), y_grids.reshape(-1,1), z_grids.reshape(-1,1)
    


In [30]:
def map_to_grid_points(coordinates, x_vals, y_vals, z_vals):
    """Maps every coordinate in coordinates to the closest grid point"""
    x_distances = scipy.spatial.distance.cdist(coordinates[:,0].reshape(-1,1), x_vals)
    y_distances = scipy.spatial.distance.cdist(coordinates[:,1].reshape(-1,1), y_vals)
    z_distances = scipy.spatial.distance.cdist(coordinates[:,2].reshape(-1,1), z_vals)
    x_closest = np.argmin(x_distances, axis=1)
    y_closest = np.argmin(y_distances, axis=1)
    z_closest = np.argmin(z_distances, axis=1)    
    
    return x_closest, y_closest, z_closest
    

In [31]:

def snippet_density_maps(universe, selection, n_steps, dimension, n_grids, output, use_com=False):
    """Calculates densities for 'subtrajectories' of n_steps each"""
    ends = np.arange(0, universe.trajectory.n_frames, n_steps)
    maps = []
    
    print(f"Will calculate {ends.shape[0] - 1} density maps...")
    
    protein = universe.select_atoms("protein")
    atoms = universe.select_atoms(selection)
    all = universe.select_atoms("all")
    x_grids, y_grids, z_grids = make_grid(dimension, n_grids)
    
    for i, end in enumerate(ends[:-1]):
        print(f"At snippet {i} / {ends.shape[0] - 1}", end="\r")
        snippet = np.zeros((n_grids, n_grids, n_grids))
        for ts in universe.trajectory[end:ends[i+1]:1]:
            com = protein.center_of_mass()
            if use_com:
                pos = atoms.center_of_mass(compound="residues")
            else:
                pos = atoms.positions
            x_closest, y_closest, z_closest = map_to_grid_points(pos - com, x_grids, y_grids, z_grids)
            index = np.array([x_closest, y_closest, z_closest]).T
            np.add.at(snippet, (index[:,0], index[:,1], index[:,2]), 1)
        snippet = snippet / (pos.shape[0] * n_steps)
        maps.append(snippet.reshape(1, n_grids**3))
    
    maps = np.concatenate(maps, axis=0)
    np.save(output, maps)
    return maps



In [32]:
def plot_heatmap(maps, vmax):

    fig, ax = plt.subplots(1, 2, figsize=(6, 3))
    full = np.mean(maps, axis=0).reshape(int(np.ceil(maps.shape[1]**(1/3))), \
                                         int(np.ceil(maps.shape[1]**(1/3))),
                                         int(np.ceil(maps.shape[1]**(1/3))))
    lower = np.sum(full[:,:,int(full.shape[2]/2):], axis=2)
    upper = np.sum(full[:,:,:int(full.shape[2]/2)], axis=2)

    ax[0].imshow(lower, cmap='hot', interpolation='nearest', vmin=0, vmax=vmax)
    ax[1].imshow(upper, cmap='hot', interpolation='nearest', vmin=0, vmax=vmax)
    plt.tight_layout()
    plt.show()


# Activation function calculation functions

In [33]:
def calculate_tm36(universe):
    resid1 = universe.select_atoms("resid 100 and name CA")
    resid2 = universe.select_atoms("resid 241 and name CA")
    return distance_array(resid1.positions, resid2.positions)[0][0]

In [34]:
def calculate_A100(universe):
    pairs = [(23, 297), (48, 87), (92, 119), (196, 241), (265, 277)] 
    # Residues between which the distances are calculated
    coefs = [-14.43, -7.62, 9.11, -6.32, -5.22] # A100 linear combination coefficients
    a100 = 0
    for i, pair in enumerate(pairs):
        resid1 = pair[0]
        resid2 = pair[1]
        coef = coefs[i]
        p1 = universe.select_atoms(f"name CA and resid {resid1}")
        p2 = universe.select_atoms(f"name CA and resid {resid2}")
        R = scipy.spatial.distance.euclidean(p1.positions.reshape(-1),p2.positions.reshape(-1))
        #print(R)
        a100 += coef * R
    a100 += 278.88
    return a100

In [35]:
def snippet_coordinates(universe, selection, n_steps, output):
    """Calculates mean coordinates of selection for each 'subtrajectory' of n_steps each"""
    atoms = universe.select_atoms(selection)
    pos = []
    ends = np.arange(0, universe.trajectory.n_frames, n_steps)

    for i, end in enumerate(ends[:-1]):
        print(f"At snippet {i + 1} / {ends.shape[0] - 1}", end="\r")
        vals = []
        for ts in universe.trajectory[end:ends[i+1]:1]:
            vals.append(atoms.positions.reshape(1, atoms.n_atoms*3))
        pos.append(np.mean(vals, axis=0))

    pos = np.concatenate(pos, axis=0)
    np.save(output, pos)

    return pos


In [36]:

def snippet_function_values(universe, function, n_steps, output):
    """Calculates mean value of function for 'subtrajectories' of n_steps each"""
    ends = np.arange(0, universe.trajectory.n_frames, n_steps)
    function_values = []
    
    print(f"Will calculate {ends.shape[0] - 1} function values...")
    
    for i, end in enumerate(ends[:-1]):
        print(f"At snippet {i + 1} / {ends.shape[0] - 1}", end="\r")
        vals = []
        for ts in universe.trajectory[end:ends[i+1]:1]:
            vals.append(function(universe))
        function_values.append(np.mean(vals))
    function_values = np.array(function_values)
    np.save(output, function_values)
    return function_values


# Binding site analysis

In [37]:
def find_sites(density_maps, percentile, plot=False):
    
    density = np.mean(density_maps, axis=0).reshape(int(np.ceil(density_maps.shape[1]**(1/3))), \
                                                    int(np.ceil(density_maps.shape[1]**(1/3))), 
                                                    int(np.ceil(density_maps.shape[1]**(1/3))))
    threshold = np.percentile(density, percentile)
    high = density.copy()
    high[high < threshold] = 0
    high = np.array(np.nonzero(high)).T
    clustering = DBSCAN(eps=np.sqrt(3) + .1, min_samples=3).fit(high)
    labels = clustering.labels_
    if plot:
        fig = plt.figure(figsize=(10, 5))  
        ax = fig.add_subplot(121, projection='3d')
        ax1 = fig.add_subplot(122, projection='3d')
        ax1.view_init(90, 0)
        ax.view_init(0, 0)
        ax.scatter(high[:,0], high[:,1], high[:,2], c=labels)
        ax1.scatter(high[:,0], high[:,1], high[:,2], c=labels)
        ax1.set_title("Top view")
        ax.set_title("Side view")
        ax.set_xlim(0, int((density_maps.shape[1]**(1/3))))
        ax.set_ylim(0, int((density_maps.shape[1]**(1/3))))
        ax.set_zlim(0, int((density_maps.shape[1]**(1/3))))
        ax1.set_xlim(0, int((density_maps.shape[1]**(1/3))))
        ax1.set_ylim(0, int((density_maps.shape[1]**(1/3))))
        ax1.set_zlim(0, int((density_maps.shape[1]**(1/3))))
        plt.show()
    clusters = []
    
    for l in set(labels):
        if l != -1:
            index = np.where(labels == l)[0]
            clusters.append(high[index,:])
        
    return clusters


In [38]:

def check_map_occupations(density_map, sites):
    
    occupations = []
    for s in sites:
        site_densities = density_map[s[:,0], s[:,1], s[:,2]]
        occupations.append(np.mean(site_densities))
    
    return occupations

In [39]:
def calculate_resids_of_site(universe, site_grids, dimension, n_grids):
    
    protein = universe.select_atoms("protein")
    x_grids, y_grids, z_grids = make_grid(dimension, n_grids)
    resids = []
    
    for grid in site_grids:
        x, y, z = x_grids[grid[0]], y_grids[grid[1]], z_grids[grid[2]]
        protein.positions -= protein.center_of_mass()
        pos = protein.center_of_mass(compound="residues")
        distances = distance_array(np.array([x, y, z]).reshape(1,-1), pos)
        min = np.argmin(distances)
        resids.append(min)
        
    return set(resids)
        

In [40]:
def write_dx(maps, origin, delta, output):
    density = np.mean(maps, axis=0).reshape(int(np.ceil(maps.shape[1]**(1/3))), \
                                            int(np.ceil(maps.shape[1]**(1/3))),
                                            int(np.ceil(maps.shape[1]**(1/3))))
    g = Grid(density, origin=origin, delta=delta)
    g.export(f"./{output}")

# Analysis of minimum distance at each residue 

In [41]:

def min_dists(protein_positions, lipid_positions):
    
    distance_arr = distance_array(protein_positions, lipid_positions)
    mins = distance_arr.min(axis=1).reshape(-1,1).T
    
    return mins


In [42]:

def minimum_distance_map(universe, protein_select, lipid_select, n_steps):
    
    data = []
    prot = universe.select_atoms(protein_select)
    lip = universe.select_atoms(lipid_select)
    
    for ts in universe.trajectory[0:-1:n_steps]:
        print(f"At frame {ts.frame}", end="\r")
        mins = min_dists(prot.positions, lip.positions)
        data.append(mins)
        
    return np.concatenate(data, axis=0)
    

# Cross contact map

In [43]:
def cross_contact(universe, protein_sel, lipid_sel, use_protein_coms=False, use_lipid_coms=False):
    
    
    cdists = []
    protein = universe.select_atoms(protein_sel)
    lipids = universe.select_atoms(lipid_sel)
    
    for ts in universe.trajectory:
        prot_pos = protein.center_of_mass(compound="residues") if use_protein_coms else protein.positions
        lipid_pos = lipids.center_of_mass(compound="residues") if use_lipid_coms else lipids.positions
        
        print(f"Currently on frame {ts.frame}", end="\r")
        cdist = scipy.spatial.distance.cdist(prot_pos, lipid_pos).flatten().reshape(1,-1)
        cdists.append(cdist)
    
    return np.concatenate(cdists, axis=0)

# Clustering helper functions

In [44]:
def plot_bics(data, max_n):
    
    for i in range(5):
        l = []
        for n in range(2,max_n):
            gmm = GaussianMixture(n_components=n).fit(data)
            l.append(gmm.bic(data))
        plt.plot([n for n in range(2,max_n)], l)
    plt.show()

In [45]:
def reorder_clusters(data, fvals, predictions):
    
    modified = np.zeros(data.shape[0])   
    means = []
    
    for p in set(predictions):
        index = np.where(predictions == p)[0]
        mean_fval = np.mean(fvals[index])
        means.append((p, mean_fval))

    sorted_means = sorted(means, key=lambda t:t[1])
    
    for i, m in enumerate(sorted_means):
        index = np.where(predictions == m[0])[0]
        modified[index] = i
    
    return modified

In [24]:
def check_robustness(n_models, subset_size, data, fvals, n_components):
    
    clusterings = np.zeros((n_models, data.shape[0]))
    
    for n in range(n_models):
        X_train, X_test, Y_train, Y_test = train_test_split(data, fvals, test_size=subset_size)
        model = GaussianMixture(n_components=n_components).fit(X_test)
        predictions = model.predict(data)
        reordered_predictions = reorder_clusters(data, fvals, predictions)
        clusterings[n,:] = reordered_predictions
    
    highest_clustering_fractions = []
    for point in range(data.shape[0]):
        cluster_occs = []
        for p in range(n_components):
            cluster_occs.append(np.where(clusterings[:,point] == p)[0].shape[0] / clusterings.shape[0])
        highest_clustering_fractions.append(np.amax(cluster_occs))
    
    print(f"Mean fraction: {np.mean(highest_clustering_fractions)}, fraction STD {np.std(highest_clustering_fractions)} worst fraction: {np.amin(highest_clustering_fractions)}")
        
    
    
        

# Cylindrical coordinate mapping

In [5]:

def make_cylindrical_grid(R_max, z_min, z_max, n_R, n_theta, n_z):
    
    R_grids = np.linspace(0, R_max, n_R)
    z_grids = np.linspace(z_min, z_max, n_z)
    theta_grids = np.linspace(0, 2*np.pi, n_theta)
    
    return R_grids, theta_grids, z_grids
    

In [48]:

def cartesian_to_cylindrical(X, Y):

    R = np.linalg.norm(np.concatenate([X.reshape(-1,1),Y.reshape(-1,1)], axis=1), axis=1)
    thetas = np.zeros(R.shape[0])

    thetas[np.where((X > 0) & (Y > 0))[0]] = np.arctan(Y[np.where((X > 0) & (Y > 0))[0]] / X[np.where((X > 0) & (Y > 0))[0]])
    thetas[np.where((X > 0) & (Y < 0))[0]] = 2*np.pi + np.arctan(Y[np.where((X > 0) & (Y < 0))[0]] / X[np.where((X > 0) & (Y < 0))[0]])
    thetas[np.where((X < 0) & (Y > 0))[0]] = np.pi + np.arctan(Y[np.where((X < 0) & (Y > 0))[0]] / X[np.where((X < 0) & (Y > 0))[0]])
    thetas[np.where((X < 0) & (Y < 0))[0]] = np.pi + np.arctan(Y[np.where((X < 0) & (Y < 0))[0]] / X[np.where((X < 0) & (Y < 0))[0]])

    return R, thetas



In [65]:
def find_mean_vector(universe):
    top = "1 62 73 143 167 264 276"
    bottom = "28 39 102 118 193 241 296"
    bottom_pos = universe.select_atoms(f"resid {bottom} and name CA").positions
    top_pos = universe.select_atoms(f"resid {top} and name CA").positions
    
    vectors = top_pos - bottom_pos
    mean_vector = (top_pos - bottom_pos).mean(axis=0)
    mean_vector /= np.linalg.norm(mean_vector)
    
    return mean_vector
    

In [67]:
def find_principal_axis(positions):
    
    pca = PCA(n_components=1).fit(positions)
    axis = pca.components_[0]
    if axis[0] < 0:
        axis *= -1
    return axis



In [20]:
def calculate_rotation_matrix(principal_axis):
    # https://math.stackexchange.com/questions/180418/calculate-rotation-matrix-to-align-vector-a-to-vector-b-in-3d
    
    z_vec = np.array([0, 0, 1])
    cross = np.cross(principal_axis, z_vec)
    sin = np.linalg.norm(cross)
    cos = np.dot(principal_axis, z_vec)
    matrix = np.array([[0, -cross[2], cross[1]],
                       [cross[2], 0, -cross[0]],
                       [-cross[1], cross[0], 0]])
    rotation_matrix = np.eye(3) + matrix + matrix@matrix/(1+cos)
    
    return rotation_matrix    

In [68]:

def cylindrical_density_map(universe, selection, R_max, z_min,
                            z_max, n_R, n_theta, n_z, output,
                            use_com=False, skip=1):
    
    maps = []
    
    protein = universe.select_atoms("protein")  
    atoms = universe.select_atoms(selection)
    R_grids, theta_grids, z_grids = make_cylindrical_grid(R_max=R_max, z_min=z_min,
                                                          z_max=z_max, n_R=n_R + 1,
                                                          n_theta=n_theta, n_z=n_z)
    normalizer = None

    for ts in universe.trajectory[::skip]:
        
        print(f"At frame {ts.frame} / {universe.trajectory.n_frames}", end="\r")
        frame_map = np.zeros((n_R, n_theta, n_z))
        
        principal_axis = find_mean_vector(universe)
        rotation_matrix = calculate_rotation_matrix(principal_axis)
        universe.atoms.positions = (rotation_matrix@universe.atoms.positions.T).T
        
        com = protein.center_of_mass()
        
        if use_com:
            pos = atoms.center_of_mass(compound="residues")
        else:
            pos = atoms.positions
        pos -= com
        if ts.frame == 0:
            normalizer = pos.shape[0]

        R, theta = cartesian_to_cylindrical(pos[:,0], pos[:,1])
        cylindrical_pos = np.concatenate([R.reshape(-1,1), theta.reshape(-1,1), pos[:,2].reshape(-1,1)], axis=1)
        R_closest, theta_closest, z_closest = map_to_grid_points(cylindrical_pos, R_grids.reshape(-1,1),
                                                                 theta_grids.reshape(-1,1), z_grids.reshape(-1,1))
        R_closest_within_range = np.where(R_closest != n_R)[0]
        index = np.array([R_closest[R_closest_within_range], theta_closest[R_closest_within_range], z_closest[R_closest_within_range]]).T
        np.add.at(frame_map, (index[:,0], index[:,1], index[:,2]), 1)
        maps.append(frame_map.reshape(1, n_R*n_theta*n_z))
        
    maps = np.concatenate(maps, axis=0)
    maps /= normalizer
    
    np.save(output, maps)
    
    return maps
    
    


In [50]:

def protein_plane(protein_coordinates, n_atoms=10, plot=False):
    
    closest_to_0_z = np.argpartition(np.abs(protein_coordinates[:,2]), n_atoms)[:n_atoms]
    A = protein_coordinates[closest_to_0_z,:][:,[0,1]]
    A = np.concatenate([A, np.ones(A.shape[0]).reshape(-1,1)], axis=1)
    B = protein_coordinates[closest_to_0_z,2]
    
    fit, residual, rnk, s = lstsq(A, B)
    
    if plot:
    
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')

        ax.scatter(protein_coordinates[closest_to_0_z,0],
                protein_coordinates[closest_to_0_z,1],
                protein_coordinates[closest_to_0_z,2])
        ax.set_xlim(-50,50)
        ax.set_ylim(-50,50)
        ax.set_zlim(-10,10)
        
        X,Y = np.meshgrid(np.arange(-50, 50),
                    np.arange(-50, 50))
        
        Z = np.zeros(X.shape)
        for r in range(X.shape[0]):
            for c in range(X.shape[1]):
                Z[r,c] = fit[0] * X[r,c] + fit[1] * Y[r,c] + fit[2]
        ax.plot_wireframe(X,Y,Z, color='k')    
        ax.view_init(0, 10)

    return fit
    


In [69]:
def plot_radial_density(maps, R_max, gro):

    
    fig = plt.figure(figsize=(6, 3))
    ax1 = plt.subplot(1,2,1, projection="polar")
    ax2 = plt.subplot(1,2,2, projection="polar")

    mean_map = np.mean(maps, axis=0).reshape(int(np.ceil(maps.shape[1]**(1/3))), \
                                         int(np.ceil(maps.shape[1]**(1/3))),
                                         int(np.ceil(maps.shape[1]**(1/3))))
    full_density = mean_map.sum()
    for i in range(1, mean_map.shape[2] + 1):
        subset = mean_map[:,:,0:i].sum(axis=2)
        if subset.sum() > .5*full_density:
            midway = i
            break 
        
    lower = mean_map[:,:,:midway].sum(axis=2).T
    upper = mean_map[:,:,midway:].sum(axis=2).T
    
    vmax = np.max([upper.max(), lower.max()])

    rad = np.linspace(0, R_max, mean_map.shape[0])
    azm = np.linspace(0, 2 * np.pi, mean_map.shape[1])
    r, th = np.meshgrid(rad, azm)
    
    print(f"Density in lower leaflet: {lower.sum()}")
    print(f"Density in upper leaflet: {upper.sum()}")

    ax1.pcolormesh(th, r, lower, cmap="Reds", vmin=0, vmax=vmax)
    ax2.pcolormesh(th, r, upper, cmap="Reds", vmin=0, vmax=vmax)
    
    ax1.grid()
    ax2.grid()

    ax1.set_title("Lower Leaflet")
    ax2.set_title("Upper Leaflet")

    ax1.set_rticks([])
    ax1.set_thetagrids([], [])
    ax2.set_rticks([])
    ax2.set_thetagrids([], [])
    
    R_grids, theta_grids, z_grids = make_cylindrical_grid(R_max, 0, 10, mean_map.shape[0], mean_map.shape[1], 5)
    univ = mda.Universe(gro)
    principal_axis = find_mean_vector(univ)
    rotation_matrix = calculate_rotation_matrix(principal_axis)
    univ.atoms.positions = (rotation_matrix@univ.atoms.positions.T).T
    
    com = univ.select_atoms("protein").center_of_mass()
    univ.atoms.positions -= com
    upper_ends = univ.select_atoms("resid 1 62 73 143 167 267 275 and name CA").positions
    lower_ends = univ.select_atoms("resid 30 37 104 116 197 240 297 and name CA").positions
    R_upper, theta_upper = cartesian_to_cylindrical(upper_ends[:,0], upper_ends[:,1])
    R_lower, theta_lower = cartesian_to_cylindrical(lower_ends[:,0], lower_ends[:,1])
    
    ax1.scatter(theta_lower, R_lower, marker="o", facecolors="None", edgecolors="k", s=200)
    ax2.scatter(theta_upper, R_upper, marker="o", facecolors="None", edgecolors="k", s=200)

    for i in range(R_upper.shape[0]):
        ax2.text(theta_upper[i], R_upper[i], str(i + 1), ha="center", va="center", color="k")
        ax1.text(theta_lower[i], R_lower[i], str(i + 1), ha="center", va="center", color="k")
   
    
    plt.tight_layout()
    plt.show()

In [76]:
def plot_cluster_densities(predictions, maps, fvals, R_max):
    
  
    fig, ax = plt.subplots(len(set(predictions)), 2,
                           figsize=(6, 3*len(set(predictions))),
                           subplot_kw={"projection": "polar"})
    

    ax = ax.ravel()
    
    for ind, pred in enumerate(sorted(set(predictions))):
        index = np.where(predictions == pred)[0]
        cluster_fvals = fvals[index]
        cluster_maps = maps[index,:]
        
        mean_map = np.mean(cluster_maps, axis=0).reshape(int(np.ceil(maps.shape[1]**(1/3))), \
                                         int(np.ceil(maps.shape[1]**(1/3))),
                                         int(np.ceil(maps.shape[1]**(1/3))))
        
        full_density = mean_map.sum()
        for i in range(1, mean_map.shape[2] + 1):
            subset = mean_map[:,:,0:i].sum(axis=2)
            if subset.sum() > .5*full_density:
                midway = i
                break 
            
        lower = mean_map[:,:,:midway].sum(axis=2).T
        upper = mean_map[:,:,midway:].sum(axis=2).T
        
        vmax = np.max([upper.max(), lower.max()])

        rad = np.linspace(0, R_max, mean_map.shape[0])
        azm = np.linspace(0, 2 * np.pi, mean_map.shape[1])
        r, th = np.meshgrid(rad, azm)
        

        ax[ind*2].pcolormesh(th, r, lower, cmap="Reds", vmin=0, vmax=vmax)
        ax[ind*2+1].pcolormesh(th, r, upper, cmap="Reds", vmin=0, vmax=vmax)
        
        ax[ind*2].grid()
        ax[ind*2+1].grid()

        ax[ind*2].set_title(f"Cluster {ind + 1}\nLower Leaflet")
        ax[ind*2+1].set_title(f"Cluster {ind + 1}\nUpper Leaflet")

        ax[ind*2].set_rticks([])
        ax[ind*2].set_thetagrids([], [])
        ax[ind*2+1].set_rticks([])
        ax[ind*2+1].set_thetagrids([], [])

        R_grids, theta_grids, z_grids = make_cylindrical_grid(R_max, 0, 10, mean_map.shape[0], mean_map.shape[1], 5)
        univ = mda.Universe(gro)
        principal_axis = find_mean_vector(univ)
        rotation_matrix = calculate_rotation_matrix(principal_axis)
        univ.atoms.positions = (rotation_matrix@univ.atoms.positions.T).T
        
        com = univ.select_atoms("protein").center_of_mass()
        univ.atoms.positions -= com
        upper_ends = univ.select_atoms("resid 1 62 73 143 167 267 275 and name CA").positions
        lower_ends = univ.select_atoms("resid 30 37 104 116 197 240 297 and name CA").positions
        R_upper, theta_upper = cartesian_to_cylindrical(upper_ends[:,0], upper_ends[:,1])
        R_lower, theta_lower = cartesian_to_cylindrical(lower_ends[:,0], lower_ends[:,1])
        
        ax[ind*2].scatter(theta_lower, R_lower, marker="o", facecolors="None", edgecolors="k", s=200)
        ax[ind*2+1].scatter(theta_upper, R_upper, marker="o", facecolors="None", edgecolors="k", s=200)

        for j in range(R_upper.shape[0]):
            ax[ind*2+1].text(theta_upper[j], R_upper[j], str(j + 1), ha="center", va="center", color="k")
            ax[ind*2].text(theta_lower[j], R_lower[j], str(j + 1), ha="center", va="center", color="k")

    
    
    plt.tight_layout()
    
    

SyntaxError: invalid syntax (1923135047.py, line 69)

In [9]:
def load_dataset(filename, skip):
    
    data = np.load(filename, mmap_mode='r')
    if len(data.shape) == 1:
        data = data.reshape(-1,1)
    subset = data[::skip,:]
    
    return subset