# RISM - A Gentle Introduction

## Preamble

The code we've written so far:

In [10]:
import numpy as np
from scipy.fftpack import dstn, idstn
from scipy.special import erf
import matplotlib.pyplot as plt

# Constants
ec = 1.602176565e-19 # A*s
eps_0 = 8.854187817620e-12 # A^2*s^4/kg/m^3
Na = 6.02214129e23 # 1 / mol
kc = 1.0 / 4.0 / np.pi/ eps_0 # J*m/A^2/s^2
kc_a = kc * 1e10 # J*angstrom/A^2/s^2
kc_akj = kc_a * 1e-3 # kJ*angstrom/A^2/s^2
kc_akjmol = kc_akj * Na # kJ*angstrom/A^2/s^2/mol
lb = ec * ec * kc_akjmol # kJ*angstrom/mol

r = 15.0
N = 100
dr = r / N
dk = 2.0 * np.pi / (2.0 * N * dr)

r_grid = np.arange(0.5, N, 1.0) * dr
k_grid = np.arange(0.5, N, 1.0) * dk

def dist_matrix(coords):
    ns = len(coords)
    dist = np.zeros((ns, ns))
    for i, j in np.ndindex((ns, ns)):
        dist[i, j] = np.linalg.norm(coords[i] - coords[j])
    return dist

def wk(coords, multiplicity, k):
    dist = dist_matrix(coords)
    out = np.zeros((N, ns, ns))
    for i, j in np.ndindex((ns, ns)):
        dists = []
        mult_j = int(multiplicity[j, j])
        dists.append(dist[i, j])
        if mult_j > 1:
            for mj in range(1, mult_j):
                dists.append(dist[i, j + mj])
        for dist_ij in dists:
            if dist_ij < 0.0:
                out[:, i, j] += np.zeros(N)
            elif dist_ij == 0.0:
                out[:, i, j] += np.ones(N)
            else:
                out[:, i, j] += np.sin(k * dist_ij) / (k * dist_ij)

        out[:, i, j] /= mult_j

    return out


def LJ(epsilon, sigma, r):
    return 4.0 * epsilon * ( np.power( (sigma / r), 12) - np.power( (sigma / r), 6) )

def Coulomb(q, r):
    return lb * q / r

def Ng_real(q, r):
    return lb * q * erf(r) / r

def Ng_fourier(q, k):
    return 4.0 * np.pi * lb * q * np.exp(-np.power(k, 2.0) / 4.0) / np.power(k, 2.0)

def lorentz_berthelot(eps1, eps2, sig1, sig2):
    return np.sqrt(eps1 * eps2), 0.5 * (sig1 + sig2)

def energy(params, r):
    out = np.zeros((N, ns, ns))
    for i, j in np.ndindex((ns, ns)):
        eps, sig = lorentz_berthelot(
            params[i][0], params[j][0], params[i][1], params[j][1]
        )
        q = params[i][2] * params[j][2]
        out[:, i, j] = LJ(eps, sig, r) + Coulomb(q, r)

    return out

def renorm(params, r, k):
    out_r = np.zeros((N, ns, ns))
    out_k = np.zeros((N, ns, ns))
    for i, j in np.ndindex((ns, ns)):
        q = params[i][2] * params[j][2]
        out_r[:, i, j] = Ng_real(q, r)
        out_k[:, i, j] = Ng_fourier(q, k)

    return out_r, out_k
    
def uv_energy(solute_params, solvent_params, r):
    nsu = len(solute_params)
    nsv = len(solvent_params)
    out = np.zeros((N, nsu, nsv))
    for i, j in np.ndindex((nsu, nsv)):
        eps, sig = lorentz_berthelot(
            solute_params[i][0], solvent_params[j][0], solute_params[i][1], solvent_params[j][1]
        )
        q = solute_params[i][2] * solvent_params[j][2]
        out[:, i, j] = LJ(eps, sig, r) + Coulomb(q, r)

    return out

