In [20]:
import torch

from IPython.display import IFrame
from torch import nn, tensor
import numpy as np
import matplotlib.pyplot as plt

# for generate latent walk
from tqdm import tqdm
import torchvision
from torchvision import transforms
import os
from PIL import Image
from datetime import datetime

PyTorch 1.10.0+cu111
Cuda Availability:  True


In [None]:
!pip install kornia==0.5.4
!pip install lightweight_gan==0.20.2;

In [3]:
# Sanity check pytorch successful input
print("PyTorch", torch.__version__)
print("Cuda Availability: ", torch.cuda.is_available())

# Mount Google Drive
from google.colab import drive
drive.mount('/content/table', force_remount=True)

Mounted at /content/table


In [12]:
from lightweight_gan import lightweight_gan, Trainer
from pathlib import Path

# Replace these parameters with your own case
ckpt_path = "/content/table/MyDrive/GenerativeGame/StyleGAN/model_9.pt"
latent_walk_dir = Path("/content/table/MyDrive/GenerativeGame/StyleGAN/ButterflyGAN")

# Please do not change the following parameters unless you modified them:
data = '/content/table/MyDrive/ML_Table/butterfly'
results_dir = '/content/table/MyDrive/GenerativeGame/StyleGAN/'
models_dir = '/content/table/MyDrive/GenerativeGame/StyleGAN/'
name = 'ButterflyGAN'
new = False
load_from = -1
image_size = 512
optimizer = 'adam'
fmap_max = 512
transparent = True
greyscale = False
batch_size = 10
gradient_accumulate_every = 4
num_train_steps = 10000
learning_rate = 2e-4
save_every = 1000
evaluate_every = 1000
generate = True
generate_types = ['default', 'ema']
generate_interpolation = True
aug_test = False
aug_prob=None
aug_types=[]
dataset_aug_prob=0.
attn_res_layers = [32]
freq_chan_attn = False
disc_output_size = 1
dual_contrast_loss = False
antialias = False
interpolation_num_steps = 100
save_frames = False
num_image_tiles = None
num_workers = None
multi_gpus = False
calculate_fid_every = None
calculate_fid_num_images = 12800
clear_fid_cache = False
seed = 42
amp = False
show_progress = False

In [13]:
def cast_list(el):
    return el if isinstance(el, list) else [el]
    
model = Trainer(name = name,
    results_dir = results_dir, 
    models_dir = models_dir, 
    batch_size = batch_size,
    gradient_accumulate_every = gradient_accumulate_every,\
    attn_res_layers = cast_list(attn_res_layers),\
    freq_chan_attn = freq_chan_attn,\
    disc_output_size = disc_output_size,\
    dual_contrast_loss = dual_contrast_loss,\
    antialias = antialias,\
    image_size = image_size,\
    num_image_tiles = num_image_tiles,\
    optimizer = optimizer,\
    num_workers = num_workers,\
    fmap_max = fmap_max,\
    transparent = transparent,\
    greyscale = greyscale,\
    lr = learning_rate,\
    save_every = save_every,\
    evaluate_every = evaluate_every,\
    aug_prob = aug_prob,\
    aug_types = cast_list(aug_types),\
    dataset_aug_prob = dataset_aug_prob,\
    calculate_fid_every = calculate_fid_every,\
    calculate_fid_num_images = calculate_fid_num_images,\
    clear_fid_cache = clear_fid_cache,\
    amp = amp)

model.load_config()
LLG = model.GAN
load_data = torch.load(ckpt_path)

LLG.load_state_dict(load_data["GAN"])

<All keys matched successfully>

In [14]:
print(model.transparent)
LLG.eval()

True


LightweightGAN(
  (G): Generator(
    (initial_conv): Sequential(
      (0): ConvTranspose2d(256, 512, kernel_size=(4, 4), stride=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GLU(dim=1)
    )
    (layers): ModuleList(
      (0): ModuleList(
        (0): Sequential(
          (0): Upsample(scale_factor=2.0, mode=nearest)
          (1): Identity()
          (2): Conv2d(256, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (4): GLU(dim=1)
        )
        (1): None
        (2): None
      )
      (1): ModuleList(
        (0): Sequential(
          (0): Upsample(scale_factor=2.0, mode=nearest)
          (1): Identity()
          (2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (4):

In [23]:
name = str(datetime.now())
num_image_tiles = 1
num_steps = 100
save_frames = True

num_rows = num_image_tiles

latent_dim = LLG.latent_dim
image_size = LLG.image_size

# latents and noise

global r
global g
global b
global pil_image

latents_low = torch.randn(num_rows ** 2, latent_dim).cuda(model.rank)
latents_high = torch.randn(num_rows ** 2, latent_dim).cuda(model.rank)

ratios = torch.linspace(0., 8., num_steps)

frames = []
heights = []
for ratio in tqdm(ratios):
    interp_latents = lightweight_gan.slerp(ratio, latents_low, latents_high)
    generated_images = model.generate_(model.GAN.GE, interp_latents).cpu()
    # print(generated_images.shape)
    # print(interp_latents.shape)
    images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows)
    # plt.imshow(images_grid.squeeze(0).reshape(256,256,4))
    pil_image = transforms.ToPILImage()(images_grid)
    print(pil_image.size)
    
    if model.transparent:
        img_array = np.array(pil_image)
        pil_image = Image.fromarray(img_array[:,:,:3])
        displacement = Image.fromarray(img_array[:,:,3])
        heights.append(displacement)

    frames.append(pil_image)

print(type(model.results_dir))
if not os.path.exists(latent_walk_dir):
    os.mkdir(latent_walk_dir)

frames[0].save(str(latent_walk_dir / f'{str(name)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True)

if save_frames:
    folder_path = (latent_walk_dir / f'{str(name)}')
    folder_path.mkdir(parents=True, exist_ok=True)
    for ind, frame in enumerate(frames):
        frame.save(str(folder_path / f'rgb_{str(ind)}.jpg'))
    if model.transparent:
        for ind, height in enumerate(heights):
            height.save(str(folder_path / f'height_{str(ind)}.png'))

  8%|▊         | 8/100 [00:00<00:02, 36.06it/s]

(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)


 18%|█▊        | 18/100 [00:00<00:01, 42.50it/s]

(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)


 23%|██▎       | 23/100 [00:00<00:01, 43.63it/s]

(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)


 28%|██▊       | 28/100 [00:00<00:01, 44.07it/s]

(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)


 42%|████▏     | 42/100 [00:01<00:01, 37.02it/s]

(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)


 47%|████▋     | 47/100 [00:01<00:01, 38.43it/s]

(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)


 57%|█████▋    | 57/100 [00:01<00:01, 41.33it/s]

(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)


 66%|██████▌   | 66/100 [00:01<00:01, 33.56it/s]

(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)


 76%|███████▌  | 76/100 [00:02<00:00, 39.21it/s]

(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)


 86%|████████▌ | 86/100 [00:02<00:00, 41.83it/s]

(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)


 96%|█████████▌| 96/100 [00:02<00:00, 44.12it/s]

(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)
(512, 512)


100%|██████████| 100/100 [00:02<00:00, 39.25it/s]


(512, 512)
(512, 512)
(512, 512)
<class 'pathlib.PosixPath'>
