In [None]:
import py3Dmol
import MDAnalysis as mda
import plumed
import os
import numpy as np
import matplotlib.pyplot as plt

%cd /root/output files

# Function Definitions

In [None]:
def parse_gro_atom(line):
    resnr = int(line[0:5])
    resname = line[5:10].strip()
    atom = line[10:15].strip()
    atomnr = int(line[15:20])
    x = float(line[20:28])
    y = float(line[28:36])
    z = float(line[36:44])
    return resnr, resname, atom, atomnr, np.array([x, y, z])


def add_nacl_crystal_to_gro(gro_path, out_path, crystal_size = 3, lattice_const_nm = 0.564, cutoff_nm = 0.25):

    with open(gro_path, "r") as f:
        lines = f.readlines()

    header = lines[0].rstrip("\n")
    natoms_declared = int(lines[1].strip())
    atom_lines = lines[2:-1]
    box_line = lines[-1].strip()

    box = np.array([float(x) for x in box_line.split()[:3]])
    center_pt = box/2.0

    # ---------- Parse atoms ----------

    atoms = []
    for line in atom_lines:
        resnr, resname, atomname, atomnr, pos = parse_gro_atom(line)
        atoms.append({
            "resnr": resnr,
            "resname": resname,
            "atomname": atomname,
            "atomnr": atomnr,
            "pos": pos
        })

    print(f"Read {len(atoms)} atoms")

    # ---------- Build NaCl crystal ----------

    nx = crystal_size
    ny = crystal_size
    nz = crystal_size

    a = lattice_const_nm
    na_list = []
    cl_list = []
    for i in range(nx):
        for j in range(ny):
            for k in range(nz):
                base = np.array([i, j, k], float)*a
                na_list.append(base)
                cl_list.append(base + 0.5*a*np.array([1,1,1]))

    crystal_positions = np.vstack([na_list, cl_list])
    crystal_types = ["NA"]*len(na_list) + ["CL"]*len(cl_list)

    extent = np.array([nx, ny, nz])*a
    origin = center_pt - 0.5*extent
    crystal_positions = crystal_positions + origin

    print(f"Crystal atoms: {len(crystal_positions)}")

    # ---------- Group atoms by residue ----------

    residue_to_indices = {}
    for i, a in enumerate(atoms):
        key = (a["resnr"], a["resname"])
        if key not in residue_to_indices:
            residue_to_indices[key] = []
        residue_to_indices[key].append(i)

    # ---------- Remove overlapping residues ----------

    to_remove = set()

    atom_pos = np.array([a["pos"] for a in atoms])
    crystal_pos = crystal_positions

    for key, idxs in residue_to_indices.items():
        pos = atom_pos[idxs]
        diff = pos[:,None,:] - crystal_pos[None,:,:]
        d2 = np.sum(diff*diff, axis = 2)
        if np.min(d2) < cutoff_nm**2:
            to_remove.add(key)
            print(f"Removing residue {key}")

    kept_atoms = [a for a in atoms if (a["resnr"], a["resname"]) not in to_remove]

    # ---------- Add crystal atoms ----------

    max_resid = max(a["resnr"] for a in atoms) if atoms else 0
    added_atoms = []

    next_resnr = max_resid + 1
    for pos, typ in zip(crystal_positions, crystal_types):
        added_atoms.append({
            "resnr": next_resnr,
            "resname": typ, 
            "atomname": typ, 
            "atomnr": 0, 
            "pos": pos,
        })
        next_resnr += 1

    combined = kept_atoms + added_atoms

    # ---------- Renumber atoms & residues ----------

    new_atoms = []
    new_resnr = 0
    new_atomnr = 0
    last_key = None

    for a in combined:
        key = (a["resnr"], a["resname"])
        if key != last_key:
            new_resnr += 1
            last_key = key
        new_atomnr += 1

        b = dict(a)
        b["resnr"] = new_resnr
        b["atomnr"] = new_atomnr
        new_atoms.append(b)

    # ---------- Write output .gro ----------

    with open(out_path, "w") as f:
        f.write(header + "\n")
        f.write(f"{len(new_atoms)}\n")
        for a in new_atoms:
            line = (
                f"{a['resnr']:5d}"
                f"{a['resname']:<5s}"
                f"{a['atomname']:>5s}"
                f"{a['atomnr']:5d}"
                f"{a['pos'][0]:8.3f}"
                f"{a['pos'][1]:8.3f}"
                f"{a['pos'][2]:8.3f}"
            )
            f.write(line + "\n")
        f.write(box_line + "\n")

