# GRAE tests
-> Implemented manually as the import has dependency flaws

In [1]:
# Imports
import os
import shutil
import copy
from six.moves import cPickle as pickle #for performance

import warnings  # Ignore sklearn future warning
warnings.simplefilter(action='ignore', category=FutureWarning)

import numpy as np
import pandas as pd
import torch
import math
import scipy
import torch.nn.functional as F
import torch.nn as nn
from sklearn.manifold import TSNE
from sklearn.decomposition import FastICA, PCA
from sklearn.feature_selection import mutual_info_regression
from scipy.stats import pearsonr, spearmanr

import time

from sklearn.model_selection import train_test_split
from skbio.stats.distance import mantel

# Datasets import
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.datasets as torch_datasets
from sklearn import datasets
from sklearn.metrics import mean_squared_error
from scipy.spatial.distance import pdist, squareform
from scipy import ndimage
from scipy.stats import pearsonr
import scprep


import urllib
from scipy.io import loadmat

# Model imports
import umap
import phate

# Plots
import matplotlib.pyplot as plt

ModuleNotFoundError: No module named 'skbio'

In [None]:
# Experiment
# Experiment parameters
RUNS = 10
FIT_RATIO = .8
DATASETS = ['Embryoid']
MODELS = ['AE','GRAE_10', 'GRAE_50', 'GRAE_100', 'Umap_t']

# Model parameters
BATCH = 200
LR = .0001
WEIGHT_DECAY = 1
EPOCHS = 200

# Dataset parameters
FIT_DEFAULT = .8  # Ratio of data to use for training
SAMPLE = 10000  # Number of points to sample from synthetic manifolds
EB_COMPONENTS = 500  # Number of principal components to use of the EB differentiation dataset

# Seeds and base path for data
BASEPATH = '/yunity/arusty/Graph-Manifold-Alignment/Results/Grae/data'
SEED = 7512183 # Used for data generation and splits

# 20 Random states for different training runs. Add more if you need more runs
RANDOM_STATES = [36087, 63286, 52270, 10387, 40556, 52487, 26512, 28571, 33380,
                9369, 28478,  4624, 29114, 41915,  6467,  4216, 16025, 34823,
                29854, 23853]

# Set seed for both torch and numpy
np.random.seed(SEED)
torch.manual_seed(SEED)

# Create directory for data
if not os.path.exists(BASEPATH):
  os.mkdir(BASEPATH)

# Create directory for results
if not os.path.exists(BASEPATH[:-4] + "results"):
  os.mkdir(BASEPATH[:-4] + "results")

# Utils

In [None]:
# Utils
# Results Logger
class Book():
  def __init__(self, datasets, models, metrics):
    self.col = ['model', 'dataset', 'run', 'split'] + metrics
    self.log = list()
    self.models = models
    self.datasets = datasets
    self.splits = ('train', 'test')
    self.metrics = metrics

  def add_entry(self, model, dataset, run, split, **kwargs):
    # Proof read entry
    self.check(model, dataset, split)
    self.check_metrics(kwargs)

    metrics_ordered = [kwargs[k] for k in self.metrics]

    signature = [model, dataset, run, split]
    entry = signature +  metrics_ordered

    if len(entry) != len(self.col):
      raise Exception('Entry size is wrong.')

    self.log.append(entry)

  def check(self, model, dataset, split):
    if model not in self.models:
      raise Exception('Invalid model name.')

    if dataset not in self.datasets:
      raise Exception('Invalid dataset name.')

    if split not in self.splits:
      raise Exception('Invalid split name.')

  def check_metrics(self, kwargs):
    if len(kwargs.keys()) != len(self.metrics):
      raise Exception('Wrong number of metrics.')

    for key in kwargs.keys():
      if key not in self.metrics:
        raise Exception(f'Trying to add undeclared metric {key}')


  def get_df(self):
    return pd.DataFrame.from_records(self.log, columns=self.col)


def refine_df(df, df_metrics):
  df_group = df.groupby(['split', 'dataset', 'model'])
  mean = df_group.mean().drop(columns=['run']).round(4)

  # Add rank columns
  for m in df_metrics:
    # Higher is better
    ascending = False

    if m == 'reconstruction' or m.split('_')[0] == 'mrre':
      # Lower is better
      ascending = True

    loc = mean.columns.get_loc(m) + 1
    rank = mean.groupby(['split', 'dataset'])[m].rank(method='min', ascending=ascending)
    mean.insert(loc=loc, column=f'{m}_rank', value = rank)

  return mean


def save_dict(di_, filename_):
    with open(filename_, 'wb') as f:
        pickle.dump(di_, f)

