In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.insert(0, "/home/ninarell/OneDrive/WF_GAN_FOR_GLASSES/B_GEN/bgflow")
import bgflow
sys.path.insert(0, "/home/ninarell/OneDrive/WF_GAN_FOR_GLASSES/B_GEN/anode")
import anode

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from bgflow.utils import (assert_numpy, length_ppp, 
                          remove_mean, IndexBatchIterator, LossReporter, as_numpy, compute_distances, distance_vectors, distances_from_vectors, length_ppp
)
from bgflow import (GaussianMCMCSampler, DiffEqFlow, BoltzmannGenerator, Energy, Sampler, 
                    MultiDoubleWellPotential, MeanFreeNormalDistribution, KernelDynamics)

from glob import glob

In [None]:
#fnames = glob('/home/ninarell/OneDrive/WF_GAN_FOR_GLASSES/LJ_CRYSTAL/T_0.700_box/dumplin/dump.npt_nose_T1.0_P0.*.lammpstrj')
#coordinates = np.([np.loadtxt(f, skiprows=9)[:,2:5] for f in fnames])
temperature = float(1.)
side =5.63 #3.98 #2.52 #2.18 # 1.78
n_particles = 10 #len(coordinates[0])
spacial_dim = 2
dim_ics = n_particles * spacial_dim
#coordinates=coordinates.reshape(len(arrays), dim_ics)

In [None]:
from  bgflow.distribution.energy import LennardJonesPotentialPPP
from  bgflow.distribution.energy import LennardJonesPotential
rm = 2**(1./6.)
target = LennardJonesPotentialPPP(dim = dim_ics, n_particles = n_particles, side = side, oscillator = False, rm=rm, two_event_dims=False)
#target = LennardJonesPotential(dim = dim_ics, n_particles = n_particles,oscillator = False, rm=rm, two_event_dims=False)

In [None]:
#def plot_energy(coordinates, target):
#    xs = torch.Tensor(coordinates)
#    #xs = xs.view(-1,10,3)
#    energy = target.energy(xs).detach().numpy()
#    x=np.arange(1,len(energy)+1)

#    fig = plt.figure(figsize=(12, 4))
#    plt.subplot(1, 2, 1)
#    plt.plot(x, energy)

#    plt.subplot(1, 2, 2)
#    counts, bins = np.histogram(energy, density=True)
#    plt.yscale("log")
#    plt.stairs(counts, bins)

In [None]:
def plot_energy(coordinates, target):
    fig = plt.figure(figsize=(9, 9))

    energies=[]
    xx=[]
    for coord in coordinates:
        xs = torch.Tensor(coord)
        #xs = xs.view(-1,10,3)
        energy = target.energy(xs).detach().numpy()
        x=np.arange(1,len(energy)+1)
        xx.append(x)
        energies.append(energy)

    fig = plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    for energy,x in zip(energies,xx):
        plt.plot(x, energy)

    plt.subplot(1, 2, 2)
    plt.yscale("log")
    for energy in energies:
        counts, bins = np.histogram(energy, density=True)
        plt.stairs(counts, bins)


In [None]:
def plot_distance_histograms(samples, data, data_prior, n_particles, n_dimensions, d, e, side, log_w):
    fig = plt.figure(figsize=(16, 9))

    distances_x = as_numpy(compute_distances(samples, n_particles, n_dimensions))
    dists_data = as_numpy(compute_distances(data, n_particles, n_dimensions))
    dists_data_prior = as_numpy(compute_distances(data_prior, n_particles, n_dimensions))

    #plt.plot(d, e, label="Groundtruth", linewidth=4, alpha=0.9)
    plt.hist(dists_data.reshape(-1), bins=50, label="training samples", alpha=0.5, density=True, histtype='step', linewidth=4)
    plt.hist(dists_data_prior.reshape(-1), bins=50, label="prior samples", alpha=0.5, density=True, histtype='step', linewidth=4)
    plt.hist(distances_x.reshape(-1), bins=50, label="bg samples", alpha=0.7, density=True, histtype='step', linewidth=4)

    repeat_counts = (len(distances_x) * np.exp(log_w) / np.sum(np.exp(log_w))).astype(int)
    distances_x = np.repeat(distances_x, repeat_counts, axis=0)
    log_w_weighted = np.repeat(log_w, repeat_counts)
    log_w_weighted = np.repeat(log_w_weighted, distances_x.shape[1])
    #plt.hist(distances_x.reshape(-1), bins=50, label="reweighted bg samples", alpha=0.7, density=True, histtype='step', linewidth=4, weights=np.exp(log_w_weighted))

    plt.xlim(0, side*side)
    plt.legend(fontsize=35)
    plt.xlabel("u(x)", fontsize=45)
    plt.xticks(fontsize=45)
    plt.yticks(fontsize=45)
    
    return fig

