In [None]:
import os
import sys
import argparse
import time

import imageio
import torch
from tqdm.notebook import tqdm
import torchvision.transforms as T
from PIL import Image

sys.path.append("../StyleCLIP_modular")
from style_clip import Imagine, create_text_path

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--gradient_accumulate_every", default=1, type=int)
parser.add_argument("--save_every", default=1, type=int)
parser.add_argument("--epochs", default=1, type=int)
parser.add_argument("--story_start_words", default=5, type=int)
parser.add_argument("--story_words_per_epoch", default=5, type=int)
parser.add_argument("--style", default="../stylegan2-ada-pytorch/VisionaryArt.pkl", type=str, choices=["faces (ffhq config-f)", "../stylegan2-ada-pytorch/VisionaryArt.pkl"])
parser.add_argument("--lr_schedule", default=0, type=int)
parser.add_argument("--start_image_steps", default=1000, type=int)
parser.add_argument("--iterations", default=100, type=int)
args = vars(parser.parse_args({}))

args["opt_all_layers"] = 1
args["lr_schedule"] = 1
args["noise_opt"] = 0
args["reg_noise"] = 0
args["seed"] = 1

args["model_type"] = "vqgan"
args["iterations"] = 200
args["save_every"] = 1
args["start_img_loss_weight"] = 0.0
args["batch_size"] = 16

args["lr"] = 0.1
args["neg_text"] = 'incoherent, confusing, cropped, watermarks'


#run(img="base_images/aicpa_logo_black.jpg", start_image_path="base_images/stance.jpg", args=args)
#run(img="base_images/aicpa_logo_black.jpg", start_image_path="base_images/earth.jpg", args=args)
#run(img="base_images/earth.jpg", start_image_path="base_images/aicpa_logo_black.jpg", args=args)

In [None]:
net = "conv" # conv, vqgan
args["sideX"] = 720
args["sideY"] = 540
args["start_image_steps"] = 10
args["iterations"] = 100
  

if net == "vqgan":
    args["model_type"] = "vqgan"
    args["lr"] = 0.1
    
elif net == "conv":
    args["model_type"] = "conv"
    args["act_func"] = "gelu"
    args["stride"] = 1
    args["num_layers"] = 5
    args["num_channels"] = 64
    args["downsample"] = True
    args["norm"] = "layer"
    args["lr"] = 0.005 #0.005 * (args["sideX"] * args["sideY"] / 480 / 480)
    args["num_channels"] = 3


In [None]:
imagine = Imagine(
                save_progress=False,
                open_folder=False,
                save_video=False,
                verbose=False,
                **args
               )

In [None]:
path_dict = {"logo_black": "base_images/aicpa_logo_black.jpg",
        "logo_purple": "base_images/aicpa_logo_purple.jpg",
        "earth": "base_images/earth.jpg",
        "stance": "base_images/stance.jpg",
        }
latent_dict = {}

for key in path_dict:
    print(key)
    path = path_dict[key]
    if net == "vqgan":
        img = Image.open(path).resize((512, 512))
        x = T.ToTensor()(img).unsqueeze(0).to(imagine.device).mul(2).sub(1)
        vqgan = imagine.model.model.model
        z, _, [_, _, indices] = vqgan.encode(x)
        latents = z
    else:
        imagine.start_image_path = path_dict[key]
        imagine.reset()
        img = imagine.prime_image()
        latents = imagine.model.model.latents.detach().cpu()
        imagine.start_image_path = None
        imagine.start_image = None
    latent_dict[key] = latents

In [None]:
text_dict = {"pride": "LQBTQA pride.",
             "rainbows": "Rainbows", 
             "rainbow_painting": "A painting of a rainbow.",
             "night": "A starry night.",
             "apocalypse": "Apocalypse",
             "psych": "A psychedelic experience",
             "death": "Death"
            }

imagine.iterations = 500
imagine.verbose = True
latent_text_dict = {}
for key in text_dict:
    print("Optimizing for ", key)
    text = text_dict[key]
    imagine.reset()
    imagine.set_clip_encoding(text=text)
    imagine()
    latents = imagine.model.model.get_latent(device="cpu")
    latent_text_dict[key] = latents
imagine.verbose= False


In [None]:
to_pil = T.ToPILImage()

def minmax(a):
    return (a - a.min()) / (a.max() - a.min())

def decode(imagine, latent):
    model = imagine.model.model
    orig_latents = model.get_latent()
    imagine.set_latent(latent)
    image = model(return_loss=False)
    image = image.detach().cpu()
    imagine.set_latent(orig_latents)

    return image

def gen(imagine, latent):
    image = decode(imagine, latent).squeeze(0)
    return to_pil(image)

