In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import numpy as np
import matplotlib.pyplot as plt
import torch 
from torch import nn
torch.set_default_dtype(torch.float64)
import torch.nn.functional as func

import copy
from tqdm.notebook import tqdm
import time
import random
import seaborn as sns
sns.set_theme()

from sklearn.neighbors import NearestNeighbors
from scipy.interpolate import griddata
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics.pairwise import euclidean_distances
from IPython.display import clear_output

In [None]:
plt.rcParams['text.usetex'] = True

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from Example1 import Mueller_System
from utils    import LangevinIntegrator,rL2,Model,Model_cpu,Solver,Metadynamics,Metadynamics_Extend

In [None]:
SYS = Mueller_System(); EPS = 10;
LI  = LangevinIntegrator(dim=SYS.dim)

In [None]:
def get_q_NN(X,model_name=-1): 
    if model_name!=-1: model.load_weights(model_name);
    return model.get_q(X).reshape(-1)
def replace_lines(file_name, line_num, text):
    lines = open(file_name, 'r').readlines()
    for k in range(len(line_num)): lines[line_num[k]] = text[k]
    out = open(file_name, 'w')
    out.writelines(lines)
    out.close()  
def get_q_FEM(X,m=200,mb=50,file_name='MU_2D.edp'):
    eps=EPS
    random_id = np.random.randint(1e8)
    file_name0 = file_name
    file_name  = file_name[:-4] + str(random_id) + file_name[-4:]
    ! cp $file_name0 $file_name
    np.savetxt("X"+str(random_id),X[:,:2])
    replace_lines(file_name,[0,10,22,23],["real eps="+str(eps)+";\n","int m=%d,mb=%d;\n"%(m,mb),
                                          "ifstream fin(\"X"+str(random_id) + "\");\n",
                                          "ofstream fout(\"q"+str(random_id) + "\");\n"])
    !FreeFem++ $file_name > FEM.log
    os.remove("X"+str(random_id))
    os.remove(file_name)
    q = np.loadtxt("q"+str(random_id)).reshape(-1)
    os.remove("q"+str(random_id))
    return q[:len(X)] # may output len(X)+1 values

# Read some data

In [None]:
def get_X_A(n,w=2*SYS.sigma*np.sqrt(EPS)):
    X = np.random.uniform(SYS.A[0]-SYS.r,SYS.A[0]+SYS.r,(10*n,SYS.dim))
    X[:,1] = np.random.uniform(SYS.A[1]-SYS.r,SYS.A[1]+SYS.r,(10*n))
    X[:,2:] = np.random.uniform(-w,w,(10*n,SYS.dim-2))
    mask = SYS.IsInA(X)
    return X[mask][:n]
def get_X_B(n,w=2*SYS.sigma*np.sqrt(EPS)):
    X = np.random.uniform(SYS.B[0]-SYS.r,SYS.B[0]+SYS.r,(10*n,SYS.dim))
    X[:,1] = np.random.uniform(SYS.B[1]-SYS.r,SYS.B[1]+SYS.r,(10*n))
    X[:,2:] = np.random.uniform(-w,w,(10*n,SYS.dim-2))
    mask = SYS.IsInB(X)
    return X[mask][:n]
def get_min_V_2D():
    xx = np.linspace(-.8,-.3,1000)
    yy = np.linspace(1.2,1.7,1000)
    XX,YY  = np.meshgrid(xx,yy)
    x_list = np.concatenate([XX[:,:,None],YY[:,:,None]],axis=-1).reshape(-1,2)
    return SYS.get_V_2D(x_list).min()
def get_uniform_data(N,Vbar=130):
    sigma     = SYS.sigma*np.sqrt(EPS)
    data      = np.random.uniform(-2*sigma,2*sigma,(N*10,SYS.dim))
    data[:,0] = np.random.uniform(SYS.xrange[0],SYS.xrange[1],N*10)
    data[:,1] = np.random.uniform(SYS.yrange[0],SYS.yrange[1],N*10)
    mask      = (SYS.get_V_2D(data[:,:2])-get_min_V_2D())<Vbar;
    return data[mask][:N]

