In [1]:
import torch
import math

## Calculate the position of a lone pair relative to three host atoms

This kind of lone pair definition corresponds to "3 hosts - relative (val1>0)" in `charmm/source/dynamc/lonepair.F90`. Assuming that $r_i$ is the position of the lone pair to be determined, $r_j$, $r_k$ and $r_l$ are the positions of three host atoms, $r_j$ can be represented as $r_i + v_1 + v_2 + v_3$ with $v_1$, $v_2$ and $v_3$ defined as follows:

<center><img src="relative.png" width="20%" height="20%"></center>

The three force-field parameters $d$, $\alpha$ and $\theta$ are the distance between the lone pair $i$ and atom $j$, the angle of $i$-$j$-$k$, and the dihedral of $i$-$j$-$k$-$l$. It ought to be noted that dihedral angles are often positive when they are defined in counterclockwise, and since in the figure above the order of $i$-$j$-$k$-$l$ is clockwise, so $-\theta$ is used. Vector $v_1$ can be calculated as
$$
v_1 = d \cos{\alpha} \frac{r_{jk}}{|r_{jk}|}
$$
where $r_{jk}$ is the vector from $r_j$ to $r_k$ or $r_k - r_j$. To calculate $v_2$ and $v_3$, the normal vector of plane $j$-$k$-$l$ has to be determined at first:
$$
\hat{n} = \frac{r_{jk}\times r_{kl}}{|r_{jk}\times r_{kl}|}
$$
and then $v_3$ is
$$
v_3 = \hat{n}d\sin\alpha\sin-\theta=-\hat{n}d\sin\alpha\sin\theta.
$$
To compute $v_2$, we need to find a vector $\hat{p}$ that is perpendicular $r_{jk}$ and parallel to the plane $j$-$k$-$l$, which is
$$
\hat{p} = \frac{\hat{n}\times r_{jk}}{|\hat{n}\times r_{jk}|}
$$
and then $v_2$ is
$$
v_2 = \hat{p}d\sin\alpha\cos-\theta=\hat{p}d\sin\alpha\cos\theta
$$
The position of lone pair, $r_i$, is then obtained by combining the equations above,
$$
\begin{split}
r_i &= r_j + v_1 + v_2 +v_3\\
&=r_j + d\cos\alpha \frac{r_{jk}}{|r_{jk}|}+
d\sin\alpha\cos\theta\frac{r_{jk}\times r_{kl}\times r_{jk}}{|r_{jk}\times r_{kl}\times r_{jk}|}-d\sin\alpha\sin\theta\frac{r_{jk}\times r_{kl}}{|r_{jk}\times r_{kl}|}
\end{split}
$$

### OpenMM's implementation