def load_dict(filename_):
    with open(filename_, 'rb') as f:
        ret_di = pickle.load(f)
    return ret_di

def slice_3D(X, Y, idx, axis, p=1):
  axis = X[:, axis]

  sli = np.zeros(shape=X.shape[0])
  sli[idx] = 1

  sampler = np.random.choice(a=[False, True], size=(sli.shape[0],), p=[1-p, p])

  sli = np.logical_and(sli, sampler)

  rest = np.logical_not(sli)

  X_2, Y_2 = X[rest], Y[rest]
  X_3, Y_3 = X[sli], Y[sli]


  return X_2, Y_2, X_3, Y_3

def make_holes(x, y, n=12, eps_range=(.2, .5), seed=SEED):
  np.random.seed(SEED)

  hole_idx = np.random.choice(x.shape[0], size=n, replace=False)
  d = squareform(pdist(x))
  eps_list = np.random.uniform(eps_range[0], eps_range[1], n)

  test_idx = list()

  for i, idx in enumerate(hole_idx):
    d_line, eps = d[idx], eps_list[i]
    test_idx.append(np.argwhere(d_line < eps))

  test_idx = np.unique(np.concatenate(test_idx))
  train_idx = np.full(fill_value=True, shape=x.shape[0])
  train_idx[test_idx] = False

  return x[train_idx], y[train_idx], x[test_idx], y[test_idx]


def plot_3D(x, y, z, c):
  fig, (a1)  = plt.subplots(1,1, figsize=(10, 10))
  a1 = fig.add_subplot(111, projection='3d')
  a1.scatter(x, y, z, c=c, cmap='jet', s = 2)
  a1.view_init(elev=10, azim = 90)

  plt.show()

def plot_3D_grey(x_train, y_train, z_train, x_test, y_test, z_test, c):
  fig, (a1)  = plt.subplots(1,1, figsize=(10, 10))
  a1 = fig.add_subplot(111, projection='3d')
  a1.scatter(x_train, y_train, z_train, color='grey', s = 2, alpha=.5)
  a1.scatter(x_test, y_test, z_test, c=c, cmap='jet', s = 2)
  a1.view_init(elev=10, azim = 90)

  plt.show()

# Models

# AE 

In [None]:
# Vanilla AE
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)



# AE building blocks
class Encoder_MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, hidden_dim3, z_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, hidden_dim1)
        self.linear2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.linear3 = nn.Linear(hidden_dim2, hidden_dim3)
        self.mu = nn.Linear(hidden_dim3, z_dim)
    def forward(self, x):
        hidden1 = F.relu(self.linear(x))
        hidden2 = F.relu(self.linear2(hidden1))
        hidden3 = F.relu(self.linear3(hidden2))
        z_mu = self.mu(hidden3)
        return z_mu

class Decoder_MLP(nn.Module):
    def __init__(self, z_dim, hidden_dim1, hidden_dim2, hidden_dim3, output_dim, sigmoid_act = False):
        super().__init__()
        self.linear = nn.Linear(z_dim, hidden_dim1)
        self.linear2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.linear3 = nn.Linear(hidden_dim2, hidden_dim3)
        self.out = nn.Linear(hidden_dim3, output_dim)
        self.relu = nn.ReLU()
        self.sigmoid_act = sigmoid_act
    def forward(self, x):
        hidden1 = F.relu(self.linear(x))
        hidden2 = F.relu(self.linear2(hidden1))
        hidden3 = F.relu(self.linear3(hidden2))
        if self.sigmoid_act == False:
            predicted = (self.out(hidden3))
        else:
            predicted = F.sigmoid(self.out(hidden3))
        return predicted

class AE_MLP(nn.Module):
    def __init__(self, enc, dec):
        super().__init__()
        self.enc = enc
        self.dec = dec

    def forward(self, x):
        latent = self.enc(x)
        predicted = self.dec(latent)
        return predicted, latent


