In [1]:
import tensorflow as tf
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import imageio
from skimage import img_as_ubyte, io
import keras
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Activation, LeakyReLU, concatenate, Embedding, Dense,Input,Reshape,Dropout,LeakyReLU,Flatten,BatchNormalization,Conv2D,Conv2DTranspose,Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
# from scipy.stats import norm
tf.compat.v1.disable_eager_execution()

In [None]:
!nvidia-smi

In [None]:
!unzip /content/mparticles.zip -d /content/

In [None]:
!pip install split-folders

In [None]:
import splitfolders  # or import split_folders

# Split with a ratio.
# To only split into training and validation set, set a tuple to `ratio`, i.e, `(.8, .2)`.
splitfolders.ratio('/content/mparticles', output='/content/mparticles-split', seed=1337, ratio=(.7, .15, .15), group_prefix=None) # default values

In [3]:
categories = ['CS', 'MC', 'SS']
# categories = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
# data_directory = '/data1/akadoura/Experiments/fashion-split/train/'
data_directory = '/content/mparticles-split/train/'


training_data = []

def create_training_data():
    for category in categories:
        path = os.path.join(data_directory, category)
        class_num = categories.index(category)
        for img in os.listdir(path):
            try:
                img_array = cv2.imread(os.path.join(path, img), cv2.IMREAD_GRAYSCALE)
                resized_array = cv2.resize(img_array, (48,48))
                training_data.append([resized_array, class_num])
            except Exception as e:
                pass

In [None]:
create_training_data()
print(len(training_data))

In [5]:
import random
random.shuffle(training_data)

x_train = []
y_train = []

for features, label in training_data:
    x_train.append(features)
    y_train.append(label)
    

x_train = np.array(x_train).reshape(-1, 48, 48, 1)
x_train = x_train/255
# x_train = x_train * 2. - 1.
# x_train = x_train.reshape(x_train.shape[0], -1)

y_train = np.array(y_train).reshape(-1)
y_train = to_categorical(y_train)
# x_train = x_train.reshape(-1, 48, 48, 1) * 2. - 1.

In [None]:
y_train.shape

In [7]:
batch_size = 50
latent_dim = 32

In [8]:
def sampling(args):
    mu, log_var = args
    eps = K.random_normal(shape=(batch_size, latent_dim), mean=0., stddev=1.0)
    return mu + K.exp(log_var/2.) * eps

In [9]:
x = Input(shape=(48,48,1))
condition = Input(shape=(y_train.shape[1],))

upsample = Dense(48*48*1)(condition)
upsample = Reshape((48,48,1))(upsample)

inputs = concatenate([x, upsample])

cx = Conv2D(filters=64, kernel_size=(3,3), strides=2, padding='same')(inputs)
cx = LeakyReLU(0.2)(cx)
cx = Dropout(0.4)(cx)

cx = Conv2D(filters=64, kernel_size=(3,3), padding='same')(cx)
cx = LeakyReLU(0.2)(cx)
cx = Dropout(0.4)(cx)

cx = Conv2D(filters=128, kernel_size=(3,3), strides=2, padding='same')(cx)
cx = LeakyReLU(alpha=0.2)(cx)
cx = Dropout(0.4)(cx)

cx = Conv2D(filters=128, kernel_size=(3,3), padding='same')(cx)
cx = LeakyReLU(alpha=0.2)(cx)
cx = Dropout(0.4)(cx)

cx = Conv2D(filters=256, kernel_size=(3,3), strides=2, padding='same')(cx)
cx = LeakyReLU(alpha=0.2)(cx)
cx = Dropout(0.4)(cx)

f = Flatten()(cx)
x_encoded = Dense(256)(f)
x_encoded = LeakyReLU(alpha=0.2)(x_encoded)

mu = Dense(latent_dim, activation='linear')(x_encoded)
log_var = Dense(latent_dim, activation='linear')(x_encoded)

In [10]:
z = Lambda(sampling, output_shape=(latent_dim,))([mu, log_var])
z_cond = concatenate([z, condition])
                      
encoder = Model([x, condition], z_cond)

In [None]:
encoder.summary()

In [12]:
di = Input(shape=(z_cond.shape[1],))

gen = Dense(6*6*256)(di)
gen = LeakyReLU(alpha=0.2)(gen)
gen = Reshape((6,6, 256))(gen)

dx = Conv2DTranspose(filters=256, kernel_size=(4,4), strides=2, padding='same')(gen)
dx = BatchNormalization(momentum=0.8)(dx)
dx = LeakyReLU(alpha=0.2)(dx)

dx = Conv2D(filters=128, kernel_size=(4,4), padding='same')(dx)
dx = BatchNormalization(momentum=0.8)(dx)
dx = LeakyReLU(alpha=0.2)(dx)

dx = Conv2DTranspose(filters=128, kernel_size=(4,4), strides=2, padding='same')(dx)
dx = BatchNormalization(momentum=0.8)(dx)
dx = LeakyReLU(alpha=0.2)(dx)

dx = Conv2D(filters=64, kernel_size=(4,4), padding='same')(dx)
dx = BatchNormalization(momentum=0.8)(dx)
dx = LeakyReLU(alpha=0.2)(dx)

dx = Conv2DTranspose(filters=64, kernel_size=(4,4), strides=2, padding='same')(dx)
dx = BatchNormalization(momentum=0.8)(dx)
dx = LeakyReLU(alpha=0.2)(dx)

y = Conv2D(filters=1, kernel_size=(7,7), padding='same', activation='sigmoid')(dx)

decoder = Model(di, y)

In [None]:
decoder.summary()

In [None]:
def vae_loss(true, pred):
    reconstruction_loss = keras.losses.binary_crossentropy(K.flatten(true), K.flatten(pred)) * 48 * 48
    kl_loss = 1 + log_var - K.square(mu) - K.exp(log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    return K.mean(reconstruction_loss + kl_loss)

    
cvae_outputs = decoder(encoder([x,condition]))
cvae = Model([x, condition], cvae_outputs, name='cvae')

cvae.compile(optimizer='adam', loss=vae_loss)

cvae.summary()

In [None]:
history = cvae.fit([x_train, y_train], x_train,
       epochs=10,
       batch_size=50, verbose=1)

decoder.save('/content/decoder-particles.h5')
encoder.save('/content/encoder-particles.h5')
cvae.save('/content/cvae-particles.h5')



In [None]:
latent_space = encoder.predict([x_train, y_train], batch_size=batch_size)
plt.figure(figsize=(6, 6))
plt.scatter(latent_space[:, 0], latent_space[:, 1], c=np.argmax(y_train, axis=1))
plt.colorbar()
plt.show()

In [17]:
labels = to_categorical(1, 3).reshape(1,-1)

In [None]:
for i in range(10):
    plt.figure(figsize=(6, 1))
    z_sample = np.random.normal(0, 1, 32).reshape(1,-1)
    x_decoded = decoder.predict(np.column_stack([z_sample, labels]))
    x_decoded = x_decoded * 255
    image = x_decoded[0].reshape(48, 48)
    # plt.imshow(image, cmap='gray',)
    # plt.show()
    img_name = f'generated-cvae-{i}.png'
    imageio.imwrite('/content/imgs/'+img_name, image)