In [None]:
import numpy as np
import torch
import torch.nn as nn
import FrEIA.framework as Ff
import FrEIA.modules as Fm
from matplotlib import pyplot as pp
import math
from tqdm import tqdm_notebook
%matplotlib inline

In [None]:
use_cuda = True
device = torch.device("cuda") if use_cuda else torch.device("cpu")

In [None]:
class DoubleWell():
    """A simple double well potential in 2-D."""
    def __init__(self, a1=1.0, a2=6.0, a4=1.0, k=1.0):
        self.a1 = a1
        self.a2 = a2
        self.a4 = a4
        self.k = k

    def energy(self, x):
        dimer_energy = (self.a4 * x[0]**4 - 
                        self.a2 * x[0]**2 + 
                        self.a1 * x[0])
        
        oscillator_energy = 0.5 * self.k * x[1] ** 2
        return  dimer_energy + oscillator_energy

    def forces(self, x):
        dimer_force = (4 * self.a4 * x[0]**3 -
                       2 * self.a2 * x[0] +
                       self.a1)
        oscillator_force = self.k * x[1]
        return np.array([dimer_force, oscillator_force])
    
    def get_dimer_energies(self):
        """ Plots the dimer energy to the standard figure """
        x_grid = np.linspace(-3, 3, num=200)
        energies = []
        for x in x_grid:
            energies.append(self.energy([x, 0.0]))
        return x_grid, np.array(energies)

In [None]:
x_grid, energies = DoubleWell().get_dimer_energies()
pp.plot(x_grid, energies);

In [None]:
class Metropolis():
    def __init__(self, model, x0, noise=0.1, burnin=0, stride=1):
        """A simple Metropolis Monte Carlo sampler.
        
        Parameters
        ----------
        model: Energy model
            An object that provides the function energy(x) that
            computes the energy.
        x0: np.array
            The initial configuration
        noise: float
            The size of the gaussian proposal step.
        burnin: int
            The number of initial steps to discard.
        stride: int
            How often output is saved.
            
        """
        self.model = model
        self.noise = noise
        self.burnin = burnin
        self.stride = stride
        self.x = None
        self.step = None
        self.traj_ = None
        self.etraj_ = None
        self.energy = None
        self.reset(x0)
        
    def reset(self, x0):
        """Reset the sampler.
        
        This will set the sampler to a new point and discard
        previous data.
        
        Parameters
        ----------
        x0: np.ndarray
            New value for the sampler.
            
        """
        self.step = 0
        self.traj_ = []
        self.etraj_ = []
        self.x = x0
        self.energy = self.model.energy(self.x)

    def run(self, n_steps):
        for _ in range(n_steps):
            self.step += 1
            self._run_trial()
            self._log_results()
        
    def _log_results(self):
        if self.step > self.burnin:
            if self.step % self.stride == 0:
                self.traj_.append(self.x)
                self.etraj_.append(self.energy)
    
    def _run_trial(self):
        x_prop = self.x + self.noise * np.random.randn(self.x.shape[0])
        e_prop = self.model.energy(x_prop)
        if e_prop <= self.energy:
            accept = True
        else:
            metrop = math.exp(self.energy - e_prop)
            if np.random.rand() < metrop:
                accept = True
            else:
                accept = False
        if accept:
            self.x = x_prop
            self.energy = e_prop
    
    @property
    def traj(self):
        return np.array(self.traj_)
    
    @property
    def etraj(self):
        return np.array(self.etraj_)

In [None]:
x0_left = np.array([-1.7, 0.0])
x0_right = np.array([1.7, 0.0])
nsteps = 20_000
ener_model = DoubleWell()

# Generate data starting from left well.
sampler = Metropolis(ener_model, x0_left, stride=2)
sampler.run(nsteps)
traj_left = sampler.traj

# Generate data starting from right well.
sampler.reset(x0_right)
sampler.run(nsteps)
traj_right = sampler.traj

In [None]:
# Plot the x-components.
pp.plot(traj_left[:, 0])
pp.plot(traj_right[:, 0]);

In [None]:
# join the trajectory and shuffle into random order
combined_traj = np.vstack([traj_left, traj_right])
np.random.shuffle(combined_traj)