def update_topology_file(input_top, gro_file):

    # ---------- Count atoms in the new .gro file ----------

    n_na = 0
    n_cl = 0
    n_sol = 0
    
    with open(gro_file, "r") as f:
        for line in f:
            if len(line) < 15:
                continue
            resname = line[5:10].strip()
            if   resname == "NA":  n_na += 1
            elif resname == "CL":  n_cl += 1
            elif resname == "SOL": n_sol += 1

    n_sol /= 3

    print(f"Detected in {gro_file}: SOL={n_sol}, NA={n_na}, CL={n_cl}")

    # ---------- Read original topology file ----------

    with open(input_top, "r") as f:
        lines = f.readlines()

    # ---------- Locate [ molecules ] section ----------

    molecules_index = None
    for i, line in enumerate(lines):
        if line.strip().lower() == "[ molecules ]":
            molecules_index = i
            break

    if molecules_index is None:
        raise ValueError("topol.top does not contain a [ molecules ] section!")

    # ---------- Keep everything above [ molecules ] and replace rest ----------

    new_lines = lines[:molecules_index + 1]
    new_lines.append(f"SOL    {n_sol}\n")
    new_lines.append(f"NA     {n_na}\n")
    new_lines.append(f"CL     {n_cl}\n")

    # ---------- Write updated topology ----------

    with open(input_top, "w") as f:
        f.writelines(new_lines)

    print("topol.top successfully updated!")

# Initialization

In [None]:
!gmx solvate -cs spc216.gro -o water_box.gro -box 6 6 6
!gmx pdb2gmx -f water_box.gro -o processed.gro -p topol.top -ff oplsaa -water spce

In [None]:
# add_nacl_crystal_to_gro("water_box.gro", "seed.gro", crystal_size = 5, lattice_const_nm = 0.188, cutoff_nm = 0.15)
# update_topology_file("topol.top", "seed.gro")

In [None]:
ions_mdp = """; ions.mdp - used for ion addition
integrator    = steep
nsteps        = 1
emtol         = 10000
cutoff-scheme = Verlet
coulombtype   = PME
rcoulomb      = 1.0
rvdw          = 1.0
pbc           = xyz
constraints   = none
"""

with open("ions.mdp", "w") as f:
    f.write(ions_mdp)
print("Created ions.mdp")

In [None]:
!gmx grompp -f ions.mdp -c processed.gro -p topol.top -o ions.tpr
!echo SOL | gmx genion -s ions.tpr -o solv_ions.gro -p topol.top -pname NA -nname CL -conc 8 -neutral -rmin 0

# Visualization

In [None]:
with open("solv_ions.gro", "r") as f:
    pdb = f.read()

view = py3Dmol.view(width = 800, height = 800)
view.addModel(pdb, "gro")
view.setProjection("orthographic") 

view.setStyle({"resn":"NA"},{"sphere":{"color":"red", "radius":0.5}})
view.setStyle({"resn":"CL"},{"sphere":{"color":"lime", "radius":0.5}})
view.setStyle({"resn":"SOL"},{"sphere":{"color":"blue", "radius":0.1}})

view.zoomTo()
view.show()

# Energy Minimization