In [None]:
X_u = get_uniform_data(int(1e5))
q_u = get_q_FEM(X_u);
def Error_Model(model): return rL2(q_u,model.get_q(X_u))
mask = abs(q_u-0.5)<0.2; X_u2 = X_u[mask]; q_u2 = q_u[mask];
def Error_Model2(model): return rL2(q_u2,model.get_q(X_u2))
print(X_u.shape,X_u2.shape)

X_A = get_X_A(5000)
X_B = get_X_B(5000)
print(X_A.shape,X_B.shape)
def E_AB(model): return np.sqrt(np.mean(model.get_q(X_A)**2))+np.sqrt(np.mean((1-model.get_q(X_B))**2))

# Set the model

In [None]:
n      = 10;
model  = Model_cpu(input_dim=SYS.dim,num_hidden=2,hidden_dim=50,n=n)
solver = Solver(model,q0=-5,q1=5)

In [None]:
def mask_fn(X): return (SYS.get_V(X)-SYS.get_V(X).min())<130; 
def ShowQandSampledData(q_fn,mask_fn=mask_fn,xrange=SYS.xrange,yrange=SYS.yrange,dim=SYS.dim,nx=100,ny=100,
                states=[],titles=[],q_list=np.linspace(.1,.9,9),fig_name=None):
    xx     = np.linspace(xrange[0],xrange[1],nx)
    yy     = np.linspace(yrange[0],yrange[1],ny)
    XX,YY  = np.meshgrid(xx,yy)
    x_list = np.concatenate([XX[:,:,None],YY[:,:,None]],axis=-1).reshape(-1,2)
    x_list = np.hstack([x_list,np.zeros(dtype=np.float64,shape=(x_list.shape[0],dim-2))])
    mask   = mask_fn(x_list); 
    V      = SYS.get_V(x_list); V[~mask] = np.nan;
    num    = len(q_fn)
    fig,ax = plt.subplots(1,num,figsize=(4.5*num,4),constrained_layout=True)
    for k in range(num):
        q  = q_fn[k](x_list); q[~mask] = np.nan;
        c1 = ax[k].contour(XX,YY,q.reshape(XX.shape),q_list,colors='black',linestyles='solid',linewidths=2);  
        ax[k].clabel(c1, inline=1, fontsize=10)
        if len(states)>0 and len(states[k])>0: [ax[k].scatter(d[:,0],d[:,1],s=1) for d in states[k]]
        
        ax[k].contour(XX,YY,V.reshape(XX.shape),5,colors='grey',linestyles='dashed',linewidths=.7);  
        ax[k].add_artist(plt.Circle(SYS.A, SYS.r, color='k'))
        ax[k].add_artist(plt.Circle(SYS.B, SYS.r, color='k'))
        ax[k].text(SYS.A[0]-0.1, SYS.A[1]+0.15, '$A$', fontsize=20)
        ax[k].text(SYS.B[0]+0.1, SYS.B[1]-0.20, '$B$', fontsize=20)
        ax[k].set_xlabel('$x_1$', fontsize=18)
        if k==0: ax[k].set_ylabel('$x_2$', fontsize=18, rotation=0)
        ax[k].tick_params(axis="both", labelsize=10) 
        ax[k].set_xlim(xrange)
        ax[k].set_ylim(yrange)
        if len(titles)>0: ax[k].set_title(titles[k],fontsize=20)

    if fig_name is not None: plt.savefig(fig_name,dpi=300)
    plt.show()    

In [None]:
data_train = [[],torch.tensor(X_A),torch.tensor(X_B)]
data_test  = [[],torch.tensor(X_A),torch.tensor(X_B)]
optimizer = torch.optim.Adam(model.parameters(),lr=torch.tensor(1e-3))

def terminal_condition(model): return E_AB(model)<1e-2
solver.train_model(data_train=data_train,data_test=data_test,c1=0,c2=1,batch_size=5000,
                   optimizer=optimizer,n_steps=int(1e4+1),n_show_loss=200,
                   error_model1=Error_Model,error_model2=E_AB,
                   terminal_condition=terminal_condition,use_tqdm=True)

print(solver.q0)
ShowQandSampledData([model.get_q,model.get_r],titles=[r'$q$',r'$r$'])
torch.save(model.state_dict(), "saved_models/par0") 

