In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
%cd gdrive/My Drive/short_run_demo_folder/short_run_demo

In [None]:
# ! git clone https://github.com/EricMFischer/short_run_demo.git
# ! git pull

In [None]:
# utils.py

# download plotting functions and toy dataset

import torch as t
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os


##################
# ## PLOTTING ## #
##################

# plot diagnostics for learning
def plot_diagnostics(batch, en_diffs, grad_mags, exp_dir, fontsize=10):
    # axis tick size
    matplotlib.rc('xtick', labelsize=6)
    matplotlib.rc('ytick', labelsize=6)
    fig = plt.figure()

    def plot_en_diff_and_grad_mag():
        # energy difference
        ax = fig.add_subplot(221)
        ax.plot(en_diffs[0:(batch+1)].data.cpu().numpy())
        ax.axhline(y=0, ls='--', c='k')
        ax.set_title('Energy Difference', fontsize=fontsize)
        ax.set_xlabel('batch', fontsize=fontsize)
        ax.set_ylabel('$d_{s_t}$', fontsize=fontsize)
        # mean langevin gradient
        ax = fig.add_subplot(222)
        ax.plot(grad_mags[0:(batch+1)].data.cpu().numpy())
        ax.set_title('Average Langevin Gradient Magnitude', fontsize=fontsize)
        ax.set_xlabel('batch', fontsize=fontsize)
        ax.set_ylabel('$r_{s_t}$', fontsize=fontsize)

    def plot_crosscorr_and_autocorr(t_gap_max=2000, max_lag=15, b_w=0.35):
        t_init = max(0, batch + 1 - t_gap_max)
        t_end = batch + 1
        t_gap = t_end - t_init
        max_lag = min(max_lag, t_gap - 1)
        # rescale energy diffs to unit mean square but leave uncentered
        en_rescale = en_diffs[t_init:t_end] / t.sqrt(t.sum(en_diffs[t_init:t_end] * en_diffs[t_init:t_end])/(t_gap-1))
        # normalize gradient magnitudes
        grad_rescale = (grad_mags[t_init:t_end]-t.mean(grad_mags[t_init:t_end]))/t.std(grad_mags[t_init:t_end])
        # cross-correlation and auto-correlations
        cross_corr = np.correlate(en_rescale.cpu().numpy(), grad_rescale.cpu().numpy(), 'full') / (t_gap - 1)
        en_acorr = np.correlate(en_rescale.cpu().numpy(), en_rescale.cpu().numpy(), 'full') / (t_gap - 1)
        grad_acorr = np.correlate(grad_rescale.cpu().numpy(), grad_rescale.cpu().numpy(), 'full') / (t_gap - 1)
        # x values and indices for plotting
        x_corr = np.linspace(-max_lag, max_lag, 2 * max_lag + 1)
        x_acorr = np.linspace(0, max_lag, max_lag + 1)
        t_0_corr = int((len(cross_corr) - 1) / 2 - max_lag)
        t_0_acorr = int((len(cross_corr) - 1) / 2)

        # plot cross-correlation
        ax = fig.add_subplot(223)
        ax.bar(x_corr, cross_corr[t_0_corr:(t_0_corr + 2 * max_lag + 1)])
        ax.axhline(y=0, ls='--', c='k')
        ax.set_title('Cross Correlation of Energy Difference\nand Gradient Magnitude', fontsize=fontsize)
        ax.set_xlabel('lag', fontsize=fontsize)
        ax.set_ylabel('correlation', fontsize=fontsize)
        # plot auto-correlation
        ax = fig.add_subplot(224)
        ax.bar(x_acorr-b_w/2, en_acorr[t_0_acorr:(t_0_acorr + max_lag + 1)], b_w, label='en. diff. $d_{s_t}$')
        ax.bar(x_acorr+b_w/2, grad_acorr[t_0_acorr:(t_0_acorr + max_lag + 1)], b_w, label='grad. mag. $r_{s_t}}$')
        ax.axhline(y=0, ls='--', c='k')
        ax.set_title('Auto-Correlation of Energy Difference\nand Gradient Magnitude', fontsize=fontsize)
        ax.set_xlabel('lag', fontsize=fontsize)
        ax.set_ylabel('correlation', fontsize=fontsize)
        ax.legend(loc='upper right', fontsize=fontsize-4)

    # make diagnostic plots
    plot_en_diff_and_grad_mag()
    plot_crosscorr_and_autocorr()
    # save figure
    plt.subplots_adjust(hspace=0.6, wspace=0.6)
    plt.savefig(os.path.join(exp_dir, 'diagnosis_plot.pdf'), format='pdf')
    plt.close()


