In [None]:
!apt install ffmpeg

In [1]:
import argparse
import os
import pickle
import re
import glob

import numpy as np
import PIL.Image
from PIL import Image
from cv2 import VideoWriter, VideoWriter_fourcc, imread

import dnnlib
import dnnlib.tflib as tflib

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
def generate_images(arrs, network_pkl, truncation_psi=1.0,noise_mode='const', outdir='out', save=True, seed=1):
    
    """
    Generates images from an array of latent vectors
    Saves to outdir if save==True
    Returns an array of nchw images 
    """
    
    tflib.init_tf()
    
    print('Loading networks from "%s"...' % network_pkl)
    with dnnlib.util.open_url(network_pkl) as fp:
        _G, _D, G = pickle.load(fp)
        
    os.makedirs(outdir, exist_ok=True)
    imgs=[]

    # Render images for dlatents initialized from random seeds.
    G_kwargs = {
        'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
        'randomize_noise': False,
        'truncation_psi': truncation_psi
    }

    noise_vars = [var for name, var in G.components.synthesis.vars.items() if name.startswith('noise')]
    label = np.zeros([1] + G.input_shapes[1][1:])
    
    for idx, w in enumerate(arrs):
        rnd = np.random.RandomState(seed)
        tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
        images = G.run(w, label, **G_kwargs) # [minibatch, height, width, channel]
        if save:
            PIL.Image.fromarray(images[0], 'RGB').save(f'{outdir}/{idx:04d}.png')
        
        print(f'Generated {idx}/{len(arrs)-1}')
        imgs.append(images[0])
    return imgs
#----------------------------------------------------------------------------

def generate_transition_latent(arr1, arr2, frames=60):
    
    '''
    Generates all latents vectors between two arrays.
    
    Final array does not include the last vector
    '''
    
    assert arr1.shape == arr2.shape
    
    assert arr1.shape
    
    delta = np.divide(np.subtract(arr2, arr1), frames)
    out = np.empty((frames, arr1.shape[1], arr1.shape[2]), dtype=float)
    
    for frame in range(frames):
        out[frame] = arr1 + np.multiply((frame), delta)
        
    return out

def transition_latent_from_key(arr): 
    
    latent_size = arr[0][0].shape
    
    for element in arr:
        assert type(element) is tuple
        assert len(element) == 2
        assert element[0].shape == latent_size
        latent_size == element[0].shape
    
    for i in range(len(arr)):
        
        try:
            transition_frame_set = generate_transition_latent(arr[i][0], arr[i+1][0], frames=arr[i][1])
        except IndexError:
            transition_frame_set = arr[i][0]
            
        
        try:
            output = np.vstack((transition_frame_set, output))

        except NameError:
            output = transition_frame_set
            
    return output

#----------------------------------------------------------------------------

def gif_from_folder(imgs_path: str, gif_path='gif/gif.gif', fps=30):
    assert imgs_path.endswith('/*.png')
    assert gif_path.endswith('.gif')
    
    print("Generating gif from images in " + imgs_path)
    
    
    img, *imgs = [Image.open(f) for f in sorted(glob.glob(imgs_path))]
    
    duration = len(imgs)/fps
    
    img.save(fp=gif_path, format='GIF', append_images=imgs,
             save_all=True, duration=duration, loop=0)
    
    print("Done")

    
#TODO: This shit sometimes works
def mp4_from_folder(imgs_path: str, mp4_path='video.mp4', fps=30):
    assert imgs_path.endswith('/*.png')
    assert mp4_path.endswith('.mp4')
    
    first = imread(next(x for x in glob.glob(imgs_path)))
    frameSize= (first.shape[0], first.shape[1])
    
    #out = VideoWriter(mp4_path,VideoWriter_fourcc(*'h264'), fps, frameSize)
    out = VideoWriter(mp4_path,VideoWriter_fourcc(*'MP4V'), fps, frameSize)
    
    for filename in sorted(glob.glob(imgs_path)):
        img = imread(filename)
        out.write(img)
        print(filename)

    out.release()

In [11]:
from matplotlib import pyplot as plt

network_pkl="out/00001-time-tf-res1024-auto1-resumecustom-freezed0/network-snapshot-000240.pkl"
a = np.random.RandomState(1).randn(1, 18,512)
b = np.random.RandomState(2).randn(1, 18,512)

arrs = generate_transition_latent(a, b, frames=60)

imgs = generate_images(arrs, network_pkl, truncation_psi=0.7, outdir='test_out')

"""
for img in imgs:
    plt.imshow(img)
    plt.show()
"""

Loading networks from "out/00001-time-tf-res1024-auto1-resumecustom-freezed0/network-snapshot-000240.pkl"...
Generated 0/59
Generated 1/59
Generated 2/59
Generated 3/59
Generated 4/59
Generated 5/59
Generated 6/59
Generated 7/59
Generated 8/59
Generated 9/59
Generated 10/59
Generated 11/59
Generated 12/59
Generated 13/59
Generated 14/59
Generated 15/59
Generated 16/59
Generated 17/59
Generated 18/59
Generated 19/59
Generated 20/59
Generated 21/59
Generated 22/59
Generated 23/59
Generated 24/59
Generated 25/59
Generated 26/59
Generated 27/59
Generated 28/59
Generated 29/59
Generated 30/59
Generated 31/59
Generated 32/59
Generated 33/59
Generated 34/59
Generated 35/59
Generated 36/59
Generated 37/59
Generated 38/59
Generated 39/59
Generated 40/59
Generated 41/59
Generated 42/59
Generated 43/59
Generated 44/59
Generated 45/59
Generated 46/59
Generated 47/59
Generated 48/59
Generated 49/59
Generated 50/59
Generated 51/59
Generated 52/59
Generated 53/59
Generated 54/59
Generated 55/59
Gener

'\nfor img in imgs:\n    plt.imshow(img)\n    plt.show()\n'

In [3]:
#gif_from_folder('test_out/*.png', 'gif.gif', fps=5)
mp4_from_folder('test_out/*.png')

test_out/0000.png
test_out/0001.png
test_out/0002.png
test_out/0003.png
test_out/0004.png
test_out/0005.png
test_out/0006.png
test_out/0007.png
test_out/0008.png
test_out/0009.png
test_out/0010.png
test_out/0011.png
test_out/0012.png
test_out/0013.png
test_out/0014.png
test_out/0015.png
test_out/0016.png
test_out/0017.png
test_out/0018.png
test_out/0019.png
test_out/0020.png
test_out/0021.png
test_out/0022.png
test_out/0023.png
test_out/0024.png
test_out/0025.png
test_out/0026.png
test_out/0027.png
test_out/0028.png
test_out/0029.png
test_out/0030.png
test_out/0031.png
test_out/0032.png
test_out/0033.png
test_out/0034.png
test_out/0035.png
test_out/0036.png
test_out/0037.png
test_out/0038.png
test_out/0039.png
test_out/0040.png
test_out/0041.png
test_out/0042.png
test_out/0043.png
test_out/0044.png
test_out/0045.png
test_out/0046.png
test_out/0047.png
test_out/0048.png
test_out/0049.png
test_out/0050.png
test_out/0051.png
test_out/0052.png
test_out/0053.png
test_out/0054.png
test_out/0