# Sample Run with DCtrVAE for CelebA dataset

In [1]:
import os
os.chdir("../")

In [2]:
import reptrvae
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt

Using TensorFlow backend.


In [3]:
condition_key = "condition"
label_key = "labels"
target_labels = [-1]
target_conditions = [1]

In [12]:
adata = sc.read("./data/celeba/celeba_Smiling_64x64_50000.h5ad")
adata

AnnData object with n_obs × n_vars = 50000 × 12288 
    obs: 'labels', 'condition'

In [13]:
adata.X /= 255.0

In [14]:
adata.X.min(), adata.X.max()

(0.0, 1.0)

In [15]:
train_adata, valid_adata = reptrvae.utils.train_test_split(adata, 0.90)
train_adata.shape, valid_adata.shape

((45000, 12288), (5000, 12288))

In [16]:
net_train_adata = train_adata[~((train_adata.obs[condition_key].isin(target_conditions)) & (train_adata.obs[label_key].isin(target_labels)))]
net_valid_adata = valid_adata[~((valid_adata.obs[condition_key].isin(target_conditions)) & (valid_adata.obs[label_key].isin(target_labels)))]

In [17]:
net_train_adata.shape, net_valid_adata.shape

((30217, 12288), (3309, 12288))

In [18]:
net_train_adata.obs.groupby([condition_key, label_key]).size()

condition  labels
-1         -1        11255
            1        10802
 1          1         8160
dtype: int64

In [20]:
input_shape = (64, 64, 3)
n_conditions = len(net_train_adata.obs[condition_key].unique().tolist())

In [22]:
network = reptrvae.models.DCtrVAE(x_dimension=input_shape,
                                  z_dimension=50,
                                  mmd_dimension=64,
                                  alpha=1e-3,
                                  beta=10,
                                  gamma=0.0,
                                  eta=10.,
                                  model_path="./models/DCtrVAE/thick_thin_mnist/",
                                  dropout_rate=0.3,
                                  arch_style=1,
                                  n_conditions=n_conditions,
                                  )

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_labels (InputLayer)     (None, 2)            0                                            
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 128)          384         encoder_labels[0][0]             
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 4096)         528384      dense_1[0][0]                    
__________________________________________________________________________________________________
data (InputLayer)    

In [23]:
condition_encoder = {-1: 0, 1:0}

In [None]:
network.train(net_train_adata, 
              net_valid_adata,
              condition_encoder,
              condition_key,
              n_epochs=10000,
              batch_size=512,
              early_stop_limit=250,
              lr_reducer=200, 
              shuffle=True, 
              save=True,
              verbose=2)

In [24]:
women_adata = adata[adata.obs[label_key].isin(target_labels)]
women_adata

View of AnnData object with n_obs × n_vars = 28956 × 12288 
    obs: 'labels', 'condition'

In [25]:
women_condition_adata = women_adata.obs[condition_key].unique().tolist()
women_condition_adata

[1, -1]

# Out-of-Sample results visualization

In [26]:
path_to_save = f"./results/CelebA/"
os.makedirs(path_to_save, exist_ok=True)
os.makedirs(os.path.join(path_to_save, f"put_smile/"), exist_ok=True)

In [32]:
k = 5
for i in range(10):
    for target_condition in target_conditions:
        plt.close("all")
        fig, ax = plt.subplots(k, 2, figsize=(k, 15))
        for j in range(k):
            unhappy_women_adata = women_adata[women_adata.obs[condition_key] == -1] 
            happy_women_adata = women_adata[women_adata.obs[condition_key] == target_condition]

            woman_idx = np.random.choice(unhappy_women_adata.shape[0], 1)[0]

            sample_unhappy_woman_adata = unhappy_women_adata[woman_idx]

            sample_unhappy_woman = np.reshape(sample_unhappy_woman_adata.X, (64, 64, 3))

            encoder_label = np.zeros((1, )) + condition_encoder[-1]
            decoder_label = np.zeros((1, )) + condition_encoder[target_condition]

            pred_adata = network.predict(sample_unhappy_woman_adata, encoder_label, decoder_label)
            pred_image = pred_adata.X.reshape((64, 64, 3))

            ax[j, 0].imshow(sample_unhappy_woman)
            ax[j, 0].axis('off')
            ax[0, 0].set_title("Un happy")

            ax[j, 1].imshow(pred_image)
            ax[j, 1].axis('off')
            ax[0, 1].set_title(f"Happy!")
#         plt.show()
        plt.savefig(os.path.join(path_to_save, f"put_smile/sample_images_{i}.pdf"), dpi=200)
        plt.close('all')

In [41]:
sc.pp.neighbors(adata)
sc.tl.umap(adata)

         Falling back to preprocessing with `sc.pp.pca` and default params.


In [43]:
sc.set_figure_params(dpi=200)

In [None]:
sc.pl.umap(adata, color=[condition_key, label_key], wspace=0.3)

In [35]:
encoder_labels, _ = reptrvae.utils.label_encoder(adata,
                                                  condition_key=condition_key, 
                                                  label_encoder=condition_encoder)
encoder_labels.shape

(50000, 1)

In [36]:
latent_adata = network.to_latent(adata, encoder_labels=encoder_labels)
latent_adata

AnnData object with n_obs × n_vars = 50000 × 50 
    obs: 'labels', 'condition'

In [57]:
sc.pp.neighbors(latent_adata)
sc.tl.umap(latent_adata)

In [None]:
sc.pl.umap(latent_adata, color=[condition_key, label_key], wspace=0.3)

In [37]:
mmd_adata = network.to_mmd_layer(adata, encoder_labels, feed_fake=-1)
mmd_adata

AnnData object with n_obs × n_vars = 50000 × 64 
    obs: 'labels', 'condition'

In [61]:
sc.pp.neighbors(mmd_adata)
sc.tl.umap(mmd_adata)

         Falling back to preprocessing with `sc.pp.pca` and default params.


In [None]:
sc.pl.umap(mmd_adata, color=[condition_key, label_key], wspace=0.3)