In [None]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid

bs = 256
# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [130]:
import numpy as np
import torch
from scipy import special as sp


def GVar(x):
    return x.to("cuda")  # pytorch 0.4.1

class vMF(torch.nn.Module):
    def __init__(self, hid_dim, lat_dim, kappa=1):
        """
        von Mises-Fisher distribution class with batch support and manual tuning kappa value.
        Implementation follows description of my paper and Guu's.
        """

        super().__init__()
        self.hid_dim = hid_dim
        self.lat_dim = lat_dim
        self.kappa = kappa
        # self.func_kappa = torch.nn.Linear(hid_dim, lat_dim)
        self.func_mu = torch.nn.Linear(hid_dim, lat_dim)

        self.kld = GVar(torch.from_numpy(vMF._vmf_kld(kappa, lat_dim)).float())
        print('KLD: {}'.format(self.kld.data[0]))

    def estimate_param(self, latent_code):
        ret_dict = {}
        ret_dict['kappa'] = self.kappa

        # Only compute mu, use mu/mu_norm as mu,
        #  use 1 as norm, use diff(mu_norm, 1) as redundant_norm
        mu = self.func_mu(latent_code)

        norm = torch.norm(mu, 2, 1, keepdim=True)
        mu_norm_sq_diff_from_one = torch.pow(torch.add(norm, -1), 2)
        redundant_norm = torch.sum(mu_norm_sq_diff_from_one, dim=1, keepdim=True)
        ret_dict['norm'] = torch.ones_like(mu)
        ret_dict['redundant_norm'] = redundant_norm

        mu = mu / torch.norm(mu, p=2, dim=1, keepdim=True)
        ret_dict['mu'] = mu

        return ret_dict

    def compute_KLD(self, tup, batch_sz):
        return self.kld.expand(batch_sz)

    @staticmethod
    def _vmf_kld(k, d):
        tmp = (k * ((sp.iv(d / 2.0 + 1.0, k) + sp.iv(d / 2.0, k) * d / (2.0 * k)) / sp.iv(d / 2.0, k) - d / (2.0 * k)) \
               + d * np.log(k) / 2.0 - np.log(sp.iv(d / 2.0, k)) \
               - sp.loggamma(d / 2 + 1) - d * np.log(2) / 2).real
        if tmp != tmp:
            exit()
        return np.array([tmp])

    @staticmethod
    def _vmf_kld_davidson(k, d):
        """
        This should be the correct KLD.
        Empirically we find that _vmf_kld (as in the Guu paper) only deviates a little (<2%) in most cases we use.
        """
        tmp = k * sp.iv(d / 2, k) / sp.iv(d / 2 - 1, k) + (d / 2 - 1) * torch.log(k) - torch.log(
            sp.iv(d / 2 - 1, k)) + np.log(np.pi) * d / 2 + np.log(2) - sp.loggamma(d / 2).real - (d / 2) * np.log(
            2 * np.pi)
        if tmp != tmp:
            exit()
        return np.array([tmp])

    def build_bow_rep(self, lat_code, n_sample):
        batch_sz = lat_code.size()[0]
        tup = self.estimate_param(latent_code=lat_code)
        mu = tup['mu']
        norm = tup['norm']
        kappa = tup['kappa']

        kld = self.compute_KLD(tup, batch_sz)
        vecs = []
        if n_sample == 1:
            return tup, kld, self.sample_cell(mu, norm, kappa)
        for n in range(n_sample):
            sample = self.sample_cell(mu, norm, kappa)
            vecs.append(sample)
        vecs = torch.cat(vecs, dim=0)
        return tup, kld, vecs

    def sample_cell(self, mu, norm, kappa):
        batch_sz, lat_dim = mu.size()
        # mu = GVar(mu)
        mu = mu / torch.norm(mu, p=2, dim=1, keepdim=True)
        w = self._sample_weight_batch(kappa, lat_dim, batch_sz)
        w = w.unsqueeze(1)

        # batch version
        w_var = GVar(w * torch.ones(batch_sz, lat_dim))
        v = self._sample_ortho_batch(mu, lat_dim)
        scale_factr = torch.sqrt(
            GVar(torch.ones(batch_sz, lat_dim)) - torch.pow(w_var, 2))
        orth_term = v * scale_factr
        muscale = mu * w_var
        sampled_vec = orth_term + muscale

        return sampled_vec.unsqueeze(0)

    def _sample_weight_batch(self, kappa, dim, batch_sz=1):
        result = torch.FloatTensor((batch_sz))
        for b in range(batch_sz):
            result[b] = self._sample_weight(kappa, dim)
        return result

    def _sample_weight(self, kappa, dim):
        """Rejection sampling scheme for sampling distance from center on
        surface of the sphere.
        """
        dim = dim - 1  # since S^{n-1}
        b = dim / (np.sqrt(4. * kappa ** 2 + dim ** 2) + 2 * kappa)  # b= 1/(sqrt(4.* kdiv**2 + 1) + 2 * kdiv)
        x = (1. - b) / (1. + b)
        c = kappa * x + dim * np.log(1 - x ** 2)  # dim * (kdiv *x + np.log(1-x**2))

        while True:
            z = np.random.beta(dim / 2., dim / 2.)  # concentrates towards 0.5 as d-> inf
            w = (1. - (1. + b) * z) / (1. - (1. - b) * z)
            u = np.random.uniform(low=0, high=1)
            if kappa * w + dim * np.log(1. - x * w) - c >= np.log(
                    u):  # thresh is dim *(kdiv * (w-x) + log(1-x*w) -log(1-x**2))
                return w

    def _sample_ortho_batch(self, mu, dim):
        """

        :param mu: Variable, [batch size, latent dim]
        :param dim: scala. =latent dim
        :return:
        """
        _batch_sz, _lat_dim = mu.size()
        assert _lat_dim == dim
        squeezed_mu = mu.unsqueeze(1)

        v = GVar(torch.randn(_batch_sz, dim, 1))  # TODO random

        # v = GVar(torch.linspace(-1, 1, steps=dim))
        # v = v.expand(_batch_sz, dim).unsqueeze(2)

        rescale_val = torch.bmm(squeezed_mu, v).squeeze(2)
        proj_mu_v = mu * rescale_val
        ortho = v.squeeze() - proj_mu_v
        ortho_norm = torch.norm(ortho, p=2, dim=1, keepdim=True)
        y = ortho / ortho_norm
        return y

    def _sample_orthonormal_to(self, mu, dim):
        """Sample point on sphere orthogonal to mu.
        """
        v = GVar(torch.randn(dim))  # TODO random

        # v = GVar(torch.linspace(-1,1,steps=dim))

        rescale_value = mu.dot(v) / mu.norm()
        proj_mu_v = mu * rescale_value.expand(dim)
        ortho = v - proj_mu_v
        ortho_norm = torch.norm(ortho)
        return ortho / ortho_norm.expand_as(ortho)