# divide into validation and training sets
n_validation = combined_traj.shape[0] // 10
traj_valid = combined_traj[:n_validation, :]
traj_train = combined_traj[n_validation:, :]

Here's a scatter plot of the combined trajectory.

In [None]:
pp.scatter(combined_traj[:, 0], combined_traj[:, 1], marker=".", alpha=0.1);

## Invertible Network

FrEIA needs a function to create the non-linear layers for the invertible modules. This function takes
the number of inputs and outputs as parameters and should construct a network.

In this case, we use 3 linear - ReLU layers with 100 hidden units, followed by a final linear layer.

In [None]:
class CreateFC:
    def __init__(self, n_hidden):
        self.n_hidden = n_hidden
        
    def __call__(self, c_in, c_out):
        lin1 = nn.Linear(c_in, self.n_hidden)
        lin2 = nn.Linear(self.n_hidden, self.n_hidden)
        lin3 = nn.Linear(self.n_hidden, self.n_hidden)
        lin4 = nn.Linear(self.n_hidden, c_out)
        
        # Initialize the weights in each layer.
        # Kaiming initialization is suitable for
        # ReLU activation functions.
        torch.nn.init.kaiming_uniform_(lin1.weight)
        torch.nn.init.kaiming_uniform_(lin2.weight)
        torch.nn.init.kaiming_uniform_(lin3.weight)
        # Initialize the weights and biases in the last
        # Layer to zero, which gives the identity transform
        # as our starting point.
        torch.nn.init.zeros_(lin4.weight)
        torch.nn.init.zeros_(lin4.bias)
        
        return nn.Sequential(
            lin1,
            nn.ReLU(),
            lin2,
            nn.ReLU(),
            lin3,
            nn.ReLU(),
            lin4)

Now we build our invertible network. This is very flexible. We currently use 4 GLOW layers.

In [None]:
# Create our reversible network

nodes = [Ff.InputNode(2, name="input")]

n_glow_layers = 8
for i in range(n_glow_layers):
    node1 = Ff.Node(
        nodes[-1],
        Fm.GLOWCouplingBlock,
        {"subnet_constructor": CreateFC(64), "clamp": 2},
        name=f"glow_{i}"
    )
    node2 = Ff.Node(node1, Fm.PermuteRandom,
                   {"seed": i},
                   name=f"permute_{i}")
    nodes.append(node1)
    nodes.append(node2)
    
nodes.append(Ff.OutputNode(nodes[-1], name="output"))
net = Ff.ReversibleGraphNet(nodes, verbose=False)

## Training

In [None]:
# Move the network and training / validation data to the device
net = net.to(device=device)
traj_train = torch.as_tensor(traj_train, device=device, dtype=torch.float32)
traj_valid = torch.as_tensor(traj_valid, device=device, dtype=torch.float32)

In [None]:
n_epoch = 2000  # This is more than necessary
n_batch = 256

# This is the number of data points per batch
I = np.arange(traj_train.shape[0])  # A list of indices into the training set

# Build the optimizer.
learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, verbose=True)

# We'll keep track of training and validation losses.
losses = []
val_losses = []


with tqdm_notebook(range(n_epoch)) as t:
    for epoch in t:
        net.train()  # Tell pytorch that we want to train the network.

        # choose a random batch of samples
        index_batch = np.random.choice(I, n_batch, replace=True)
        x_batch = traj_train[index_batch, :]

        # pass it through the network
        z = net(x_batch)

        # compute the loss
        loss = 0.5 * torch.mean(z**2) - torch.mean(net.log_jacobian(run_forward=False)) / 2.0

        # take a gradient step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()                                               

        if epoch % 10 == 0:
            net.eval()
            with torch.no_grad():  # No gradients because we're validating.
                z_val = net(traj_valid)
                loss_val = 0.5 * torch.mean(z_val**2) - torch.mean(net.log_jacobian(run_forward=False)) / 2
            losses.append(loss.item())
            val_losses.append(loss_val.item())
            scheduler.step(loss_val.item())
            t.set_postfix(loss=loss.item(), val_loss=loss_val.item())

## Evaluation

The training and validation losses have both plateaued. The validation loss is about the same as the training loss, indicating that we're not over-fitting.

In [None]:
pp.plot(losses)
pp.plot(val_losses);