In [None]:
em_mdp = """; Energy minimization
integrator    = steep
emtol         = 1000
nsteps        = 200000
emstep        = 0.01
constraints   = h-bonds
lincs-order   = 4
lincs-iter    = 2
cutoff-scheme = Verlet
coulombtype   = PME
rcoulomb      = 1.0
rvdw          = 1.0
pbc           = xyz
"""

with open("em.mdp", "w") as f:
    f.write(em_mdp)

In [None]:
with open("topol.top", "r") as f:
    lines = f.readlines()

lines = [line for i, line in enumerate(lines) if i not in (25, 30)]

with open("topol.top", "w") as f:
    f.writelines(lines)

with open("solv_ions.gro", "r") as f:
    lines = f.readlines()

lines[2:5] = [
    "\t1SOL     OW    1   3.000   3.000   3.000\n", 
    "\t1SOL    HW1    2   2.907   2.998   3.037\n", 
    "\t1SOL    HW2    3   3.001   2.961   2.908\n"
]

with open("solv_ions.gro", "w") as f:
    f.writelines(lines)

In [None]:
!gmx grompp -f em.mdp -c solv_ions.gro -r solv_ions.gro -p topol.top -o em.tpr -maxwarn 5
!gmx mdrun -deffnm em

# NVT Equilibriation

In [None]:
nvt_mdp = """; NVT
integrator=md
dt=0.002
nsteps=5000
tcoupl=V-rescale
tc-grps=System
tau-t=0.1
ref-t=350
pcoupl=no
constraints=h-bonds
cutoff-scheme=Verlet
coulombtype=PME
rcoulomb=1.0
rvdw=1.0
pbc=xyz
nstxout-compressed=20
nstenergy=1000
nstlog=1000
"""

with open("nvt.mdp", "w") as f:
    f.write(nvt_mdp)

In [None]:
!gmx grompp -f nvt.mdp -c em.gro -r solv_ions.gro -p topol.top -o nvt.tpr -maxwarn 5
!gmx mdrun -deffnm nvt

# NPT Equilibriation

In [None]:
npt_mdp = """; NPT
integrator=md
dt=0.002
nsteps=5000
tcoupl=V-rescale
tc-grps=System
tau-t=0.1
ref-t=350
pcoupl=Parrinello-Rahman
pcoupltype=isotropic
tau-p=5.0
ref-p=1.0
compressibility=4.5e-5
constraints=h-bonds
cutoff-scheme=Verlet
coulombtype=PME
rcoulomb=1.0
rvdw=1.0
pbc=xyz
DispCorr=EnerPres
nstxout-compressed=20
nstenergy=1000
nstlog=1000
"""

with open("npt.mdp", "w") as f:
    f.write(npt_mdp)

In [None]:
!gmx grompp -f npt.mdp -c nvt.gro -r solv_ions.gro -p topol.top -o npt.tpr -maxwarn 5
!gmx mdrun -deffnm npt

# Production Run

In [None]:
prod_mdp = """; Production MD parameters
integrator=md
dt=0.002
nsteps=500000
tcoupl=V-rescale
tc-grps=System
tau-t=0.1
ref-t=350
pcoupl=Parrinello-Rahman
pcoupltype=isotropic
tau-p=5.0
ref-p=1.0
compressibility=4.5e-5
constraints=h-bonds
lincs-order=4
lincs-iter=2
cutoff-scheme=Verlet
coulombtype=PME
rcoulomb=1.0
rvdw=1.0
pbc=xyz
DispCorr=EnerPres
nstxout-compressed=500
nstenergy=5000
nstlog=5000
"""

with open("prod.mdp", "w") as f:
    f.write(prod_mdp)

In [None]:
na_indices = []
cl_indices = []

with open("npt.gro", "r") as f:
    lines = f.readlines()
    for line in lines[2:-1]:  # skip header and box line
        atom_index_str = line[:5].strip()  # first 5 chars = atom index
        atom_index = int(atom_index_str)
        resname = line[5:10].strip()      # next 5 chars = residue/atom name
        if resname == "NA":
            na_indices.append(atom_index)
        elif resname == "CL":
            cl_indices.append(atom_index)

