In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os

import torch

In [None]:
import sys
sys.path.append('../src')

In [None]:
from systems.LJ import lennard_jones
from systems.dynamic_prior import dynamic_prior

from samplers.metropolis_MC import metropolis_monte_carlo

In [None]:
device = torch.device("cuda:0")

dimensions = 2
n_particles = 32
cutin = 0.8

T_source = 2
beta_source = 1/T_source
box_length_source = 6.6
rho_source = n_particles/(box_length_source)**(dimensions)
WCA = lennard_jones(n_particles=n_particles, dimensions=dimensions, rho=rho_source, device=device, cutin=cutin, cutoff="wca")
box_length_pr = WCA.box_length

T_target = 1
beta_target = 1/T_target
box_length_target = 6.6 
rho_target = n_particles/(box_length_target)**(dimensions)
# rho_target = 0.70408163
# T_target = 0.60816327
# beta_target = 1/T_target
LJ = lennard_jones(n_particles=n_particles, dimensions=dimensions, rho=rho_target, device=device, cutin=cutin)
box_length_sys = LJ.box_length
# box_length_target = box_length_sys[0].item()
scale = (rho_source/rho_target)**(1/dimensions)

print(f"rho_source = {rho_source}, T_source = {T_source}")
print(f"rho_target = {rho_target}, T_target = {T_target}")
print(f"s = {scale}")

In [None]:
from tools.util import generate_output_directory

run_id = f"NVT_N{n_particles:03d}_WCA2LJ_rho_{rho_source:.2g}_T{T_source:.2g}_to_rho_{rho_target:.2g}_T{T_target:.2g}_main"
output_dir = generate_output_directory(run_id)

In [None]:
MCMC_pr = metropolis_monte_carlo(system=WCA, step_size=0.2, n_equilibration=5000, n_cycles=1000, transform=True)
MCMC_sy = metropolis_monte_carlo(system=LJ, step_size=0.2, n_equilibration=5000, n_cycles=1000, transform=True)

In [None]:
load_data_pr = True
load_data_sy = True

wca_train_filepath = f"./data/N{WCA.n_particles:03d}/{WCA.name}/rho_{rho_source:.02g}_T_{T_source:.02g}_train.pt"
wca_sample_filepath = f"./data/N{WCA.n_particles:03d}/{WCA.name}/rho_{rho_source:.02g}_T_{T_source:.02g}_sample.pt"

n_samples_pr = 100000
n_samples_sy = 100000

if load_data_pr:
    print()
    print("Loading WCA Training Datasets")
    wca_train = torch.load(wca_train_filepath, map_location=device)
    print(f"WCA Train Dataset: {wca_train_filepath}")
    wca_sample = torch.load(wca_sample_filepath, map_location=device)
    print(f"WCA Sample Dataset: {wca_sample_filepath}")
else:
    print()
    print("Generating WCA Training Datasets")
    wca_train, _, acc = MCMC_pr.sample_space(n_samples_pr, 0.2*beta_source)
    MCMC_pr.equilibrated = False
    wca_train, _, acc = MCMC_pr.sample_space(n_samples_pr, beta_source)
    print(f"WCA Train Dataset: acc = {acc.item()}")
    wca_sample, _, acc = MCMC_pr.sample_space(n_samples_pr, beta_source)
    print(f"WCA Sample Dataset: acc = {acc.item()}")
    
    torch.save(wca_train, wca_train_filepath)
    print(f"WCA Train Dataset: {wca_train_filepath}")
    torch.save(wca_sample, wca_sample_filepath)
    print(f"WCA Sample Dataset: {wca_sample_filepath}")


lj_train_filepath = f"./data/N{LJ.n_particles:03d}/{LJ.name}/rho_{rho_target:.02g}_T_{T_target:.02g}_train.pt"
lj_sample_filepath = f"./data/N{LJ.n_particles:03d}/{LJ.name}/rho_{rho_target:.02g}_T_{T_target:.02g}_sample.pt"

if load_data_sy:
    print()
    print("Loading LJ Training Datasets")
    lj_train = torch.load(lj_train_filepath, map_location=device)
    print(f"LJ Train Dataset: {lj_train_filepath}")
    lj_sample = torch.load(lj_sample_filepath, map_location=device)
    print(f"LJ Sample Dataset: {lj_sample_filepath}")