# AE main class
class AE():
    """Autoencoder class with sklearn interface."""
    def __init__(self, input_size, random_state=SEED, track_rec=False,
                 AE_wrapper=AE_MLP, batch_size=BATCH, lr=LR,
                 weight_decay=WEIGHT_DECAY, reduction='sum', epochs=EPOCHS, **kwargs):
      layer_1 = 800
      layer_2 = 400
      layer_3 = 200
      self.lr = lr
      self.epochs = epochs
      self.batch_size = batch_size
      self.weight_decay = weight_decay
      self.encoder = Encoder_MLP(input_size, layer_1, layer_2, layer_3, 2)
      self.decoder = Decoder_MLP(2, layer_3, layer_2, layer_1, input_size)
      self.model = AE_wrapper(self.encoder, self.decoder, **kwargs)
      self.model = self.model.float().to(device)

      self.criterion = nn.MSELoss(reduction=reduction)

      self.optimizer = torch.optim.Adam(self.model.parameters(),
                                        lr = self.lr,
                                        weight_decay=self.weight_decay)
      self.loss = list()
      self.track_rec = track_rec
      self.random_state = random_state

    def fit(self, x):
      # Train AE
      self.model.train()

      # Reproducibility
      torch.manual_seed(self.random_state)
      torch.backends.cudnn.deterministic = True
      torch.backends.cudnn.benchmark = False



      loader = torch.utils.data.DataLoader(x, batch_size=self.batch_size,
                                           shuffle=True)

      for epoch in range(self.epochs):
        for batch in loader:
            data, y = batch
            data = data.to(device)

            self.optimizer.zero_grad()
            x_hat, _ = self.model(data)
            x_hat = x_hat.to(device)
            loss = self.criterion(data, x_hat)
            loss.backward()
            self.optimizer.step()

        if self.track_rec:
          x_np, _ = x.numpy()
          x_hat = self.inverse_transform(self.transform(x))
          self.loss.append(mean_squared_error(x_np, x_hat))

    def transform(self, x):
      self.model.eval()
      loader = torch.utils.data.DataLoader(x, batch_size=self.batch_size,
                                           shuffle=False)
      z = [self.encoder(batch.to(device)).cpu().detach().numpy()
      for batch, _ in loader]

      return np.concatenate(z)

    def fit_transform(self, x):
      self.fit(x)
      return self.transform(x)

    def inverse_transform(self, z):
      self.model.eval()
      z = NumpyDataset(z)
      loader = torch.utils.data.DataLoader(z, batch_size=self.batch_size,
                                           shuffle=False)
      x_hat = [self.decoder(batch.to(device)).cpu().detach().numpy()
      for batch in loader]

      return np.concatenate(x_hat)


class ManifoldLoss(nn.Module):
    def __init__(self, lam):
      super().__init__()
      self.lam = lam
      self.MSE = nn.MSELoss(reduction='sum')
      self.loss = None

    def forward(self, x, y, z, emb):
      self.loss = self.MSE(x, y) + self.lam * self.MSE(z, emb)
      return self.loss

    def backward(self):
      self.loss.backward()

    def decay_lam(self, factor):
      self.lam *= factor



NameError: name 'torch' is not defined

# Grae defintion

In [None]:

class ManifoldNet(AE):
  """Base class for GRAE."""
  def __init__(self, input_size, embedder, random_state, track_rec=False, lam=10, lam_decay=1, **kwargs):
      super().__init__(input_size, random_state, track_rec)
      self.criterion = ManifoldLoss(lam)
      self.emb = None
      self.targets = None
      self.precomputed = False
      self.embedder_args = kwargs
      self.embedder = embedder
      self.lam_decay = lam_decay

  def fit(self, x):
    self.model.train()

    # Reproducibility
    torch.manual_seed(self.random_state)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    x_np, self.targets = x.numpy()

    # Compute embedding target
    if not self.precomputed:
      embedder_m = self.embedder(**self.embedder_args,
                                 random_state=self.random_state)

      emb = embedder_m.fit_transform(x_np)

      # Normalize
      emb = scipy.stats.zscore(emb)

      self.emb = emb


    # Loader to iterate over both data batches and target embedding batches
    data = ConcatDataset(torch.from_numpy(self.emb).float(), x)

    loader = torch.utils.data.DataLoader(data,
                                          batch_size=self.batch_size,
                                          shuffle=True)

    for epoch in range(self.epochs):
      for embedding, batch in loader:
          x_in, _ = batch[0].to(device), batch[1]
          embedding = embedding.to(device)
          self.optimizer.zero_grad()
          x_hat, z_mu = self.model(x_in)

          x_hat = x_hat.to(device)
          z_mu = z_mu.to(device)
          self.criterion(x_in, x_hat, z_mu, embedding)
          self.criterion.backward()
          self.optimizer.step()

      self.criterion.decay_lam(self.lam_decay)

      if self.track_rec:
        x_hat = self.inverse_transform(self.transform(x))
        self.loss.append(mean_squared_error(x_np, x_hat))

  def plot_latent(self):
    plt.scatter(*self.emb.T, c=self.targets, cmap="jet")
    plt.show()

  def set_embedding(self, emb):
    self.emb = emb
    self.precomputed = True



