<a href="https://colab.research.google.com/github/Vlasovets/MB-GAN/blob/master/MBGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!git clone https://github.com/Vlasovets/MB-GAN.git

Cloning into 'MB-GAN'...
remote: Enumerating objects: 173, done.[K
remote: Counting objects: 100% (43/43), done.[K
remote: Compressing objects: 100% (43/43), done.[K
remote: Total 173 (delta 18), reused 8 (delta 0), pack-reused 130[K
Receiving objects: 100% (173/173), 250.12 MiB | 33.19 MiB/s, done.
Resolving deltas: 100% (57/57), done.
Checking out files: 100% (66/66), done.


In [1]:
%cd MB-GAN/

/content/MB-GAN


In [2]:
import pickle
import pandas as pd
import os
from model import *
from utils import *
from mbgan_train_demo import *
from functools import partial

from keras.layers import Input, Layer
from keras.models import Sequential, Model

In [3]:
def build_generator(input_shape, output_units, n_channels=512):
    """ build the generator model. """
    model = Sequential()

    model.add(Dense(n_channels, activation="relu", input_shape=input_shape))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Activation("relu"))
    model.add(Dense(n_channels))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Activation("relu"))
    model.add(Dense(n_channels))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Activation("relu"))
    model.add(Dense(output_units))
    model.add(Activation("softmax"))

    noise = Input(shape=input_shape)
    output = model(noise)

    return Model(noise, output)

In [4]:
def build_critic(input_shape, n_channels=256, dropout_rate=0.25, tf_matrix=None, t_pow=1000.):
    """ build the critic model. """
    model = Sequential()
    
    model.add(PhyloTransform(tf_matrix, input_shape=input_shape))
    model.add(Lambda(lambda x: K.log(1 + x * t_pow)/K.log(1 + t_pow))) #EM-distance
    model.add(Dense(n_channels))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(dropout_rate))
    model.add(Dense(n_channels))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(dropout_rate))
    model.add(Dense(n_channels))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(dropout_rate))
    model.add(Dense(1))
    
    inputs = Input(shape=input_shape)
    validity = model(inputs)
    
    return Model(inputs, validity)

In [5]:
class RandomWeightedAverage(Layer):
    """ Calculate a random weighted average between two tensors. """
    def _merge_function(self, inputs):
        batch_size = K.shape(inputs[0])[0]
        alpha = K.random_uniform((batch_size, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

In [6]:
FILE = "./data/raw_data.pkl"

data_o_case, data_o_ctrl, taxa_list = load_sample_pickle_data(FILE)

adj_matrix, taxa_indices = expand_phylo(taxa_list)

tf_matrix = adjmatrix_to_dense(adj_matrix, shape=(len(taxa_list), len(taxa_indices)))

In [None]:
model_config = {
        'ntaxa': 719,
        'latent_dim': 100,
        'generator': {'n_channels': 512},
        'critic': {'n_channels': 256, 'dropout_rate': 0.25, 
                   'tf_matrix': tf_matrix, 't_pow': 1000.}
    }
    
train_config = {
    'generator': {'optimizer': ('RMSprop', {}), 'lr': 0.00005},
    'critic': {'loss_weights': [1, 1, 10], 
                'optimizer': ('RMSprop', {}), 'lr': 0.00005},
}

# Subsample from the real data

In [1]:
real = data_o_case[np.random.randint(0, data_o_case.shape[0], 32)]
real.shape
# pd.DataFrame(real)

NameError: ignored

# Simulate some noise signal

In [None]:
noise = np.random.normal(0, 1, (32, model_config['latent_dim']))
noise.shape

# Determines fake sample from given noise

In [None]:
z = Input(shape=(model_config['latent_dim'],))
z.shape

In [None]:
generator = build_generator((model_config['latent_dim'],), model_config['ntaxa'])

In [None]:
critic = build_critic((model_config['ntaxa'],))

In [None]:
fake_sample = generator(z)
fake_sample.shape

In [None]:
fake = critic(fake_sample)

# Determines real sample

In [None]:
real_sample = Input(shape=(model_config['ntaxa'],))
valid = critic(real_sample)

# Determines weighted average between real and fake sample
        

In [None]:
interpolated_sample = RandomWeightedAverage()([real_sample, fake_sample])
validity_interpolated = critic(interpolated_sample)

# Get gradient penalty loss

In [None]:
partial_gp_loss = partial(gradient_penalty_loss, averaged_samples=interpolated_sample)
partial_gp_loss.__name__ = 'gradient_penalty'

# Construct critic computational graph

In [None]:
critic_graph = Model([real_sample, z])

In [None]:
#standard keras optimizers adam, etc.
optimizer = get_optimizer(train_config['critic']['optimizer'][0], lr=train_config['critic']['lr'])


In [None]:
loss_weights = train_config['critic']['loss_weights']

In [None]:
critic_graph.compile(loss=[wasserstein_loss, wasserstein_loss, partial_gp_loss],
            optimizer=optimizer, loss_weights=loss_weights,
        )

In [None]:
def construct_generator_graph(self):
        """ Construct computational graph for generator. """
        # Freeze the critic's layers while training the generator
        self.critic.trainable = False
        self.generator.trainable = True
        
        # Generate sample and update generator
        z = Input(shape=(self.latent_dim,))
        fake_sample = self.generator(z)
        valid = self.critic(fake_sample)
        
        # Construct generator computational graph
        self.generator_graph = Model(z, valid)
        optimizer = get_optimizer(self.train_config['generator']['optimizer'][0], 
                                  lr=self.train_config['generator']['lr'], 
                                  **self.train_config['generator']['optimizer'][1])
        self.generator_graph.compile(loss=wasserstein_loss, optimizer=optimizer)
    

In [None]:
critic.summary()

In [None]:
generator.summary()

In [None]:
NAME = "mbgan_case"
EXP_DIR = "NielsenHB_2014_stool"

mbgan = MBGAN(NAME, model_config, train_config)
mbgan.__dict__

In [None]:
# batch_size=32
# n_critic=5
# n_generator=1 
# save_interval=50
# save_fn=None
# experiment_dir="mbgan_train"
# verbose=0

# valid = -np.ones((32, 1))
# fake =  np.ones((32, 1))
# dummy = np.zeros((32, 1))

# for epoch in range(1, 5):
#             for _ in range(n_critic):
#                 # Randomly select a batch of samples to train the critic
#                 real = data_o_case[np.random.randint(0, data_o_case.shape[0], 32)]
#                 noise = np.random.normal(0, 1, (32, model_config['latent_dim']))
#                 d_loss = critic_graph.train_on_batch([real, noise], [valid, fake, dummy])
            
#             # for _ in range(n_generator):
#             #     #  Update the generator
#             #     noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
#             #     g_loss = self.generator_graph.train_on_batch(noise, valid)

#             # # Plot the progress
#             # log_info = [
#             #     "iter={:d}".format(epoch), 
#             #     "[D loss={:.6f}, w_loss_real={:.6f}, w_loss_fake={:.6f}, gp_loss={:.6f}]".format(*d_loss),
#             #     "[G loss={:.6f}]".format(g_loss),
#             # ]
#             # print("{} {} {}".format(*log_info))

# generator_graph = Model(z, valid)

# d_loss = critic_graph.train_on_batch([real, noise], [valid, fake, dummy])
# critic_graph.train_on_batch([real, noise], [valid, fake, dummy])