else:
    print()
    print("Generating LJ Training Datasets")
    lj_train, _, acc = MCMC_sy.sample_space(n_samples_sy, 0.2*beta_target)
    MCMC_sy.equilibrated = False
    lj_train, _, acc = MCMC_sy.sample_space(n_samples_sy, beta_target)
    print(f"LJ Train Dataset: acc = {acc.item()}")
    lj_sample, _, acc = MCMC_sy.sample_space(n_samples_sy, beta_target)
    print(f"LJ Sample Dataset: acc = {acc.item()}")
    
    torch.save(lj_train, lj_train_filepath)
    print(f"LJ Train Dataset: {lj_train_filepath}")
    torch.save(lj_sample, lj_sample_filepath)
    print(f"LJ Sample Dataset: {lj_sample_filepath}")

wca_train_cpu = wca_train.view(-1, n_particles, dimensions).cpu().numpy()
wca_sample_cpu = wca_sample.view(-1, n_particles, dimensions).cpu().numpy()
lj_train_cpu = lj_train.view(-1, n_particles, dimensions).cpu().numpy()
lj_sample_cpu = lj_sample.view(-1, n_particles, dimensions).cpu().numpy()

wca_energy_train_cpu = WCA.energy(wca_train).squeeze().cpu().numpy()
lj_energy_train_cpu = LJ.energy(lj_train).squeeze().cpu().numpy()
wca_energy_sample_cpu = WCA.energy(wca_sample).squeeze().cpu().numpy()
lj_energy_sample_cpu = LJ.energy(lj_sample).squeeze().cpu().numpy()

print()
print(f"Prior train size: {wca_train.shape[0]}")
print(f"Prior sample size: {wca_sample.shape[0]}")
print(f"Posterior train size: {lj_train.shape[0]}")
print(f"Posterior sample size: {lj_sample.shape[0]}")

In [None]:
fig_size = (10 * 0.393701,  10 * 0.393701)
fig, ax = plt.subplots(1, 1, figsize = fig_size, dpi = 100)

ax.scatter(wca_train_cpu[::50,:,0], wca_train_cpu[::50,:,1], alpha=0.005, label="WCA")
ax.scatter(lj_train_cpu[::50,:,0], lj_train_cpu[::50,:,1], alpha=0.005, label="LJ")

plt.savefig(os.path.join(output_dir, "configurations.png"))

In [None]:
fig_size = (10 * 0.393701,  7.5 * 0.393701)
fig, ax = plt.subplots(1, 1, figsize = fig_size, dpi = 100)

ax.hist(wca_energy_train_cpu[::10], bins=40, density=True, alpha=0.5, label="Reference WCA data")
ax.hist(wca_energy_sample_cpu[::10], bins=40, density=True, alpha=0.5, label="Reference WCA data")
ax.hist(lj_energy_train_cpu[::10], bins=40, density=True, alpha=0.5, label="Reference LJ data")
ax.hist(lj_energy_sample_cpu[::10], bins=40, density=True, alpha=0.5, label="Reference LJ data")
# ax.hist(LJ.energy(wca_train[::10]).cpu().numpy(), bins=40, density=True, alpha=0.5, label="Identity WCA to LJ")
# ax.hist(LJ.energy(wca_sample[::10]).cpu().numpy(), bins=40, density=True, label="Identity WCA to LJ")
# ax.hist(WCA.energy(lj_train[::10]).cpu().numpy(), bins=40, density=True, alpha=0.5, label="Identity LJ to WCA")
# ax.hist(WCA.energy(lj_sample[::10]).cpu().numpy(), bins=40, density=True, label="Identity LJ to WCA")

plt.savefig(os.path.join(output_dir, "energies.png"))

In [None]:
box_length_target=box_length_sys[0].item()

In [None]:
from tools.observables import rdf

n_bins = 100
cutoff_pr = box_length_source/2
cutoff_sys = box_length_target/2
RDF_r, RDF_wca_train = rdf(wca_train, n_particles=n_particles, dimensions=dimensions, box_length=box_length_pr, cutoff=cutoff_pr, n_bins=n_bins)
RDF_r, RDF_lj_train = rdf(lj_train, n_particles=n_particles, dimensions=dimensions, box_length=box_length_sys, cutoff=cutoff_sys, n_bins=n_bins)
RDF_r, RDF_wca_sample = rdf(wca_sample, n_particles=n_particles, dimensions=dimensions, box_length=box_length_pr, cutoff=cutoff_pr, n_bins=n_bins)
RDF_r, RDF_lj_sample = rdf(lj_sample, n_particles=n_particles, dimensions=dimensions, box_length=box_length_sys, cutoff=cutoff_sys, n_bins=n_bins)

