In [None]:
import os
import numpy as np
import math
import copy 

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import pyemma
from pyemma.plots import plot_free_energy
from pyemma import coordinates
from pyemma import msm
import torch
from torch import nn, optim, autograd
from torch.nn import functional as F

import mdshare
import time
import logging

import yaml

from cfg.parsing import parse_train_args, save_yaml_file
from utils.common import *
from utils.training import *
from utils.visualize import *

from torch.utils.tensorboard import SummaryWriter

from models.ema import ExponentialMovingAverage
from models.RC_DiffFlow import RC_DiffFlow

import warnings
warnings.filterwarnings("ignore")


args = parse_train_args()
args.sample_stepsize = 1
args.alpha = 1
args.latent_dim = 2
args.component_num = 50
args.gamma = 0.9
args.train_portion = 0.7
args.system = 'ad'
args.tica_dim = 24
args.n_epochs_embedding = 10
args.n_epochs_latent = 10
args.n_epochs_joint = 10
print(args)

# record parameters
run_dir = os.path.join(args.log_dir, args.run_name)
if '/' in run_dir and os.path.dirname(run_dir) and not os.path.exists(os.path.dirname(run_dir)):
    os.makedirs(os.path.dirname(run_dir))
    yaml_file_name = os.path.join(run_dir, 'model_parameters.yml')
    save_yaml_file(yaml_file_name, args.__dict__)

import datetime
now = datetime.datetime.now().strftime("%Y_%m_%d %H:%M")
# writer = SummaryWriter(log_dir=os.path.join(run_dir, now))

# define logger
logger = logging.getLogger()
logger.setLevel('INFO')
BASIC_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT)
chlr = logging.StreamHandler()
chlr.setFormatter(formatter)
fhlr = logging.FileHandler('log.log', mode='w') # 输出到文件的handler, os.path.join(run_dir, now, 'log.log')
fhlr.setFormatter(formatter)
logger.addHandler(chlr)
logger.addHandler(fhlr)
logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR)

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')


In [None]:
def visualize_embedding(traj_RE, label):
    """
    Visualize the results of embedding.

    Args:
    - traj_RE (numpy.ndarray): The embedded trajectory.
    - label (numpy.ndarray): Labels for the trajectory data points.
    - save_path (str): Path to save the visualization.
    - experiment_idx (int): Index of the experiment.
    """
    re = traj_RE[:, :2]
    scatter = np.hstack([re, label])
    d_scatter = pd.DataFrame(scatter, columns=['phi', 'xi', 'label'])
    colors = ['b', 'g', 'r', 'c', 'orange', 'pink']
    plt.figure(figsize=(12, 7))
    gs = gridspec.GridSpec(2, 6)
    gs.update(wspace=0.8)
    ax1 = plt.subplot(gs[0, :2])
    ax2 = plt.subplot(gs[0, 2:4])
    ax3 = plt.subplot(gs[0, 4:6])
    ax4 = plt.subplot(gs[1, 1:3])
    ax5 = plt.subplot(gs[1, 3:5])
    axs = [ax1, ax2, ax3, ax4, ax5]
    for index in range(5):
        phi = d_scatter.loc[d_scatter['label'] == index]['phi']
        xi = d_scatter.loc[d_scatter['label'] == index]['xi']
        axs[index].set_xlim(re[:, 0].min(), re[:, 0].max())
        axs[index].set_ylim(re[:, 1].min(), re[:, 1].max())
        axs[index].scatter(re[:, 0], re[:, 1], c='gray', s=3)
        axs[index].scatter(phi.to_numpy(), xi.to_numpy(), c=colors[index], s=3)

# Training embedding

In [None]:
prefix = f'ckpts/diffflow-ad-samplestepsize{args.sample_stepsize}_embedding{args.embeding_type}_train_portion{args.train_portion}'

In [None]:
# Loader
if args.system == 'ad':
    from utils.datasets import construct_dataset_ad as construct_dataset
    traj, dihedral, label, train_loader, val_loader = construct_dataset(args)
    visualize_dihedral(dihedral, label, args.visual_results, 0)
else:
    from utils.datasets import construct_dataset_penta as construct_dataset
    traj, train_loader, val_loader = construct_dataset(args)

model = RC_DiffFlow(args).to(device)

In [None]:
# train embedding
mode = 'train'
ckpt = prefix + f'_v1.ckpt'
if mode == 'train':
    train_embedding(model, train_loader, args, device, logger, zero_potential=True)
    torch.save(model.state_dict(), ckpt)
else:
    state_dict = torch.load(ckpt)
    model.load_state_dict(state_dict)

# train latent
rc_latent = embedding_traj(model.embedding_model, traj, device, args.latent_dim)
model.Prior.re_init(rc_latent[:, :args.latent_dim], args)
ckpt = prefix + f'_v2.ckpt'
if mode == 'train':
    train_diffusion(model, train_loader, args, device, logger)
    torch.save(model.state_dict(), ckpt)
else:
    state_dict = torch.load(ckpt)
    model.load_state_dict(state_dict)

# train both
ckpt = prefix + f'_v3.ckpt'
if mode == 'train':
    train_joint_12(model, train_loader, args, device, logger)
    torch.save(model.state_dict(), ckpt)
else:
    state_dict = torch.load(ckpt)
    model.load_state_dict(state_dict)
rc_latent = embedding_traj(model.embedding_model, traj, device, args.latent_dim)

In [None]:
plot_free_energy(rc_latent[:, 0], rc_latent[:, 1])

In [None]:
def visualize_potential(X1, X2, tmp_traj_V, force, save_path, experiment_idx):
    """
    Visualize the potential derived from the model.

    Args:
    - X1, X2 (numpy.ndarray): Meshgrid arrays.
    - tmp_traj_V (numpy.ndarray): Potential values for the meshgrid.
    - save_path (str): Path to save the visualization.
    - experiment_idx (int): Index of the experiment.
    """
    plt.figure(figsize=(7, 6))
    plt.contourf(X1, X2, tmp_traj_V, 100, cmap='jet', levels=np.linspace(0,0.08,50))
    plt.colorbar()
    sample_freq = 5
    plt.quiver(X1[::sample_freq, ::sample_freq], X2[::sample_freq, ::sample_freq], force[::sample_freq, ::sample_freq, 0], force[::sample_freq, ::sample_freq, 1], color='white')

In [None]:
tmp_traj_RE = np.zeros([100, 2])
tmp_traj_RE[:, 0] = np.linspace(rc_latent[:, 0].min(), rc_latent[:, 0].max(), tmp_traj_RE.shape[0])
tmp_traj_RE[:, 1] = np.linspace(rc_latent[:, 1].min(), rc_latent[:, 1].max(), tmp_traj_RE.shape[0])

x1 = tmp_traj_RE[:, 0]
x2 = tmp_traj_RE[:, 1]
X1, X2 = np.meshgrid(x1, x2)

XX1 = X1.reshape((-1, 1))
XX2 = X2.reshape((-1, 1))
X_new = torch.from_numpy(np.hstack([XX1, XX2])).float()

tmp_traj_V, force = calc_potential_traj(model.Prior.prior1, X_new, device)
visualize_potential(X1, X2, np.exp(- tmp_traj_V.reshape(X1.shape)), force.reshape(100, 100, 2), args.visual_results, 0)
# plt.savefig('visual/density.png', bbox_inches='tight', pad_inches=0.1)