# Dataset

In [1]:
import torch
import random
import sklearn
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sympy.combinatorics.graycode import GrayCode
torch.cuda.set_device(7)

def inf_train_gen(data, rng=None, batch_size=200):
    if rng is None:
        rng = np.random.RandomState()

    if data == "swissroll":
        data = sklearn.datasets.make_swiss_roll(n_samples=batch_size, noise=1.0)[0]
        data = data.astype("float32")[:, [0, 2]]
        data /= 5
        return data

    elif data == "circles":
        data = sklearn.datasets.make_circles(n_samples=batch_size, factor=.5, noise=0.08)[0]
        data = data.astype("float32")
        data *= 3
        return data

    elif data == "moons":
        data = sklearn.datasets.make_moons(n_samples=batch_size, noise=0.1)[0]
        data = data.astype("float32")
        data = data * 2 + np.array([-1, -0.2])
        return data

    elif data == "8gaussians":
        scale = 4.
        centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)),
                   (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2),
                                                         1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))]
        centers = [(scale * x, scale * y) for x, y in centers]

        dataset = []
        for i in range(batch_size):
            point = rng.randn(2) * 0.5
            idx = rng.randint(8)
            center = centers[idx]
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)
        dataset = np.array(dataset, dtype="float32")
        dataset /= 1.414
        return dataset

    elif data == "25gaussians":
        scale = 4.
        centers = [(x/2, y/2) for x in range(-2,3) for y in range(-2,3)]
        centers = [(scale * x, scale * y) for x, y in centers]

        dataset = []
        for i in range(batch_size):
            point = rng.randn(2) * 0.2
            idx = rng.randint(25)
            center = centers[idx]
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)
        dataset = np.array(dataset, dtype="float32")
        dataset /= 1.414
        return dataset

    elif data == "pinwheel":
        radial_std = 0.3
        tangential_std = 0.1
        num_classes = 5
        num_per_class = batch_size // 5
        rate = 0.25
        rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)

        features = rng.randn(num_classes*num_per_class, 2) \
            * np.array([radial_std, tangential_std])
        features[:, 0] += 1.
        labels = np.repeat(np.arange(num_classes), num_per_class)

        angles = rads[labels] + rate * np.exp(features[:, 0])
        rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
        rotations = np.reshape(rotations.T, (-1, 2, 2))

        return 2 * rng.permutation(np.einsum("ti,tij->tj", features, rotations))

    elif data == "2spirals":
        n = np.sqrt(np.random.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360
        d1x = -np.cos(n) * n + np.random.rand(batch_size // 2, 1) * 0.5
        d1y = np.sin(n) * n + np.random.rand(batch_size // 2, 1) * 0.5
        x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3
        x += np.random.randn(*x.shape) * 0.1
        return x

    elif data == "checkerboard":
        x1 = np.random.rand(batch_size) * 4 - 2
        x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2
        x2 = x2_ + (np.floor(x1) % 2)
        return np.concatenate([x1[:, None], x2[:, None]], 1) * 2
    
    elif data == "rings":
        from scipy.stats import multivariate_normal, truncnorm
        toy_sample = np.zeros(0).reshape(0, 2)

        toy_groups = 4
        weights = np.ones(toy_groups) / toy_groups
        sample_group_sz = np.random.multinomial(batch_size, weights)

        toy_radius = 0.8
        toy_sd = 0.05
        for i in range(toy_groups):
            truncnorm_rv = truncnorm(
                a=(0 - toy_radius * (i + 1)) / toy_sd,
                b=np.inf,
                loc=toy_radius * (i + 1),
                scale=toy_sd,
            )
            # sample_radii = self.toy_radius*(i+1) + self.toy_sd * np.random.randn(sample_group_sz[i])
            sample_radii = truncnorm_rv.rvs(sample_group_sz[i])
            sample_thetas = 2 * np.pi * np.random.random(sample_group_sz[i])
            sample_x = sample_radii.reshape(-1, 1) * np.cos(sample_thetas).reshape(
                -1, 1
            )
            sample_y = sample_radii.reshape(-1, 1) * np.sin(sample_thetas).reshape(
                -1, 1
            )
            sample_group = np.concatenate((sample_x, sample_y), axis=1)
            toy_sample = np.concatenate(
                (toy_sample, sample_group.reshape(-1, 2)), axis=0
            )
        return toy_sample

    else:
        raise NotImplementedError

class OnlineToyDataset(object):
  """Wrapper of inf_datagen."""

  def __init__(self, data_name, vocab_size, discrete_dim):
    self.dim = vocab_size
    self.data_name = data_name
    self.rng = np.random.RandomState()
    self.vocab_size = vocab_size
    self.discrete_dim = discrete_dim

    rng = np.random.RandomState(1)
    samples = inf_train_gen(self.data_name, rng, 5000)
    self.f_scale = np.max(np.abs(samples)) + 1
    self.int_scale = self.dim ** (discrete_dim // 2) / (self.f_scale + 1)
    # print("f_scale,", self.f_scale, "int_scale,", self.int_scale)

  def gen_batch(self, batch_size):
    return inf_train_gen(self.data_name, self.rng, batch_size)

  def data_gen(self, batch_size):
    while True:
      yield self.gen_batch(batch_size)

def float2base(samples, discrete_dim, f_scale, int_scale, vocab_size):
    base_list = []
    for i in range(samples.shape[0]):
        x, y = (samples[i] + f_scale) / 2 * int_scale
        bx, by = compress(x, discrete_dim, vocab_size), compress(y, discrete_dim, vocab_size)
        base_list.append(np.array(list(bx + by), dtype=int))
    return np.array(base_list)

def base2float(samples, discrete_dim, f_scale, int_scale, vocab_size):
    """Convert binary to float numpy."""
    floats = []
    for i in range(samples.shape[0]):
        s = ''
        for j in range(samples.shape[1]):
            s += str(samples[i, j])
        x, y = s[:discrete_dim//2], s[discrete_dim//2:]
        x, y = recover(x, vocab_size), recover(y, vocab_size)
        x = x / int_scale * 2. - f_scale
        y = y / int_scale * 2. - f_scale
        floats.append((x, y))
    return np.array(floats)

def compress(x, discrete_dim, vocab_size):
    bx = np.base_repr(int(abs(x)), base=vocab_size).zfill(discrete_dim // 2)
    return bx

def recover(bx, vocab_size):
    x = int(bx, vocab_size)
    return x

def get_batch_data(db, batch_size):
    discrete_dim = 16
    vocab_size = 5

    bx = db.gen_batch(batch_size)
    f_scale = db.f_scale
    int_scale = db.int_scale

    bx = float2base(bx, discrete_dim, f_scale, int_scale, vocab_size)
    return bx

def plot_heat(score_func, db, device):
    score_func.eval()
    w = 100
    x = np.linspace(-db.f_scale, db.f_scale, w)
    y = np.linspace(-db.f_scale, db.f_scale, w)
    xx, yy = np.meshgrid(x, y)
    xx = np.reshape(xx, [-1, 1])
    yy = np.reshape(yy, [-1, 1])
    heat_samples = float2base(np.concatenate((xx, yy), axis=-1), db.discrete_dim, db.f_scale, db.int_scale, db.vocab_size)
    heat_samples = torch.from_numpy(np.float32(heat_samples)).to(device)
    heat_score = F.softmax(-score_func(heat_samples).view(1, -1), dim=-1)
    a = heat_score.view(w, w).data.cpu().numpy()
    a = np.flip(a, axis=0)
    plt.imshow(a)
    plt.axis('equal')
    plt.axis('off')
    plt.show()
    plt.close()


# Models and Training Algorithms

In [2]:
import torch.nn as nn
import torch.distributions as dists
from torch.distributions.bernoulli import Bernoulli
from torch.distributions.categorical import Categorical

class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()
        self.beta = nn.Parameter(torch.tensor(1.0))

    def forward(self, x):
        return x * torch.sigmoid(self.beta * x)

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims):
        super(MLP, self).__init__()
        self.input_dim = input_dim

        if isinstance(hidden_dims, str):
            hidden_dims = list(map(int, hidden_dims.split("-")))
        assert len(hidden_dims)
        hidden_dims = [input_dim] + hidden_dims
        self.output_size = hidden_dims[-1]

        list_layers = []

        for i in range(1, len(hidden_dims)):
            list_layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
            if i + 1 < len(hidden_dims):  # not the last layer
                list_layers.append(Swish())

        self.main = nn.Sequential(*list_layers)

    def forward(self, z):
        x = self.main(z)
        return x

class MLPScore(nn.Module):
    def __init__(self, input_dim, hidden_dims):
        super(MLPScore, self).__init__()
        self.mlp = MLP(input_dim, hidden_dims)

    def forward(self, z):
        raw_score = self.mlp(z.float())
        return raw_score

class EBM(nn.Module):
    def __init__(self, net, emb_dim=4, vocab_size=5):
        super().__init__()
        self.net = net
        self.emb_dim = emb_dim
        self.vocab_size = vocab_size
        self.embedding = nn.Linear(self.vocab_size, self.emb_dim)

    def forward(self, x):
        '''we define p(x) = exp(-f(x)) / Z, the output of net is f(x)
        x: [bs, dim, vocab_size] or [bs, dim], we use one-hot encoding
        '''
        if len(x.shape) == 2:
            x = F.one_hot(x.long(), num_classes=self.vocab_size).float()
            emb = self.embedding(x)
        else:
            emb = self.embedding(x)
        emb = emb.view(x.shape[0], -1)

        logp = self.net(emb).squeeze()
        return logp


# Loss Function

In [3]:
def perturb_cat_grid(samples, num_classes, m_particles=32):
    device = samples.device
    bs, C = samples.shape # C is the number of categorical entries

    # number of perturbed columns
    l = 1
    assert l == 1, 'Only one perturbed column is supported'

    # Per sample noise
    noise = torch.rand(bs, C, device = device)
    ids_perm = torch.argsort(noise, dim = -1)
    ids_restore = torch.argsort(ids_perm, dim = -1)

    samples_keep = torch.gather(samples, dim = -1, index = ids_perm[:, :-l]).unsqueeze(1).expand(bs, m_particles, -1)

    pert_num_classes = torch.gather(num_classes, dim = 0, index = ids_perm[:, -l])

    uniform_noise = torch.rand(bs, m_particles, device = device)
    scaled_noise = torch.einsum('bm, b -> bm', uniform_noise, pert_num_classes)

    samples_pert = scaled_noise.int().unsqueeze(-1)

    pert_samples = torch.cat([samples_keep, samples_pert], dim = -1)
    pert_samples = torch.gather(pert_samples, dim = -1, index = ids_restore.unsqueeze(1).expand(bs, m_particles, -1))
    return pert_samples

def ed_categorical(energy_net, samples, K = 5, dim=16, epsilon = 1., m_particles = 32, w_stable = 1.):
    """ Perturbation assumes periodic structure on discrete values"""
    device = samples.device
    bs, dim = samples.shape

    num_classes = torch.tensor([K] * dim, device=device)
    neg_data = perturb_cat_grid(samples, num_classes, m_particles)   # [bs, m_particles, dim]

    pos_energy = energy_net(samples)   # [bs]
    neg_energy = energy_net(neg_data.view(-1, dim)).view(bs, -1)  # [bs, m_particles]
    val = pos_energy.view(bs, 1) - neg_energy
    if w_stable != 0:
        val = torch.cat([val, np.log(w_stable) * torch.ones_like(val[:, :1])], dim=-1)
    
    loss = val.logsumexp(dim=-1).mean()
    return loss

# Training Main Loop

In [4]:
import copy
from tqdm import tqdm
import torch.optim as optim

def training_main_loop(energy_func, db, device):
    ema_energy_func = copy.deepcopy(energy_func)
    ema_energy_func.to(device)
    opt_energy = optim.Adam(energy_func.parameters(), lr=1e-4)

    for epoch in range(1000):
        pbar = tqdm(range(100))
        for it in pbar:
            samples = get_batch_data(db, 128)
            samples = torch.from_numpy(np.float32(samples)).to(device)

            loss = ed_categorical(energy_func, samples, K=db.vocab_size, dim=db.discrete_dim)
            
            opt_energy.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(energy_func.parameters(), max_norm=5)
            opt_energy.step()
            
            with torch.no_grad():
                for p, ema_p in zip(energy_func.parameters(), ema_energy_func.parameters()):
                    ema_p.data = ema_p.data * 0.999 + p.data * (1. - 0.999)

            pbar.set_description('epoch: %d, loss: %.4f' % (epoch, loss.item()))

    return ema_energy_func

# Training Discrete EBM using Grid Perturbation

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

emb_dim = 4
vocab_size = 5
discrete_dim = 16
db = OnlineToyDataset('pinwheel', vocab_size, discrete_dim)
net = MLPScore(discrete_dim*emb_dim, [256] * 3 + [1]).to(device)
energy_func = EBM(net, emb_dim=4, vocab_size=5).to(device)
ema_energy_func = training_main_loop(energy_func, db, device)
plot_heat(energy_func, db, device)

epoch: 0, loss: 3.4812: 100%|██████████| 100/100 [00:01<00:00, 54.41it/s]
epoch: 1, loss: 3.4539: 100%|██████████| 100/100 [00:01<00:00, 79.28it/s]
epoch: 2, loss: 3.4354: 100%|██████████| 100/100 [00:01<00:00, 99.63it/s]
epoch: 3, loss: 3.4552: 100%|██████████| 100/100 [00:01<00:00, 93.74it/s]
epoch: 4, loss: 3.4621: 100%|██████████| 100/100 [00:01<00:00, 81.09it/s]
epoch: 5, loss: 3.4033: 100%|██████████| 100/100 [00:01<00:00, 87.50it/s]
epoch: 6, loss: 3.4388: 100%|██████████| 100/100 [00:01<00:00, 88.81it/s]
epoch: 7, loss: 3.3874: 100%|██████████| 100/100 [00:01<00:00, 73.23it/s]
epoch: 8, loss: 3.4208: 100%|██████████| 100/100 [00:01<00:00, 80.81it/s]
epoch: 9, loss: 3.4277: 100%|██████████| 100/100 [00:01<00:00, 85.97it/s]
epoch: 10, loss: 3.3848:  10%|█         | 10/100 [00:00<00:00, 95.81it/s]