In [None]:
fig_size = (10 * 0.393701,  7.5 * 0.393701)
fig, ax = plt.subplots(1, 1, figsize = fig_size, dpi = 100)

plt.plot(RDF_r, RDF_wca_train, label=r"WCA train")
plt.plot(RDF_r, RDF_wca_sample, label=r"WCA sample")
plt.plot(RDF_r, RDF_lj_train, label=r"LJ train")
plt.plot(RDF_r, RDF_lj_sample, label=r"LJ sample")
plt.legend(frameon=False)
plt.savefig(os.path.join(output_dir, "rdfs.png"))

In [None]:
WCA = dynamic_prior(n_cached=90000, test_fraction=0.1, system=WCA, sampler=MCMC_pr, init_confs=wca_train)

## Two-sided Circular Spline Flow Equivariant Transformer Generator

### Parameters definition

In [None]:
n_blocks = 1
n_bins = 16

### Mask definition

In [None]:
from tools.util import get_targets

targets = get_targets(dimensions, n_blocks)

In [None]:
print(targets)

### Definition of NF block list

In [None]:
from normalizing_flow.equivariant_transformer import RQS_coupling_block
from normalizing_flow.circular_shift import circular_shift

block_list = [
    
    # Block 1
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    
    # Block 2
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),

    # Block 3
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    
    # Block 4
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),

    # Block 5
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    
    # Block 6
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),

    # Block 7
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    
    # Block 8
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    circular_shift(n_particles-1, dimensions, device),
    RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),

    # # Block 9
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    
    # # Block 10
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),

    # # Block 11
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    
    # # Block 12
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),

    # # Block 13
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    
    # # Block 14
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),

    # # Block 15
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    
    # # Block 16
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    
    # circular_shift(n_particles-1, dimensions, device),
    # RQS_coupling_block((1,), n_particles-1, dimensions, device, n_bins),
    # RQS_coupling_block((0,), n_particles-1, dimensions, device, n_bins),
]

### Transformation Layers definition

In [None]:
from transformations.normalization import normalize_box
from transformations.remove_origin import remove_origin

norm_box_pr = normalize_box(n_particles=n_particles, dimensions=dimensions, box_length=box_length_pr, device=device)
norm_box_sys = normalize_box(n_particles=n_particles, dimensions=dimensions, box_length=box_length_sys, device=device)

rm_origin = remove_origin(n_particles=n_particles, dimensions=dimensions, device=device)

## Flow definition

In [None]:
from normalizing_flow.flow_assembler import flow_assembler

flow = flow_assembler(prior = WCA, posterior = LJ, device=device, 
                        blocks = block_list,
                        prior_sided_transformation_layers = [norm_box_pr, rm_origin], 
                        post_sided_transformation_layers = [norm_box_sys, rm_origin]
                        ).to(device)

print(f"Flow parameters: {sum(p.numel() for p in flow.parameters() if p.requires_grad)}")

# Training the Flow

## Dataset definition

In [None]:
from normalizing_flow.dataset import PBCDataset

train_dataset = PBCDataset(flow, data_tensor=lj_train, test_fraction=0.1, beta_source=beta_source, beta_target=beta_target, shuffle_data=False, transform=True, augment=True, energy_labels=LJ.energy(lj_train))

## Training hyperparameters

### General parameters

In [None]:
n_epochs = 10
batch_size = 512
n_dump = 1
n_save = 5

steps_per_epoch = len(train_dataset)//batch_size
print(f"Total number of optimization steps: {n_epochs*steps_per_epoch}")

### Directions of training

In [None]:
w_xz = 1
w_zx = 1

### Set optimizer and scheduler

In [None]:
from normalizing_flow.network_trainer import Trainer 

flow_trainer = Trainer(flow)

optimizer = None
optimizer = torch.optim.Adam([p for p in flow.parameters() if p.requires_grad], lr=1e-4)

scheduler = None
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5*steps_per_epoch, 7*steps_per_epoch])

## Start Training

In [None]:
import datetime