# Variants of GRAE
class GRAE(ManifoldNet):
  """Vanilla GRAE."""
  def __init__(self, input_size, random_state=SEED, track_rec=False,
               lam=10, lam_decay=1, t='auto', knn=20, n_landmark=2000,
               mds='metric'):
    super().__init__(input_size=input_size,
                      random_state=random_state,
                      embedder=phate.PHATE,
                      track_rec=track_rec,
                      lam=lam,
                      lam_decay=lam_decay,
                      t=t,
                      knn=knn,
                      n_jobs=-1,
                      verbose=0,
                      n_landmark=n_landmark,
                      mds=mds)


class GRAE_UMAP(ManifoldNet):
  """UMAP GRAE, as presented in the supplement."""
  def __init__(self, input_size, random_state=SEED,
               track_rec=False, lam=10, lam_decay=1, n_neighbors=15):
    super().__init__(input_size=input_size,
                      random_state=random_state,
                      embedder=umap.UMAP,
                      track_rec=track_rec,
                      lam=lam,
                      n_neighbors=n_neighbors)


class GRAE_TSNE(ManifoldNet):
  """t-SNE GRAE, as presented in the supplement."""
  def __init__(self, input_size, random_state=SEED, track_rec=False,
               lam=10, lam_decay=1, perplexity=30):
    super().__init__(input_size=input_size,
                      random_state=random_state,
                      embedder=TSNE,
                      track_rec=track_rec,
                      lam=lam,
                      n_jobs=-1,
                      verbose=0,
                      perplexity=perplexity)

# Metrics