def uv_renorm(solute_params, solvent_params, r, k):
    nsu = len(solute_params)
    nsv = len(solvent_params)
    out_r = np.zeros((N, nsu, nsv))
    out_k = np.zeros((N, nsu, nsv))
    for i, j in np.ndindex((nsu, nsv)):
        q = solute_params[i][2] * solvent_params[j][2]
        out_r[:, i, j] = Ng_real(q, r)
        out_k[:, i, j] = Ng_fourier(q, k)

    return out_r, out_k

def HNC(beta, ur, tr):
    return np.exp(-beta * ur + tr) - 1.0 - tr

def PY(beta, ur, tr):
    return np.exp(-beta * ur) * (1.0 + tr) - 1.0 - tr

def fbt(fr, r, k, dr):
    return 2.0 * np.pi * dr * dstn(fr * r[:, np.newaxis, np.newaxis], type=4, axes=[0]) / k[:, np.newaxis, np.newaxis]

def ifbt(fk, r, k, dk):
    return dk / 4.0 / np.pi / np.pi * idstn(fk * k[:, np.newaxis, np.newaxis], type=4, axes=[0]) / r[:, np.newaxis, np.newaxis]

def RISM(cr, vk_lr, w, n, p, r, k, dr, dk):
    N = r.shape[0]
    
    I = np.eye(w.shape[1])
    tk = np.zeros_like(cr)
    # Transform c(r) to c(k)
    ck = fbt(cr, r, k, dr)

    ck -= vk_lr

    # Solve the RISM equation for t(k)
    for l in np.arange(0, N):
        A = np.linalg.inv(I - p @ w[l] @ (n @ ck[l] @ n))
        B = w[l] @ (n @ ck[l] @ n) @ w[l]
        tk[l] = A @ B - ck[l]
    
    tk -= vk_lr
    
    # Transform t(k) to t(r)
    tr = ifbt(tk, r, k, dk)
    return tr

def picard_iteration(tolerance, max_step, alpha, initial_tr, beta, ur_sr, uk_lr,  wk, rho, mult):
    # Set up our loop variable tr
    tr = initial_tr
    
    # Define our counter
    step = 0
    
    while True:
        # Store previous t(r)
        tr_prev = tr

        # t(r) -> c(r) from closure
        cr = HNC(beta, ur_sr, tr)

        # c(r) -> F(t(r)) from RISM equation
        f_tr = RISM(cr, beta * uk_lr, wk, mult, rho, r_grid, k_grid, dr, dk)

        # Use Picard iteration to get the new solution
        # Since we set alpha = 1.0 above, this is 
        # essentially direct iteration (tr_new = f_tr).
        tr_new = alpha * f_tr + (1.0 - alpha) * tr_prev

        # Check the absolute value of the difference between the current and previous solution
        # We use .max() because we want to check that all values are below the tolerance, and
        # if the maximum value of the difference array is below, then so must the other values be.
        diff = np.sqrt(np.power((tr_new - tr_prev), 2).sum() * dr)

        # Set t(r) to the new t(r) we just computed
        tr = tr_new

        # Print the iteration step every 100 steps and difference
        if step % 100 == 0:
            print("Iteration: {step} Diff: {diff:.2e}".format(step=step, diff=diff))

        # Test if the difference is below our tolerance
        if diff < tolerance:
            print("Final Iteration: {step} Diff: {diff:.2e}".format(step=step, diff=diff))
            break
        
        # If we reach max_step, raise an error!
        if step == max_step:
            print("Reached max steps!")
            break

        if np.any(np.isnan(tr)):
            print("Diverged at iteration: {step}".format(step=step))
            break

        # Update our step counter. We have a max_step because we don't want this calculation to go 
        # on forever and freeze our PC.
        step += 1
    
    # Return solved t(r)
    return tr