train_start_time = datetime.datetime.now()
print("Training the network:\n")
metrics = flow_trainer.training_routine(train_dataset, beta_source=beta_source, beta_target=beta_target, 
                                        w_xz=w_xz, w_zx=w_zx, batch_size=batch_size,
                                        n_epochs=n_epochs, n_dump=n_dump, n_save=n_save, save_dir=output_dir, 
                                        optimizer=optimizer, scheduler=scheduler)

### Training metrics

In [None]:
fig_size = (50 * 0.393701, 10 * 0.393701)
fig, ax = plt.subplots(1, 3, figsize = fig_size, dpi = 600)

if w_xz > 0:
    ax[0].plot(metrics[:,0], metrics[:,1], label="train", color="C0")
ax[0].plot(metrics[:,0], metrics[:,3], label="eval", color="C1")
ax[0].set_xlabel("epochs")
ax[0].set_ylabel("NLL loss")

if w_zx > 0:
    ax[1].plot(metrics[:,0], metrics[:,2], label="training", color="C0")
ax[1].plot(metrics[:,0], metrics[:,5], label="validation", color="C1")
ax[1].set_xlabel("epochs")
ax[1].set_ylabel("KLD loss")
ax[1].legend(frameon=False)

ax[2].plot(metrics[:,0], metrics[:,4], label=r"$\text{A}\to \text{B}$", color="C2")
ax[2].plot(metrics[:,0], metrics[:,6], label=r"$\text{B}\to \text{A}$", color="C3")
ax[2].set_xlabel("epochs")
ax[2].set_ylabel("RESS")
ax[2].set_yscale("log")
ax[2].set_ylim(None,1)
ax[2].legend(frameon=False)

plt.savefig(os.path.join(output_dir, "metrics.png"))

In [None]:
with open(os.path.join(output_dir, "run_details.out"), "w+") as f:
    
    f.write(f"Run ID: {run_id}\n")
    f.write(f"Training started on: {train_start_time}\n")
    f.write(f"Training finished on: {datetime.datetime.now()}\n\n")
    f.write(f"z: {flow.prior.name} -> x: {flow.posterior.name}\n")
    f.write(f"Tz: {T_source} -> Tx: {T_target}\n")
    f.write(f"\n")
    f.write(f"source training data: {wca_train_filepath}\n")
    f.write(f"target training data: {lj_train_filepath}\n")
    f.write(f"source sample data: {wca_sample_filepath}\n")
    f.write(f"target sample data: {lj_sample_filepath}\n")
    f.write(f"\n")
    f.write(f"elements in source training data: {wca_train.shape[0]}\n")
    f.write(f"elements in target training data: {lj_train.shape[0]}\n")
    f.write(f"elements in source sample data: {wca_sample.shape[0]}\n")
    f.write(f"elements in target sample data: {lj_sample.shape[0]}\n")
    f.write(f"\n")
    f.write(f"batch size: {batch_size}\n")
    if w_xz > 0:
        f.write(f"Training x->z: w_xz = {w_xz}\n")
    if w_zx > 0:
        f.write(f"Training z->x: w_zx = {w_zx}\n")
    f.write(f"\n")
    f.write(f"Flow architecture:\n")
    f.write(str(flow) + "\n\n")

# Generate from the flow

In [None]:
WCA2LJ_energy_identity = (LJ.energy(scale*wca_sample)).cpu().numpy()
LJ2WCA_energy_identity = (WCA.energy(lj_sample/scale)).cpu().numpy()

In [None]:
from tools.util import ress

# PyTorch does not need the gradient for the transformation 
with torch.no_grad():

    flow.eval()

    # Transforming from latent to target via the Normalizing Flow
    z = wca_sample[::10]
    x, logJ_zx = flow.F_zx(z)

    # Compute energy of transformed configurations
    WCA2LJ_energy_transformed = (LJ.energy(x)).cpu().numpy()

    # Computing weights
    log_prob_zx = -beta_target*flow.posterior.energy(x)
    log_prob_z = -beta_source*flow.prior.energy(z)        
    log_w = (log_prob_zx - log_prob_z + logJ_zx).squeeze(-1)
    ress_zx = ress(log_w)

    print(f"RESS zx = {ress_zx}")

    # Resampling to obtain unbiased target distribution
    x_cpu = x.view(-1, n_particles, dimensions).cpu().numpy()
    w = torch.exp(log_w - torch.max(log_w)).cpu().numpy()
    N = x_cpu.shape[0]
    indx = np.random.choice(np.arange(0, N), replace=True, size = N, p = w/np.sum(w))
    x_resampled = x_cpu[indx]