import torch.nn as nn
import torch
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Gauss(nn.Module):
    # __slots__ = ['lat_dim', 'logvar', 'mean']

    def __init__(self, hid_dim, lat_dim):
        super().__init__()
        self.hid_dim = hid_dim
        self.lat_dim = lat_dim
        self.func_mean = torch.nn.Linear(hid_dim, lat_dim)
        self.func_logvar = torch.nn.Linear(hid_dim, lat_dim)

    def estimate_param(self, latent_code):
        mean = self.func_mean(latent_code)
        logvar = self.func_logvar(latent_code)
        return {'mean': mean, 'logvar': logvar}

    def compute_KLD(self, tup):
        mean = tup['mean']
        logvar = tup['logvar']

        kld = -0.5 * torch.sum(1 - torch.mul(mean, mean) +
                               2 * logvar - torch.exp(2 * logvar), dim=1)
        return kld

    def sample_cell(self, batch_size):
        eps = torch.autograd.Variable(torch.normal(torch.zeros((batch_size, self.lat_dim))))
        eps.to(device)
        return eps.unsqueeze(0)

    def build_bow_rep(self, lat_code, n_sample):
        batch_sz = lat_code.size()[0]
        tup = self.estimate_param(latent_code=lat_code)
        mean = tup['mean']
        logvar = tup['logvar']

        kld = self.compute_KLD(tup)
        if n_sample == 1:
            eps = self.sample_cell(batch_size=batch_sz).to(device)
            vec = torch.mul(torch.exp(logvar), eps) + mean
            return tup, kld, vec

        vecs = []
        for ns in range(n_sample):
            eps = self.sample_cell(batch_size=batch_sz)
            vec = torch.mul(torch.exp(logvar), eps) + mean
            vecs.append(vec)
        vecs = torch.cat(vecs, dim=0)
        return tup, kld, vecs

    def get_aux_loss_term(self, tup):

        return torch.from_numpy(np.zeros([1]))


