In [None]:
import numpy as np
import os
import sys
import subprocess
import matplotlib.pyplot as plt

from multiprocessing import Pool
os.environ["PATH"] = "/usr/local/gromacs/bin:" + os.environ["PATH"]
os.environ["PLUMED_KERNEL"] = "/Users/wuzhiyou/plumed/lib/libplumedKernel.dylib"

def restrained_simulation(grid_point):
    phi_deg, psi_deg, phi, psi = grid_point
    dir_name = f"phi_{phi_deg}_psi_{psi_deg}"
    os.makedirs(dir_name, exist_ok=True)

    # Generate plumed.dat
    plumed_content_initial = f"""
MOLINFO STRUCTURE=./ala2.pdb

phi: TORSION ATOMS=@phi-2
psi: TORSION ATOMS=@psi-2
restraint: RESTRAINT ARG=phi,psi AT={phi},{psi} KAPPA=100.0,100.0

PRINT ARG=phi,psi,restraint.bias FILE={dir_name}/COLVAR_initial STRIDE=100
"""

    plumed_content = f"""
MOLINFO STRUCTURE=./ala2.pdb

phi: TORSION ATOMS=@phi-2
psi: TORSION ATOMS=@psi-2
restraint: RESTRAINT ARG=phi,psi AT={phi},{psi} KAPPA=1000.0,1000.0

PRINT ARG=phi,psi,restraint.bias FILE={dir_name}/COLVAR_restrained STRIDE=100
"""
    with open(f"{dir_name}/plumed.dat", "w") as f:
        f.write(plumed_content)

    with open(f"{dir_name}/plumed_initial.dat", "w") as f:
        f.write(plumed_content_initial)

    # Copy input files

    # Run simulation
    subprocess.run(
        f"gmx mdrun -s nvt.tpr -plumed {dir_name}/plumed_initial.dat -ntmpi 1 -deffnm {dir_name}/initial -v",
        shell=True,
    )
    subprocess.run(
        f"gmx grompp -f nvt.mdp -c {dir_name}/initial.gro -r {dir_name}/initial.gro -p topol.top -o {dir_name}/restrained.tpr",
        shell=True,
    )
    subprocess.run(
        f"gmx mdrun -s {dir_name}/restrained.tpr -plumed {dir_name}/plumed.dat  -ntmpi 1 -deffnm {dir_name}/restrained -v",
        shell=True,
    )

# Create grid


def plot_colvar(filename, figname):
    colvar = np.loadtxt(filename, comments="#")
    deg_to_rad = np.pi / 180.0
    plt.figure(figsize=(10, 6))
    plt.scatter(colvar[:, 1]/deg_to_rad, colvar[:, 2]/deg_to_rad, c=colvar[:, 3], alpha=0.5)
    plt.colorbar(label='Restraint Bias')
    plt.xlabel('phi (rad)')
    plt.ylabel('psi (rad)')
    plt.title('COLVAR Analysis')
    plt.savefig(figname)
    plt.close()

def FES_gradient(grid_point):
    phi_deg, psi_deg, phi, psi = grid_point
    dir_name = f"phi_{phi_deg}_psi_{psi_deg}"
    colvar = np.loadtxt(f"{dir_name}/COLVAR_restrained", comments="#")
    phi_restrained = colvar[:, 1]
    psi_restrained = colvar[:, 2]
    bias = colvar[:, 3]
    dphi = phi_restrained - phi
    dpsi = psi_restrained - psi

    while np.any(dphi >= np.pi):
        dphi[dphi >= np.pi] -= 2 * np.pi
    while np.any(dphi <= -np.pi):
        dphi[dphi <= -np.pi] += 2 * np.pi
    while np.any(dpsi >= np.pi):
        dpsi[dpsi >= np.pi] -= 2 * np.pi
    while np.any(dpsi <= -np.pi):
        dpsi[dpsi <= -np.pi] += 2 * np.pi

    gradient_phi = -1000.0 * np.mean(dphi)
    gradient_psi = -1000.0 * np.mean(dpsi)

    return gradient_phi, gradient_psi


import os
import glob

# Define the directory (use "." for current directory)
def delete_backupfile():
    directory = "."

    # Define the pattern for the backup files
    backup_pattern = os.path.join(directory, "bck.*.PLUMED.OUT")

    # Find all files matching the backup pattern
    backup_files = glob.glob(backup_pattern)

    # Iterate over the backup files and delete each one
    for file_path in backup_files:
        try:
            os.remove(file_path)
            print(f"Deleted: {file_path}")
        except OSError as e:
            print(f"Error deleting {file_path}: {e}")

    print("Done deleting backup files.")

step = 10
deg_to_rad = np.pi / 180.0



In [None]:
import pandas as pd
import numpy as np
device = 'cpu'

# Define the file path to your CSV file.
file_path = '/Users/wuzhiyou/Code/StringNET/data_standard_FF6/fes_gradients.csv'
gradients_df = pd.read_csv(file_path)

print("First few rows of the mean force data:")
print(gradients_df.head())
data = np.array(gradients_df)
gradients_df = gradients_df[['phi','psi', 'grad_phi','grad_psi']]  # Keep only these two columns
data = np.array(gradients_df)
print(data.shape)

In [None]:
import numpy as np
import torch


dtype = torch.float64  # We work in double precision