In [None]:
with torch.no_grad():

    id_x = wca_sample[::10]

    # Compute energy of transformed configurations
    id_energy_x = flow.posterior.energy(scale*id_x)

    # Computing weights
    id_log_prob_zx = -beta_target*id_energy_x
    id_log_prob_z = -beta_source*flow.prior.energy(z)        
    id_log_w = (id_log_prob_zx - id_log_prob_z).squeeze(-1)
    id_ress_zx = ress(id_log_w)

    print(f"id RESS zx = {id_ress_zx}")

In [None]:
fig_size = (20 * 0.393701, 20 * 0.393701)
fig, ax = plt.subplots(2, 2, figsize = fig_size, dpi = 600)

ax[0][0].scatter(wca_sample_cpu[::10, :, 0], wca_sample_cpu[::10, :, 1], alpha=0.100, s=0.25)
ax[0][0].set_title(r'$x_{\text{A}} \sim \rho_\text{A}(x_{\text{A}})$')
ax[0][0].set_xlim(-box_length_pr[0].item()*(1+0.1)/2,box_length_pr[0].item()*(1+0.1)/2)
ax[0][0].set_ylim(-box_length_pr[1].item()*(1+0.1)/2,box_length_pr[1].item()*(1+0.1)/2)

ax[0][1].scatter(x_cpu[:, :, 0], x_cpu[:, :, 1], alpha=0.100, s=0.25)
ax[0][1].set_title(r'$x_{\text{B}} = F(x_{\text{A}})$')
ax[0][1].set_xlim(-box_length_sys[0].item()*(1+0.1)/2,box_length_sys[0].item()*(1+0.1)/2)
ax[0][1].set_ylim(-box_length_sys[1].item()*(1+0.1)/2,box_length_sys[1].item()*(1+0.1)/2)
    
ax[1][0].scatter(x_resampled[:, :, 0], x_resampled[:, :, 1], alpha=0.100, s=0.25)
ax[1][0].set_title(r'$x_{\text{B}} = \bar{F}(x_{\text{A}})$')
ax[1][0].set_xlim(-box_length_sys[0].item()*(1+0.1)/2,box_length_sys[0].item()*(1+0.1)/2)
ax[1][0].set_ylim(-box_length_sys[1].item()*(1+0.1)/2,box_length_sys[1].item()*(1+0.1)/2)

ax[1][1].scatter(lj_sample_cpu[::10, :, 0], lj_sample_cpu[::10, :, 1], alpha=0.100, s=0.25)
ax[1][1].set_title(r'Reference B')
ax[1][1].set_xlim(-box_length_sys[0].item()*(1+0.1)/2,box_length_sys[0].item()*(1+0.1)/2)
ax[1][1].set_ylim(-box_length_sys[1].item()*(1+0.1)/2,box_length_sys[1].item()*(1+0.1)/2)

plt.savefig(os.path.join(output_dir, "WCA2LJ_confs.png"))

In [None]:
fig_size = (15 * 0.393701, 10 * 0.393701)
fig, ax = plt.subplots(1, figsize = fig_size, dpi = 600)

ax.hist(lj_energy_sample_cpu, bins=100, density=True, alpha=0.5, label="Reference")
ax.hist(WCA2LJ_energy_identity, bins=100, density=True, alpha=0.5, label="Identity")
ax.hist(WCA2LJ_energy_transformed, bins=100, density=True, alpha=0.5, label="Transformed")
ax.set_xlabel(r"$U(x)$")
ax.set_ylabel(r"$P(U)$")
ax.set_title("Energy of Target System")

plt.legend(frameon=False)
plt.savefig(os.path.join(output_dir, "WCA2LJ_ener.png"))

In [None]:
from tools.util import ress

# PyTorch does not need the gradient for the transformation 
with torch.no_grad():

    flow.eval()

    # Transforming from latent to target via the Normalizing Flow
    x = lj_sample[::10]
    z, logJ_xz = flow.F_xz(x)

    # Compute energy of transformed configurations
    LJ2WCA_energy_transformed = (WCA.energy(z)).cpu().numpy()

    # Computing weights
    log_prob_xz = -beta_source*flow.prior.energy(z)
    log_prob_x = -beta_target*flow.posterior.energy(x)        
    log_w = (log_prob_xz - log_prob_x + logJ_xz).squeeze(-1)
    ress_xz = ress(log_w)

    print(f"RESS xz = {ress_xz}")

    # Resampling to obtain unbiased target distribution
    z_cpu = z.view(-1, n_particles, dimensions).cpu().numpy()
    w = torch.exp(log_w - torch.max(log_w)).cpu().numpy()
    N = z_cpu.shape[0]
    indx = np.random.choice(np.arange(0, N), replace=True, size = N, p = w/np.sum(w))
    z_resampled = z_cpu[indx]