#####################
# ## TOY DATASET ## #
#####################

class ToyDataset:
    # TODO: toy_groups I think can be set to 3. toy_sd and toy_radius should be hardcoded?
    def __init__(self, toy_type='gmm', toy_groups=8, toy_sd=0.15, toy_radius=1, viz_res=500, kde_bw=0.05):
        # import helper functions
        from scipy.stats import gaussian_kde
        from scipy.stats import multivariate_normal
        self.gaussian_kde = gaussian_kde
        self.mvn = multivariate_normal

        # toy dataset parameters
        self.toy_type = toy_type
        self.toy_groups = toy_groups
        self.toy_sd = toy_sd
        self.toy_radius = toy_radius
        self.weights = np.ones(toy_groups) / toy_groups
        if toy_type == 'gmm':
            means_x = np.cos(2*np.pi*np.linspace(0, (toy_groups-1)/toy_groups, toy_groups)).reshape(toy_groups, 1, 1, 1)
            means_y = np.sin(2*np.pi*np.linspace(0, (toy_groups-1)/toy_groups, toy_groups)).reshape(toy_groups, 1, 1, 1)
            self.means = toy_radius * np.concatenate((means_x, means_y), axis=1)
        else:
            self.means = None

        # ground truth density
        if self.toy_type == 'gmm':
            def true_density(x):
                density = 0
                for k in range(toy_groups):
                    density += self.weights[k]*self.mvn.pdf(np.array([x[0], x[1]]), mean=self.means[k].squeeze(),
                                                            cov=(self.toy_sd**2)*np.eye(2))
                return density
        elif self.toy_type == 'gmm2': # TODO
            def true_density(x):
                density = 0
                for k in range(toy_groups):
                    density += self.weights[k]*self.mvn.pdf(np.array([x[0], x[1]]), mean=self.means[k].squeeze(),
                                                            cov=(self.toy_sd**2)*np.eye(2))
                return density
        elif self.toy_type == 'rings':
            def true_density(x):
                radius = np.sqrt((x[0] ** 2) + (x[1] ** 2))
                density = 0
                for k in range(toy_groups):
                    density += self.weights[k] * self.mvn.pdf(radius, mean=self.toy_radius * (k + 1),
                                                              cov=(self.toy_sd**2))/(2*np.pi*self.toy_radius*(k+1))
                return density
        else:
            raise RuntimeError('Invalid option for toy_type (use "gmm", "gmm2", or "rings")')
        self.true_density = true_density

        # viz parameters
        self.viz_res = viz_res
        self.kde_bw = kde_bw
        if toy_type == 'rings':
            self.plot_val_max = toy_groups * toy_radius + 4 * toy_sd
        else:
            self.plot_val_max = toy_radius + 4 * toy_sd

        # save values for plotting groundtruth landscape
        self.xy_plot = np.linspace(-self.plot_val_max, self.plot_val_max, self.viz_res)
        self.z_true_density = np.zeros(self.viz_res**2).reshape(self.viz_res, self.viz_res)
        for x_ind in range(len(self.xy_plot)):
            for y_ind in range(len(self.xy_plot)):
                self.z_true_density[x_ind, y_ind] = self.true_density([self.xy_plot[x_ind], self.xy_plot[y_ind]])

    def sample_toy_data(self, num_samples):
        toy_sample = np.zeros(0).reshape(0, 2, 1, 1)
        sample_group_sz = np.random.multinomial(num_samples, self.weights)
        if self.toy_type == 'gmm':
            for i in range(self.toy_groups):
                sample_group = self.means[i] + self.toy_sd * np.random.randn(2*sample_group_sz[i]).reshape(-1, 2, 1, 1)
                toy_sample = np.concatenate((toy_sample, sample_group), axis=0)
        elif self.toy_type == 'gmm2': # TODO
            for i in range(self.toy_groups):
                sample_group = self.means[i] + self.toy_sd * np.random.randn(2*sample_group_sz[i]).reshape(-1, 2, 1, 1)
                toy_sample = np.concatenate((toy_sample, sample_group), axis=0)
        elif self.toy_type == 'rings':
            for i in range(self.toy_groups):
                sample_radii = self.toy_radius*(i+1) + self.toy_sd * np.random.randn(sample_group_sz[i])
                sample_thetas = 2 * np.pi * np.random.random(sample_group_sz[i])
                sample_x = sample_radii.reshape(-1, 1) * np.cos(sample_thetas).reshape(-1, 1)
                sample_y = sample_radii.reshape(-1, 1) * np.sin(sample_thetas).reshape(-1, 1)
                sample_group = np.concatenate((sample_x, sample_y), axis=1)
                toy_sample = np.concatenate((toy_sample, sample_group.reshape(-1, 2, 1, 1)), axis=0)
        else:
            raise RuntimeError('Invalid option for toy_type ("gmm", "gmm2", or "rings")')

        return toy_sample

    def plot_toy_density(self, plot_truth=False, f=None, epsilon=0.0, x_s_t=None, save_path='toy.pdf'):
        num_plots = 0
        if plot_truth:
            num_plots += 1

        # density of learned EBM
        if f is not None:
            num_plots += 1
            xy_plot_torch = t.Tensor(self.xy_plot).view(-1, 1, 1, 1).to(next(f.parameters()).device)
            # y values for learned energy landscape of descriptor network
            z_learned_energy = np.zeros([self.viz_res, self.viz_res])
            for i in range(len(self.xy_plot)):
                y_vals = float(self.xy_plot[i]) * t.ones_like(xy_plot_torch)
                vals = t.cat((xy_plot_torch, y_vals), 1)
                z_learned_energy[i] = f(vals).data.cpu().numpy()
            # rescale y values to correspond to the groundtruth temperature
            if epsilon > 0:
                z_learned_energy *= 2 / (epsilon ** 2)

            # transform learned energy into learned density
            z_learned_density_unnormalized = np.exp(- z_learned_energy)
            bin_area = (self.xy_plot[1] - self.xy_plot[0]) ** 2
            z_learned_density = z_learned_density_unnormalized / (bin_area * np.sum(z_learned_density_unnormalized))

        # kernel density estimate of shortrun samples
        if x_s_t is not None:
            num_plots += 1
            density_estimate = self.gaussian_kde(x_s_t.squeeze().cpu().numpy().transpose(), bw_method=self.kde_bw)
            z_kde_density = np.zeros([self.viz_res, self.viz_res])
            for i in range(len(self.xy_plot)):
                for j in range(len(self.xy_plot)):
                    z_kde_density[i, j] = density_estimate((self.xy_plot[j], self.xy_plot[i]))

        # plot results
        plot_ind = 0
        fig = plt.figure()

        # true density
        if plot_truth:
            plot_ind += 1
            ax = fig.add_subplot(2, num_plots, plot_ind)
            ax.set_title('True density')
            plt.imshow(self.z_true_density, cmap='viridis')
            plt.axis('off')
            ax = fig.add_subplot(2, num_plots, plot_ind + num_plots)
            ax.set_title('True log-density')
            plt.imshow(np.log(self.z_true_density + 1e-10), cmap='viridis')
            plt.axis('off')
        # learned ebm
        if f is not None:
            plot_ind += 1
            ax = fig.add_subplot(2, num_plots, plot_ind)
            ax.set_title('EBM density')
            plt.imshow(z_learned_density, cmap='viridis')
            plt.axis('off')
            ax = fig.add_subplot(2, num_plots, plot_ind + num_plots)
            ax.set_title('EBM log-density')
            plt.imshow(np.log(z_learned_density + 1e-10), cmap='viridis')
            plt.axis('off')
        # shortrun kde
        if x_s_t is not None:
            plot_ind += 1
            ax = fig.add_subplot(2, num_plots, plot_ind)
            ax.set_title('Short-run KDE')
            plt.imshow(z_kde_density, cmap='viridis')
            plt.axis('off')
            ax = fig.add_subplot(2, num_plots, plot_ind + num_plots)
            ax.set_title('Short-run log-KDE')
            plt.imshow(np.log(z_kde_density + 1e-10), cmap='viridis')
            plt.axis('off')

        plt.tight_layout()
        plt.savefig(save_path, bbox_inches='tight', format='pdf')
        plt.close()