na_start = min(na_indices)
na_end   = max(na_indices)
cl_start = min(cl_indices)
cl_end   = max(cl_indices)

parameters = [
    "r_IN:   CONSTANT VALUE=1.0",
    "r_F:    CONSTANT VALUE=2.0",
    "alpha:  CONSTANT VALUE=0.1",
    "sigma:  CONSTANT VALUE=0.05",
    "K_NA:   CONSTANT VALUE=1.0",
    "K_CL:   CONSTANT VALUE=1.0",
    "c0_NA:  CONSTANT VALUE=0.02",
    "c0_CL:  CONSTANT VALUE=0.02",
    "Vshell: CONSTANT VALUE=10.0"
]

chunk_size = 100

def generate_selection_chunks(atom_prefix, start, end, chunk_size):

    lines = []
    chunks = [range(i, min(i+chunk_size, end+1)) for i in range(start, end+1, chunk_size)]
    chunk_labels = []

    for i in range(start, end+1):
        lines.append(
            f"f_{atom_prefix}_{i}: MATHEVAL ARG=d_{atom_prefix}_{i},r_IN,r_F,alpha "
            f"VAR=w,x,y,z FUNC=1/(1+exp(-(w-x)/z))*1/(1+exp((w-y)/z)) PERIODIC=NO"
        )

    for idx, chunk in enumerate(chunks):
        f_labels = [f"f_{atom_prefix}_{i}" for i in chunk]
        var_list = [f"x{i}" for i in range(len(f_labels))]
        chunk_label = f"c_{atom_prefix}_chunk{idx+1}"
        chunk_labels.append(chunk_label)
        lines.append(
            f"{chunk_label}: MATHEVAL ARG={','.join(f_labels)} VAR={','.join(var_list)} FUNC={'+'.join(var_list)} PERIODIC=NO"
        )

    return lines, chunk_labels

with open("plumed.dat", "w") as f:

    f.write("# --- PARAMETERS ---\n")
    for line in parameters:
        f.write(line + "\n")
    f.write("\n")

    f.write("origin: COM ATOMS=1\n\n")

    f.write("# --- Na distances ---\n")
    for i in range(na_start, na_end+1):
        f.write(f"d_na_{i}: DISTANCE ATOMS=origin,{i} NOPBC\n")
    f.write("\n# --- Cl distances ---\n")
    for i in range(cl_start, cl_end+1):
        f.write(f"d_cl_{i}: DISTANCE ATOMS=origin,{i} NOPBC\n")
    f.write("\n")

    na_lines, na_chunks = generate_selection_chunks('na', na_start, na_end, chunk_size)
    cl_lines, cl_chunks = generate_selection_chunks('cl', cl_start, cl_end, chunk_size)
    for line in na_lines + cl_lines:
        f.write(line + "\n")

    f.write(f"c_na_total: MATHEVAL ARG={','.join(na_chunks)} VAR={','.join([f'x{i}' for i in range(len(na_chunks))])} FUNC={'+'.join([f'x{i}' for i in range(len(na_chunks))])} PERIODIC=NO\n")
    f.write("rho_na_total: MATHEVAL ARG=c_na_total,Vshell VAR=x,y FUNC=x/y PERIODIC=NO\n")
    f.write(f"c_cl_total: MATHEVAL ARG={','.join(cl_chunks)} VAR={','.join([f'x{i}' for i in range(len(cl_chunks))])} FUNC={'+'.join([f'x{i}' for i in range(len(cl_chunks))])} PERIODIC=NO\n")
    f.write("rho_cl_total: MATHEVAL ARG=c_cl_total,Vshell VAR=x,y FUNC=x/y PERIODIC=NO\n\n")

    f.write("# --- Per-ion potentials ---\n")
    for i in range(na_start, na_end+1):
        f.write(
            f"pot_na_{i}: MATHEVAL ARG=d_na_{i},K_NA,f_na_{i},rho_na_total,c0_NA,r_F,sigma "
            f"VAR=r,k,f,c,c0,w,z FUNC=-0.5*k*(c-c0)*tanh((r-w)/(2*z)) PERIODIC=NO\n"
        )
    for i in range(cl_start, cl_end+1):
        f.write(
            f"pot_cl_{i}: MATHEVAL ARG=d_cl_{i},K_CL,f_cl_{i},rho_cl_total,c0_CL,r_F,sigma "
            f"VAR=r,k,f,c,c0,w,z FUNC=-0.5*k*(c-c0)*tanh((r-w)/(2*z)) PERIODIC=NO\n"
        )

    f.write("\n# --- Apply radial forces ---\n")
    for i in range(na_start, na_end+1):
        f.write(f"radial_force_na_{i}: MOVINGRESTRAINT ARG=d_na_{i},pot_na_{i}\n")
    for i in range(cl_start, cl_end+1):
        f.write(f"radial_force_cl_{i}: MOVINGRESTRAINT ARG=d_cl_{i},pot_cl_{i}\n")