class GPWithGradients:
    def __init__(self, train_data: torch.Tensor):
        """
        Args:
            train_data: Torch tensor of shape (N, 4) with columns:
                        [phi, psi, grad_phi, grad_psi].
        """
        self.phi_train = train_data[:, 0].clone()
        self.psi_train = train_data[:, 1].clone()
        self.grad_phi_train = train_data[:, 2].clone()
        self.grad_psi_train = train_data[:, 3].clone()
        self.N = self.phi_train.shape[0]
        self.alpha = None  
        self.L = None  # Cholesky factor (if desired to be reused)
        self.sigma_f = torch.nn.Parameter(torch.tensor(1.0, dtype=dtype, device=device))
        self.l = torch.nn.Parameter(torch.tensor(1.0, dtype=dtype, device=device))

    def __call__(self, test_input: torch.Tensor, noise=1e-6):
        """Makes the instance callable by delegating to the forward method."""
        return self.predict(test_input, noise)
    
    def kernel(self, phi1, psi1, phi2, psi2):
        """
        Computes the periodic kernel between two sets of inputs.
        The kernel is defined as:
            k((phi, psi), (phi', psi')) = sigma_f^2 * exp( -(2 - cos(phi-phi') - cos(psi-psi')) / l^2 )
        Inputs:
            phi1, psi1: tensors of shape [n1]
            phi2, psi2: tensors of shape [n2]
        Returns:
            Tensor of shape [n1, n2]
        """
        # Using broadcasting:
        diff_phi = phi1.unsqueeze(1) - phi2.unsqueeze(0)  # shape: [n1, n2]
        diff_psi = psi1.unsqueeze(1) - psi2.unsqueeze(0)    # shape: [n1, n2]
        expr = 2 - torch.cos(diff_phi) - torch.cos(diff_psi)
        k_val = (self.sigma_f ** 2) * torch.exp(-expr / (self.l ** 2))
        return k_val

    def compute_K_grad(self):
        """
        Computes the (2N x 2N) covariance matrix of the gradient observations using the second derivatives.
        
        For training points i and j:
          K_phi_phi = k * [ cos(phi_i - phi_j)/(l^2) - sin(phi_i - phi_j)^2/(l^4) ]
          K_psi_psi = k * [ cos(psi_i - psi_j)/(l^2) - sin(psi_i - psi_j)^2/(l^4) ]
          K_phi_psi = - k * sin(phi_i - phi_j)*sin(psi_i - psi_j)/(l^4)
        """
        diff_phi = self.phi_train.unsqueeze(1) - self.phi_train.unsqueeze(0)  # shape: [N, N]
        diff_psi = self.psi_train.unsqueeze(1) - self.psi_train.unsqueeze(0)  # shape: [N, N]
        
        k_val = self.kernel(self.phi_train, self.psi_train,
                            self.phi_train, self.psi_train)  # shape: [N, N]
        
        K_phi_phi = k_val * (torch.cos(diff_phi) / (self.l ** 2) -
                             (torch.sin(diff_phi) ** 2) / (self.l ** 4))
        K_psi_psi = k_val * (torch.cos(diff_psi) / (self.l ** 2) -
                             (torch.sin(diff_psi) ** 2) / (self.l ** 4))
        K_phi_psi = - k_val * (torch.sin(diff_phi) * torch.sin(diff_psi)) / (self.l ** 4)
        
       
        top = torch.cat([K_phi_phi, K_phi_psi], dim=1)     # shape: [N, 2N]
        bottom = torch.cat([K_phi_psi, K_psi_psi], dim=1)    # shape: [N, 2N]
        K = torch.cat([top, bottom], dim=0)                   # shape: [2N, 2N]
        return K

    def log_marginal_likelihood(self, noise=1e-6):
        """
        Computes the log marginal likelihood (LML) using the gradient observations.
        Data vector d is (2N x 1) stacked as [grad_phi; grad_psi].
        """
        d = torch.cat([self.grad_phi_train, self.grad_psi_train], dim=0).view(2 * self.N, 1)
        K = self.compute_K_grad() + noise * torch.eye(2 * self.N, dtype=dtype, device=device)
        
        L = torch.cholesky(K)
        alpha = torch.cholesky_solve(d, L)
        logdetK = 2 * torch.sum(torch.log(torch.diag(L)))
        ll = -0.5 * (d.t() @ alpha) - 0.5 * logdetK - (2 * self.N / 2) * torch.log(2 * torch.pi)
        return ll.squeeze()

    def optimize_hyperparameters(self, n_iter=2000, lr=0.01):
        """
        Optimizes the hyperparameters (sigma_f and l) by maximizing the log marginal likelihood.
        """
        optimizer = torch.optim.Adam([self.sigma_f, self.l], lr=lr)
        
        for i in range(n_iter):
            optimizer.zero_grad()
            ll = self.log_marginal_likelihood()
            loss = -ll  # maximize LML <==> minimize negative LML
            loss.backward()
            optimizer.step()
            
            if i % 100 == 0:
                print(f"Iteration {i:4d}, Loss: {loss.item():.5f}, sigma_f: {self.sigma_f.item():.5f}, l: {self.l.item():.5f}")

    def predict(self, test_input: torch.Tensor, noise=1e-6):
        """
        Given a batch of test inputs as a torch.Tensor of shape (M, 2) with columns [phi, psi],
        predicts the function value at each test point using the predictive mean:
        
            μ = k_*^T K^{-1} d
        
        where k_* is the cross-covariance between the test function value and the training
        gradient observations.
        
        Returns:
            A torch.Tensor of shape (M, 1) with the predicted function values.
        """
        if not torch.is_tensor(test_input):
            raise ValueError("Test input must be a torch.Tensor.")
        if test_input.dim() != 2 or test_input.shape[1] != 2:
            raise ValueError(f"Test input must have shape [N, 2], got {test_input.shape}")
        
        if self.alpha is None:
            K = self.compute_K_grad() + noise * torch.eye(2 * self.N, dtype=dtype, device=device)
            L = torch.cholesky(K)
            d = torch.cat([self.grad_phi_train, self.grad_psi_train], dim=0).view(2 * self.N, 1)
            self.alpha = torch.cholesky_solve(d, L).detach()
        
        M = test_input.shape[0]
        phi_star = test_input[:, 0]  # shape: (M,)
        psi_star = test_input[:, 1]  # shape: (M,)
        
        k_val = self.kernel(phi_star, psi_star, self.phi_train, self.psi_train)
        
        diff_phi = phi_star.unsqueeze(1) - self.phi_train.unsqueeze(0)  # shape: (M, N)
        diff_psi = psi_star.unsqueeze(1) - self.psi_train.unsqueeze(0)    # shape: (M, N)
        
        k_grad_phi = (torch.sin(diff_phi) / (self.l ** 2)) * k_val  # shape: (M, N)
        k_grad_psi = (torch.sin(diff_psi) / (self.l ** 2)) * k_val   # shape: (M, N)
        
        k_star = torch.cat([k_grad_phi, k_grad_psi], dim=1)
        
        # Compute the predictive mean:
        # For each test point m, μ_m = k_star[m, :] @ alpha (alpha has shape (2N, 1)).
        pred_mean = k_star @ self.alpha  # shape: (M, 1)
        return pred_mean