def uv_RISM(cr, vk_lr, wu, wv, hvv, n, p, r, k, dr, dk):
    N = r.shape[0]
    
    tk = np.zeros_like(cr)
    # Transform c(r) to c(k)
    ck = fbt(cr, r, k, dr)

    ck -= vk_lr

    # Solve the RISM equation for t(k)
    for l in np.arange(0, N):
        tk[l] = wu[l] @ (n @ ck[l] @ n) @ (wv[l] + p @ hvv[l]) - ck[l]
    
    tk -= vk_lr
    
    # Transform t(k) to t(r)
    tr = ifbt(tk, r, k, dk)
    return tr

def uv_picard_iteration(tolerance, max_step, alpha, initial_tr, beta, ur_sr, uk_lr,  wk, rho, mult, wv, hvv):
    # Set up our loop variable tr
    tr = initial_tr
    
    # Define our counter
    step = 0
    
    while True:
        # Store previous t(r)
        tr_prev = tr

        # t(r) -> c(r) from closure
        cr = HNC(beta, ur_sr, tr)

        # c(r) -> F(t(r)) from RISM equation
        f_tr = uv_RISM(cr, beta * uk_lr, wk, wv, hvv, mult, rho, r_grid, k_grid, dr, dk)

        # Use Picard iteration to get the new solution
        # Since we set alpha = 1.0 above, this is 
        # essentially direct iteration (tr_new = f_tr).
        tr_new = alpha * f_tr + (1.0 - alpha) * tr_prev

        # Check the absolute value of the difference between the current and previous solution
        # We use .max() because we want to check that all values are below the tolerance, and
        # if the maximum value of the difference array is below, then so must the other values be.
        diff = np.sqrt(np.power((tr_new - tr_prev), 2).sum() * dr)

        # Set t(r) to the new t(r) we just computed
        tr = tr_new

        # Print the iteration step every 100 steps and difference
        if step % 100 == 0:
            print("Iteration: {step} Diff: {diff:.2e}".format(step=step, diff=diff))

        # Test if the difference is below our tolerance
        if diff < tolerance:
            print("Final Iteration: {step} Diff: {diff:.2e}".format(step=step, diff=diff))
            break
        
        # If we reach max_step, raise an error!
        if step == max_step:
            print("Reached max steps!")
            break

        if np.any(np.isnan(tr)):
            print("Diverged at iteration: {step}".format(step=step))
            break

        # Update our step counter. We have a max_step because we don't want this calculation to go 
        # on forever and freeze our PC.
        step += 1
    
    # Return solved t(r)
    return tr

The problems we had defined before:

In [11]:
T = 298.15 # Kelvin
kB = 8.314462618e-3 # kJ / mol / K
beta = 1.0 / T / kB # Thermodynamic Beta

# Solvent Details

ns = 2

ow_eps = 78.15 * kB
ow_sigma = 3.16572
ow_charge = -0.8476

hw_eps = 7.815 * kB
hw_sigma = 1.16572
hw_charge = 0.4238

dens = 0.0334
rho = np.diag([dens, dens])

water_mult = np.diag([1.0, 2.0])
water_params = [ [ow_eps, ow_sigma, ow_charge], [hw_eps, hw_sigma, hw_charge] ]
water_coords = [ np.array([0.0, 0.0, 0.0]), np.array([1.0, 0.0, 0.0]), np.array([-0.333314, 0.942816, 0.0])]

water_ur = energy(water_params, r_grid)
water_ur_lr, water_uk_lr = renorm(water_params, r_grid, k_grid)
water_ur_sr = water_ur - water_ur_lr

water_wk = wk(water_coords, water_mult, k_grid)

# Solute Details

ns = 2

c_eps = 55.05221691240736 * kB
c_sigma = 1.6998347542253702
c_charge = -0.1088

h_eps = 7.900546687705287 * kB
h_sigma = 1.324766393630111
h_charge = 0.026699999999999998