In [None]:
# nets.py

import torch as t
import torch.nn as nn
import torch.nn.functional as F


#################################
# ## TOY NETWORK FOR 2D DATA ## #
#################################

class ToyNet(nn.Module):
    def __init__(self, dim=2, n_f=32, leak=0.05):
        super(ToyNet, self).__init__()
        self.f = nn.Sequential(
            nn.Conv2d(dim, n_f, 1, 1, 0), # in, out, kernel size, stride, padding
            nn.LeakyReLU(leak),
            nn.Conv2d(n_f, n_f * 2, 1, 1, 0),
            nn.LeakyReLU(leak),
            nn.Conv2d(n_f * 2, n_f * 2, 1, 1, 0),
            nn.LeakyReLU(leak),
            nn.Conv2d(n_f * 2, n_f * 2, 1, 1, 0),
            nn.LeakyReLU(leak),
            nn.Conv2d(n_f * 2, 1, 1, 1, 0))

    def forward(self, x):
        return self.f(x).squeeze()

In [None]:
# train_toy.py

#############################################
# ## TRAIN EBM USING 2D TOY DISTRIBUTION ## #
#############################################
# Measures and plots the diagnostic values d_{s_t} and r_t

# GOAL: Implement another toy density for learning 2D DeepFRAME models.
# Modify ToyDataset class in utils to implement another density.
# Density should be a GMM groundtruth density with several modes that have
# different covariance matrices.
# This file (train_toy.py) should run without changes once new dataset class has
# been made.