In [None]:
# Unsupervised Metrics
# Code adapted from the Topological autoencoders paper
class MeasureCalculator():
    # measures = MeasureRegistrator()

    def __init__(self, X, Z, X_hat, k_max=20):
        self.k_max = k_max
        self.X = X
        self.X_hat = X_hat
        self.pairwise_X = squareform(pdist(X))
        self.pairwise_Z = squareform(pdist(Z))
        self.neighbours_X, self.ranks_X = \
            self._neighbours_and_ranks(self.pairwise_X, k_max)
        self.neighbours_Z, self.ranks_Z = \
            self._neighbours_and_ranks(self.pairwise_Z, k_max)


    @staticmethod
    def _neighbours_and_ranks(distances, k):
        """
        Inputs:
        - distances,        distance matrix [n times n],
        - k,                number of nearest neighbours to consider
        Returns:
        - neighbourhood,    contains the sample indices (from 0 to n-1) of kth nearest neighbor of current sample [n times k]
        - ranks,            contains the rank of each sample to each sample [n times n], whereas entry (i,j) gives the rank that sample j has to i (the how many 'closest' neighbour j is to i)
        """
        # Warning: this is only the ordering of neighbours that we need to
        # extract neighbourhoods below. The ranking comes later!
        indices = np.argsort(distances, axis=-1, kind='stable')

        # Extract neighbourhoods.
        neighbourhood = indices[:, 1:k+1]

        # Convert this into ranks (finally)
        ranks = indices.argsort(axis=-1, kind='stable')

        return neighbourhood, ranks

    def get_X_neighbours_and_ranks(self, k):
        return self.neighbours_X[:, :k], self.ranks_X

    def get_Z_neighbours_and_ranks(self, k):
        return self.neighbours_Z[:, :k], self.ranks_Z

    def compute_k_independent_measures(self):
        return {key: fn(self) for key, fn in
                self.measures.get_k_independent_measures().items()}

    def compute_k_dependent_measures(self, k):
        return {key: fn(self, k) for key, fn in
                self.measures.get_k_dependent_measures().items()}

    def compute_measures_for_ks(self, ks):
        return {
            key: np.array([fn(self, k) for k in ks])
            for key, fn in self.measures.get_k_dependent_measures().items()
        }


    # @measures.register(False)
    def stress(self):
        sum_of_squared_differences = \
            np.square(self.pairwise_X - self.pairwise_Z).sum()
        sum_of_squares = np.square(self.pairwise_Z).sum()

        return np.sqrt(sum_of_squared_differences / sum_of_squares)

    # @measures.register(False)
    def rmse(self):
        n = self.pairwise_X.shape[0]
        sum_of_squared_differences = np.square(
            self.pairwise_X - self.pairwise_Z).sum()
        return np.sqrt(sum_of_squared_differences / n**2)

    @staticmethod
    def _trustworthiness(X_neighbourhood, X_ranks, Z_neighbourhood,
                         Z_ranks, n, k):
        '''
        Calculates the trustworthiness measure between the data space `X`
        and the latent space `Z`, given a neighbourhood parameter `k` for
        defining the extent of neighbourhoods.
        '''

        result = 0.0

        # Calculate number of neighbours that are in the $k$-neighbourhood
        # of the latent space but not in the $k$-neighbourhood of the data
        # space.
        for row in range(X_ranks.shape[0]):
            missing_neighbours = np.setdiff1d(
                Z_neighbourhood[row],
                X_neighbourhood[row]
            )

            for neighbour in missing_neighbours:
                result += (X_ranks[row, neighbour] - k)

        return 1 - 2 / (n * k * (2 * n - 3 * k - 1) ) * result

    # @measures.register(True)
    def trustworthiness(self, k):
        X_neighbourhood, X_ranks = self.get_X_neighbours_and_ranks(k)
        Z_neighbourhood, Z_ranks = self.get_Z_neighbours_and_ranks(k)
        n = self.pairwise_X.shape[0]
        return self._trustworthiness(X_neighbourhood, X_ranks, Z_neighbourhood,
                                     Z_ranks, n, k)

    # @measures.register(True)
    def continuity(self, k):
        '''
        Calculates the continuity measure between the data space `X` and the
        latent space `Z`, given a neighbourhood parameter `k` for setting up
        the extent of neighbourhoods.

        This is just the 'flipped' variant of the 'trustworthiness' measure.
        '''

        X_neighbourhood, X_ranks = self.get_X_neighbours_and_ranks(k)
        Z_neighbourhood, Z_ranks = self.get_Z_neighbours_and_ranks(k)
        n = self.pairwise_X.shape[0]
        # Notice that the parameters have to be flipped here.
        return self._trustworthiness(Z_neighbourhood, Z_ranks, X_neighbourhood,
                                     X_ranks, n, k)

    # @measures.register(True)
    def neighbourhood_loss(self, k):
        '''
        Calculates the neighbourhood loss quality measure between the data
        space `X` and the latent space `Z` for some neighbourhood size $k$
        that has to be pre-defined.
        '''

        X_neighbourhood, _ = self.get_X_neighbours_and_ranks(k)
        Z_neighbourhood, _ = self.get_Z_neighbours_and_ranks(k)

        result = 0.0
        n = self.pairwise_X.shape[0]

        for row in range(n):
            shared_neighbours = np.intersect1d(
                X_neighbourhood[row],
                Z_neighbourhood[row],
                assume_unique=True
            )

            result += len(shared_neighbours) / k

        return 1.0 - result / n


    # @measures.register(True)
    def rank_correlation(self, k):
        '''
        Calculates the spearman rank correlation of the data
        space `X` with respect to the latent space `Z`, subject to its $k$
        nearest neighbours.
        '''

        X_neighbourhood, X_ranks = self.get_X_neighbours_and_ranks(k)
        Z_neighbourhood, Z_ranks = self.get_Z_neighbours_and_ranks(k)

        n = self.pairwise_X.shape[0]
        #we gather
        gathered_ranks_x = []
        gathered_ranks_z = []
        for row in range(n):
            #we go from X to Z here:
            for neighbour in X_neighbourhood[row]:
                rx = X_ranks[row, neighbour]
                rz = Z_ranks[row, neighbour]
                gathered_ranks_x.append(rx)
                gathered_ranks_z.append(rz)
        rs_x = np.array(gathered_ranks_x)
        rs_z = np.array(gathered_ranks_z)
        coeff, _ = spearmanr(rs_x, rs_z)

        ##use only off-diagonal (non-trivial) ranks:
        #inds = ~np.eye(X_ranks.shape[0],dtype=bool)
        #coeff, pval = spearmanr(X_ranks[inds], Z_ranks[inds])
        return coeff

    # @measures.register(True)
    def mrre(self, k):
        '''
        Calculates the mean relative rank error quality metric of the data
        space `X` with respect to the latent space `Z`, subject to its $k$
        nearest neighbours.
        '''

        X_neighbourhood, X_ranks = self.get_X_neighbours_and_ranks(k)
        Z_neighbourhood, Z_ranks = self.get_Z_neighbours_and_ranks(k)

        n = self.pairwise_X.shape[0]

        # First component goes from the latent space to the data space, i.e.
        # the relative quality of neighbours in `Z`.

        mrre_ZX = 0.0
        for row in range(n):
            for neighbour in Z_neighbourhood[row]:
                rx = X_ranks[row, neighbour]
                rz = Z_ranks[row, neighbour]

                mrre_ZX += abs(rx - rz) / rz

        # Second component goes from the data space to the latent space,
        # i.e. the relative quality of neighbours in `X`.

        mrre_XZ = 0.0
        for row in range(n):
            # Note that this uses a different neighbourhood definition!
            for neighbour in X_neighbourhood[row]:
                rx = X_ranks[row, neighbour]
                rz = Z_ranks[row, neighbour]

                # Note that this uses a different normalisation factor
                mrre_XZ += abs(rx - rz) / rx

        # Normalisation constant
        C = n * sum([abs(2*j - n - 1) / j for j in range(1, k+1)])
        # return mrre_ZX / C, mrre_XZ / C
        return mrre_ZX / C

    # @measures.register(False)
    def density_global(self, sigma=0.1):
        X = self.pairwise_X
        X = X / X.max()
        Z = self.pairwise_Z
        Z = Z / Z.max()

        density_x = np.sum(np.exp(-(X ** 2) / sigma), axis=-1)
        density_x /= density_x.sum(axis=-1)

        density_z = np.sum(np.exp(-(Z ** 2) / sigma), axis=-1)
        density_z /= density_z.sum(axis=-1)

        return np.abs(density_x - density_z).sum()

    # @measures.register(False)
    def density_kl_global(self, sigma=0.1):
        X = self.pairwise_X
        X = X / X.max()
        Z = self.pairwise_Z
        Z = Z / Z.max()

        density_x = np.sum(np.exp(-(X ** 2) / sigma), axis=-1)
        density_x /= density_x.sum(axis=-1)

        density_z = np.sum(np.exp(-(Z ** 2) / sigma), axis=-1)
        density_z /= density_z.sum(axis=-1)

        return (density_x * (np.log(density_x) - np.log(density_z))).sum()

    # @measures.register(False)
    def density_kl_global_10(self):
        return self.density_kl_global(10.)

    # @measures.register(False)
    def density_kl_global_1(self):
        return self.density_kl_global(1.)

    # @measures.register(False)
    def density_kl_global_01(self):
        return self.density_kl_global(0.1)

    # @measures.register(False)
    def density_kl_global_001(self):
        return self.density_kl_global(0.01)

    # @measures.register(False)
    def density_kl_global_0001(self):
        return self.density_kl_global(0.001)

    def reconstruction(self):
      if self.X_hat == None:
        return None

      return mean_squared_error(self.X, self.X_hat)


    def get_metrics(self, metrics):
      results = dict()

      for metric in metrics:
        args = metric.split('_')

        if len(args) == 2:
          m, k = args
          k = dict(k=int(k))
        elif len(args) == 1:
          m, k = args[0], dict()
        else:
          raise Exception('Invalid string metric.')

        results[metric]=getattr(self, m)(**k)

      return results


