### Imports 

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Silence tensorflow alerts
import tensorflow as tf

In [None]:
import sys
path_to_software_folder = sys.path[0][:-22] + 'software/'
sys.path.append(path_to_software_folder) # Add the software folder to path, note that outside of the "demo" folder you will need to manually set "path_to_software_folder"
from DoubleWell import DoubleWell
from ProcessingNets import *
from Network_Active import *
from FlowRES_MCMC_Active import *
from Plotting import *

# Define the model

In [None]:
start_point = [-1.0, 0]
duration = 1.6
traj_len = 32
D_t = 0.15
mu = 0.1
v = 1.0
D_r = 2.5
dt = duration/traj_len
tran_coeff = np.sqrt(2*D_t*dt)
rot_coeff = np.sqrt(2*D_r*dt)
act_coeff = v*dt
pot_coeff = mu*dt

In [None]:
params = {'x_0' :  1.0,
          'k_x' :  2.0,
          'k_y' :  10.0,
          'k_BH' : 15,
          'target': 7.5}
Double_Well = DoubleWell(params=params)
Double_Well.plot_energy_surface_DW()

# Create a network
#### Runtime ~6 min

In [None]:
Affine_Wave_params = {'num_filters': 32,    # For demo performance, we use a smaller network than in the main paper (32 vs 64)
                      'kernel_size': 3,
                      'num_dilated_conv_layers': 3}

Num_Layers = 2    # number of scales
FLOWS_per_LAYER = 5    # number of complete flow steps per flow
Pos_flows_per_Flow = 2    # number of position flow steps per complete flow step

Ang_flows_per_Flow = 0    # Leave this as 0
CCs = 1    # Just leave this as one
Pos_CC_per_Flow = 1    # Number of affine cross coupling layers per complete flow step
Ang_CC_per_Flow = 0    # Leave this as 0, this term lets positions affect angles

Net = CreateFlowNet(Num_Layers, FLOWS_per_LAYER,
                    Pos_flows_per_Flow, Ang_flows_per_Flow, 
                    CCs, Pos_CC_per_Flow, Ang_CC_per_Flow,
                    Affine_WaveNet=Wave_unit, Affine_WaveParams=Affine_Wave_params, 
                    potential_grad=Double_Well.batch_gradient,
                    start=start_point, tran_coeff=tran_coeff, rot_coeff=rot_coeff, act_coeff=act_coeff, pot_coeff=pot_coeff,
                    max_len=traj_len, dim=3)

# Explore Using FlowRES
#### Runtime ~90 min

In [None]:
num_chains = 30000    # For demo performance, we use fewer chains than in the main paper (3e4 vs 5e4)
num_iterations = 100    # For demo performance, we use fewer iterations than in the main paper (100 vs 150)

FlowRES_Active_DoubleWell = FlowRes_MCMC(Net, num_chains, Chain_Initialiser_Active, Double_Well)    # define the FlowRES framework for this system
FlowRES_Active_DoubleWell.Compile_ML_Model(batch_size=512, lr=0.00005)    # compile the network so that it is ready to train    

all_FlowRES_hists = FlowRES_Active_DoubleWell.Explore(iterations=num_iterations, return_hists=True, hist_bins=np.linspace(-2,2,200))[0]

# Compare FlowRES to Direct Integration

In [None]:
# Load in Direct Integration data for comparison
Validation_histogram = np.load(sys.path[0] + '/Direct_Integration_Data/Validation_histogram.npy')    # load a histogram generated from 10000 direct integration paths
Direct_integration_histogram = np.load(sys.path[0] + '/Direct_Integration_Data/Direct_Integration_histogram.npy')    # load a histogram generated from 50000 direct integration paths
Direct_Integration_JSD_vs_Proposals = np.load(sys.path[0] + '/Direct_Integration_Data/Direct_Integration_JSD_vs_Proposals.npy')    # load a JSD with validation ensemble against proposal number for direct integration

## Plot JSDs with Validation Ensemble

In [None]:
bins = np.linspace(-2,2,200)
all_FlowRES_JSDs = np.zeros(num_iterations+1)

for i, FlowRES_histogram in enumerate(all_FlowRES_hists):    # plot a histogram for the FlowRES ensemble at each iteration and work out its JSD from the validation histogram
    all_FlowRES_JSDs[i] = JSD(FlowRES_histogram, Validation_histogram)
    
Num_Paths_Proposed_FlowRES = num_chains * np.arange(num_iterations+1)
FlowRES_JSD_vs_Proposals = np.stack([Num_Paths_Proposed_FlowRES, all_FlowRES_JSDs])

In [None]:
compare_JSDS(FlowRES_JSD_vs_Proposals, Direct_Integration_JSD_vs_Proposals)

## FlowRES and Direct Integration Histograms

In [None]:
# plot histograms
Compare_Hists(FlowRES_histogram, Validation_histogram, bins, Double_Well,
              text=str(params['k_BH']) + r'$\ k_{\rm B} T_{\rm eff}$')