import torch as t
import json
import os

# directory for experiment results
EXP_DIR = './short_run_demo/out_toy/toy_config_4/'
# json file with experiment config
CONFIG_FILE = './short_run_demo/config_locker/toy_config.json'


#######################
# ## INITIAL SETUP ## #
#######################

# load experiment config
with open(CONFIG_FILE) as file:
    config = json.load(file)

# make directory for saving results
if os.path.exists(EXP_DIR):
    # prevents overwriting old experiment folders by accident
    raise RuntimeError('Experiment folder "{}" already exists. Please use a different "EXP_DIR".'.format(EXP_DIR))
else:
    os.makedirs(EXP_DIR)
    for folder in ['checkpoints', 'landscape', 'plots', 'code']:
        os.mkdir(EXP_DIR + folder)

# save copy of code in the experiment folder
def save_code():
    def save_file(file_name):
        file_in = open(file_name, 'r')
        file_out = open(EXP_DIR + 'code/' + os.path.basename(file_name), 'w')
        for line in file_in:
            file_out.write(line)
    for file in ['./short_run_demo/train_toy.py', './short_run_demo/nets.py', './short_run_demo/utils.py', CONFIG_FILE]:
        save_file(file)
save_code()

# set seed for cpu and CUDA, get device
t.manual_seed(config['seed'])
if t.cuda.is_available():
    print('t.cuda is available')
    t.cuda.manual_seed_all(config['seed'])
device = t.device('cuda' if t.cuda.is_available() else 'cpu')


########################
# ## TRAINING SETUP # ##
########################

print('Setting up network and optimizer...')
# set up network
net_bank = {'toy': ToyNet}
f = net_bank[config['net_type']]().to(device)
# set up optimizer
optim_bank = {'adam': t.optim.Adam, 'sgd': t.optim.SGD}
if config['optimizer_type'] == 'sgd' and config['epsilon'] > 0:
    # scale learning rate according to langevin noise for invariant tuning
    config['lr_init'] *= (config['epsilon'] ** 2) / 2
    config['lr_min'] *= (config['epsilon'] ** 2) / 2
optim = optim_bank[config['optimizer_type']](f.parameters(), lr=config['lr_init'])

print('Processing data...')
# toy dataset for which true samples can be obtained
q = ToyDataset(config['toy_type'], config['toy_groups'], config['toy_sd'],
               config['toy_radius'], config['viz_res'], config['kde_bw'])

# initialize persistent states from noise
# s_t_0 is used when init_type == 'persistent' in sample_s_t()
s_t_0 = 2 * t.rand([config['s_t_0_size'], 2, 1, 1]).to(device) - 1


################################
# ## FUNCTIONS FOR SAMPLING ## #
################################

# sample batch from given array of states
def sample_state_set(state_set, batch_size=config['batch_size']):
    rand_inds = t.randperm(state_set.shape[0])[0:batch_size]
    return state_set[rand_inds], rand_inds

# sample positive states from toy 2d distribution q
def sample_q(batch_size=config['batch_size']): return t.Tensor(q.sample_toy_data(batch_size)).to(device)