print("PLUMED input with per-ion potentials based on total concentration generated in 'test.dat'")

In [None]:
!gmx grompp -f prod.mdp -c npt.gro -r solv_ions.gro -p topol.top -o prod.tpr -maxwarn 5
!gmx mdrun -deffnm prod -plumed plumed.dat

# Analysis

In [None]:
from MDAnalysis.analysis.rdf import InterRDF

start_frame = 900
stop_frame = 1000
r_range = (0.0, 10.0)
nbins = 500

stage = "prod"

u1 = mda.Universe("prod.tpr", "prod.xtc")

na1 = u1.select_atoms("name NA")
cl1 = u1.select_atoms("name CL")


print(len(u1.trajectory))

rdf_calc1 = InterRDF(na1, cl1, range = r_range, nbins = nbins)
rdf_calc1.run(start = start_frame, stop = stop_frame)

r1 = rdf_calc1.results.bins
g_r1 = rdf_calc1.results.rdf

plt.plot(np.array(r1), g_r1)
plt.xlabel("r (Å)")
plt.ylabel("g(r)")
plt.ylim(0, 3)
plt.grid()
plt.show()

In [None]:
from MDAnalysis.analysis.rdf import InterRDF

start_frame = 900
stop_frame = 1000
r_range = (0.0, 10.0)
nbins = 500

stage = "prod"

u1 = mda.Universe("/root/No Seeding, 8M/" + stage + ".tpr", "/root/No Seeding, 8M/" + stage + ".xtc")
u2 = mda.Universe("/root/5x5x5 cell (0.564 nm), 7M/" + stage + ".tpr", "/root/5x5x5 cell (0.564 nm), 7M/" + stage + ".xtc")
u3 = mda.Universe("/root/5x5x5 cell (0.188 nm), 7M/" + stage + ".tpr", "/root/5x5x5 cell (0.188 nm), 7M/" + stage + ".xtc")

na1 = u1.select_atoms("name NA")
cl1 = u1.select_atoms("name CL")
na2 = u2.select_atoms("name NA")
cl2 = u2.select_atoms("name CL")
na3 = u3.select_atoms("name NA")
cl3 = u3.select_atoms("name CL")

print(len(u1.trajectory))

rdf_calc1 = InterRDF(na1, cl1, range = r_range, nbins = nbins)
rdf_calc1.run(start = start_frame, stop = stop_frame)
rdf_calc2 = InterRDF(na2, cl2, range = r_range, nbins = nbins)
rdf_calc2.run(start = start_frame, stop = stop_frame)
rdf_calc3 = InterRDF(na3, cl3, range = r_range, nbins = nbins)
rdf_calc3.run(start = start_frame, stop = stop_frame)

