In [None]:
""" live (realtime) latent space interpolations of trained models """

import argparse
import torch as th
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from MSG_GAN.GAN import Generator, Discriminator
from generate_multi_scale_samples import progressive_upscaling
from torchvision.utils import make_grid
from math import ceil, sqrt
from scipy.ndimage import gaussian_filter

from matplotlib import animation, rc
from IPython.display import HTML

In [None]:
# create the device for running the demo:
#device = th.device("cuda" if th.cuda.is_available() else "cpu")
device = th.device("cpu")


In [None]:
def adjust_dynamic_range(data, drange_in=(-1, 1), drange_out=(0, 1)):
    """
    adjust the dynamic colour range of the given input data
    :param data: input image data
    :param drange_in: original range of input
    :param drange_out: required range of output
    :return: img => colour range adjusted images
    """
    if drange_in != drange_out:
        scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (
                np.float32(drange_in[1]) - np.float32(drange_in[0]))
        bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)
        data = data * scale + bias
    return th.clamp(data, min=0, max=1)

In [None]:
def get_image(gen, point):
    """
    obtain an All-resolution grid of images from the given point
    :param gen: the generator object
    :param point: random latent point for generation
    :return: img => generated image
    """
    images = list(map(lambda x: x.detach(), gen(point)))[1:]
    images = [adjust_dynamic_range(image) for image in images]
    images = progressive_upscaling(images)
    images = list(map(lambda x: x.squeeze(dim=0), images))
    image = make_grid(
        images,
        nrow=int(ceil(sqrt(len(images))))
    )
    return image.cpu().numpy().transpose(1, 2, 0)

In [None]:
class Object(object):
    pass

args = Object()
args.generator_file = 'models/005/GAN_GEN_007000.pth'
args.discriminator_file = 'models/005/GAN_DIS_007000.pth'
args.depth = 5 #5 == 64x64
args.latent_size = 512
args.num_points = 1
args.transition_points = 120
args.smoothing = 1.0

In [None]:
# load the model for the demo
gen = th.nn.DataParallel(
    Generator(
        depth=args.depth,
        latent_size=args.latent_size))
gen.load_state_dict(th.load(args.generator_file, map_location=str(device)))
gen.cuda()

dis = Discriminator(depth=args.depth, feature_size=args.latent_size, gpu_parallelize=True)
dis.load_state_dict(th.load(args.discriminator_file, map_location=str(device)))
dis.cuda()

In [None]:
# generate the set of points:
total_frames = args.num_points * args.transition_points
#all_latents = th.randn(total_frames, args.latent_size).to(device)

def normalize(vec):    
    return (vec / vec.norm(dim=-1, keepdim=True)) * sqrt(args.latent_size)

variance = 10.0

"""
dim1 = 10
dim2 = 511
start_latent = normalize(th.ones([args.transition_points, args.latent_size]).to(device))
end_latent = th.ones([args.latent_size])
end_latent[dim1:dim2] = 0.0
end_latent = normalize(end_latent.expand(args.transition_points, args.latent_size).to(device))
"""
#end_latent = th.ones([args.transition_points, args.latent_size]).to(device))

start_latent = normalize(th.randn(args.latent_size).to(device))
end_latent = normalize(start_latent+variance*normalize(th.randn(args.latent_size)).to(device))

linear_ramp = th.linspace(0, 1, args.transition_points).unsqueeze(1).expand(args.transition_points, args.latent_size).to(device)

all_latents = start_latent * linear_ramp + end_latent * (1.0 - linear_ramp)

#all_latents=normalize(all_latents)

#all_latents = th.from_numpy(
#    gaussian_filter(
#        all_latents.cpu(),
#        [args.smoothing * args.transition_points, 0], mode="wrap"))

#all_latents = (all_latents /
#               all_latents.norm(dim=-1, keepdim=True)) * sqrt(args.latent_size)

print(all_latents)

start_point = th.unsqueeze(all_latents[0], dim=0)
points = all_latents[1:]

fig, ax = plt.subplots()
plt.axis("off")
shower = plt.imshow(get_image(gen, start_point))

print('all_latents size: {}'.format(all_latents.size()))
print(start_point[0][0:8])
print(points[-1][0:8])

start_img = gen(start_point)
score1 = dis(start_img)
score2 = dis(gen(th.unsqueeze(points[-1], dim=0)))
print('start_img: {} {}'.format(len(start_img),start_img[-1].size()))
print('score: {} -> {}'.format(score1[0], score2[0]))

def init():
    return shower,

def update(point):
    shower.set_data(get_image(gen, th.unsqueeze(point, dim=0)))
    return shower,

# define the animation function
anim = FuncAnimation(fig, update, frames=points,
                    init_func=init, interval=20, blit=True)
HTML(anim.to_jshtml())