# initialize and update states with langevin dynamics to obtain samples from finite-step MCMC distribution s_t
def sample_s_t(batch_size, L=config['num_mcmc_steps'], init_type=config['init_type'], update_s_t_0=True):
    # get initial mcmc states for langevin updates ("persistent", "data", "uniform", or "gaussian")
    def sample_s_t_0():
        if init_type == 'persistent':
            return sample_state_set(s_t_0, batch_size)
        elif init_type == 'data':
            return sample_q(batch_size), None
        elif init_type == 'uniform':
            return config['noise_init_factor'] * (2 * t.rand([batch_size, 2, 1, 1]) - 1).to(device), None
        elif init_type == 'gaussian':
            return config['noise_init_factor'] * t.randn([batch_size, 2, 1, 1]).to(device), None
        else:
            raise RuntimeError('Invalid method for "init_type" (use "persistent", "data", "uniform", or "gaussian")')

    # initialize MCMC samples
    x_s_t_0, s_t_0_inds = sample_s_t_0()

    # iterative langevin updates of MCMC samples
    x_s_t = t.autograd.Variable(x_s_t_0.clone(), requires_grad=True)
    r_s_t = t.zeros(1).to(device)  # variable r_s_t (Section 3.2) to record average gradient magnitude
    for ell in range(L):
        f_prime = t.autograd.grad(f(x_s_t).sum(), [x_s_t])[0]
        x_s_t.data += - f_prime + config['epsilon'] * t.randn_like(x_s_t)
        r_s_t += f_prime.view(f_prime.shape[0], -1).norm(dim=1).mean()

    if init_type == 'persistent' and update_s_t_0:
        # update persistent state bank
        s_t_0.data[s_t_0_inds] = x_s_t.detach().data.clone()

    return x_s_t.detach(), r_s_t.squeeze() / L


#######################
# ## TRAINING LOOP ## #
#######################

# containers for diagnostic records (see Section 3)
d_s_t_record = t.zeros(config['num_train_iters']).to(device)  # energy difference between positive and negative samples
r_s_t_record = t.zeros(config['num_train_iters']).to(device)  # average state gradient magnitude along Langevin path

print('Training has started.')
for i in range(config['num_train_iters']):
    # obtain positive and negative samples
    x_q = sample_q()
    x_s_t, r_s_t = sample_s_t(batch_size=config['batch_size'])

    # calculate ML computational loss d_s_t (Section 3) for data and shortrun samples
    d_s_t = f(x_q).mean() - f(x_s_t).mean()
    if config['epsilon'] > 0:
        # scale loss with the langevin implementation
        d_s_t *= 2 / (config['epsilon'] ** 2)
    # stochastic gradient ML update for model weights
    optim.zero_grad()
    d_s_t.backward()
    optim.step()

    # record diagnostics
    d_s_t_record[i] = d_s_t.detach().data
    r_s_t_record[i] = r_s_t

    # anneal learning rate
    for lr_gp in optim.param_groups:
        lr_gp['lr'] = max(config['lr_min'], lr_gp['lr'] * config['lr_decay'])

    # print and save learning info
    if (i + 1) == 1 or (i + 1) % config['log_info_freq'] == 0:
        print('{:>6d}   d_s_t={:>14.9f}   r_s_t={:>14.9f}'.format(i+1, d_s_t.detach().data, r_s_t))
        # save network weights
        t.save(f.state_dict(), EXP_DIR + 'checkpoints/' + 'net_{:>06d}.pth'.format(i+1))
        # plot diagnostics for energy difference d_s_t and gradient magnitude r_t
        if (i + 1) > 1:
            plot_diagnostics(i, d_s_t_record, r_s_t_record, EXP_DIR + 'plots/')

    # visualize density and log-density for groundtruth, learned energy, and short-run distributions
    if (i + 1) % config['log_viz_freq'] == 0:
        print('{:>6}   Visualizing true density, learned density, and short-run KDE.'.format(i+1))
        x_kde = sample_s_t(batch_size=config['batch_size_kde'], update_s_t_0=False)[0]
        q.plot_toy_density(True, f, config['epsilon'], x_kde, EXP_DIR+'landscape/'+'toy_viz_{:>06d}.pdf'.format(i+1))
        print('{:>6}   Visualizations saved.'.format(i + 1))

In [None]:
NOTE: Before running train_toy.py code above:

1) copy updated code for utils.py and train_toy.py to VSCode, and

2) copy updated files (utils.py, train_toy.py, toy_config.json) from VSCode to Google Drive.