## **README**
  
[GVAE] H. Hosoya, “Group-based Learning of Disentangled Representations with Generalizability for Novel Contents,” in Proceedings of the Twenty-Eighth International Joint Conference on Artificial Intelligence, Macao, China, 2019, pp. 2506–2513, doi: 10.24963/ijcai.2019/348.

 
[MLVAE] D. Bouchacourt, R. Tomioka, and S. Nowozin, “Multi-Level Variational Autoencoder: Learning Disentangled Representations from Grouped Observations,” arXiv:1705.08841 [cs, stat], May 2017, Accessed: Feb. 19, 2021. [Online]. Available: http://arxiv.org/abs/1705.08841.
  

# Initialize

### Import python packages and initialize gpus

In [None]:
from IPython import display
import os, sys, time, glob, io, pprint, re, shutil
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
import matplotlib.pyplot as plt
import h5py
import pandas as pd
from datetime import datetime
import tensorflow as tf
tfk = tf.keras
tfkl = tfk.layers
tfkltd = tf.keras.layers.TimeDistributed
from tensorflow.keras.utils import to_categorical
from absl import app, flags
from IPython.display import clear_output
clear_output()

### Check requirements

In [None]:
if not sys.version_info[0]==3:
    sys.exit("Python 3 required")
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="7"
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs", len(logical_gpus), "Logical GPUs\n\n")
    except RuntimeError as e:
        print(e)        
tfk.backend.clear_session()
strategy = tf.distribute.MirroredStrategy()
print('num devices = %d'%strategy.num_replicas_in_sync)

### Set parameters

In [None]:
NUM_INSTANCES = 21 # number of instances in a bag
NUM_CLASS = 10     # mnist has 10 classes

### Flatten MNIST data 28*28 to 784

In [None]:
##### My Tmnist dataset #####
#Tmnist = np.genfromtxt('TMNIST_Data.csv', delimiter=',', 
    #                   names=['names', 'labels'])
df = pd.read_csv('TMNIST_Data.csv')
Tmnist_labels = df.to_numpy()[:,1]
Tmnist_labels = np.array(Tmnist_labels, dtype='float32')
Tmnist_images = df.to_numpy()[:,2:]
Tmnist_images = np.array(Tmnist_images, dtype='float32')
Tmnist_images.shape

In [None]:
(trn_images, trn_labels), (tst_images, tst_labels) = tfk.datasets.mnist.load_data()
##### from N*28*28 to N*784 #####
def rescale_and_flatten_images(images):
    images = images.reshape((images.shape[0], 28*28)) / 255.
    return images.astype('float32')
trn_images = rescale_and_flatten_images(trn_images)[0:10000]
tst_images = rescale_and_flatten_images(tst_images)[0:10000]
trn_labels = trn_labels[0:10000]
tst_labels = tst_labels[0:10000]
print(trn_images.shape)
print(trn_labels.shape)
print(tst_images.shape)
print(tst_labels.shape)

In [None]:
# trn_images = np.concatenate([trn_images,Tmnist_images[0:10000]],axis=0)
# trn_labels = np.concatenate([trn_labels,Tmnist_labels[0:10000]],axis=0)
# tst_images = np.concatenate([tst_images,Tmnist_images[10000:20000]],axis=0)
# tst_labels = np.concatenate([tst_labels,Tmnist_labels[10000:20000]],axis=0)
# print(trn_images.shape)
# print(trn_labels.shape)
# print(tst_images.shape)
# print(tst_labels.shape)

In [None]:
indices = np.random.permutation(20000)
trn_images[indices]
trn_labels[indices]
indices = np.random.permutation(20000)
tst_images[indices]
tst_labels[indices]

### Sort it by labels(0-9)

In [None]:
sort_idx = np.argsort(trn_labels)
trn_labels = trn_labels[sort_idx]
trn_images = trn_images[sort_idx, :]
print(trn_labels.shape)
print(trn_images.shape)
_, class_count = np.unique(trn_labels, return_counts=True)
class_cumsum = np.cumsum(class_count)
print('class count: ', class_count)
print('class cumsum: ', class_cumsum)