In [None]:
img = gen(imagine, latent_text_dict["apocalypse"])
img

In [None]:
img = gen(imagine, latent_text_dict["pride"])
img

In [None]:
img = gen(imagine, latent_text_dict["night"])
img

In [None]:
img = gen(imagine, latent_text_dict["rainbow_painting"])
img

In [None]:
key = "rainbows"
#key = list(latent_text_dict.keys())[0]
img = gen(imagine, latent_text_dict[key])
import numpy as np
arr = np.array(img)
#print(arr.min(), arr.max())
img

In [None]:
key = "psych"
#key = list(latent_text_dict.keys())[0]
img = gen(imagine, latent_text_dict[key])
import numpy as np
arr = np.array(img)
print(arr.min(), arr.max())
img

In [None]:
print(latent_dict.keys())
print(latent_text_dict.keys())

In [None]:
# earth, logo_black, logo_purple, stance
prompt = "apocalypse"

mode = "translate_opt" # transition, translate_opt
latent = latent_dict[prompt] if prompt in latent_dict else latent_text_dict[prompt]

if mode == "transition":
    start = "earth"
    end = "logo_black"
    steps = 100
    start_latent, end_latent = latent_dict[start], latent_dict[end]
    # interpolate
    # Obtain evenly-spaced ratios between 0 and 1
    linspace = torch.linspace(0, 1, steps)
    # Generate arrays for interpolation using ratios
    latent_transition = [(1 - l) * start_latent + l * end_latent for l in linspace]
elif mode == "translate_opt":
    # how to make a loop:
    """
    Looping technique is roughly the same as https://twitter.com/genekogan/status/918513720481009666

    Notes how to have this whole thing looping:
        Gist is to regenerate each frame in the loop N times, initializing it 
        from a mixture of the previous (t-1) and next (t+1) frame,
        gradually interpolating from 100% t-1 to 100% t+1. A bit clunky, but it works.
    """
    
    total_steps = 500
    pixel_step_size_x = 0
    pixel_step_size_y = 0
    opt_steps = 5
    zoom_factor = 1.02 #0.9
    #zoom_factor = 1#1.02
    angle = 0
    shear = 0
    
    translate_settings = f"net-{net}_steps{total_steps}_move{pixel_step_size_x}-{pixel_step_size_y}_rot{angle}_zoom{zoom_factor}_optfor{opt_steps}"
    
    # setup starting latent and goal
    imagine.set_clip_encoding(text=prompt)#, img=path_dict[prompt], encoding=None)
    imagine.set_latent(latent)
    latent_transition = [latent]
    from torchvision.transforms.functional import InterpolationMode
    
    for _ in tqdm(range(total_steps)):
        # recreate img
        if net == "vqgan":
            img = decode(imagine, latent)
        elif net == "conv":
            img, params = imagine.model.model.get_latent(device="cpu")
        # transform it
        transformed = T.functional.affine(img, angle=angle, 
                                          translate=(pixel_step_size_x, pixel_step_size_y), 
                                          scale=zoom_factor, 
                                          shear=shear,
                                          interpolation=InterpolationMode.BILINEAR)
       
        if net == "vqgan":
            # encode it again
            latent, _, [_, _, indices] = vqgan.encode(transformed.to(imagine.device).mul(2).sub(1))
        elif net == "conv":
            latent = [transformed, params]
            
        # set latent in imagine properly such that it can be optimized by the optimzier
        imagine.set_latent(latent)
        
        for _ in range(opt_steps):
            imagine.train_step(0, 0)
            
        # get new latents
        latent = imagine.model.model.get_latent() #.latents.detach().cpu()
        # store latents
        latent_transition.append(latent)

In [None]:
# generate images from latents
def gen_imgs(imagine, latents):
    images = []
    for latent in tqdm(latent_transition):
        model = imagine.model.model
        imagine.set_latent(latent)
        image = model(return_loss=False)
        image = image.detach().cpu().squeeze(0).permute(1, 2, 0).clamp(0, 1) * 255
        image = image.type(torch.uint8).clamp(0, 255)
        images.append(image)
    return images
images = gen_imgs(imagine, latents)

In [None]:
folder = mode
os.makedirs(folder, exist_ok=1)

time_str = time.strftime("%X_%x", time.gmtime()).replace("/", "_")
path = os.path.join(folder, time_str + "_" + (f"{start}_to_{end}_{steps}.mp4" if mode == "transition" else f"{prompt}_{translate_settings}.mp4"))
imageio.mimwrite(path, images)

In [None]:
len(latent_transition)

In [None]:
len(latent_transition[0])

In [None]:
latent_transition[0][0].shape

In [None]:
latent_transition[0][0]

In [None]:
latent_transition[-1][0]