# Fit Model

In [None]:
# Build dict with various model parameters

# Neighborhood parameters of manifold learners
PHATE_knn = dict(
    Faces=dict(knn=5),
    RotatedDigits=dict(knn=5),
    ribbons=dict(knn=20),
    Embryoid=dict(knn=5),
)

UMAP_n_neighbors = dict(
    Faces=dict(n_neighbors=15),
    RotatedDigits=dict(n_neighbors=15),
    ribbons=dict(n_neighbors=20),
    Embryoid=dict(n_neighbors=15),
)

TSNE_perp = dict(
    Faces=dict(perplexity=10),
    RotatedDigits=dict(perplexity=10),
    ribbons=dict(perplexity=30),
    Embryoid=dict(perplexity=10),
)

# t parameter for PHATE
t = 'auto'

ds_names = dataset_constructors.keys()


# Input size. Should be passed to AE and GRAE to adjust
# input and output layers accordingly
size_dict = dict(
              RotatedDigits=dict(input_size=784),
              ribbons=dict(input_size=3),
              Faces=dict(input_size=4096),
              Embryoid=dict(input_size=EB_COMPONENTS),
)

# Build dict with both input size and knn parameter
PHATE_knn_size = copy.deepcopy(PHATE_knn)

for key, item in PHATE_knn_size.items():
  item.update(size_dict[key])

UMAP_n_neighbors_size = copy.deepcopy(UMAP_n_neighbors)

for key, item in UMAP_n_neighbors_size.items():
  item.update(size_dict[key])

TSNE_perp_size = copy.deepcopy(TSNE_perp)

for key, item in TSNE_perp_size.items():
  item.update(size_dict[key])


# As placeholder when no parameters are needed for dataset specific inits
empty_dict = dict(zip(ds_names, [{} for _ in ds_names]))