methane_mult = np.diag([1.0, 4.0])
methane_params = [ [c_eps, c_sigma, c_charge], [h_eps, h_sigma, h_charge] ]
methane_coords = [ 
    np.array([3.537, 1.423, 0.0]), 
    np.array([4.089, 2.224, 0.496]), 
    np.array([4.222, 0.611, -0.254]),
    np.array([2.759, 1.049, 0.669]),
    np.array([3.077, 1.81, -0.912]),
]

methane_ur = uv_energy(methane_params, water_params, r_grid)
methane_ur_lr, methane_uk_lr = uv_renorm(methane_params, water_params, r_grid, k_grid)
methane_ur_sr = methane_ur - methane_ur_lr

methane_wk = wk(methane_coords, methane_mult, k_grid)

# Solver Parameters

tolerance = 1e-7
max_step = 10000
alpha = 0.7
initial_tr = np.zeros((N, ns, ns))
lambdas = 1

# Solvent-Solvent (vv) problem

for ilam in range(1, lambdas+1):
    lam = ilam / lambdas
    print(lam)
    tr = picard_iteration(tolerance, max_step, alpha, initial_tr, beta, lam * water_ur_sr, lam * water_uk_lr, water_wk, rho, water_mult)
    initial_tr = tr

water_cr = HNC(beta, water_ur_sr, tr) - beta * water_ur_lr
water_tr = tr + beta * water_ur_lr
water_hr = water_tr + water_cr

# Solute-Solvent (uv) problem

alpha = 0.3 # Changing to a stiffer damped iteration because not as well-posed as vv problem.
initial_tr = np.zeros((N, ns, ns))

for ilam in range(1, lambdas+1):
    lam = ilam / lambdas
    print(lam)
    tr = uv_picard_iteration(tolerance, max_step, alpha, initial_tr, beta, lam * methane_ur_sr, lam * methane_uk_lr, methane_wk, rho, water_mult, water_wk, water_hr)
    initial_tr = tr


methane_cr = HNC(beta, methane_ur_sr, tr) - beta * methane_ur_lr
methane_tr = tr + beta * methane_ur_lr
methane_hr = methane_tr + methane_cr

1.0
Iteration: 0 Diff: 7.95e+00
Iteration: 100 Diff: 3.57e-01
Iteration: 200 Diff: 1.31e-01
Iteration: 300 Diff: 3.93e-02
Iteration: 400 Diff: 1.08e-02
Iteration: 500 Diff: 2.89e-03
Iteration: 600 Diff: 7.69e-04
Iteration: 700 Diff: 2.04e-04
Iteration: 800 Diff: 5.42e-05
Iteration: 900 Diff: 1.44e-05
Iteration: 1000 Diff: 3.82e-06
Iteration: 1100 Diff: 1.01e-06
Iteration: 1200 Diff: 2.69e-07
Final Iteration: 1275 Diff: 9.95e-08
1.0
Iteration: 0 Diff: 4.68e+00
Iteration: 100 Diff: 4.57e-02
Iteration: 200 Diff: 2.65e-02
Iteration: 300 Diff: 1.56e-02
Iteration: 400 Diff: 9.32e-03
Iteration: 500 Diff: 5.58e-03
Iteration: 600 Diff: 3.35e-03
Iteration: 700 Diff: 2.01e-03
Iteration: 800 Diff: 1.21e-03
Iteration: 900 Diff: 7.27e-04
Iteration: 1000 Diff: 4.38e-04
Iteration: 1100 Diff: 2.63e-04
Iteration: 1200 Diff: 1.59e-04
Iteration: 1300 Diff: 9.55e-05
Iteration: 1400 Diff: 5.75e-05
Iteration: 1500 Diff: 3.46e-05
Iteration: 1600 Diff: 2.08e-05
Iteration: 1700 Diff: 1.25e-05
Iteration: 1800 Di

## Thermodynamics

### Solvation Free Energies (SFEs)