Generate a set of structures from the network and plot them. These random latent samples map to structure that are similar to the input (see above). The one differnce is that there is a band of structures connecting the two points. I think this comes from the region where the two meta-stable regions butt against each other (see below). Since we're training by example, what matters is that examples map to high probability regions of the Gaussian. However, it doesn't say anything about the reverse—that is, where high-probabilty latent samples map to in structure-space.

In [None]:
with torch.no_grad():
    z_output = torch.normal(0, 1, (combined_traj.shape[0], 2), device=device)
    x_output = net(z_output, rev=True).cpu()
    pp.scatter(x_output[:, 0], x_output[:, 1], marker=".", alpha=0.1);

Show how the training data maps into the latent space.

In [None]:
z_left  = net(torch.as_tensor(traj_left,  dtype=torch.float32, device=device))
z_right = net(torch.as_tensor(traj_right, dtype=torch.float32, device=device))

In [None]:
def plot_trans(rev, lim=6, n_grid=51):
    net.eval()
    n_dens = 5000
    pp.figure(figsize=(9, 9))
    
    coarse_xs = torch.linspace(-lim, lim, n_grid, device=device)
    coarse_ys = torch.linspace(-lim, lim, n_grid, device=device)
    fine_xs = torch.linspace(-lim, lim, n_dens, device=device)
    fine_ys = torch.linspace(-lim, lim, n_dens, device=device)
    
    for x in np.linspace(-lim, lim, n_grid):
        points = torch.zeros((n_dens, 2), device=device)
        points[:, 0] = x
        points[:, 1] = fine_ys
        with torch.no_grad():
            trans = net(points, rev=rev)
        trans = trans.cpu().detach().numpy()
        pp.plot(trans[:, 0], trans[:, 1], color="grey", linewidth=0.5)
        
    for y in coarse_ys:
        points = torch.zeros((n_dens, 2), device=device)
        points[:, 0] = fine_xs
        points[:, 1] = y
        with torch.no_grad():
            trans = net(points, rev=rev)
        trans = trans.cpu().detach().numpy()
        pp.plot(trans[:, 0], trans[:, 1], color="grey", linewidth=0.5)
    pp.xlim(-3, 3)
    pp.ylim(-3, 3)

In [None]:
plot_trans(rev=True, n_grid=31)
pp.scatter(traj_left[:, 0], traj_left[:, 1], marker='.', alpha=0.2)
pp.scatter(traj_right[:, 0], traj_right[:, 1], marker='.', alpha=0.2);

In [None]:
plot_trans(rev=False, n_grid=51)
z_left_cpu = z_left.cpu().detach().numpy()
z_right_cpu = z_right.cpu().detach().numpy()
pp.scatter(z_left_cpu[:, 0], z_left_cpu[:, 1], marker='.', alpha=0.1)
pp.scatter(z_right_cpu[:, 0], z_right_cpu[:, 1], marker='.', alpha=0.1);

In [None]:
with torch.no_grad():
    z_output = torch.normal(0, 1, (combined_traj.shape[0] * 50, 2), device=device)
    x_output = net(z_output, rev=True)

z_output = z_output.cpu().detach().numpy()
x_output = x_output.cpu().detach().numpy()
counts, edges = np.histogram(x_output[:, 0], range=(-3, 3), bins=50)
edges = 0.5 * (edges[:-1] + edges[1:])
keep_ind = np.where(counts > 0)
edges = edges[keep_ind]
counts = counts[keep_ind]
free_energy = -np.log(counts)

x_grid, energies = ener_model.get_dimer_energies()
pp.plot(x_grid, energies);

pp.scatter(edges, free_energy + (np.min(energies) - np.min(free_energy)), color="C1")

pp.ylim(-12, 5);

## KL Training

In [None]:
class GetEnergy(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, energy_model):
        ctx.save_for_backward(input)
        ctx.energy_model = energy_model
        
        n_batch = input.shape[0]
        energies = torch.zeros((n_batch, 1))
        
        # get the coordinates on the cpu
        input = input.cpu().detach().numpy()
        for i in range(n_batch):
            x = input[i, :]
            energy = energy_model.energy(x)
            energies[i, :] = energy
        return energies.to(device=device)
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        energy_model = ctx.energy_model
        
        # move onto the cpu
        input = input.cpu().detach().numpy()
        grad_output = grad_output.cpu()
        
        n_batch = input.shape[0]
        n_dim = input.shape[1]
        grad_energy = torch.zeros((n_batch, n_dim))
        
        for i in range(n_batch):
            x = input[i, :]
            forces = energy_model.forces(x).astype("float32")
            grad_energy[i, :] = grad_output[i] * torch.from_numpy(forces)
        
        return grad_energy.to(device=device), None