# Parameter variable
params = {
    'Umap': dict( # Vanilla Umap with no transforms, use Umap_t instead
        constructor=umap.UMAP,
        phate_cache=False,
        numpy=True,
        train_only=True,
        init_default=dict(), # Parameters to be used for all model inits
        init_dataset=empty_dict, # Dataset specific model inits
        FIT_DEFAULT={}, # Parameters to be used for all model fits
        fit_dataset=empty_dict # Dataset specific model fits
        ),

      'Umap_t': dict(
        constructor=umap.UMAP,
        phate_cache=False,
        numpy=True,
        train_only=False,
        init_default=dict(),
        init_dataset=UMAP_n_neighbors,
        FIT_DEFAULT={},
        fit_dataset=empty_dict
        ),


      'PHATE': dict( # vanilla PHATE
        constructor=phate.PHATE,
        phate_cache=True,
        numpy=True,
        train_only=True,
        init_default=dict(verbose=0, n_jobs=-1, t=t),
        init_dataset=PHATE_knn,
        FIT_DEFAULT={},
        fit_dataset=empty_dict
        ),

      'TSNE': dict(
        constructor=TSNE,
        phate_cache=False,
        numpy=True,
        train_only=True,
        init_default=dict(verbose=0, n_jobs=-1),
        init_dataset=empty_dict,
        FIT_DEFAULT={},
        fit_dataset=empty_dict
        ),


      'GRAE': dict( # This is GRAE.
        constructor=GRAE,
        phate_cache=True,
        numpy=False,
        train_only=False,
        init_default=dict(lam=1, t=t),
        init_dataset=PHATE_knn_size,
        FIT_DEFAULT={},
        fit_dataset=empty_dict
        ),
      'GRAE_UMAP': dict(
        constructor=GRAE_UMAP,
        phate_cache=False,
        numpy=False,
        train_only=False,
        init_default=dict(lam=1),
        init_dataset=UMAP_n_neighbors_size,
        FIT_DEFAULT={},
        fit_dataset=empty_dict
      ),
      'GRAE_TSNE': dict(
        constructor=GRAE_TSNE,
        phate_cache=False,
        numpy=False,
        train_only=False,
        init_default=dict(lam=1),
        init_dataset=TSNE_perp_size,
        FIT_DEFAULT={},
        fit_dataset=empty_dict
      ),
      'AE': dict(
        constructor=AE,
        phate_cache=False,
        numpy=False,
        train_only=False,
        init_default={},
        init_dataset=size_dict,
        FIT_DEFAULT={},
        fit_dataset=empty_dict
        ),
}

tae = False

try:
  params.update({
      'TopoAE': dict(
        constructor=TopoAE,
        phate_cache=False,
        numpy=False,
        train_only=False,
        init_default={},
        init_dataset=size_dict,
        FIT_DEFAULT={},
        fit_dataset=empty_dict
      )
  })
  tae = True
except Exception:
  print('TopoAE is not defined. If you wish to use them, please refer to the Topological autoencoders subsection.')


# Add variants of PHATE-Net, UMAP-Net and TSNE-Net
for n in (1, 10, 50, 100, 200, 1000, 10000):
  params[f'GRAE_{n}'] = copy.deepcopy(params['GRAE'])
  params[f'GRAE_{n}']['init_default']['lam'] = n
  params[f'GRAE_UMAP_{n}'] = copy.deepcopy(params['GRAE_UMAP'])
  params[f'GRAE_UMAP_{n}']['init_default']['lam'] = n
  params[f'GRAE_TSNE_{n}'] = copy.deepcopy(params['GRAE_TSNE'])
  params[f'GRAE_TSNE_{n}']['init_default']['lam'] = n
  if tae:
    params[f'TopoAE_{n}'] = copy.deepcopy(params['TopoAE'])
    params[f'TopoAE_{n}']['init_default']['lam'] = n

In [None]:
# Fit models
# All embeddings are saved under the embeddings folder

# Remove previous embedding folder if any
# !rm -rf 'embeddings'

if os.path.exists('embeddings'):
    shutil.rmtree('embeddings')

os.mkdir('embeddings')