In [None]:
# PyTorch does not need the gradient for the transformation 
with torch.no_grad():

    flow.eval()

    # Transforming from latent to target via the Normalizing Flow
    id_z = lj_sample[::10]

    # Compute energy of transformed configurations
    id_energy_z = flow.prior.energy(id_z/scale)

    # Computing weights
    id_log_prob_xz = -beta_source*id_energy_z
    id_log_prob_x = -beta_target*flow.posterior.energy(x)        
    id_log_w = (id_log_prob_xz - id_log_prob_x).squeeze(-1)
    id_ress_xz = ress(id_log_w)

    print(f"id RESS xz = {id_ress_xz}")

In [None]:
fig_size = (20 * 0.393701, 20 * 0.393701)
fig, ax = plt.subplots(2, 2, figsize = fig_size, dpi = 600)

ax[0][0].scatter(lj_sample_cpu[::10, :, 0], lj_sample_cpu[::10, :, 1], alpha=0.100, s=0.25)
ax[0][0].set_title(r'$x_{\text{B}} \sim \rho_\text{B}(x_{\text{B}})$')
ax[0][0].set_xlim(-box_length_sys[0].item()*(1+0.1)/2,box_length_sys[0].item()*(1+0.1)/2)
ax[0][0].set_ylim(-box_length_sys[1].item()*(1+0.1)/2,box_length_sys[1].item()*(1+0.1)/2)

ax[0][1].scatter(z_cpu[:, :, 0], z_cpu[:, :, 1], alpha=0.100, s=0.25)
ax[0][1].set_title(r'$x_{\text{A}} = F^{-1}(x_{\text{B}})$')
ax[0][1].set_xlim(-box_length_pr[0].item()*(1+0.1)/2,box_length_pr[0].item()*(1+0.1)/2)
ax[0][1].set_ylim(-box_length_pr[1].item()*(1+0.1)/2,box_length_pr[1].item()*(1+0.1)/2)
    
ax[1][0].scatter(z_resampled[:, :, 0], z_resampled[:, :, 1], alpha=0.100, s=0.25)
ax[1][0].set_title(r'$x_{\text{A}} = \bar{F}^{-1}({\text{B}})$')
ax[1][0].set_xlim(-box_length_pr[0].item()*(1+0.1)/2,box_length_pr[0].item()*(1+0.1)/2)
ax[1][0].set_ylim(-box_length_pr[1].item()*(1+0.1)/2,box_length_pr[1].item()*(1+0.1)/2)

ax[1][1].scatter(wca_sample_cpu[::10, :, 0], wca_sample_cpu[::10, :, 1], alpha=0.100, s=0.25)
ax[1][1].set_title(r'Reference A')
ax[1][1].set_xlim(-box_length_pr[0].item()*(1+0.1)/2,box_length_pr[0].item()*(1+0.1)/2)
ax[1][1].set_ylim(-box_length_pr[0].item()*(1+0.1)/2,box_length_pr[0].item()*(1+0.1)/2)

plt.savefig(os.path.join(output_dir, "LJ2WCA_confs.png"))

In [None]:
fig_size = (15 * 0.393701, 10 * 0.393701)
fig, ax = plt.subplots(1, figsize = fig_size, dpi = 600)

ax.hist(wca_energy_sample_cpu, bins=100, density=True, alpha=0.5, label="Reference")
ax.hist(LJ2WCA_energy_identity, bins=100, density=True, alpha=0.5, label="Identity")
ax.hist(LJ2WCA_energy_transformed, bins=100, density=True, alpha=0.5, label="Transformed")
ax.set_xlabel(r"$U(x)$")
ax.set_ylabel(r"$P(U)$")
ax.set_title("Energy of Source System")

plt.legend(frameon=False)
plt.savefig(os.path.join(output_dir, "LJ2WCA_ener.png"))
# plt.show()