In [None]:
def plot_scatter_with_limits_and_lines(data, side, limit_factor):
    plt.xlim(-side*limit_factor, side*limit_factor)
    plt.ylim(-side*limit_factor, side*limit_factor)
    plt.scatter(data[:,0], data[:,1])
    plt.scatter(data[:,2], data[:,3])
    plt.scatter(data[:,4], data[:,5])
    #plt.scatter(data[:,6], data[:,7])
    plt.plot([-side, -side], [-side, side], linestyle="--", color='black')
    plt.plot([side, side], [-side, side], linestyle="--", color='black')
    plt.plot([-side, side], [-side, -side], linestyle="--", color='black')
    plt.plot([-side, side], [side, side], linestyle="--", color='black')

In [None]:
def plot_scatter_with_limits_and_lines_ppp(data, side, limit_factor):

    plot_scatter_with_limits_and_lines(data, side, limit_factor)

In [None]:
def apply_ppp(data):
    data = remove_mean(data, n_particles, spacial_dim)
    while torch.any(abs(data) > side):
        data = length_ppp(data, side)
        data = remove_mean(data, n_particles, spacial_dim)
    return data

In [None]:
def learning_rate(lr,epoch,it,tau):
    xi_0=0.9
    xi_f=0.1
    A_l = lr*np.power(xi_0,epoch)
    dumping = 1+np.cos(np.pi*it/tau)
    return A_l * ((1-xi_f)*dumping/2 + xi_f)

In [None]:
# define a MCMC sampler to sample from the target energy
#Box constraint
def constraint(x):
    return length_ppp(x, side)

#init_state = torch.Tensor([-0.5,-0.5,-0.5,0.5,0.5,-0.5,0.5,0.5])
#init_state = torch.Tensor([-0.5,0, 0.5,0, 0, 0.5, 0, -0.5])
#init_state = torch.Tensor([-0.5,0,  0.5,0, 0,0.866, 0,-0.866]) #, 0, 0.5, 0, -0.5])
init_state = torch.Tensor([-0.5,0, 0.5,0, 0,0.866, 0,-0.866, -1.5,0, 1.5,0, 0,1.866, 0,-1.866, 1.866,1.866, -1.866,-1.866])
mc_step = 0.2
mcsampler = GaussianMCMCSampler(target, init_state=init_state, temperature=temperature, box_constraint=constraint, noise_std=mc_step)
void = mcsampler.sample(10000)

In [None]:
n_data=4096 #512 #8192 #16384 2048
data = mcsampler.sample(n_data)
apply_ppp(data)
data = data.view(-1,dim_ics)

In [None]:
fig = plot_energy([data], target)

In [None]:
fig = plot_scatter_with_limits_and_lines(data, side, 2)

In [None]:
### now set up a prior

from bgflow import NormalDistribution, TruncatedNormalDistribution, MeanFreeNormalDistribution, CircularNormalDistribution

prior =  MeanFreeNormalDistribution(dim_ics, n_particles, std=1.,two_event_dims=False) #.cuda()

In [None]:
data_prior = prior.sample(1000, temperature=temperature)
#data_prior = prior.sample(1000)
#data_prior = data_prior.view(-1, 10, 3)
plot_energy([data_prior,data], target)

In [None]:
plt.close(fig)

In [None]:
plot_scatter_with_limits_and_lines(data_prior, side, 2)

In [None]:
# set of the equivariant kernel dynamics

n_dimensions = spacial_dim
d_max = 8
mus = torch.linspace(0, d_max, 50) #.cuda()

mus.sort()
gammas = 0.3 * torch.ones(len(mus)) #.cuda()

mus_time = torch.linspace(0, 1, 10) #.cuda()
gammas_time = 0.3 * torch.ones(len(mus_time)) #.cuda()


kdyn = KernelDynamics(n_particles, n_dimensions, mus, gammas, optimize_d_gammas=True, optimize_t_gammas=True,
                      mus_time=mus_time, gammas_time=gammas_time, periodic = True, side = side) #.cuda()


In [None]:
flow = DiffEqFlow(dynamics = kdyn)

In [None]:
# having a flow and a prior, we can now define a Boltzmann Generator

bg = BoltzmannGenerator(prior, flow, target) #.cuda()

In [None]:
n_samples = 2000
samples, latent, dlogp = bg.sample(n_samples, with_latent=True, with_dlogp=True, temperature=temperature)
log_w = as_numpy(bg.log_weights_given_latent(samples, latent, dlogp, temperature=temperature))

In [None]:
def lennard_jones_energy_torch(r, eps=1.0, rm=rm):
    lj = eps * ((rm / r) ** 12 - 2 * (rm / r) ** 6)
    return lj