OpenMM supports the relative LP by defining virtual sites (see http://docs.openmm.org/latest/userguide/theory/05_other_features.html#virtual-sites and https://github.com/openmm/openmm/blob/9e4b6ba5945c829f7450fa6ada1d57b3156b1f54/devtools/forcefield-scripts/processCharmmForceField.py#L320-L333). The weights of the virtual sites are pre-calculated from $d$, $alpha$ and $\theta$.

### NAMD's implementation

NAMD computes the positions of lone pairs directly using the equations above (which might be slower than openMM's way because the sine and cosine of the angle and the dihedral are repeatedly computed every step). The code can be found in the `distance>=0` branch of `HomePatch::reposition_relative_lonepair` in HomePatch.C. The `distance<0` corresponds to the bisector kind of lone pair that will be discussed later. The following code should yield the same result of the NAMD's implementation:

In [2]:
def reposition_lone_pair_relative(rj, rk, rl, distance, angle, dihedral):
    sin_theta = math.sin(math.radians(dihedral))
    cos_theta = math.cos(math.radians(dihedral))
    sin_alpha = math.sin(math.radians(angle))
    cos_alpha = math.cos(math.radians(angle))
    rjk = rk - rj
    rkl = rl - rk
    plane = torch.linalg.cross(rjk, rkl)
    dsina = distance * sin_alpha
    dcosa = distance * cos_alpha
    ri = rj + dcosa * (rjk / rjk.norm()) + \
         dsina * cos_theta * (torch.linalg.cross(plane, rjk) / torch.linalg.cross(plane, rjk).norm()) - \
         dsina * sin_theta * (plane / plane.norm())
    return ri

### Example

#### Test PDB file content
```
REMARK original generated coordinate pdb file
ATOM      1  CL  UNL     1      -8.861  -2.813   4.867  1.00  0.00      HETACl
ATOM      2  O   UNL     1     -10.724   1.125   2.290  1.00  0.00      HETA O
ATOM      3  C1  UNL     1     -10.277   0.770   5.887  1.00  0.00      HETA C
ATOM      4  C2  UNL     1      -9.735  -0.512   5.954  1.00  0.00      HETA C
ATOM      5  C3  UNL     1     -10.612   1.324   4.654  1.00  0.00      HETA C
ATOM      6  C4  UNL     1      -9.858  -0.695   3.543  1.00  0.00      HETA C
ATOM      7  C5  UNL     1      -9.529  -1.236   4.782  1.00  0.00      HETA C
ATOM      8  C6  UNL     1     -10.397   0.583   3.497  1.00  0.00      HETA C
ATOM      9  H1  UNL     1     -10.006   1.731   2.029  1.00  0.00      HETA H
ATOM     10  H2  UNL     1     -10.369   1.301   6.818  1.00  0.00      HETA H
ATOM     11  H3  UNL     1      -9.473  -0.956   6.901  1.00  0.00      HETA H
ATOM     12  H4  UNL     1     -10.831   2.363   4.471  1.00  0.00      HETA H
ATOM     13  H5  UNL     1      -9.839  -0.863   2.479  1.00  0.00      HETA H
ATOM     14  LP1 UNL     1      -8.222  -4.322   4.949  1.00  0.00      HETADU
END
```

#### NUMLP section in the PSF file
```
         1         4 !NUMLP NUMLPH
         3        1 F          1.6400   180.0000     0.0000
        14         1         7         6
```
The first field of the first line is the number of lone pairs, and the second field is not used in both NAMD and openMM (https://github.com/openmm/openmm/blob/9e4b6ba5945c829f7450fa6ada1d57b3156b1f54/wrappers/python/openmm/app/charmmpsffile.py#L383-L385).
The first field of the second line is the number of lone pair hosts, the second field is the index of the lone pair item, and the third field should be hardcoded as "F" in either NAMD or openMM. The remaining three fields correpsond to $d$, $\alpha$ and $\theta$ discussed above.
The third line contains the atom indices of the lone pair, $j$, $k$ and $l$ atoms.

In [3]:
# Some test data
fi = torch.tensor([1.0, 2.0, 4.0])
rj = torch.tensor([-8.861,  -2.813,   4.867], requires_grad=True)
rk = torch.tensor([-9.529,  -1.236,   4.782], requires_grad=True)
rl = torch.tensor([-9.858,  -0.695,   3.543], requires_grad=True)
distance = 1.0
angle = 180.0
dihedral = 0.0

In [4]:
print(reposition_lone_pair_relative(rj, rk, rl, distance, angle, dihedral))

tensor([-8.4714, -3.7327,  4.9166], grad_fn=<SubBackward0>)


### Redistribute the force on a lone pair to the three host atoms (by gradients)

The lone pair is presented as a virtual atom without mass, so it cannot be integrated normally as other atoms. Since the lone pair position is a function of its host atoms, the forces acting on the host atoms are
$$
\begin{cases}
F_j &= \nabla_{r_j} V = \nabla_{r_j} r_i \nabla_{r_i} V =  \nabla_{r_j} r_i F_i\\
F_k &= \nabla_{r_k} V = \nabla_{r_k} r_i \nabla_{r_i} V =  \nabla_{r_k} r_i F_i\\
F_l &= \nabla_{r_l} V = \nabla_{r_l} r_i \nabla_{r_i} V =  \nabla_{r_l} r_i F_i
\end{cases}
$$
where $F_i$ is the force acting on the lone pair. The $\nabla_{r_j} r_i$ is a Jacobian matrix,
$$
\begin{bmatrix}
\dfrac{\partial r_{i}^x}{\partial r_{j}^x} & \dfrac{\partial r_{i}^y}{\partial r_{j}^x} & \dfrac{\partial r_{i}^z}{\partial r_{j}^x}\\
\dfrac{\partial r_{i}^x}{\partial r_{j}^y} & \dfrac{\partial r_{i}^y}{\partial r_{j}^y} & \dfrac{\partial r_{i}^z}{\partial r_{j}^y}\\
\dfrac{\partial r_{i}^x}{\partial r_{j}^z} & \dfrac{\partial r_{i}^y}{\partial r_{j}^z} & \dfrac{\partial r_{i}^z}{\partial r_{j}^z}
\end{bmatrix},
$$
where $r_{i}^x$ is the $x$ component of vector $r_i$. $\nabla_{r_k} r_i$ and $\nabla_{r_l} r_i$ are constructed in the same way.

The actual calculation of the derivatives are very cumbersome. Here I just put my draft of the derivation without further explanations:

#### Derivative of $\frac{r_{jk}}{|r_{jk}|}$

$$
\begin{split}
\frac{\partial}{\partial r_j^x} \frac{r_{jk}^x}{|r_{jk}|}&=
\frac{-|r_{jk}|+(r_k^x-r_j^x)\frac{(r_k^x-r_j^x)}{|r_{jk}|}}{|r_{jk}|^2}\\
&=\frac{-|r_{jk}|^2+r_{jk}^x r_{jk}^x}{|r_{jk}|^3}
\end{split}
$$

$$
\begin{split}
\frac{\partial}{\partial r_j^x} \frac{r_{jk}^y}{|r_{jk}|}
&=\frac{r_{jk}^yr_{jk}^x}{|r_{jk}|^3}
\end{split}
$$

$$
\begin{split}
\frac{\partial}{\partial r_j^x} \frac{r_{jk}^z}{|r_{jk}|}
&=\frac{r_{jk}^zr_{jk}^x}{|r_{jk}|^3}
\end{split}
$$

$$
\nabla_{r_j^x} \frac{r_{jk}}{|r_{jk}|}=\left[\frac{-|r_{jk}|^2+r_{jk}^x r_{jk}^x}{|r_{jk}|^3},\ \frac{r_{jk}^yr_{jk}^x}{|r_{jk}|^3},\ \frac{r_{jk}^zr_{jk}^x}{|r_{jk}|^3}\right]
$$

The other two gradients, namely $\nabla_{r_j^y} \frac{r_{jk}}{|r_{jk}|}$ and $\nabla_{r_j^z} \frac{r_{jk}}{|r_{jk}|}$, can be derived in a similar manner.

#### Derivative of $\frac{r_{jk}\times r_{kl}\times r_{jk}}{|r_{jk}\times r_{kl}\times r_{jk}|}$

We will use the quotient rule of derivative for this term. For the numerator, according to the rule of triple cross product $\vec{a}\times(\vec{b}\times\vec{c}) = (\vec{a}\cdot\vec{c})\vec{b}-(\vec{a}\cdot\vec{b})\vec{c}$, we have

$$
\begin{split}
\frac{d}{dt}(r_{jk}\times r_{kl}\times r_{jk})&=\frac{d}{dt}\left(
(r_{jk}\cdot r_{jk})r_{kl}-(r_{jk}\cdot r_{kl})r_{jk}
\right)\\
&=\left(2r_{jk}\cdot\frac{dr_{jk}}{dt}\right)r_{kl}+(r_{jk}\cdot r_{jk})\frac{dr_{kl}}{dt}-\\
&\quad\left(\frac{dr_{jk}}{dt}\cdot r_{kl}+\frac{dr_{kl}}{dt}\cdot r_{jk}\right)r_{jk}-(r_{jk}\cdot r_{kl})\frac{dr_{jk}}{dt}
\end{split}
$$

Replacing $t$ with all variables, there are

$$
\begin{dcases}
\frac{\partial}{\partial r_j^x}(r_{jk}\times r_{kl}\times r_{jk})^x&=
-2r_{jk}^x r_{kl}^x+r_{kl}^xr_{jk}^x+(r_{jk}\cdot r_{kl})\\
\frac{\partial}{\partial r_j^x}(r_{jk}\times r_{kl}\times r_{jk})^y&=-2r_{jk}^x r_{kl}^y + r_{kl}^x r_{jk}^y\\
\frac{\partial}{\partial r_j^x}(r_{jk}\times r_{kl}\times r_{jk})^z&=-2r_{jk}^x r_{kl}^z + r_{kl}^x r_{jk}^z
\end{dcases}
$$

$$
\begin{dcases}
\frac{\partial}{\partial r_k^x}(r_{jk}\times r_{kl}\times r_{jk})^x&=
2r_{jk}^x r_{kl}^x - (r_{jk}\cdot r_{jk})-(r_{kl}^x-r_{jk}^x)r_{jk}^x-(r_{jk}\cdot r_{kl})\\
\frac{\partial}{\partial r_k^x}(r_{jk}\times r_{kl}\times r_{jk})^y&=
2r_{jk}^x r_{kl}^y-(r_{kl}^x-r_{jk}^x)r_{jk}^y\\
\frac{\partial}{\partial r_k^x}(r_{jk}\times r_{kl}\times r_{jk})^z&=
2r_{jk}^x r_{kl}^z-(r_{kl}^x-r_{jk}^x)r_{jk}^z
\end{dcases}
$$

$$
\begin{dcases}
\frac{\partial}{\partial r_l^x}(r_{jk}\times r_{kl}\times r_{jk})^x&=
(r_{jk}\cdot r_{jk})-r_{jk}^x r_{jk}^x\\
\frac{\partial}{\partial r_l^x}(r_{jk}\times r_{kl}\times r_{jk})^y&=
-r_{jk}^x r_{jk}^y\\
\frac{\partial}{\partial r_l^x}(r_{jk}\times r_{kl}\times r_{jk})^z&=
-r_{jk}^x r_{jk}^z
\end{dcases}
$$

We can calculate the derivatives of numerator at first, and then reuse the result for the denominator since

$$
\begin{split}
\frac{d}{dt}|r_{jk}\times r_{kl}\times r_{jk}|&=
\frac{1}{|r_{jk}\times r_{kl}\times r_{jk}|}\left((r_{jk}\times r_{kl}\times r_{jk})\cdot\frac{d}{dt}(r_{jk}\times r_{kl}\times r_{jk})\right)
\end{split}
$$

#### Derivative of $\frac{r_{jk}\times r_{kl}}{|r_{jk}\times r_{kl}|}$


$$
r_{jk}\times r_{kl}=\left[
r_{jk}^y r_{kl}^z-r_{jk}^z r_{kl}^y, 
r_{jk}^z r_{kl}^x-r_{jk}^x r_{kl}^z,
r_{jk}^x r_{kl}^y-r_{jk}^y r_{kl}^x
\right]
$$

$$
\begin{dcases}
\frac{\partial}{\partial r_j^x}\left(r_{jk}\times r_{kl}\right)&=
\left[0,r_{kl}^z, -r_{kl}^y\right]\\
\frac{\partial}{\partial r_j^y}\left(r_{jk}\times r_{kl}\right)&=
\left[-r_{kl}^z,0,r_{kl}^x\right]\\
\frac{\partial}{\partial r_j^z}\left(r_{jk}\times r_{kl}\right)&=
\left[r_{kl}^y,-r_{kl}^x,0\right]
\end{dcases}
$$

$$
\begin{dcases}
\frac{\partial}{\partial r_k^l}\left(r_{jk}\times r_{kl}\right)&=
\left[0,r_{jk}^z, -r_{jk}^y\right]\\
\frac{\partial}{\partial r_k^l}\left(r_{jk}\times r_{kl}\right)&=
\left[-r_{jk}^z,0,r_{jk}^x\right]\\
\frac{\partial}{\partial r_k^l}\left(r_{jk}\times r_{kl}\right)&=
\left[r_{jk}^y,-r_{jk}^x,0\right]
\end{dcases}
$$

$$
\begin{split}
\frac{d}{dt}|r_{jk}\times r_{kl}|&=
\frac{1}{|r_{jk}\times r_{kl}|}\left((r_{jk}\times r_{kl})\cdot\frac{d}{dt}(r_{jk}\times r_{kl})\right)
\end{split}
$$


In [5]:
def redistribute_relative_force_grad(fi, rj, rk, rl, distance, angle, dihedral):
    fk = torch.zeros(3)
    fl = torch.zeros(3)
    fj = fi.detach().clone()
    dsina = distance * math.sin(math.radians(angle))
    dcosa = distance * math.cos(math.radians(angle))
    dsinacost = dsina * math.cos(math.radians(dihedral))
    dsinasint = dsina * math.sin(math.radians(dihedral))
    rjk = rk.detach().clone() - rj.detach().clone()
    rkl = rl.detach().clone() - rk.detach().clone()
    plane = torch.linalg.cross(rjk, rkl)
    plane_norm = torch.linalg.norm(plane)
    plane_normed = plane / torch.linalg.norm(plane)
    triple = torch.linalg.cross(plane, rjk)
    triple_norm =  torch.linalg.norm(triple)
    triple_normed = triple / triple_norm
    inv_triple_norm = 1.0 / triple_norm
    rjk_outer_rkl = torch.outer(rjk, rkl)
    rjk_outer_rjk = torch.outer(rjk, rjk)
    rjk_norm = torch.sqrt(torch.trace(rjk_outer_rjk))
    inv_rjk_norm = 1.0 / rjk_norm
    inv_rjk_norm3 = inv_rjk_norm * inv_rjk_norm * inv_rjk_norm

    dujk_drj = inv_rjk_norm3 * rjk_outer_rjk
    dujk_drj = dujk_drj - torch.diag(torch.stack([inv_rjk_norm, inv_rjk_norm, inv_rjk_norm]))
    fr1 = dcosa * torch.matmul(dujk_drj, fi)
    fj += fr1
    fk -= fr1

    dtri_drj = -2 * rjk_outer_rkl + rjk_outer_rkl.T
    rjk_dot_rkl = torch.trace(rjk_outer_rkl)
    dtri_drj = dtri_drj + torch.diag(torch.stack([rjk_dot_rkl, rjk_dot_rkl, rjk_dot_rkl]))
    fact_rj = torch.matmul(dtri_drj, triple_normed)
    dr2_drj = inv_triple_norm * (dtri_drj - torch.outer(fact_rj, triple_normed))
    fr2_rj = dsinacost * (torch.matmul(dr2_drj, fi))
    fj += fr2_rj

    dtri_drl = -rjk_outer_rjk
    rjk_dot_rjk = torch.trace(rjk_outer_rjk)
    dtri_drl = dtri_drl + torch.diag(torch.stack([rjk_dot_rjk, rjk_dot_rjk, rjk_dot_rjk]))
    fact_rl = torch.matmul(dtri_drl, triple_normed)
    dr2_drl = inv_triple_norm * (dtri_drl - torch.outer(fact_rl, triple_normed))
    fr2_rl = dsinacost * torch.matmul(dr2_drl, fi)
    fl += fr2_rl
    fk += -(fr2_rj + fr2_rl)

    dcross_drj = torch.vstack([torch.hstack([torch.zeros(1), rkl[2], -rkl[1]]),
                               torch.hstack([-rkl[2], torch.zeros(1), rkl[0]]),
                               torch.hstack([rkl[1], -rkl[0], torch.zeros(1)])])
    inv_p_norm = 1.0 / plane_norm
    dcross_norm_drj = torch.matmul(dcross_drj, plane_normed)
    dr3_drj = inv_p_norm * (dcross_drj - torch.outer(dcross_norm_drj, plane_normed))
    fr3_rj = -dsinasint * (torch.matmul(dr3_drj, fi))
    fj += fr3_rj

    dcross_drl = torch.vstack([torch.hstack([torch.zeros(1), rjk[2], -rjk[1]]),
                               torch.hstack([-rjk[2], torch.zeros(1), rjk[0]]),
                               torch.hstack([rjk[1], -rjk[0], torch.zeros(1)])])
    dcross_norm_drl = torch.matmul(dcross_drl, plane_normed)
    dr3_drl = inv_p_norm * (dcross_drl - torch.outer(dcross_norm_drl, plane_normed))
    fr3_rl = -dsinasint * (torch.matmul(dr3_drl, fi))
    fl += fr3_rl
    
    fk += -(fr3_rj + fr3_rl)
    fi_new = torch.zeros(3)
    return fi_new, fj, fk, fl


#### Test the function

The distance, angle and dihedral for testing are changed to 1.0 angstrom, 60.0 degrees and 50.0 degrees, respectively, to cover the more general cases.

In [6]:
new_fi, fj, fk, fl = redistribute_relative_force_grad(fi, rj, rk, rl, 1.0, 60.0, 50.0)
print((new_fi, fj, fk, fl))

(tensor([0., 0., 0.]), tensor([2.0716, 2.3615, 2.2850]), tensor([-3.7683, -1.4918,  1.9375]), tensor([ 2.6968,  1.1303, -0.2225]))


#### Check if the total force and torque are conserved

In [7]:
# Check if the total torque is conserved
torque_ri = torch.linalg.cross(reposition_lone_pair_relative(rj, rk, rl, 1.0, 60.0, 50.0), fi)
total_torque = torch.linalg.cross(rj, fj) + torch.linalg.cross(rk, fk) + torch.linalg.cross(rl, fl)
print(f'Torque at i = {torque_ri};\ntotal torque after redistributing the force = {total_torque}')
print(f'fi = {fi};\ntotal force after redistributing fi = {fj + fk + fl}')

Torque at i = tensor([-17.0321,  38.1332, -14.8085], grad_fn=<LinalgCrossBackward0>);
total torque after redistributing the force = tensor([-17.0321,  38.1332, -14.8085], grad_fn=<AddBackward0>)
fi = tensor([1., 2., 4.]);
total force after redistributing fi = tensor([1.0000, 2.0000, 4.0000])


#### Also compare the output above with respect to the pytorch autograd below

In [8]:
def redistribute_relative_force_autograd(fi, rj, rk, rl, distance, angle, dihedral):
    rj.grad = None
    rk.grad = None
    rl.grad = None
    rii = reposition_lone_pair_relative(rj, rk, rl, distance, angle, dihedral)
    total_f = torch.dot(fi, rii)
    total_f.backward(retain_graph=True)
    return torch.zeros(3), rj.grad, rk.grad, rl.grad

print(redistribute_relative_force_autograd(fi, rj, rk, rl, 1.0, 60.0, 50.0))

(tensor([0., 0., 0.]), tensor([2.0716, 2.3615, 2.2850]), tensor([-3.7683, -1.4918,  1.9375]), tensor([ 2.6968,  1.1303, -0.2225]))


### Redistribute the force on a lone pair to the three host atoms (by projecting the force)

This approach is from the reverse engineering of CHARMM and NAMD's implementation (the NAMD's one is translated from the Fortran code of CHARMM) with the help of Bernard Brooks.

To redistribute the force $f_i$ on lone pair $i$, we decompose $f_i$ into three components, (a) the force along $r_{ji}$, (b) the force perpendicular to the plane $ijk$, and (c) the remaining force on the plane $ijk$ but perpendicular to $r_{ji}$. The three components should keep the distance, the dihedral angle $\theta$, and the bond angle $\alpha$ unchanged, respectively. The following calculations frequently use the "inverse" of cross product, which can be found in [the wikipedia page](https://en.wikipedia.org/wiki/Cross_product#Cross_product_inverse).

#### The force along $r_{ji}$ (radial force)

We just need to find the projected force and accumulate it to the force of $j$, namely $f_j$ ($f_j$ is set to $(0,0,0)$ at first):

$$
f_{r_{ij}} = \left(f_{i} \cdot \frac{r_{ji}}{|r_{ji}|}\right)\frac{r_{ji}}{|r_{ji}|}
$$

$$
f_j \mathrel{+}= f_{r_{ij}}
$$

#### The force perpendicular to the plane $ijk$ (torsional force)

This force perpendicular to the plane $ijk$, $f_{p}$ moves $i$ along the direction of the dihedral angle, so to keep the dihedral angle unchanged, we need a force acting on $l$ that (i) is perpendicular to the plane $jkl$, and (ii) has the same torque as the torsion force on $i$.

$$
f_{p} = \frac{f_i \cdot \left(r_{ji} \times r_{jk}\right)}{|\left(r_{ji} \times r_{jk}\right)|}\frac{ \left(r_{ji} \times r_{jk}\right)}{|\left(r_{ji} \times r_{jk}\right)|}
$$

The arm of the force is $v_2 + v_3$, so the torque is calculated as 

$$
\tau_p = (v_2 + v_3) \times f_{p}
$$

The force on $l$, $f_l'$, should have the same torque with arm $h_{l}$, where $h_{l}$ is the projection of $r_{jl}$ on the unit vector of $v_2$:

$$
h_l = \frac{r_{jl}\cdot v_2}{|v_2|}\frac{v_2}{|v_2|}
$$

$$
h_l\times f_l' = \tau_p
$$

The solution of the equation above is
$$
f_l' = \frac{\tau_p\times h_l}{|h_l|^2}+\lambda h_l
$$
where $\lambda$ is an arbitrary constant. We choose $\lambda=0$ because we expect that the force on $l$ to be perpendicular to the plane $jkl$, namely $f_l'\cdot h_l = 0$, so
$$
f_l \mathrel{+}= \frac{\tau_p\times h_l}{|h_l|^2}
$$

We still need to further project the remaining force, namely $f_p - f_l'$ to atoms $j$ and $k$. Assuming that the projected forces on $j$ and $k$ are $f_j'$ and $f_k'$, respectively, then we could obtain the following equations since (i) the force on $i$ should be the same as the total forces on the three other atoms, and (ii) the total torque is conserved after projection:

$$
\begin{dcases}
f_j'+f_k' &= f_p - f_l'\\
r_j\times f_j' + r_k\times f_k' + r_l\times f_l' &= r_i \times f_p
\end{dcases}
$$

Solve the equations, we get

$$
f_j' = \frac{(r_{ki} \times f_p - r_{kl}\times f_l')\times r_{jk}}{|r_{jk}|^2} + t r_{jk}
$$

For the same reason, we expect that the force on $j$ perpendicular to $r_{jk}$, so

$$
f_j \mathrel{+}= \frac{(r_{ki} \times f_p - r_{kl}\times f_l')\times r_{jk}}{|r_{jk}|^2}
$$

$$
f_k \mathrel{+}= f_p - f_l' - \frac{(r_{ki} \times f_p - r_{kl}\times f_l')\times r_{jk}}{|r_{jk}|^2}
$$

#### The force on the plane $ijk$ but perpendicular to $r_{ji}$ (angular force)

The remaining force, $f_{a_i}$, should keep the angle $\angle ijk$ unchanged. Similar to the dihedral angle, we move the torque of $f_a$ on $i$ to $k$.

$$
\begin{dcases}
f_{a_i} &= f_i - f_p - f_{r_{ij}}\\
r_{ji}\times f_{a_i} &= r_{jk} \times f_{a_k}
\end{dcases}
$$

Again, we expect $f_{a_k}$ to be perpendicular to $r_{jk}$, so

$$
f_{a_k} = \frac{\left(r_{ji}\times f_{a_i}\right)\times r_{jk}}{|r_{jk}|^2}
$$

$$
f_k \mathrel{+}= f_{a_k}
$$

Add the remaining force to $j$ by 
$$
f_j \mathrel{+}= f_{a_i} - f_{a_k}
$$

In [9]:
def redistribute_relative_force_proj(fi: torch.tensor, rj: torch.tensor, rk: torch.tensor, rl: torch.tensor, distance: float, angle: float, dihedral: float):
    fj = torch.zeros(fi.shape)
    fk = torch.zeros(fi.shape)
    fl = torch.zeros(fi.shape)
    sin_alpha = math.sin(math.radians(angle))
    cos_alpha = math.cos(math.radians(angle))
    sin_theta = math.sin(math.radians(dihedral))
    cos_theta = math.cos(math.radians(dihedral))
    rjk = rk - rj
    rkl = rl - rk
    plane = torch.cross(rjk, rkl)
    dsina = distance * sin_alpha
    dcosa = distance * cos_alpha
    v1 = dcosa * (rjk / torch.linalg.norm(rjk))
    v2_normed = plane.cross(rjk) / torch.linalg.norm(plane.cross(rjk))
    v2 = dsina * cos_theta * v2_normed
    v3 = - dsina * sin_theta * (plane / torch.linalg.norm(plane))
    rji = v1 + v2 + v3
    ri = rj + rji

    # Project the force on rij to keep the distance unchanged
    rij_norm = torch.linalg.norm(rji)
    if rij_norm > 0:
        uji = rji / rij_norm
        frij = torch.dot(fi, uji) * uji
        fj += frij

    # Project the force on the dihedral angle to keep it unchanged
    # Normal vector of plane ikj
    fpl = torch.zeros(3)
    normal_ikj = torch.linalg.cross(uji, rjk)
    normal_ikj_norm = torch.linalg.norm(normal_ikj)
    if normal_ikj_norm > 0:
        normal_ikj /= normal_ikj_norm
    # Force on plane ikj
    fp = torch.dot(fi, normal_ikj) * normal_ikj
    # The height vector in triangle ijk from jk to i
    hijk_ri = v2 + v3
    # Torque of fpikj
    torque_p = torch.linalg.cross(hijk_ri, fp)
    # The height vector in triangle ljk from jk to l
    h_l = torch.dot(rkl, v2_normed) * v2_normed
    h_l_norm = torch.linalg.norm(h_l)
    # The force on l after moving the torque
    if h_l_norm > 0:
        fpl += torch.linalg.cross(torque_p, h_l) / (h_l_norm * h_l_norm)
        fl += fpl
    rjk_norm = torch.linalg.norm(rjk)
    if rjk_norm > 0:
        tmp = torch.linalg.cross(torch.linalg.cross((ri - rk), fp) - torch.linalg.cross(rkl, fpl), -rjk) / (rjk_norm * rjk_norm)
        tmp_k = fp - fpl - tmp
        fj += tmp
        fk += tmp_k

    # Project the remaining force on the bond angle (alpha)
    fai = fi - fp - frij
    torque_j = torch.linalg.cross(rji, fai)
    if rjk_norm > 0:
        fak = torch.linalg.cross(torque_j, rjk) / (rjk_norm * rjk_norm)
        fk += fak
        fj += fai - fak
    
    return torch.zeros(3), fj.detach().clone(), fk.detach().clone(), fl.detach().clone()



#### Test the force projection and compare the result with respect to the gradient method

These two results should be the same, although I don't know how to prove it.

In [10]:
print(redistribute_relative_force_proj(fi, rj, rk, rl, 1.0, 60.0, 50.0))
print(redistribute_relative_force_autograd(fi, rj, rk, rl, 1.0, 60.0, 50.0))

(tensor([0., 0., 0.]), tensor([2.0716, 2.3615, 2.2850]), tensor([-3.7683, -1.4918,  1.9375]), tensor([ 2.6968,  1.1303, -0.2225]))
(tensor([0., 0., 0.]), tensor([2.0716, 2.3615, 2.2850]), tensor([-3.7683, -1.4918,  1.9375]), tensor([ 2.6968,  1.1303, -0.2225]))