get_energy = GetEnergy.apply

In [None]:
n_epoch = 1000
n_batch = 1024

# Build the optimizer.
learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, verbose=True)


# We'll keep track of training and validation losses.
losses = []
val_losses = []

# Create a validation set
z_val = torch.normal(0, 1, size=(n_batch, 2), device=device)

with tqdm_notebook(range(n_epoch)) as t:
    for epoch in t:
        net.train()  # Tell pytorch that we want to train the network.
        # generate a random batch of samples
        index_batch = np.random.choice(I, n_batch, replace=True)
        x_batch = traj_train[index_batch, :]

        # pass it through the network
        z = net(x_batch)

        # compute the loss
        loss1 = 0.5 * torch.mean(z**2) - torch.mean(net.log_jacobian(run_forward=False)) / 2

        # generate anothr batch of samples
        z_batch = torch.normal(0, 1, size=(n_batch, 2), device=device)

        # run it back through the network
        x = net(z_batch, rev=True)

        # compute the loss
        loss2 = torch.mean(get_energy(x, ener_model) - net.log_jacobian(run_forward=False, rev=True) / 2)
        
        loss = loss1 + loss2

        # take a gradient step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            net.eval()
            with torch.no_grad():  # No gradients because we're validating.
                x_val = net(z_val, rev=True)
                loss_val = torch.mean(get_energy(x_val, ener_model) - net.log_jacobian(run_forward=False, rev=True) / 2)
            losses.append(loss.item())
            val_losses.append(loss_val.item())
            scheduler.step(loss_val.item())
            t.set_postfix(loss=loss.item(), val_loss=loss_val.item())

In [None]:
pp.plot(losses)
pp.plot(val_losses);

In [None]:
with torch.no_grad():
    z_output = torch.normal(0, 1, (combined_traj.shape[0], 2), device=device)
    x_output = net(z_output, rev=True)
z_output = z_output.cpu().detach().numpy()
x_output = x_output.cpu().detach().numpy()
ind_left = np.where(x_output[:, 0] < 0)
ind_right = np.where(x_output[:, 0] >= 0)
pp.figure(figsize=(7, 7))
pp.scatter(x_output[ind_left, 0], x_output[ind_left, 1], marker=".", alpha=0.1)
pp.scatter(x_output[ind_right, 0], x_output[ind_right, 1], marker=".", alpha=0.1)
pp.xlim(-3, 3)
pp.ylim(-3, 3);

In [None]:
plot_trans(rev=True, n_grid=31)
pp.scatter(x_output[ind_left, 0], x_output[ind_left, 1], marker='.', alpha=0.1)
pp.scatter(x_output[ind_right, 0], x_output[ind_right, 1], marker='.', alpha=0.1);

In [None]:
plot_trans(rev=False, n_grid=51)
pp.scatter(z_output[ind_left, 0], z_output[ind_left, 1], marker='.', alpha=0.1);
pp.scatter(z_output[ind_right, 0], z_output[ind_right, 1], marker='.', alpha=0.1);

In [None]:
with torch.no_grad():
    z_output = torch.normal(0, 1, (combined_traj.shape[0] * 50, 2), device=device)
    x_output = net(z_output, rev=True)

z_output = z_output.cpu().detach().numpy()
x_output = x_output.cpu().detach().numpy()
counts, edges = np.histogram(x_output[:, 0], range=(-3, 3), bins=50)
edges = 0.5 * (edges[:-1] + edges[1:])
keep_ind = np.where(counts > 0)
edges = edges[keep_ind]
counts = counts[keep_ind]
free_energy = -np.log(counts)

x_grid, energies = ener_model.get_dimer_energies()
pp.plot(x_grid, energies);

pp.scatter(edges, free_energy + (np.min(energies) - np.min(free_energy)), color="C1")

pp.ylim(-12, 5);