d = torch.linspace(0, 5, 1000).view(-1, 1) + 1e-6 
u = torch.exp(-(lennard_jones_energy_torch(d).view(-1, 1))/(temperature*1)).sum(dim=-1, keepdim=True)  * d.abs() **(dim_ics // n_particles - 1)
Z = (u * 1 / (len(d) / (d.max() - d.min()))).sum()
e = u / Z 


In [None]:
#fig = plot_distance_histograms(apply_ppp(samples), data, data_prior, n_particles, n_dimensions, d, e, side, log_w)

In [None]:
# use DTO in the training process
flow._use_checkpoints = True

# Anode options
options={
    "Nt": 3,
    "method": "RK4"
}
flow._kwargs = options

In [None]:
# initial training with likelihood maximization on data set

n_kl_samples = 64
n_batch = 64
batch_iter = IndexBatchIterator(len(data), n_batch)

lr=8e-3
tau= n_data/n_batch
optim = torch.optim.Adam(bg.parameters(), lr=lr, weight_decay=lr/50)

n_epochs = 2
n_report_steps = 4

# mixing parameter
lambdas = torch.linspace(1., 0.1, n_epochs) #.cuda()

reporter = LossReporter("NLL", "KLL")

In [None]:
for epoch, lamb in enumerate(lambdas):
    #(1 - np.power(-float(epoch)/float(n_epochs+1), 4))  
            
    for it, idxs in enumerate(batch_iter):   
        batch = data[idxs] #.cuda()
        
        for g in optim.param_groups:
            g['lr'] = learning_rate(lr,epoch,it,tau)                 #lr * np.exp(-10*float(epoch)/float(n_epochs)) 
            g['weight_decay'] = learning_rate(lr,epoch,it,tau)/50
            
        optim.zero_grad()

        # negative log-likelihood of the batch is equal to the energy of the BG
        nll = bg.energy(batch, temperature=temperature).mean()
        # aggregate weighted gradient
        (lamb * nll).backward()
        
        # kl divergence to the target
        kll = bg.kldiv(n_kl_samples, temperature=temperature).mean()

        # aggregate weighted gradient
        ((1. - lamb) * kll).backward()
        
        reporter.report(nll, kll)
        
        optim.step()
        
        
        if it % n_report_steps == 0:
            print("\repoch: {0}, iter: {1}/{2}, lambda: {3}, NLL: {4:.4}, KLL: {5:.4}".format(
                    epoch,
                    it,
                    len(batch_iter),
                    lamb,
                    *reporter.recent(1).ravel()
                ), end="")
            
        #n_samples = 2000
        #samples, latent, dlogp = bg.sample(n_samples, with_latent=True, with_dlogp=True, temperature=temperature)
        #log_w = as_numpy(bg.log_weights_given_latent(samples, latent, dlogp))
        #repeat_counts = (len(samples)* np.exp(log_w)/np.sum(np.exp(log_w))).astype(int)
        #samples = samples.view(-1,n_particles, n_dimensions)
        #replicated_samples = np.repeat(samples.detach().cpu().numpy(), repeat_counts, axis=0)
        #replicated_samples = replicated_samples.reshape((replicated_samples.shape[0], -1))
        #fig = plot_energy([replicated_samples,data], target)
        #fig = plot_distance_histograms(samples, data, data_prior, n_particles, n_dimensions, d, e, side, log_w)
        #filename = '/home/ninarell/Desktop/FIG_ENERGY/fig_'+str(epoch)+".png"
        #fig.savefig(filename, dpi=fig.dpi, format='png')
        # plt.close(fig)

In [None]:
reporter.plot()

In [None]:
n_samples = 20000
samples, latent, dlogp = bg.sample(n_samples, with_latent=True, with_dlogp=True, temperature=temperature)
log_w = as_numpy(bg.log_weights_given_latent(samples, latent, dlogp, temperature=temperature))

In [None]:
repeat_counts = (len(samples)* np.exp(log_w)/np.sum(np.exp(log_w))).astype(int)
samples = samples.view(-1,n_particles, n_dimensions)

In [None]:
np.where(repeat_counts>0)[0].size/n_samples*100

In [None]:
np.where(repeat_counts>0)[0].size

In [None]:
np.mean(np.power(np.exp(log_w),2))/np.power(np.mean(np.exp(log_w)),2)/n_samples*100

In [None]:
replicated_samples = np.repeat(samples.detach().cpu().numpy(), repeat_counts, axis=0)
replicated_samples = replicated_samples.reshape((replicated_samples.shape[0], -1))

In [None]:
fig = plot_distance_histograms(apply_ppp(samples), data, data_prior, n_particles, n_dimensions, d, e, side, log_w)

In [None]:
#plot_scatter_with_limits_and_lines_ppp(length_ppp(torch.Tensor(replicated_samples),side), side, 2)
#plot_scatter_with_limits_and_lines_ppp(apply_ppp(torch.Tensor(replicated_samples)), side, 2)
plot_scatter_with_limits_and_lines_ppp(apply_ppp(torch.Tensor(samples.view(-1, dim_ics).detach().numpy())), side, 2)
#plot_scatter_with_limits_and_lines_ppp(torch.Tensor(replicated_samples), side, 2)

In [None]:
plot_scatter_with_limits_and_lines_ppp(apply_ppp(torch.Tensor(replicated_samples)), side, 2)

In [None]:
plot_energy([samples.view(-1, dim_ics),data], target)

In [None]:
plot_energy([replicated_samples,data], target)