FES = GPWithGradients(torch.tensor(data, dtype = torch.float64, device = device))
# gp_model.optimize_hyperparameters(n_iter=2000, lr=0.01)




# Predict at a new test input given as a numpy array of shape (2,): [phi, psi].
test_input = torch.tensor(np.array([[1.0, 2.0]], dtype=np.float64), device = device)
pred_value = FES(test_input)
print(f"Predicted function value F({test_input[0,0]}, {test_input[0,1]}) = {pred_value}")
test_input = torch.tensor(np.array([[1.0 + 2*np.pi, 2.0]], dtype=np.float64), device = device)
pred_value = FES(test_input)
print(f"Predicted function value F({test_input[0,0]}, {test_input[0,1]}) = {pred_value}")
test_input = torch.tensor(np.array([[1.0, 2.0 + 2*np.pi]], dtype=np.float64), device = device)
pred_value = FES(test_input)
print(f"Predicted function value F({test_input[0,0]}, {test_input[0,1]}) = {pred_value}")
test_input = torch.tensor(np.array([[1.0+ 2*np.pi, 2.0 + 2*np.pi]], dtype=np.float64), device = device)
pred_value = FES(test_input)
print(f"Predicted function value F({test_input[0,0]}, {test_input[0,1]}) = {pred_value}")

In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch

def visualize_model(model, device, title="Potential Surface", resolution=100):
    """
    Visualizes the potential energy surface (PES) predicted by the model.
    
    Constructs a grid over [-π, π]^2 for (phi, psi) and feeds the raw angles
    to the model. The model internally extracts the periodic features.
    
    Two plots are generated:
      - a 3D surface plot,
      - a 2D contour plot.
    """
    with torch.no_grad():
        # Create a grid for phi and psi over [-π, π]
        phi_vals = torch.linspace(-np.pi, np.pi, resolution, device=device, dtype=torch.float64)
        psi_vals = torch.linspace(-np.pi, np.pi, resolution, device=device, dtype=torch.float64)
        phi_mesh, psi_mesh = torch.meshgrid(phi_vals, psi_vals, indexing='ij')
        # Stack the raw angles to form inputs of shape [resolution^2, 2]
        coords = torch.stack([phi_mesh.flatten(), psi_mesh.flatten()], dim=1).double()
        print("Coordinates shape:", coords.shape)
        
        V_pred = model(coords).reshape(phi_mesh.shape).cpu().numpy()
        V_min = np.min(V_pred)
        V_pred = V_pred - V_min
    
    phi_np = phi_mesh.cpu().numpy()
    psi_np = psi_mesh.cpu().numpy()

    # Create the plot.
    fig = plt.figure(figsize=(12, 10))
    norm = mpl.colors.TwoSlopeNorm(vmin=V_pred.min(), vcenter=50, vmax=V_pred.max())
    
    cp = plt.contourf(phi_np, psi_np, V_pred, levels=100, cmap=mpl.cm.jet, norm=norm)
    plt.colorbar(cp)
    
    plt.xlabel('$x_1$')
    plt.ylabel('$x_2$')
    plt.title(title)
    
    return phi_np, psi_np, V_pred

phi_np, psi_np, V_pred = visualize_model(FES, device=device, resolution=200)
plt.show()

In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch

def visualize_model_zoomin(model, device, traj = None, traj_ref = None, title="Potential Surface", resolution=100,
                    xlim=(-1, 1), ylim=(-1, 1), scale_factor=5):
    with torch.no_grad():
        phi_vals = torch.linspace(xlim[0], xlim[1], resolution, device=device, dtype=torch.float64)
        psi_vals = torch.linspace(ylim[0], ylim[1], resolution, device=device, dtype=torch.float64)
        phi_mesh, psi_mesh = torch.meshgrid(phi_vals, psi_vals, indexing='ij')
        
        coords = torch.stack([phi_mesh.flatten(), psi_mesh.flatten()], dim=1).double()
        print("Coordinates shape:", coords.shape)
        
        V_pred = model(coords).reshape(phi_mesh.shape).cpu().numpy()
        V_min = np.min(V_pred)
        V_pred = V_pred - V_min
    
    phi_np = phi_mesh.cpu().numpy()
    psi_np = psi_mesh.cpu().numpy()

    domain_width = xlim[1] - xlim[0]
    domain_height = ylim[1] - ylim[0]
    
    figsize = (domain_width * scale_factor, domain_height * scale_factor*0.8)
    fig = plt.figure(figsize=figsize)
    norm = mpl.colors.TwoSlopeNorm(vmin=V_pred.min(), vcenter=60, vmax=V_pred.max())
    
    cp = plt.contourf(phi_np, psi_np, V_pred, levels=200, cmap=mpl.cm.jet, norm=norm)
    if traj is not None:
        plt.plot(traj[:, 0], traj[:, 1], "r.", markersize=5)
    if traj_ref is not None:
        plt.plot(traj_ref[:, 0], traj_ref[:, 1], "k.", markersize=5)
    plt.colorbar(cp)
    
    plt.xlabel('$x_1$')
    plt.ylabel('$x_2$')
    plt.title(title)
    
    # Zoom in to the specified region.
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.gca().set_aspect('equal', adjustable='box')
    plt.show()
    
    return phi_np, psi_np, V_pred

phi_np, psi_np, V_pred = visualize_model_zoomin(FES, device=device, resolution=200,
                                           xlim=(-np.pi, np.pi), ylim=(-np.pi, np.pi), scale_factor=1)