### Bag MNIST data into instances

In [None]:
bag_idx = np.array([])
for kk in np.arange(NUM_CLASS):
    trim = (class_count[kk]//NUM_INSTANCES)*NUM_INSTANCES
    if kk == 0:
        bag_idx = np.hstack([bag_idx, np.arange(trim)])
    else:
        bag_idx = np.hstack([bag_idx, class_cumsum[kk-1]+np.arange(trim)])
bag_idx = np.ix_(bag_idx.astype('int'))
bagged_trn_labels = trn_labels[bag_idx].reshape((-1, NUM_INSTANCES))
bagged_trn_images = trn_images[bag_idx, :].reshape((-1, NUM_INSTANCES, 28*28))
print(bagged_trn_labels.shape)
print(bagged_trn_images.shape)

### Creat dataset for distributed training

In [None]:
BUFFER_SIZE = bagged_trn_images.shape[0]
BATCH_SIZE_PER_REPLICA = 32
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
trn_dataset = (tf.data.Dataset
                .from_tensor_slices(bagged_trn_images)
                .shuffle(BUFFER_SIZE)
                .batch(GLOBAL_BATCH_SIZE))
trn_dataset

In [None]:
trn_dist_dataset = strategy.experimental_distribute_dataset(trn_dataset)
trn_dist_dataset

# Define GVAE Model

### Layers

In [None]:
class StyleEncoder(tfkl.Layer):
    """ Responsible for encoding the style of each instance.
    Maps [?, instances, data_dim] -> [?, instances, sty_dim] (for mean and log-variance)
    """
    def __init__(self, style_dim, mask_rate, name='style_encoder', **kwargs):
        super().__init__(name=name, **kwargs)
        self.mask_rate = mask_rate
        self.style_dim = style_dim
        self.enc_per_instance = tfk.Sequential(
                [   tfkl.Dense(512),
                    tfkl.LeakyReLU(),
                    tfkl.Dense(256),
                    tfkl.LeakyReLU(),
                    tfkl.Dense(style_dim),
                    ],
                    name="enc_per_instance",)
        self.enc_mean = tfk.Sequential(
                [   tfkl.Dense(style_dim)
                    ],
                    name="z_mean",)
        
    def call(self, inputs):
        X = inputs 
        ninstances = inputs.shape[1]
        data_dim = inputs.shape[2]
        inputs =  tfkl.Reshape((-1, data_dim), name='flatten_instances')(inputs)
        common = self.enc_per_instance(inputs)
        z_mean = self.enc_mean(common)
        # reshape back to [?, instances, style_dim]
        z_mean = tfkl.Reshape((ninstances, self.style_dim))(z_mean)
        z_mean = tfkl.Dropout(self.mask_rate)(z_mean)
        return z_mean
    
class ContentEncoder(tfkl.Layer):
    """ Responsible for encoding the content common to each instance.
    
    For GVAE the content encoder N(z|x1,...xk) has the parametric form
    of N( \avg \mu(xi), \avg s(xi), i.e. averaging a common encoder across instances.
    
    Maps [?, instances, data_dim] -> [?, 1, content_dim] (for mean and log-variance)
    """
    def __init__(self, content_dim, name='style_encoder', **kwargs):
        super().__init__(name=name, **kwargs)
        self.content_dim = content_dim
        self.enc_per_instance = tfk.Sequential(
                [   tfkl.Dense(512),
                    tfkl.LeakyReLU(),
                    tfkl.Dense(256),
                    tfkl.LeakyReLU(),
                    tfkl.Dense(content_dim),
                    ],
                    name="enc_per_instance",
                )  
        self.enc_mean = tfk.Sequential(
                [   tfkl.Dense(content_dim)
                    ],
                    name="z_mean",
                )
    def call(self, inputs):
        X = inputs
        ninstances = inputs.shape[1]
        data_dim = inputs.shape[2]
        inputs =  tfkl.Reshape((-1, data_dim), name='flatten_instances')(inputs)
        common = self.enc_per_instance(inputs)
        z_mean = self.enc_mean(common)
        # reshape back to [?, instances, content_dim]
        z_mean = tfkl.Reshape((ninstances, self.content_dim))(z_mean)
        # compute average to [?, content_dim]
        z_mean = tfkl.GlobalAveragePooling1D()(z_mean)  
        # [?, content_dim] -> [?, 1, content_dim]
        z_mean = tfkl.Reshape((1, self.content_dim))(z_mean)   
        return z_mean
    
class Decoder(tfkl.Layer):
    """ Decodes each pair of (z_content, z_instance)
    
    The number of instances is automatically inferred at runtime.
    
    Maps [?, instances, latent_dim] -> [?, instances, data_dim]
    """
    def __init__(self, data_dim, name='decoder', **kwargs):
        super().__init__(name=name, **kwargs)
        self.data_dim = data_dim     
        self.dec_per_instance = tfk.Sequential(
            [   tfkl.Dense(256),
                tfkl.LeakyReLU(),
                tfkl.Dense(512),
                tfkl.LeakyReLU(),
                tfkl.Dense(data_dim, activation='sigmoid'),
            ],
            name="dec_per_instance",        
        )

    def call(self, z):
        # infer number of instances [?, instances, latent_dim]
        ninstances = z.shape[1]
        latent_dim = z.shape[2]
        # reshape instances to sample axis : [?*instances, latent_dim]
        z = tfkl.Reshape((-1, latent_dim))(z)
        # apply decoding to each instance : [?*instances, data_dim]
        x_mean = self.dec_per_instance(z)
        # collect instances : [?, instances, data_dim]
        x_mean = tfkl.Reshape((ninstances, self.data_dim))(x_mean)
        return x_mean
    
class SYMAE(tfk.Model):
    """ Grouped Variational Auto-Encoder
    [Input]
        X ~ [?, instances, data_dim]
    """
    def __init__(self, data_dim, style_dim, content_dim, mask_rate,
            nsamp=1, name='GVAE', dec_var_model ='trainable', **kwargs):
        super().__init__(name=name, **kwargs)
        self.mask_rate = mask_rate
        self.data_dim = data_dim
        self.style_dim = style_dim
        self.content_dim = content_dim
        latent_dim = style_dim + content_dim
        self.latent_dim = style_dim + content_dim

        """ build encoder & decoder graphs """
        self.style_encoder = StyleEncoder(style_dim, mask_rate)
        self.content_encoder = ContentEncoder(content_dim)
        self.decoder = Decoder(data_dim)
        
    def call(self, inputs):
        # input = [?, instances, data_dim]
        X = inputs
        ninstances = X.shape[1]
        """ evaluate encoder/decoder """        
        # style encoding: [?, instances, data_dim] -> [?, instances, style_dim]
        sty_mean = self.style_encoder(X)
        # content encoding: [?, instances, data_dim] -> [?, 1, content_dim]
        cnt_mean = self.content_encoder(X)
        # replicate along instance dimension -> [?, instances, content_dim]
        cnt_mean_replicate = tfkl.UpSampling1D(ninstances, name="replicate")(cnt_mean)
        # concatenate
        z = tfkl.Concatenate(axis=-1)([sty_mean, cnt_mean_replicate])     
        # decode
        dec_mean = self.decoder(z)
        """ compute ELBO loss """
        # squared error per example
        se = tf.square(X - dec_mean) # [?, instances, data_dim]
        se_per_instance = tf.reduce_sum(se, axis=-1) # [?, instances]
        se_per_ex = tf.reduce_sum(se_per_instance, axis=-1) #[?, ]
        return se_per_ex, dec_mean

    def encode_decode(self, inputs):
        """ Apply VAE deterministically: xhat = dec(enc(x))"""
        X = inputs
        ninstances = X.shape[1]  
        """ evaluate encoder/decoder """        
        # style encoding: [?, instances, data_dim] -> [?, instances, style_dim]
        sty_mean = self.style_encoder(X)
        # content encoding: [?, instances, data_dim] -> [?, 1, content_dim]
        cnt_mean = self.content_encoder(X)
        # replicate along instance dimension -> [?, instances, content_dim]
        cnt_mean_replicate = tfkl.UpSampling1D(ninstances, name="replicate")(cnt_mean)
        # concatenate:  [?, instances, latent_dim]
        z = tfkl.Concatenate(axis=-1)([sty_mean, cnt_mean_replicate])  
        # decode: [?, instances, data_dim]
        dec_mean = self.decoder(z)
        return dec_mean
    
    def content_encode(self, X):        
        # content encoding: [?, instances, data_dim] -> [?, 1, content_dim]
        cnt_mean = self.content_encoder(X)
        return cnt_mean
    
    def style_encode(self, X):         
        # style encoding: [?, instances, data_dim] -> [?, instances, style_dim]
        sty_mean = self.style_encoder(X)
        return sty_mean
    
    def style_decode(self, cnt_mean, sty_mean):
        ninstances = sty_mean.shape[1]
        # replicate along instance dimension -> [?, instances, content_dim]
        cnt_mean_replicate = tfkl.UpSampling1D(ninstances)(cnt_mean)
        # concatenate:  [?, instances, latent_dim]
        z = tfkl.Concatenate(axis=-1)([sty_mean, cnt_mean_replicate])  
        # decode: [?, instances, data_dim]
        dec_mean = self.decoder(z)
        return dec_mean

# Optimization

### Set optimization parameters

If you want to use the weights stored in './checkpoint',
set load_weights = True

In [None]:
with strategy.scope():
    load_weights = True
    style_dim = 20
    content_dim = 20
    data_dim = 28*28
    mask_rate = 0.5
    initial_learning_rate = 2e-4
    ninstances = NUM_INSTANCES
    symae = SYMAE(data_dim, style_dim, content_dim, mask_rate)
    if load_weights == True:
        symae.load_weights('./checkpoint')
    lr_schedule = tfk.optimizers.schedules.ExponentialDecay(
                initial_learning_rate,
                decay_steps=2000,
                decay_rate=0.96,
                staircase=True)
    # opt = tf.optimizers.Adam(learning_rate=5e-4)
    opt = tf.optimizers.Adam(learning_rate=lr_schedule)
    trn_loss_metric = tfk.metrics.Mean()
    trn_mse_metric = tfk.metrics.Mean()
    trn_sty_KL_metric = tfk.metrics.Mean()
    trn_cnt_KL_metric = tfk.metrics.Mean()
clear_output()

### Set each training step

In [None]:
@tf.function
def train_step(inputs):
    with tf.GradientTape() as tape:
        mse_per_ex, _ = symae(inputs, training=True)
        loss = tf.nn.compute_average_loss(mse_per_ex, global_batch_size=GLOBAL_BATCH_SIZE)
    trn_loss_metric(mse_per_ex)
    grads = tape.gradient(loss, symae.trainable_weights)
    opt.apply_gradients(zip(grads, symae.trainable_weights))
    return loss

@tf.function
def distributed_train_step(dataset_inputs):
    per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
    return 

# Train and Plot

In [None]:
# Iterate over epochs.
num_epochs = 1000
for epoch in range(num_epochs):        
    if epoch % 5 == 0:
        display.clear_output(wait=True)
    print('===================', flush=True)
    print('Start of epoch %d' % (epoch,))
    start_time = time.perf_counter()
    
    for step, trn_batch in enumerate(trn_dist_dataset):
        _ = distributed_train_step(trn_batch)
    
    if epoch % 5 == 0:
        print('=== TRN ===')
        #trn_batch = trn_batch.values[0].numpy()
        trn_batch = trn_batch.numpy()
        """ plot decoding of replicates """
        """ training loss """
        pred = symae.encode_decode(trn_batch)
        pred = pred.numpy()       
        plt.figure(figsize=(12,4), facecolor='w')
        num_instance_plot = min([5, NUM_INSTANCES])
        for kk in np.arange(num_instance_plot):
            plt.subplot(2, num_instance_plot, kk+1)
            plt.imshow(trn_batch[0,kk,:].reshape(28,28))
            plt.xticks([])
            plt.yticks([])
            if kk == 0:
                plt.ylabel('Exact')            
            plt.subplot(2, num_instance_plot, num_instance_plot+kk+1)
            plt.imshow(pred[0,kk,:].reshape(28,28))
            plt.xticks([])
            plt.yticks([])
            if kk == 0:
                plt.ylabel('Pred') 
        fig_training = plt.gcf()
        plt.show()
        
        """ plot styling of replicates """
        """ Redatuming """
        nsplot = 5 # number of style plots
        content_code = symae.content_encode(trn_batch).numpy()
        style_code = symae.style_encode(trn_batch).numpy()
        plt.figure(figsize=(10,10), facecolor='w')
        for jj in np.arange(nsplot):
            styled_mean = symae.style_decode(content_code[[jj],:,:], style_code[[0],:,:])
            styled_mean = styled_mean.numpy()
            for kk in np.arange(nsplot):
                if jj == 0:
                    # plot the style of each instance inside the 0th bag
                    plt.subplot(nsplot+1, nsplot+1, kk+2)
                    plt.imshow(trn_batch[0,kk,:].reshape(28,28))
                    plt.xticks([]); plt.yticks([]);
                    if kk==0:
                        plt.ylabel('Style')
                else:
                    if kk==0:
                        # plot an instance to show the "content"
                        plt.subplot(nsplot+1, nsplot+1, jj*(nsplot+1)+kk+1)
                        plt.imshow(trn_batch[jj,0,:].reshape(28,28))
                        plt.xticks([]); plt.yticks([]);
                        if jj == 1:
                            plt.title('Content')
                    # style the jth bag with styles from the 0th bag
                    plt.subplot(nsplot+1, nsplot+1, jj*(nsplot+1)+kk+2)
                    plt.imshow(styled_mean[0,kk,:].reshape(28,28))
                    plt.xticks([]); plt.yticks([]);
        fig_redatuming = plt.gcf()
        plt.show()
        
    print('mean loss = %.3f' % trn_loss_metric.result().numpy())
    trn_loss_metric.reset_states()
    print('epoch running time = %.2fs' % (time.perf_counter()-start_time))

# Save weights

In [None]:
symae.save_weights('./checkpoint')

# Plot figures in our paper

In [None]:
nsplot = 5 # number of style plots
content_code = symae.content_encode(trn_batch).numpy()
style_code = symae.style_encode(trn_batch).numpy()
plt.figure(figsize=(10,10), facecolor='w')
for jj in np.arange(nsplot):
    styled_mean = symae.style_decode(content_code[[jj],:,:], style_code[[0],:,:])
    styled_mean = styled_mean.numpy()
    for kk in np.arange(nsplot):
        if jj == 0:
            # plot the style of each instance inside the 0th bag
            plt.subplot(nsplot+1, nsplot+1, kk+2)
            plt.imshow(trn_batch[0,kk,:].reshape(28,28), cmap='Greys')
            plt.xticks([]); plt.yticks([]);
            if kk==0:
                plt.ylabel('Style',fontsize=15)
        else:
            if kk==0:
                # plot an instance to show the "content"
                plt.subplot(nsplot+1, nsplot+1, jj*(nsplot+1)+kk+1)
                plt.imshow(trn_batch[jj,0,:].reshape(28,28), cmap='Greys')
                plt.xticks([]); plt.yticks([]);
                if jj == 1:
                    plt.title('Content',fontsize=15)
            # style the jth bag with styles from the 0th bag
            plt.subplot(nsplot+1, nsplot+1, jj*(nsplot+1)+kk+2)
            plt.imshow(styled_mean[0,kk,:].reshape(28,28), cmap='Greys')
            plt.xticks([]); plt.yticks([]);
fig_redatuming = plt.gcf()
plt.savefig("Mnist_style.pdf")
plt.show()
