In [13]:
import tensorflow as tf
%matplotlib notebook
%matplotlib inline
import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio
import cv2
import datetime
import random
from tensorflow.keras.layers import Dense, GRU, TimeDistributed

from IPython import display

latent_dim = 50

In [7]:
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.InteractiveSession(config=config)

In [8]:
from convolutional_vae import CVAE
cvae = CVAE()
cvae.load_weights('models/cvae_lat50.chkpt')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fecfd802090>

In [9]:
def VideoGenerator(video_path):
    while True:
        vidcap = cv2.VideoCapture(video_path)
        success, img = vidcap.read()

        while success:
            yield cv2.resize(img, (128, 128)).astype(np.float32) / 255
            success, img = vidcap.read()

train_gen = VideoGenerator('data/train.mp4')
test_gen = VideoGenerator('data/test.mp4')


def batches_generator(batch_size, frames_generator, buffer_size=3000):
    buffer = []
    while True:
        while len(buffer) < buffer_size:
            buffer.append(next(frames_generator))
        random.shuffle(buffer)
        batch, buffer = buffer[:batch_size], buffer[batch_size:]
        yield np.array(batch)

In [10]:
class MDN_RNN(tf.keras.Model):
    def __init__(self, n_mixtures, out_dims):
        super(MDN, self).__init__()
        
        self.n_mixtures = n_mixtures
        self.out_dims = out_dims
        self.fc1 = TimeDistributed(Dense(128, activation='tanh'))
        self.gru = GRU(128)
        self.fc2 = Dense(3 * n_mixtures * out_dims, activation='tanh')

    def predict_distribution(self, x):
        ''' 
            x.shape = [batch_size, seq_length, latent_dim]
        '''
        x = self.fc1(x)
        x = self.gru(x)
        x = self.fc2(x)
        mean, log_std, alpha = tf.split(x, 3, axis=1)
        
        mean = tf.reshape(mean, [-1, self.n_mixtures])
        log_std = tf.reshape(log_std, [-1, self.n_mixtures])
        alpha = tf.reshape(alpha, [-1, self.n_mixtures])
        alpha = tf.nn.softmax(alpha, axis=1)

        return mean, log_std, alpha
    
    def sample(self, x, temperature=1.0, verbose=False):
        # Predicting distribution
        mean, log_std, alpha = self.predict_distribution(x)
        std = tf.exp(log_std) * np.sqrt(temperature)
#         print(alpha)
#         alpha = alpha / temperature
#         alpha -= tf.reduce_max(alpha)
#         alpha = tf.nn.softmax(alpha, axis=-1)
#         print(alpha)
        
        
        # Picking component
        rnd = tf.random.uniform(shape=[x.shape[0] * self.out_dims, 1], maxval=1)
        rnd = tf.repeat(rnd, self.n_mixtures, axis=-1)
        pdf = tf.cumsum(alpha, axis=-1)
        component_idx = np.argmax(pdf > rnd, axis=-1)
        
        # Gathering std and mean
        idx_flattened = tf.range(0, mean.shape[0]) * mean.shape[1] + component_idx
        component_std = tf.gather(tf.reshape(std, [-1]), idx_flattened)
        component_mean = tf.gather(tf.reshape(mean, [-1]), idx_flattened)

        # Sampling
        samples = tf.random.normal(shape=[x.shape[0] * self.out_dims])
        samples = samples * component_std + component_mean
        samples = tf.reshape(samples, [-1, self.out_dims])
        
        return samples
    
    def call(self, x):
        return self.sample(x)