In [None]:
import torch
import torch.nn as nn
from torch.distributions.normal import Normal
import numpy as np
import math
import matplotlib.pyplot as plt
from torch.autograd import Variable
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import pickle
from io import StringIO
import sys
import copy
from torch.autograd import grad
for K in range(3,4):
    C5 = [-2.54247629,  2.76015096]
    C7eq = [-1.39111302,  0.99116045]
    C7ax = [ 1.05515806, -0.68824516]#
    MEPpass_start_point = C7eq#[-1.36744612,  0.98019218]#[-1.36744612,  0.98019218]
    MEPpass_end_point = C7ax#[-2.55186883,  2.86101161]#[ 1.0910803,  -0.89521033]
    N = 20
    #-------------------------------------------------
    mycase = "36_amber_vacuum"
    casenum = f"beta0{K}_C7eqC7ax_N{N}"

    import os
    base_path = '/Users/wuzhiyou/Code/StringNET/on_the_fly/Data/'

    # Full path to the new folder
    folder_path = os.path.join(base_path, casenum)

    os.makedirs(folder_path, exist_ok=True)

    all_path = folder_path

    seed = 42
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    path_s = os.path.join(folder_path, f"{mycase}_s_{casenum}.pkl")
    path_path = os.path.join(folder_path, f"{mycase}_path_{casenum}.pkl")
    path_cos = os.path.join(folder_path, f"{mycase}_cos_{casenum}.pkl")
    path_energy = os.path.join(folder_path, f"{mycase}_energy_{casenum}.pkl")
    path_force = os.path.join(folder_path, f"{mycase}_force_{casenum}.pkl")
    path_cosloss = os.path.join(folder_path, f"{mycase}_coslosss_{casenum}.pkl")
    path_lmax = os.path.join(folder_path, f"{mycase}_lmax_{casenum}.pkl")
    path_lg = os.path.join(folder_path, f"{mycase}_lg_{casenum}.pkl")
    #-------------------------------------------------
    print(path_s)

    plot_use_all = []

    torch.set_printoptions(precision=10)

    model_Parameters_name = all_path + "model_MEP_" + mycase[0:2] + "_" + casenum + "pretrain" + ".pth"  # Path
    model_name = all_path + "model_MEP_" + mycase + "_" + casenum + ".pth"
    model_init = all_path + "model_MEP_" + "dw" + "_" + "2" + ".pth"

    load_model = False
    batches = 50  # iterations
    beta = 0.1*K
    print(f"beta = {beta}")
    learning_rate = 2e-4
    dimension = 2  

    x_lb = -math.pi
    x_ub = math.pi
    y_lb = -math.pi
    y_ub = math.pi

    alpha1 = 0
    alpha2 = 2
    alpha3 = 0.5

    nu = 0
    ww = 10
    lambda__ = 0

    def flatten_tensor_list(tensor_list):
        """Flatten a list of tensors into a single 1D tensor."""
        return torch.cat([t.reshape(-1) for t in tensor_list])

    def unflatten_vector(vec, tensor_list):
        """
        Unflatten a 1D tensor into a list of tensors with shapes matching those in tensor_list.
        """
        out_list = []
        idx = 0
        for t in tensor_list:
            numel = t.numel()
            out_list.append(vec[idx: idx+numel].view_as(t))
            idx += numel
        return out_list

    # ------------------------------
    #  Conjugate Gradient Solver
    # ------------------------------
    def conjugate_gradient(Gv_func, b, tol=1e-6, max_iter=50):
        """
        Solve the linear system G*v = b for v using the Conjugate Gradient (CG) algorithm.
        Gv_func: callable that returns G*v given any flat vector v.
        b: right-hand side (flat vector).
        """
        x = torch.zeros_like(b)
        r = b - Gv_func(x)
        p = r.clone()
        rsold = torch.dot(r, r)
        for i in range(max_iter):
            Ap = Gv_func(p)
            alpha = rsold / (torch.dot(p, Ap) + 1e-12)
            x = x + alpha * p
            r = r - alpha * Ap
            rsnew = torch.dot(r, r)
            control = torch.sqrt(rsnew)
            if control < tol:
                print(f"k = {i}, control = {control}")
                break
            p = r + (rsnew / rsold) * p
            rsold = rsnew
        print(f"k = {i}, control = {control}")
        return x

    def potential_tensor(q):
        out = FES(q)
        return out

    def gradient_tensor(q):
        p = potential_tensor(q)
        return torch.autograd.grad(
            outputs=p,
            inputs=q,
            grad_outputs=torch.ones_like(p),
            create_graph=True
        )[0]
    def gradient_md(q):
        q = q.cpu().detach().numpy()
        phi_deg_recovered = q[:, 0] / deg_to_rad
        psi_deg_recovered = q[:, 1] / deg_to_rad
        grid = np.column_stack((phi_deg_recovered, psi_deg_recovered, q))
        grad = []
        for g in grid:
                restrained_simulation(g)
                delete_backupfile()
                phi_deg, psi_deg, phi, psi = g
                dir_name = f"phi_{phi_deg}_psi_{psi_deg}"
                grad_phi, grad_psi = FES_gradient(g)
                grad.append(np.array([grad_phi, grad_psi]))
    #    grad
        g = np.stack(grad)
        out = torch.tensor(g, dtype = torch.float64).to(device)
        subprocess.run(
        'find . -type d -name "phi_*" -exec rm -rf {} +',
        shell=True,
        )
        return out

    def fig_cos_V_force(i, s, cos, g, x_pred_list, norm_F_perp):
        v = potential_tensor(x_pred).cpu().detach().numpy()
        force = np.sqrt(np.sum(g * g, axis=1))
        sin_square = 1 - cos**2

        mask = sin_square > 0.9  # Boolean mask for entries greater than 0.9

        # Separate s based on mask
        s_high = s[mask]
        s_low = s[~mask]

        # Separate sin_square based on mask
        cos_high = cos[mask]
        cos_low = cos[~mask]
        sin_square_high = sin_square[mask]
        sin_square_low = sin_square[~mask]

        num_subplots = 5

        # Define the desired size of each subplot in inches (e.g., 4x4 inches per subplot)
        subplot_size = 4

        # Calculate the overall figure size
        fig_width = subplot_size * num_subplots
        fig_height = subplot_size

        # Create the figure with the calculated size
        fig, axs = plt.subplots(1, num_subplots, figsize=(fig_width, fig_height))

        # Subplot 1: Cosine
        axs[0].plot(s_low, cos_low, 'b.')
        axs[0].plot(s_high, cos_high, 'r.')
        axs[0].set_ylim(-1.1, 1.1)
        axs[0].set_title("$cos$")
        axs[0].set_box_aspect(1)  
        # Subplot 2: l3(1-Cos^2)
        axs[1].plot(s_low, sin_square_low, 'b.')
        axs[1].plot(s_high, sin_square_high, 'r.')
        axs[1].set_ylim(-0.1, 1.1)
        axs[1].set_title("l3 $(1-cos^2)$")
        axs[1].set_box_aspect(1)  
        # Subplot 3: Energy
        axs[2].plot(s, v, 'b.')
        axs[2].set_title("$Energy$")
        axs[2].set_box_aspect(1)  
        # Subplot 4: Force
        axs[3].plot(s, force, 'b.')
        axs[3].set_title("$Force$")
        axs[3].set_box_aspect(1)  
        # Subplot 5: Norm of F^{\perp}
        axs[4].plot(s, norm_F_perp, 'b.')
        axs[4].set_title("norm of $F^{\perp}$")
        axs[4].set_box_aspect(1)  
        # Adjust the layout to avoid overlap
        plt.tight_layout()

        if i % 10000 == 0 or i == 1 or i == 500 or i == 1000 or i == 2000 or i == 5000:
            file_path = os.path.join(folder_path, "Dot_plot_" + "Iter" + str(i) + "_" + mycase + "_" + casenum + ".eps")
            plt.savefig(file_path)

    def fig_loss_batch(loss, loss_1, loss_3, loss_ref, loss_1_ref, loss_3_ref, loss_EL, loss_EL_ref):
        fig, axes = plt.subplots(1, 4, figsize=(20, 4))
        axes[0].plot(loss, color="#B41830", linestyle='-')
        axes[0].plot(loss_ref, color="#2218B4", linestyle='-')
        axes[0].set_xlabel("Batches")
        axes[0].set_ylabel("Loss")
        axes[0].set_title("Loss vs Loss_ref")

        axes[1].plot(loss_1, color="#B41830", linestyle='-')
        axes[1].plot(loss_1_ref, color="#2218B4", linestyle='-')
        axes[1].set_xlabel("Batches")
        axes[1].set_ylabel("Loss_1")
        axes[1].set_title("Loss_1 vs Loss_1_ref")

        axes[2].plot(loss_3, color="#B41830", linestyle='-')
        axes[2].plot(loss_3_ref, color="#2218B4", linestyle='-')
        axes[2].set_xlabel("Batches")
        axes[2].set_ylabel("Loss_3")
        axes[2].set_title("Loss_3 vs Loss_3_ref")


        axes[3].semilogy(loss_EL, color="#B41830", linestyle='-')
        axes[3].semilogy(loss_EL_ref, color="#2218B4", linestyle='-')
        axes[3].set_xlabel("Batches")
        axes[3].set_ylabel("Loss_Euler_Lagrangian")
        axes[3].set_title("Loss_EL vs Loss_EL_ref")

        # Adjust subplots for a neat layout
        plt.tight_layout()
        plt.show()

    def fig_cos(plt_batch, cos_batch, lmax_batch, lg_batch):
        fig = plt.figure(figsize=(15, 4))
        fig.subplots_adjust(hspace=0.4, wspace=0.4)

        plt.subplot(1, 3, 1)
        plt.semilogy(plt_batch, cos_batch, 'b-')
        plt.xlabel("Batches")
        plt.ylabel("$\int 1-Cos^2(force,tangent) ds$")

        plt.subplot(1, 3, 2)
        plt.plot(plt_batch, lmax_batch, 'b-')
        plt.xlabel("Batches")
        plt.ylabel("$lmax$")

        plt.subplot(1, 3, 3)
        plt.plot(plt_batch, lg_batch, 'b-')
        plt.xlabel("Batches")
        plt.ylabel("$l_g$")


    x = np.linspace(x_lb, x_ub, num=51, endpoint=True)
    y = np.linspace(y_lb, y_ub, num=51, endpoint=True)
    X, Y = np.meshgrid(x, y)
    X_new = X.reshape(-1, 1)
    Y_new = Y.reshape(-1, 1)
    XY = np.hstack((X_new, Y_new))
    for i in range(dimension - 2):
        XY = np.hstack((XY, np.zeros(X_new.shape)))

    XY_tensor = torch.from_numpy(XY)
    X_list = []
    for i in range(dimension):
        X_list.append(XY_tensor[:, i:i+1].to(device))
        print(XY_tensor[:, i:i+1].to(device).shape)
    X_list = torch.hstack(X_list)
    Z = potential_tensor(X_list)
    Z_new = Z.reshape(X.shape).cpu().detach().numpy()
    Z_new = Z_new - np.min(Z_new)
    def fig_countour(k, x_plot, x_plot_ref, xlim, ylim, scale_factor=1.3):
        visualize_model_zoomin(FES, device=device, traj = x_plot, traj_ref = x_plot_ref, resolution=200,
                                            xlim=xlim, ylim=ylim, scale_factor=scale_factor)

        if k % 10000 == 0 or k == 1 or k == 500 or k == 1000 or k == 2000 or k == 5000:
            file_path = os.path.join(folder_path, "Contour_" + "Iter" + str(k) + "_" + mycase + "_" + casenum + ".eps")
            plt.savefig(file_path)
        

    class ResNetBlock(nn.Module):
        """A single ResNet block with two linear layers and Tanh activations."""
        def __init__(self, hidden_size):
            super(ResNetBlock, self).__init__()
            self.linear1 = nn.Linear(hidden_size, hidden_size)
            self.activation = nn.Tanh()
            self.linear2 = nn.Linear(hidden_size, hidden_size)
        
        def forward(self, x):
            residual = x
            out = self.linear1(x)
            out = self.activation(out)
            out = self.linear2(out)
            out = self.activation(out)
            return out + residual
        
    class NeuralNetwork(nn.Module):
        def __init__(self,st,ed):
            super(NeuralNetwork, self).__init__()
            self.flatten = nn.Flatten()
            self.linear_tanh_stack = nn.Sequential(
                nn.Linear(1, 50),
                nn.Tanh(),
                nn.Linear(50, 50),
                nn.Tanh(),
                nn.Linear(50, 50),
                nn.Tanh(),
                nn.Linear(50, 50),
                nn.Tanh(),
                nn.Linear(50, dimension)
            )
            self.startpoint=Variable(torch.from_numpy(np.array([st])),requires_grad=False).to(device)
            self.endpoint=Variable(torch.from_numpy(np.array([ed])),requires_grad=False).to(device)


        def forward(self, s):
            s = self.flatten(s)
            x_pred = self.linear_tanh_stack(s)
            out=s*(1-s)*x_pred + (1-s)*self.startpoint + s*self.endpoint
            return out
   

        def gradient(self, s):
            """
            Compute the gradient of the network's output with respect to the input s.
            This function first ensures that s has requires_grad=True, computes the output,
            and then computes the gradient using torch.autograd.grad.
            """

            out = self.forward(s)
            grad0 = torch.autograd.grad(
                outputs=out[:,0], 
                inputs=s, 
                grad_outputs=torch.ones_like(out[:,0]),
                create_graph=True, 
                retain_graph=True
            )[0]
            grad1 = torch.autograd.grad(
                outputs=out[:,1], 
                inputs=s, 
                grad_outputs=torch.ones_like(out[:,1]),
                create_graph=True, 
                retain_graph=True
            )[0]
            grad = torch.hstack((grad0,grad1))
            
            return grad
        def hessian(self, s):
            # Ensure s has requires_grad enabled.
            
            out = self.gradient(s)
            
            # First derivatives for each output component.
            grad0 = torch.autograd.grad(
                outputs=out[:, 0],
                inputs=s,
                grad_outputs=torch.ones_like(out[:, 0]),
                create_graph=True,
                retain_graph=True,
            )[0]
            
            grad1 = torch.autograd.grad(
                outputs=out[:, 1],
                inputs=s,
                grad_outputs=torch.ones_like(out[:, 1]),
                create_graph=True,
                retain_graph=True,
            )[0]
            grad = torch.hstack((grad0,grad1))
            
            return grad
        def partial_s(self, s, input):
            # Ensure s has requires_grad enabled.
            grad0 = torch.autograd.grad(
                outputs=input[:, 0],
                inputs=s,
                grad_outputs=torch.ones_like(input[:, 0]),
                create_graph=True,
                retain_graph=True,
            )[0]
            
            grad1 = torch.autograd.grad(
                outputs=input[:, 1],
                inputs=s,
                grad_outputs=torch.ones_like(input[:, 1]),
                create_graph=True,
                retain_graph=True,
            )[0]
            grad = torch.hstack((grad0,grad1))
            
            return grad
    def compute_jacobian_list(model, s, theta_list):
        """
        For each integration point s[i] (i=1,...,N), compute the Jacobian J_i = d(φ(s_i))/dθ.
        Each Jacobian is stored as a tensor of shape [output_dim, P],
        where P is the total number of parameters (flattened).
        """
        N = s.shape[0]
        jacobian_list = []
        # Total number of parameters P
        P = sum(p.numel() for p in theta_list)
        output_dim = model(s[:1]).shape[1]  # e.g., dimension of output (here 12)
        for i in range(N):
            s_i = s[i:i+1]  # shape (1,1)
            x_pred_i = model(s_i)  # shape (1, output_dim)
            # For each output component, compute gradient (a row vector of shape [P])
            rows = []
            for j in range(output_dim):
                grads = grad(outputs=x_pred_i[:, j], inputs=theta_list,
                            grad_outputs=torch.ones_like(x_pred_i[:, j]),
                            retain_graph=True, create_graph=False)
                # Flatten all gradients to a single row
                row = flatten_tensor_list(grads)
                rows.append(row)
            # Stack rows to produce J_i of shape (output_dim, P)
            J_i = torch.stack(rows, dim=0)
            jacobian_list.append(J_i)
        return jacobian_list

    def natural_gradient_operator(v_flat, jacobian_list, damping=0.0):
        """
        Given a flat vector v_flat of shape [P], compute:
        G v = (1/N) * sum_{i=1}^N (J_i^T (J_i v_flat))  + damping * v_flat
        where each J_i comes from jacobian_list.
        """
        total = torch.zeros_like(v_flat)
        N = len(jacobian_list)
        for J in jacobian_list:
            # J: shape [output_dim, P], v_flat: [P]
            jv = torch.matmul(J, v_flat)           # shape: [output_dim]
            total += torch.matmul(J.t(), jv)         # shape: [P]
        return total / N + damping * v_flat

    def weights_init_zero(m):
        if isinstance(m, nn.Linear):
            nn.init.zeros_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    print(f"Using {device} device to train")
    device = torch.device('cpu')

        
    model = NeuralNetwork(MEPpass_start_point, MEPpass_end_point).to(device).double()
    if load_model == True:
        state_dict = torch.load(f"/Users/wuzhiyou/Code/Server/StringNET/Data/dataset3/beta0{K}_C7eqC7ax_ref2_N100/model_MEP_36_amber_vacuum_beta0{K}_C7eqC7ax_ref2_N100_60000.pth", map_location=device)
        model.load_state_dict(state_dict)
    model_ref = NeuralNetwork(MEPpass_start_point, MEPpass_end_point).to(device).double()
    model_ref.load_state_dict(state_dict)


    save_data = False
    output_results = True

    loss_batch = []
    plt_batch = []
    cos_batch = []

    lg_batch = []
    lmax_batch = []

    index_seq = []
    loss_seq = []
    loss1_seq = []
    loss2_seq = []
    loss3_seq = []
    loss4_seq = []
    loss7_seq = []
    lossimf_seq = []
    loss_EL_seq = []

    loss_ref_seq = []
    loss1_ref_seq = []
    loss2_ref_seq = []
    loss3_ref_seq = []
    loss_EL_ref_seq = []

    lr_seq = []

    data = {}
    generated_samples = None


    optimizer = torch.optim.Adam(list(model.parameters()), lr = 1e-4)
    optimizer_ref = torch.optim.Adam(list(model_ref.parameters()), lr = 1e-4)

    for i in range(1, batches + 1):
        #---------------n1-------------------
        if i < 20000:
            alpha1 = 10
            alpha2 = 10
            alpha3 = 1
        else:
            alpha1 = 0.1
            alpha2 = 10
            alpha3 = 1

        if i < 80:
            learning_rate = 5e-4
        elif 80 <= i < 50000:
            learning_rate = 2e-4
        elif 50000 <= i <= 200000:
            learning_rate = 2e-4
        else:
            learning_rate = 1e-4
        
        # learning_rate = 2e-5
        lr_seq.append(learning_rate)

    

        s_np = np.random.uniform(0.001,0.999,(N,1))
            
        s = torch.tensor(s_np, requires_grad=True, dtype = torch.float64).to(device)




        x_pred_ref = model_ref(s)
        #-----------------loss_1-----------------
        v_ref = potential_tensor(x_pred_ref)
        
        partial_s_x_ref = model_ref.gradient(s)
        partial_ss_x_ref = model_ref.hessian(s)
        
        norm_partial_s_phi_square_ref = torch.sum(partial_s_x_ref * partial_s_x_ref, axis=1, keepdim=True)
        norm_partial_s_phi_ref = norm_partial_s_phi_square_ref**0.5
        g_ref = gradient_tensor(x_pred_ref)
        norm_g_square_ref = torch.sum(g_ref * g_ref, axis=1, keepdim=True)
        norm_g_ref = norm_g_square_ref**0.5

        loss_1_ref = 1 / beta * torch.log(torch.mean(torch.exp(v_ref * beta) * norm_partial_s_phi_ref))
        
        
        cos_ref = torch.sum(partial_s_x_ref * g_ref, axis=1, keepdim=True) / (norm_g_ref * norm_partial_s_phi_ref)
        cos2_ref = cos_ref**2

        

        loss_2_ref = torch.mean(1 - cos2_ref)
        #-----------------loss_4-----------------
        F_ref = g_ref
        inner_product_ref = torch.sum(partial_s_x_ref * g_ref, axis=1, keepdim=True)
        F_perpendicular_ref = F_ref - (inner_product_ref / (norm_partial_s_phi_square_ref)) * partial_s_x_ref
        norm_F_perp_square_ref = torch.sum(F_perpendicular_ref * F_perpendicular_ref, axis=1, keepdim=True)
        norm_F_perp_ref = norm_F_perp_square_ref**0.5
        loss_4_ref = torch.mean(norm_F_perp_square_ref)

        one_minus_cos2_ref = 1 - cos2_ref
        one_minus_cos2_flat_ref = one_minus_cos2_ref.view(-1)

        # Start with the initial threshold
        threshold = 0.9
        filtered_values = torch.tensor([])  # Initialize an empty tensor for filtered values
        loss_g_ref = torch.mean(norm_partial_s_phi_ref*norm_g_ref)


        tau_ref = partial_s_x_ref/norm_partial_s_phi_ref
        dtau_ds_ref = model.partial_s(s, tau_ref)
        dI_dphi_ref = F_perpendicular_ref - 1/beta*dtau_ds_ref/norm_partial_s_phi_ref
        EL_ref = torch.sum(dI_dphi_ref * dI_dphi_ref, axis=1, keepdim=True)
        loss_Euler_Lag_ref = torch.mean(EL_ref)

    
        cons_ref = torch.sum(partial_s_x_ref * partial_ss_x_ref, axis=1, keepdim=True)
        loss_3_ref = torch.mean(cons_ref**2)

        loss_ref = loss_1_ref + 0.1*loss_3_ref
        optimizer_ref.zero_grad()
        loss_here_ref = 0.1*loss_3_ref
        loss_here_ref.backward(retain_graph=True)  # retain_graph if you'll compute additional grads

        


        ds = 1.0 / N  # Integration weight
        theta_list_ref = list(model_ref.parameters())
        custom_gradients_ref = torch.autograd.grad(
            outputs=x_pred_ref,
            inputs=theta_list_ref,
            grad_outputs=dI_dphi_ref,      # This is δI/δφ(s)
            retain_graph=True,         # Retain the graph if you need further grad computations
            create_graph=False         # No need for higher-order derivatives in this example
        )
        custom_gradients_ref = [grad * ds for grad in custom_gradients_ref]
        with torch.no_grad():
            for param, custom_grad in zip(theta_list_ref, custom_gradients_ref):
                if param.grad is None:
                    param.grad = custom_grad.clone()
                else:
                    param.grad.add_(custom_grad)  # Gradient descent uses grad; the optimizer subtracts it.
        optimizer_ref.step()



        

        x_pred = model(s)
        x_pred_list = []
        for n in range(dimension):
            x_pred_list.append(x_pred[:, n:n+1])
        #-----------------loss_1-----------------
        v = potential_tensor(x_pred)
        
        partial_s_x = model.gradient(s)
        partial_ss_x = model.hessian(s)
        
        norm_partial_s_phi_square = torch.sum(partial_s_x * partial_s_x, axis=1, keepdim=True)
        norm_partial_s_phi = norm_partial_s_phi_square**0.5

        loss_1 = 1 / beta * torch.log(torch.mean(torch.exp(v * beta) * norm_partial_s_phi))
        
        g = gradient_md(x_pred)
        # else:
        #     g = gradient_tensor(x_pred)
        
        cos = torch.sum(partial_s_x * g, axis=1, keepdim=True) / (
            torch.sum(g * g, axis=1, keepdim=True)**0.5 * torch.sum(partial_s_x * partial_s_x, axis=1, keepdim=True)**0.5)
        cos2 = cos**2

        norm_g_square = torch.sum(g * g, axis=1, keepdim=True)
        norm_g = norm_g_square**0.5

        loss_2 = torch.mean(1 - cos2)
        #-----------------loss_4-----------------
        F = g
        inner_product = torch.sum(partial_s_x * g, axis=1, keepdim=True)
        F_perpendicular = F - (inner_product / (norm_partial_s_phi_square)) * partial_s_x
        norm_F_perp_square = torch.sum(F_perpendicular * F_perpendicular, axis=1, keepdim=True)
        norm_F_perp = norm_F_perp_square**0.5
        loss_4 = torch.mean(norm_F_perp_square)

        one_minus_cos2 = 1 - cos2
        one_minus_cos2_flat = one_minus_cos2.view(-1)

        # Start with the initial threshold
        threshold = 0.9
        filtered_values = torch.tensor([])  # Initialize an empty tensor for filtered values
        loss_g = torch.mean(norm_partial_s_phi*norm_g)

        tau = partial_s_x/norm_partial_s_phi
        dtau_ds = model.partial_s(s, tau)
        dI_dphi = F_perpendicular - 1/beta*dtau_ds/norm_partial_s_phi

        EL = torch.sum(dI_dphi * dI_dphi, axis=1, keepdim=True)
        loss_Euler_Lag = torch.mean(EL)
    
        cons = torch.sum(partial_s_x * partial_ss_x, axis=1, keepdim=True)
        loss_3 = torch.mean(cons**2)

        loss = loss_1 + 0.1*loss_3
        optimizer.zero_grad()
        loss_here = 0.1*loss_3
        loss_here.backward(retain_graph=True)  # retain_graph if you'll compute additional grads

        


        ds = 1.0 / N  # Integration weight
        theta_list = list(model.parameters())
        custom_gradients = torch.autograd.grad(
            outputs=x_pred,
            inputs=theta_list,
            grad_outputs=dI_dphi,      # This is δI/δφ(s)
            retain_graph=True,         # Retain the graph if you need further grad computations
            create_graph=False         # No need for higher-order derivatives in this example
        )

        custom_gradients = [grad * ds for grad in custom_gradients]
        with torch.no_grad():
            for param, custom_grad in zip(theta_list, custom_gradients):
                if param.grad is None:
                    param.grad = custom_grad.clone()
                else:
                    param.grad.add_(custom_grad)  # Gradient descent uses grad; the optimizer subtracts it.
        optimizer.step()



        index_seq.append(i)
        loss_seq.append(loss.detach().cpu().numpy())
        loss1_seq.append(loss_1.detach().cpu().numpy())
        loss2_seq.append(loss_2.detach().cpu().numpy())
        loss3_seq.append(loss_3.detach().cpu().numpy())
        loss4_seq.append(loss_4.detach().cpu().numpy())
        loss_EL_seq.append(loss_Euler_Lag.detach().cpu().numpy())

        loss_ref_seq.append(loss_ref.detach().cpu().numpy())
        loss1_ref_seq.append(loss_1_ref.detach().cpu().numpy())
        # loss2_seq.append(loss_2.detach().cpu().numpy())
        loss3_ref_seq.append(loss_3_ref.detach().cpu().numpy())
        loss_EL_ref_seq.append(loss_Euler_Lag_ref.detach().cpu().numpy())
    

        
        

        if i % 1  == 0:
            print(f'batches: {i + 1}')
            print(f'loss: {loss.detach().cpu().numpy()}')


        if i % 100 == 0 or i == 1 or i == 500 or i == 1000 or i == 2000 or i == 5000 or i > 0:
        
            fig_cos(plt_batch, cos_batch, lmax_batch, lg_batch)
            fig_countour(i, x_pred.cpu().detach().numpy(), x_pred_ref.cpu().detach().numpy(), xlim=(-2, 2), ylim=(-2, 2))
            fig_cos_V_force(i, s.cpu().detach().numpy(), cos.cpu().detach().numpy(), g.cpu().detach().numpy(), x_pred, norm_F_perp.cpu().detach().numpy())
            fig_loss_batch(np.array(loss_seq),np.array(loss1_seq),np.array(loss3_seq),
                        np.array(loss_ref_seq),np.array(loss1_ref_seq),np.array(loss3_ref_seq),
                        np.array(loss_EL_seq),np.array(loss_EL_ref_seq))
            
            plt.show()

            model_name = os.path.join(folder_path, f"model_MEP_{mycase}_{casenum}_{i}.pth")
            # Save the PyTorch model state
            torch.save(model.state_dict(), model_name)
            print(f"Saved PyTorch Model State to {model_name}")

            if i % 5000 == 0 or i == 1 or i == 100 or i == 200 or i == 500 or i == 1000 or i == 2000 or i == 5000 or i > 0:
                data["loss_seq"] = loss_seq
                data["loss1_seq"] = loss1_seq
                data["loss2_seq"] = loss2_seq
                data["loss3_seq"] = loss3_seq
                data["loss4_seq"] = loss4_seq
                data["loss_seq"] = loss_seq
                data["loss1_seq"] = loss1_seq
                data["loss2_seq"] = loss2_seq
                data["loss3_seq"] = loss3_seq
                data["loss4_seq"] = loss4_seq
                data["loss_EL_seq"] = loss_EL_seq

                data["loss_ref_seq"] = loss_ref_seq
                data["loss1_ref_seq"] = loss1_ref_seq
                # data["loss2_seq"] = loss2_seq
                data["loss3_ref_seq"] = loss3_ref_seq
                # data["loss4_seq"] = loss4_seq
                data["loss_EL_ref_seq"] = loss_EL_ref_seq
            
                data["loss7_seq"] = loss7_seq
                data["lossimf_seq"] = lossimf_seq
                data["lr_seq"] = lr_seq
                data["lr_seq"] = lr_seq
                if hasattr(model, 'module'):
                    data[f'model_state_dict_{i}'] = copy.deepcopy(model.module.state_dict())
                    data[f'model_ref_state_dict_{i}'] = copy.deepcopy(model_ref.module.state_dict())
                else:
                    data[f'model_state_dict_{i}'] = copy.deepcopy(model.state_dict())
                    data[f'model_ref_state_dict_{i}'] = copy.deepcopy(model_ref.state_dict())
                data[f"x_pred_iter_{i}"] = x_pred.cpu().detach().numpy()
                data[f"x_pred_ref_iter_{i}"] = x_pred_ref.cpu().detach().numpy()
                data[f"dI_dphi_{i}"] = dI_dphi.cpu().detach().numpy()
                # Store additional variables
                data[f"s_iter_{i}"] = s.cpu().detach().numpy()
                data[f"cos_iter_{i}"] = cos.cpu().detach().numpy()
                data[f"g_iter_{i}"] = g.cpu().detach().numpy()
                # data[f"x_pred_list_iter_{i}"] = x_pred_list
                data[f"v_iter_{i}"] = potential_tensor(x_pred).cpu().detach().numpy()
                data[f"norm_F_perp_iter_{i}"] = norm_F_perp.cpu().detach().numpy()
                data[f"EL_{i}"] = EL.cpu().detach().numpy()
                data[f"EL_ref{i}"] = EL_ref.cpu().detach().numpy()
                if save_data == True:
                    
                    filename = os.path.join(folder_path, mycase + "_" + casenum + '.data')
                    with open(filename, 'wb') as file:
                        pickle.dump(data, file)
                    print(f"Data saved successfully in {filename}")
