In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader


class BayesNetDataset(Dataset):
    def __init__(self, is_train=True):
        if is_train:
            self.x = np.load('./data/dag_mat_child3Domains.npz')
        else:
            self.x = np.load('./data/dag_mat_child3Domains.npz')

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return torch.from_numpy(self.x[idx,:,:]).float()

def get_bayesnet_dataloaders(length, batch_size):
    # Create dataloaders
    train_loader = DataLoader(BayesNetDataset(is_train=True), batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(BayesNetDataset(is_train=False), batch_size=batch_size, shuffle=True)
    return train_loader, test_loader

In [None]:
import numpy as np
import pickle
import torch
from torch.autograd import Variable
from torch.nn import ModuleList
import matplotlib.pyplot as plt
from IPython import display
import time
import pprint

from dataloaders import *
from models import Generator, Discriminator
from train import Trainer
from eval import *


def get_modified_network(network, dist):
    np.random.seed(int(time.time()))
    old_network = network
    network = network.copy()
    length = len(network)
    for _ in range(dist):
        while True:
            i = np.random.choice(length)
            j = np.random.choice(length)
            if i!=j and j in network[i]:
                tmp = network[i].copy()
                tmp.remove(j)
                network[i] = tmp
                break
        while True:
            i = np.random.choice(length)
            j = np.random.choice(length)
            if i!=j and j not in network[i] and j not in old_network[i]:
                tmp = network[i].copy()
                tmp.append(j)
                network[i] = tmp
                break
    return network


def get_random_network(length, p):
    np.random.seed(int(time.time()))
    network = []
    for i in range(length):
        n = [i]
        for j in range(length):
            if i!=j and np.random.uniform()<p:
                n.append(j)
        network.append(n)
    return network
    
TRUE_NETWORK = [[0, 7, 8], [1, 0, 7, 8], [2], [3, 2, 4], [4, 2], [5, 1, 7], [6, 5, 7], [7, 8], [8], [9, 7, 8], [10, 7, 8]]
WITHOUT_NETWORK = [list(range(len(TRUE_NETWORK)))]
MODIFIED_NETWORK = get_modified_network(TRUE_NETWORK, 3)
RANDOM_NETWORK = get_random_network(len(TRUE_NETWORK), 17/110)
NETWORK_DICT = {
    'true': TRUE_NETWORK,
    'without': WITHOUT_NETWORK,
    'modified': MODIFIED_NETWORK,
    'random': RANDOM_NETWORK,
}


def main(mode, network_types, expid, epochs, lr, batch_size, pretrain_epochs, ntrials):
    length = len(TRUE_NETWORK)
    network_dict = {}
    result_dict = {}
    losses_dict = {}
    for network_type in network_types:
        # Get network
        network = NETWORK_DICT[network_type]
    
        # Obtain models
        generator = Generator(length=length)
        discriminator = ModuleList([Discriminator(corrlength=len(n)) for n in network])

        # Obtain dataloader
        data_loader, test_data_loader = get_bayesnet_dataloaders(length=length, batch_size=batch_size)

        # Initialize optimizers
        G_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(.9, .99))
        D_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(.9, .99))
        G_scheduler = torch.optim.lr_scheduler.StepLR(G_optimizer, step_size=50, gamma=1.0)

        # Set up trainer
        trainer = Trainer(mode, network, network_type, generator, discriminator, G_optimizer, D_optimizer, G_scheduler, use_cuda=torch.cuda.is_available())

        # Train model
        expdesc = "{}-e{}lr{}bs{}pe{}nt{}-{}".format(mode, epochs, int(10000*lr), batch_size, pretrain_epochs, ntrials, expid)
        trainer.train(data_loader, test_data_loader, expdesc, epochs=epochs, pretrain_epochs=pretrain_epochs)
        
        # Get result
        display.clear_output(wait=True)
        generator.eval()
        result_dict[network_type] = {}
        ess = []
        dss = []
        for trial in range(ntrials):
            it_data = iter(test_data_loader)
            sampled_data = [next(it_data).data.numpy() for i in range(8)]
            sampled_data = np.concatenate(sampled_data, axis=0).argmax(2)
            fixed_latents = Variable(generator.sample_latent(sampled_data.shape[0]))
            generated = generator(fixed_latents.cuda(), 0.1).detach().cpu().data.numpy().argmax(2)
            ess.append(energy_statistics(sampled_data, generated))
            dss.append(discriminative_score(sampled_data, generated, diagnose=False))
            print("Evaluating... Progress {:.2f}%".format((trial+1)/ntrials*100), end='\r')
        ess = np.array(ess)
        dss = np.array(dss)
        result_dict[network_type]['energy_statistics'] = (ess.mean(), ess.std()/np.sqrt(ntrials))
        result_dict[network_type]['discriminative_score'] = (dss.mean(), dss.std()/np.sqrt(ntrials))
        losses_dict[network_type] = {}
        losses_dict[network_type]['wasserstein_loss'] = trainer.losses['D']
        losses_dict[network_type]['energy_statistics'] = trainer.losses['energy_statistics']
        network_dict[network_type] = network
        
    display.clear_output(wait=True)
    fig = plt.figure(figsize=(18, 16))
    gs0 = fig.add_gridspec(2, 1)
    ax1 = fig.add_subplot(gs0[0])
    ax2 = fig.add_subplot(gs0[1])
    for network_type in network_types:
        loss = losses_dict[network_type]['wasserstein_loss']
        ax1.plot(loss, label=network_type.capitalize())
    ax1.set_xlabel("Iterations", fontsize=18)
    ax1.set_ylabel("Wasserstein Estimations", fontsize=18)
    ax1.legend(fontsize=18)
    for network_type in network_types:
        ax2.plot(losses_dict[network_type]['energy_statistics'], label=network_type.capitalize())
    ax2.set_xlabel("Iterations", fontsize=18)
    ax2.set_ylabel("Energy Statistics", fontsize=18)
    ax2.legend(fontsize=18)
    newpath = './results/{}/'.format(expdesc)
    fig.savefig(newpath+'Losses.png')
    with open(newpath+'Losses.pickle', 'wb') as handle:
        pickle.dump(losses_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(newpath+'Networks.pickle', 'wb') as handle:
        pickle.dump(network_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(newpath+'Results.pickle', 'wb') as handle:
        pickle.dump(result_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    pprint.pprint(result_dict, width=200)

In [None]:
import os
import imageio
import numpy as np
import torch
import torch.nn as nn
from torch.nn import ModuleList
from torchvision.utils import make_grid
from torch.autograd import Variable
from torch.autograd import grad as torch_grad
import matplotlib.pyplot as plt
from IPython import display
import pickle
from skimage import img_as_ubyte
import time
from scipy.optimize import linear_sum_assignment
from sklearn.decomposition import PCA

from models import Generator, Discriminator
from eval import energy_statistics


class Trainer():
    def __init__(self, mode, network, expname, generator, discriminator, gen_optimizer, dis_optimizer, gen_scheduler,
                 gp_weight=10, critic_iterations=5, print_every=50, use_cuda=True):
        assert mode in ['W', 'JS', 'SH', 'KL', 'TV'], 'Invalid mode.'
        self.mode = mode
        self.network = network
        self.expname = expname
        self.G = generator
        self.G_opt = gen_optimizer
        self.G_sch = gen_scheduler
        self.D = discriminator
        self.D_opt = dis_optimizer
        self.losses = {'G': [], 'D': [], 'gradient_norm': [], 'd_real': [], 'd_generated': [], 'energy_statistics': []}
        self.num_steps = 0
        self.use_cuda = use_cuda
        self.gp_weight = gp_weight
        self.critic_iterations = critic_iterations
        self.print_every = print_every
        self.pca = PCA(n_components=2)
        if self.use_cuda:
            self.G.cuda()
            self.D.cuda()

    def _critic_train_iteration(self, data, latent_samples, tau, pretrain=False):
        # Get generated data
        generated_data = self.G(latent_samples, tau).detach()

        # Calculate probabilities on real and generated data
        d_real, d_generated = 0, 0
        for idx, n in enumerate(self.network):
            d_real += self.D[idx](data[:,n,:])
            d_generated += self.D[idx](generated_data[:,n,:])
            
        # Create total loss and optimize
        self.D_opt.zero_grad()
        d_loss = self._fdiv_activation(d_real, d_generated)
        
        # Record loss
        if not pretrain:
            self.losses['d_real'].append(d_real.mean().data.item())
            self.losses['d_generated'].append(d_generated.mean().data.item())
            self.losses['D'].append(-d_loss.data.item())        
        
        if self.mode == 'W':
            # Get gradient penalty
            gradient_penalty = self._gradient_penalty(data, generated_data)
            d_loss += gradient_penalty
            
        # Optimization
        d_loss.backward()
        self.D_opt.step()

    def _generator_train_iteration(self, data, latent_samples, tau):
        # Get generated data
        generated_data = self.G(latent_samples, tau)

        # Calculate loss and optimize
        d_generated = 0
        for idx, n in enumerate(self.network):
            d_generated += self.D[idx](generated_data[:,n,:])
        
        self.G_opt.zero_grad()
        g_loss = -self._fdiv_activation(torch.zeros_like(d_generated), d_generated)
        
        # Record loss
        self.losses['G'].append(g_loss.data.item())
        
        # Optimization
        g_loss.backward()
        self.G_opt.step()
        
    def _fdiv_activation(self, d_real, d_generated):
        # F-divergence Functions
        def gf(v):
            if self.mode == 'SH':
                return 1-torch.exp(-v)
            elif self.mode == 'KL':
                return v
            elif self.mode == 'JS':
                return np.log(2)-torch.log(1+torch.exp(-v))
            elif self.mode == 'TV':
                return torch.tanh(v)/2
            elif self.mode == 'W':
                return v
        def fs(t):
            if self.mode == 'SH':
                return t/(1-t)
            elif self.mode == 'KL':
                return torch.exp(t-1)
            elif self.mode == 'JS':
                return -torch.log(2-torch.exp(t))
            elif self.mode == 'TV':
                return t
            elif self.mode == 'W':
                return t
        return (fs(gf(d_generated)) - gf(d_real)).mean()
            

    def _gradient_penalty(self, real_data, generated_data):
        batch_size = real_data.size()[0]

        # Calculate interpolation
        alpha = torch.rand(batch_size, 1, 1)
        alpha = alpha.expand_as(real_data)
        if self.use_cuda:
            alpha = alpha.cuda()
        interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
        interpolated = Variable(interpolated, requires_grad=True)
        if self.use_cuda:
            interpolated = interpolated.cuda()

        # Calculate probability of interpolated examples
        prob_interpolated = 0
        for idx, n in enumerate(self.network):
            prob_interpolated += self.D[idx](interpolated[:,n,:])
        
        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
                               grad_outputs=torch.ones(prob_interpolated.size()).cuda() 
                               if self.use_cuda else torch.ones(prob_interpolated.size()),
                               create_graph=True, retain_graph=True)[0]

        # Gradients have shape (batch_size, num_channels, img_width, img_height),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)
        self.losses['gradient_norm'].append(gradients.norm(2, dim=1).mean().data.item())

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

        # Return gradient penalty
        return self.gp_weight * ((gradients_norm - 1) ** 2).mean()

    def _train_epoch(self, data_loader):
        for i, data in enumerate(data_loader):
            self.num_steps += 1
            
            batch_size = data.size()[0]
            data = Variable(data)
            if self.use_cuda:
                data = data.cuda()
            latent_samples = Variable(self.G.sample_latent(batch_size))
            if self.use_cuda:
                latent_samples = latent_samples.cuda()
            
            self._critic_train_iteration(data, latent_samples, self.tau)
            # Only update generator every |critic_iterations| iterations
            if self.num_steps % self.critic_iterations == 0:
                self._generator_train_iteration(data, latent_samples, self.tau)
            
            # Visualization
            if self.num_steps % self.print_every == 0 and self.num_steps > self.critic_iterations:               
                display.clear_output(wait=True)
                self.fig = plt.figure(figsize=(21, 9))
                self.fig.suptitle("Network Type: {}, Epoch: {}/{}".format(self.expname.capitalize(), self.epoch, self.epochs), fontsize=24)
                gs0 = self.fig.add_gridspec(1, 2)
                gs1 = gs0[0].subgridspec(2, 1)
                ax1 = self.fig.add_subplot(gs1[0])
                ax2 = self.fig.add_subplot(gs1[1])
                ax3 = self.fig.add_subplot(gs0[1])
                
                ax1.plot(self.losses['D'])
                ax1.set_xlabel('Iterations', fontsize=18)
                ax1.set_ylabel('Discriminator Loss', fontsize=18)
                
                self.generated_data = self.G(self.fixed_latents, self.tau).detach().cpu().data.numpy().argmax(2)
                self.losses['energy_statistics'].append(energy_statistics(self.sampled_data, self.generated_data))
                ax2.plot(self.print_every*np.arange(len(self.losses['energy_statistics'])), self.losses['energy_statistics'])
                ax2.set_xlabel('Iterations', fontsize=18)
                ax2.set_ylabel('Energy Statistics', fontsize=18)
                
                generated_dots = self.pca.transform(self.generated_data)
                ax3.scatter(self.true_dots[:,0], self.true_dots[:,1])
                ax3.scatter(generated_dots[:,0], generated_dots[:,1])
                plt.show()
                self.training_progress_images.append(img_as_ubyte(self._draw_img_grid(generated_dots)))
        self.G_sch.step()
        
    def _pretrain_epoch(self, data_loader):
        for i, data in enumerate(data_loader):            
            batch_size = data.size()[0]
            data = Variable(data)
            if self.use_cuda:
                data = data.cuda()
            latent_samples = Variable(self.G.sample_latent(batch_size))
            if self.use_cuda:
                latent_samples = latent_samples.cuda()
            self._critic_train_iteration(data, latent_samples, tau=1, pretrain=True)
                
    def _draw_img_grid(self, generated_dots):
        fig = plt.figure(figsize=(5, 5))
        plt.scatter(self.true_dots[:,0], self.true_dots[:,1])
        plt.scatter(generated_dots[:,0], generated_dots[:,1])
        plt.title('Epoch {}'.format(self.epoch), fontsize=12)
        fig.canvas.draw()       # draw the canvas, cache the renderer
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
        image  = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        plt.close()
        return image

    def train(self, data_loader, test_data_loader, expdesc, epochs, pretrain_epochs):
        # Fix latents to see how image generation improves during training
        it_data = iter(test_data_loader)
        sampled_data = [next(it_data).data.numpy() for i in range(8)]
        self.sampled_data = np.concatenate(sampled_data, axis=0).argmax(2)
        self.fixed_latents = Variable(self.G.sample_latent(self.sampled_data.shape[0]))
        if self.use_cuda:
            self.fixed_latents = self.fixed_latents.cuda()
        self.pca.fit(self.sampled_data)
        self.true_dots = self.pca.transform(self.sampled_data)
        self.training_progress_images = []
        
        display.clear_output(wait=True)
        for epoch in range(pretrain_epochs):
            self._pretrain_epoch(data_loader)
            print("Pretraining... Progress {:.2f}%".format((epoch+1)/pretrain_epochs*100), end='\r')
    
        self.epochs = epochs
        for epoch in range(epochs):
            self.epoch = epoch + 1
            self.tau = 1 - 0.9*(self.epoch/self.epochs)
            self._train_epoch(data_loader)
        
        self.expdesc = expdesc
        newpath = './results/{}/'.format(self.expdesc) 
        if not os.path.exists(newpath):
            os.makedirs(newpath)
        imageio.mimsave(newpath+'{}.gif'.format(self.expname), self.training_progress_images, 
                        format='GIF', duration=10.0 / len(self.training_progress_images))
        with open(newpath+'{}.pickle'.format(self.expname), 'wb') as handle:
            pickle.dump(self.losses, handle, protocol=pickle.HIGHEST_PROTOCOL)
        self.fig.savefig(newpath+'{}.png'.format(self.expname))
        torch.save(self.G.state_dict(), newpath+'{}.pt'.format(self.expname+'_G'))
        torch.save(self.D.state_dict(), newpath+'{}.pt'.format(self.expname+'_D'))

In [None]:
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator


def scatter_plot(real_ebd, fake_ebd, save_path=None):
    '''Draw the scatter plot to compare between the real and fake distributions.
    
    :param (ndarray, float64) read_ebd: 2-dim embedding of the real samples.
    :param (ndarray, float64) fake_ebd: 2-dim embedding of the fake samples.
    :param (str) save_path: path to save the figure.
    '''
    # Validation
    assert len(real_ebd.shape) == 2 and len(fake_ebd.shape) == 2 and real_ebd.shape[0] == fake_ebd.shape[0] and real_ebd.shape[1] == 2 and fake_ebd.shape[1] == 2, 'Invalid real_ebd and fake_ebd.'
    assert save_path is None or isinstance(save_path, str), 'Invalid save_path.'
    # Drawing
    fig = plt.figure(figsize=(10, 10))
    fig.tight_layout(pad=0)
    plt.scatter(real_ebd[:,0], real_ebd[:,1], label='Real', alpha=1)
    plt.scatter(fake_ebd[:,0], fake_ebd[:,1], label='Fake', alpha=1)
    plt.legend(loc='upper left', fontsize=30)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.gca().yaxis.get_offset_text().set_size(16)
    if save_path is not None:
        plt.savefig(save_path)