# metadynamics

In [None]:
def show_distr(X,ax):
    ax.scatter(X[:,0],X[:,1],s=1);
    ax.set_xlim(SYS.xrange); ax.set_ylim(SYS.yrange)
    ax.set_xlabel(r'$x_1$'); ax.set_ylabel(r'$x_2$',rotation=1)
    ax.set_title('shape:'+str(X.shape))
def down_sample(X):
    mask = SYS.IsInA(X) | SYS.IsInB(X)
    return X[~mask]
def get_train_test(X,coef,X_A,X_B,ratio=0.7):
    Xc         = np.hstack([X,coef.reshape(-1,1)])
    data_train = []
    data_test  = []
    for d,requires_grad in [[Xc,True],[X_A,False],[X_B,False]]:
        perm = np.random.permutation(len(d))
        d1 = d[perm[:int(len(d)*ratio)]]
        data_train.append(torch.tensor(d1,requires_grad=requires_grad))
        d2 = d[perm[int(len(d)*ratio):]]
        data_test.append(torch.tensor(d2,requires_grad=requires_grad))
    return data_train,data_test 

In [None]:
n      = 10;
model  = Model_cpu(input_dim=SYS.dim,num_hidden=2,hidden_dim=50,n=n)
solver = Solver(model,q0=-5,q1=5)
meta   = Metadynamics(model=model,h=2,w=.003)

In [None]:
NAME = 'ABC51'
if not os.path.isdir('saved_models/'+NAME+'/'): os.mkdir('saved_models/'+NAME+'/')
if not os.path.isdir('meta_data/'+NAME+'/'): os.mkdir('meta_data/'+NAME+'/')

In [None]:
#########     training details  ##########################
c1 = 1;  c2 = 1; learning_rate = 1e-4;                   #
##########################################################

model.load_state_dict(torch.load("saved_models/par0"))
ShowQandSampledData([model.get_q,model.get_r],titles=[r'$q$',r'$r$'])
torch.save(meta,'saved_models/'+NAME+'/meta%d'%0)
torch.save(model.state_dict(), 'saved_models/'+NAME+'/par%d'%0)

K = 10
for k in range(1,K+1):

    meta.re_init()
    meta.perform(dV=SYS.get_dV,x=SYS.A.reshape(1,-1),dt=1e-5,eps=EPS,
                 N=int(1e6),N_add=500,show_freq=.5,use_tqdm=False,show_distr=show_distr,
                 fig_name='meta_data/'+NAME+'/ite%d'%k)
    meta.show_meta(show_distr)

    def get_f(X): return -SYS.get_dV(X) - meta.get_dV(X)
    X0  = np.repeat(SYS.A.reshape(1,-1),50,axis=0)
    X   = LI.get_data(X0,get_f,eps=EPS,dt=1e-5,m=100,T0=1,T=2,use_tqdm=False);
    X   = down_sample(X) 
    fig,ax = plt.subplots(1,1,figsize=(3,3))
    show_distr(X,ax)
    plt.show()

    V_add = meta.get_V(X)
    coef  = np.exp(1/EPS*(V_add-V_add.max()))
    coef  = coef/coef.mean()

    data_train,data_test = get_train_test(X,coef,X_A,X_B,ratio=.9)
    for i in range(len(data_train)): print(data_train[i].shape,data_test[i].shape)
    optimizer = torch.optim.Adam(model.parameters(),lr=torch.tensor(learning_rate))
    solver.train_model(data_train=data_train,data_test=data_test,c1=c1,c2=c2,batch_size=5000,
                       optimizer=optimizer,n_steps=int(5000+1),n_show_loss=1000,use_tqdm=False,
                       error_model1=Error_Model,error_model2=Error_Model2)
    ShowQandSampledData([get_q_FEM,model.get_q],states=[[],[X]],
                        titles=['Ref.','Sampling Scheme I'])

    torch.save(meta,'saved_models/'+NAME+'/meta%d'%k)
    torch.save([X,coef],'meta_data/'+NAME+'/data%d'%k)
    torch.save(model.state_dict(), 'saved_models/'+NAME+'/par%d'%k)