r1 = rdf_calc1.results.bins
g_r1 = rdf_calc1.results.rdf
r2 = rdf_calc2.results.bins
g_r2 = rdf_calc2.results.rdf
r3 = rdf_calc3.results.bins
g_r3 = rdf_calc3.results.rdf

plt.plot(np.array(r1), g_r1, label = "No Seeding")
plt.plot(np.array(r2), g_r2, label = "5x5x5 cell (0.564 nm)")
plt.plot(np.array(r3), g_r3, label = "5x5x5 cell (0.188 nm)")
plt.xlabel("r (Å)")
plt.ylabel("g(r)")
plt.title("Time Averaging: " + str(start_frame*0.001) + " to " + str(stop_frame*0.001) + " ns")
plt.ylim(0, 3)
plt.grid()
plt.legend()
plt.show()

In [None]:
start_frame = 100
stop_frame = 200

stage = "prod"

u1 = mda.Universe("/root/No Seeding, 8M/" + stage + ".tpr", "/root/No Seeding, 8M/" + stage + ".xtc")
u2 = mda.Universe("/root/5x5x5 cell (0.564 nm), 7M/" + stage + ".tpr", "/root/5x5x5 cell (0.564 nm), 7M/" + stage + ".xtc")
u3 = mda.Universe("/root/5x5x5 cell (0.188 nm), 7M/" + stage + ".tpr", "/root/5x5x5 cell (0.188 nm), 7M/" + stage + ".xtc")

center_sel1 = u1.select_atoms("name NA")
ions1 = u1.select_atoms("resname CL")
center_sel2 = u2.select_atoms("name NA")
ions2 = u2.select_atoms("resname CL")
center_sel3 = u3.select_atoms("name NA")
ions3 = u3.select_atoms("resname CL")

bins = np.linspace(0, 30, 60)
counts_total1 = np.zeros(len(bins) - 1)
counts_total2 = np.zeros(len(bins) - 1)
counts_total3 = np.zeros(len(bins) - 1)

n_frames1 = 0
n_frames2 = 0
n_frames3 = 0

for ts in u1.trajectory[start_frame:stop_frame]:
    center = center_sel1.center_of_mass()
    d = np.linalg.norm(ions1.positions - center, axis = 1)
    counts, _ = np.histogram(d, bins = bins)
    counts_total1 += counts
    n_frames1 += 1

for ts in u2.trajectory[start_frame:stop_frame]:
    center = center_sel2.center_of_mass()
    d = np.linalg.norm(ions2.positions - center, axis = 1)
    counts, _ = np.histogram(d, bins = bins)
    counts_total2 += counts
    n_frames2 += 1

for ts in u3.trajectory[start_frame:stop_frame]:
    center = center_sel3.center_of_mass()
    d = np.linalg.norm(ions3.positions - center, axis = 1)
    counts, _ = np.histogram(d, bins = bins)
    counts_total3 += counts
    n_frames3 += 1

volumes = (4/3)*np.pi*(bins[1:]**3 - bins[:-1]**3)

counts_avg1 = counts_total1/n_frames1
density1 = counts_avg1/volumes
counts_avg2 = counts_total2/n_frames2
density2 = counts_avg2/volumes
counts_avg3 = counts_total3/n_frames3
density3 = counts_avg3/volumes

plt.plot(bins[:-1], density1, label = "No Seeding")
plt.plot(bins[:-1], density2, label = "5x5x5 cell (0.564 nm)")
plt.plot(bins[:-1], density3, label = "5x5x5 cell (0.188 nm)")
plt.xlabel("r (Å)")
plt.ylabel("Ion Density (1/Å³)")
plt.title("Time Averaging: " + str(start_frame*0.001) + " to " + str(stop_frame*0.001) + " ns")
plt.grid()
plt.legend()
plt.show()