In [None]:
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim, prior="gauss", hparam=None):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        self.prior = prior
        if prior == "vmf":
            self.hparam = 1 if hparam is None else hparam
            self.dist = vMF(z_dim*2, z_dim, kappa=self.hparam)
            
        elif prior == "gauss":
            self.dist = Gauss(z_dim*2, z_dim)
            self.hparam = None
        
        self.sampled_z = None
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    def div_loss(self, z1, z2):
        linear_x_2 = torch.cat([z1, z2], dim=-1)
        tup = self.dist.estimate_param(latent_code=linear_x_2)
        batch_sz = z1.size()[0]
        if self.prior == "gauss":
            kld = self.dist.compute_KLD(tup)
        else:
            kld = self.dist.compute_KLD(tup, batch_sz)
        return kld.sum()
    def sampling(self, mu, log_var):
        linear_x_2 = torch.cat([mu, log_var], dim=-1)
        tup, _, vecs = self.dist.build_bow_rep(linear_x_2, 1)
        batch_sz = mu.size()[0]
        return vecs.view(batch_sz, -1)
    
    def log_latents(self, z1, z2):
        linear_x_2 = torch.cat([z1, z2], dim=-1)
        # matplot to show the distribution of the latent space
        # sub axis, 2 rows  , 2 columns
        fig, axs = plt.subplots(2, 2)
        axs[0, 0].hist(z1.cpu().detach().numpy())
        axs[0, 0].set_title('z1 histogram')
        axs[0, 1].hist(z2.cpu().detach().numpy())
        axs[0, 1].set_title('z2 histogram')
        axs[1, 0].plot(self.sampled_z.cpu().detach().numpy())
        axs[1, 0].set_title('sampled_z')
        axs[1, 1].hist(self.sampled_z.flatten().cpu().detach().numpy())
        axs[1, 1].set_title('sampled z histogram')
        plt.tight_layout()
        
        plt.savefig("./samples/latent_"+self.prior+ f"_{self.hparam or ''}"+".png") 
        plt.close()
            # return 

    def generate_sample(self):
        if self.prior == "gauss":
            z = torch.randn(64, 4).cuda()
            sample = vae.decoder(z).cuda()
            
            save_image(sample.view(64, 1, 28, 28), './samples/sample_' + self.prior + f"_{self.hparam or ''}" + '.png')
            print("sample generated")
        elif self.prior == "vmf":
            mu = torch.rand(25, 4).cuda()
            mu = mu / torch.norm(mu, p=2, dim=1, keepdim=True)
            images = []
            for i in torch.linspace(0.1, 15, steps=5):
                # i = i.round()
                z = self.dist.sample_cell(mu, None, i)
                sample = vae.decoder(z).cuda()

                grid = make_grid(sample.view(25, 1, 28, 28), nrow=5)
                # Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
                ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
                _im = Image.fromarray(ndarr)
                im = Image.new('RGB', (_im.width, _im.height+40), (255, 255, 255))
                im.paste(_im, (0, 0))
                draw = ImageDraw.Draw(im)
                # font with size 30px
                font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", size=20)
                draw.text((
                    20 , _im.height + 10), f"Kappa: {i:.2f}", (0, 0, 0), font)
                    # , _im.height), f"Kappa: {i}", (0, 0, 0))
                # im.save(fp, format=format)
                images.append(im)
            image_draw = Image.new('RGB', (
                len(images) * (images[0].size[0] + 10), images[0].size[1]), (255, 255, 255))
                # , 280))
            for i, image in enumerate(images):
                image_draw.paste(image, (i * (images[0].size[0] + 10), 0))
            image_draw.save('./samples/sample_' + self.prior + f"_{self.hparam or ''}" + '.png')
            # z = self.dist.sample_cell(mu, None, self.hparam) # kappa is adjustable here
            
        

    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x):
        z1, z2 = self.encoder(x.view(-1, 784))
        z = self.sampling(z1, z2)
        self.sampled_z = z
        return self.decoder(z),  z1, z2

# build model
vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=4, prior="vmf", hparam=50)
if torch.cuda.is_available():
    vae.cuda()

In [None]:
vae

In [157]:
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
# return reconstruction error + KL divergence losses
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD =  vae.div_loss(mu, log_var)
    return BCE + KLD

In [158]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.cuda()
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [159]:
def test():
    vae.eval()
    test_loss= 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.cuda()
            recon, mu, log_var = vae(data)
            
            
            # sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var).item()
        vae.log_latents(mu, log_var)
        vae.generate_sample()
        
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [None]:
for epoch in range(1, 51):
    train(epoch)
    test()