# Experiment loop
for model in MODELS:
  print(f'Training {model}...')

  os.mkdir(os.path.join('embeddings', model))

  for i, dataset in enumerate(DATASETS):
    target = os.path.join('embeddings', model, dataset)

    os.mkdir(target)

    print(f'   On {dataset}...')
    # Training loop
    for j in range(RUNS):
      print(f'       Run {j + 1}...')

      # Fetch and split dataset. Handle numpy input for some models
      data_train = dataset_constructors[dataset](split="train",
                                                 split_ratio=FIT_RATIO,
                                                 seed=SEED)
      data_test = dataset_constructors[dataset](split="test",
                                                split_ratio=FIT_RATIO,
                                                seed=SEED)


      data_train_np, y_train = data_train.numpy()
      data_test_np, y_test = data_test.numpy()


      if params[model]['numpy']:
        data_train = data_train_np
        data_test = data_test_np



      m = params[model]['constructor']( # New Model
          **params[model]['init_default'],
          **params[model]['init_dataset'][dataset],
          random_state=RANDOM_STATES[j])


      # Benchmark fit time
      fit_start = time.time()

      z_train = m.fit_transform(data_train,
            **params[model]['FIT_DEFAULT'],
            **params[model]['fit_dataset'][dataset])

      fit_stop = time.time()

      fit_time = fit_stop - fit_start


      if not params[model]['train_only']:
        # Benchmark transform time if required
        transform_start = time.time()
        z_test = m.transform(data_test)
        transform_stop = time.time()

        transform_time = transform_stop - transform_start



      if params[model]['train_only']:
         # T-SNE and PHATE do not have inverse transforms
        inv_train, inv_test, rec_train, rec_test = None, None, None, None
      else:
        inv_train = m.inverse_transform(z_train)
        inv_test = m.inverse_transform(z_test)

        rec_train = mean_squared_error(data_train_np, inv_train)
        rec_test = mean_squared_error(data_test_np, inv_test)


      # Save embeddings
      if params[model]['train_only']:
        obj = dict(z_train=z_train, z_test=None,
                   rec_train=None, rec_test=None,
                   fit_time=fit_time, transform_time=None,
                   dataset_seed=SEED, run_seed=RANDOM_STATES[j])
      else:
        obj = dict(z_train=z_train, z_test=z_test,
                   rec_train=rec_train, rec_test=rec_test,
                   fit_time=fit_time, transform_time=transform_time,
                   dataset_seed=SEED, run_seed=RANDOM_STATES[j])


      save_dict(obj, os.path.join(target, f'run_{j + 1}.pkl'))




# Plot Embeddings

In [None]:
# Prettier names for models and datasets

model_name = dict(Umap_t= 'UMAP',
                  Umap='UMAP',
                  diffusion_net='Diffusion Nets',
                  PHATE='PHATE',
                  TSNE='t-SNE',
                  AE='Autoencoder')

base_name = dict(GRAE='GRAE (', GRAE_TSNE='GRAE t-SNE (',
                 GRAE_UMAP='GRAE UMAP (', TopoAE='TAE (')

for m in MODELS:
  if m not in model_name:
    splits = m.split('_')
    base, lam = splits[0], splits[1]
    model_name[m] = base_name[base] + f'{lam})'



ds_name = dict(ribbons='Swiss Roll',
               Faces='Faces',
               Embryoid='Embryoid',
               RotatedDigits='Rotated Digits')

In [None]:
# Plot embeddings
PLOT_RUN = 1


titles = [model_name[m] for m in MODELS]
n_d = len(DATASETS)
n_m = len(MODELS)
fig, ax = plt.subplots(n_d, n_m, figsize=(n_m * 3.5, n_d * 3.5))

for j, model in enumerate(MODELS):
  for i, dataset in enumerate(DATASETS):
    file_path = os.path.join('embeddings', model, dataset, f'run_{PLOT_RUN}.pkl')

    if os.path.exists(file_path):
      # Retrieve datasets for coloring
      data = load_dict(file_path)
      X_train = dataset_constructors[dataset](split='train',
                                              seed=data['dataset_seed'])
      X_test = dataset_constructors[dataset](split='test',
                                             seed=data['dataset_seed'])
      _, y_train = X_train.numpy()
      _, y_test = X_test.numpy()
      z_train, z_test = data['z_train'], data['z_test']
    else:
      # Filler if plot is not found
      z_train, z_test = np.array([[0, 0]]), np.array([[0, 0]])
      y_train = np.array([1])
      y_test = np.array([1])


    if n_d == 1:
      ax_i = ax[j]
    elif n_m == 1:
      ax_i = ax[i]
    else:
      ax_i = ax[i, j]

    l = ax_i.scatter(*z_train.T, s = 1.5,  alpha=.2, color='grey')

    ax_i.scatter(*z_test.T, c = y_test, s = 15, cmap='jet')



    if i == 0:
      ax_i.set_title(f'{titles[j]}', fontsize=20, color='black')
    ax_i.set_xticks([])
    ax_i.set_yticks([])


plt.savefig(os.path.join('results', 'plot.png'))
plt.show()