In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import torchvision.transforms as T
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from bendable_gan import BendedGenerator
from bending_modules import BendingConvModule, BendingConvModule_XY, \
    BendingCPPN, BendingDiffSort, BendingDiffSort_XY, ConcatenatedModules
from losses import compute_diversity_loss
from utils import generate_image, generate_image_from_seed, image_grid
from clip import TextPrompt, NCELoss
import gc

%load_ext autoreload
%autoreload 2

device = 'cuda'

In [None]:
seeds = [i for i in range(16)]

In [None]:
vanillagen = BendedGenerator.from_pretrained("ceyda/butterfly_cropped_uniq1K_512")

In [None]:
exampleimgs = [generate_image_from_seed(vanillagen, seed=seeds[i]) for i in range(16)]
for i, img in enumerate(exampleimgs):
    img.save(f"butterfly_vanilla_{seeds[i]}.pdf")

In [None]:
image_grid(exampleimgs, 4, 4)

## CLIP Loss (+ diversity losses)

In [None]:
# Create new bending module to optimize
# with CLIP loss

numchans = [1024, 1024, 512, 256, 128, 64, 6]

bending_idx = 5

bendingmod_clip = BendingConvModule(numchans[bending_idx],
                                    act_fn='sin')

bend_generator_clip = BendedGenerator.from_pretrained("ceyda/butterfly_cropped_uniq1K_512",
                                                 bending_module=bendingmod_clip,
                                                 bending_idx=bending_idx,
                                                 train_bending=True)
bend_generator_clip = bend_generator_clip.to(device)


tgt_text = 'Low-poly rendering of Benjamin Franklin going to Venice'
tgt_text = 'Dinosaur tiffany lamp'
tgt_text = 'A tree painted by Cezanne'
text_prompt = TextPrompt(tgt_text, device=device)
nce_loss = NCELoss(tgt_text, device=device, temperature=0.001)

In [None]:
torch.cuda.empty_cache()

import random
random.seed(23456)
np.random.seed(54321)
torch.manual_seed(12345)

batch_size = 32

n_iter = 1000

div_loss = False
div_weight = 6.
div_loss_clip = True
div_clip_weight = 6.

opt = Adam(bendingmod_clip.parameters(), 1e-3)

loss_log = []

for i in tqdm(range(n_iter)):
    
    noise_input = torch.randn(batch_size, 
                    bend_generator_clip.latent_dim, 
                    device=device)
    
    out, b_in, _ = bend_generator_clip(noise_input, return_inout=True)
    out = out.clamp_(0., 1.)
        
    if div_loss_clip:
        loss, clip_div = text_prompt(out, diversity=True)
        loss += div_clip_weight * clip_div
    else:
        loss = nce_loss(out) #text_prompt(out)
    if div_loss:
        loss += div_weight * compute_diversity_loss(out, b_in)
    
    loss_log.append(loss.detach().cpu().numpy())

    with torch.no_grad():
        loss.backward()
        opt.step()
        opt.zero_grad()
    
    
plt.plot(range(n_iter), loss_log)   

In [None]:
exampleimgs = [generate_image_from_seed(bend_generator_clip, seed=seeds[i]) for i in range(16)]

ts = tgt_text.replace(" ", "_")
for i, img in enumerate(exampleimgs):
    img.save(f"butterfly_clip_{ts}_bendindex_{bending_idx}_{seeds[i]}.pdf")
image_grid(exampleimgs, 4, 4)

In [None]:
exampleimgs = [generate_image_from_seed(bend_generator_clip, seed=seeds[i]) for i in range(16)]

ts = tgt_text.replace(" ", "_")
for i, img in enumerate(exampleimgs):
    img.save(f"butterfly_clip_{ts}_bendindex_{bending_idx}_{seeds[i]}.pdf")
image_grid(exampleimgs, 4, 4)

In [None]:
exampleimgs = [generate_image_from_seed(bend_generator_clip, seed=seeds[i]) for i in range(16)]

ts = tgt_text.replace(" ", "_")
for i, img in enumerate(exampleimgs):
    img.save(f"butterfly_clip_{ts}_bendindex_{bending_idx}_{seeds[i]}.pdf")
image_grid(exampleimgs, 4, 4)

In [None]:
exampleimgs = [generate_image_from_seed(bend_generator_clip, seed=seeds[i]) for i in range(16)]

ts = tgt_text.replace(" ", "_")
for i, img in enumerate(exampleimgs):
    img.save(f"butterfly_clip_{ts}_bendindex_{bending_idx}_{seeds[i]}.pdf")
image_grid(exampleimgs, 4, 4)

In [None]:
exampleimgs = [generate_image_from_seed(bend_generator_clip, seed=seeds[i]) for i in range(16)]
image_grid(exampleimgs, 4, 4)

In [None]:
del bendingmod_clip, bend_generator_clip, text_prompt, nce_loss, noise_input#, b_in, b_out

gc.collect()
torch.cuda.empty_cache()

## Convolutional with coordinates

In [None]:
# Create new bending module to optimize
# with CLIP loss

device = 'cuda'
numchans = [1024, 1024, 512, 256, 128, 64, 6]
inputsizes = [8, 16, 32, 64, 128, 256, 512]

bending_idx = 5

bendingmod_clip = BendingConvModule_XY(numchans[bending_idx],
                                       inputsizes[bending_idx])

bend_generator_clip = BendedGenerator.from_pretrained("ceyda/butterfly_cropped_uniq1K_512",
                                                 bending_module=bendingmod_clip,
                                                 bending_idx=bending_idx,
                                                 train_bending=True)
bend_generator_clip = bend_generator_clip.to(device)

tgt_text = 'A gang of biker pumpkins painted by Jan van Eyck'
tgt_text = 'A tree painted by Cezanne'
#text_prompt = TextPrompt(tgt_text, device=device)
nce_loss = NCELoss(tgt_text, temperature=0.01, device=device)

In [None]:
torch.cuda.empty_cache()

torch.manual_seed(12345)

batch_size = 16

n_iter = 2000

opt = Adam(bendingmod_clip.parameters(), 1e-3)

loss_log = []

for i in tqdm(range(n_iter)):
    
    noise_input = torch.randn(batch_size, 
                    bend_generator_clip.latent_dim, 
                    device=device)
    
    out = bend_generator_clip(noise_input)
    out = out.clamp_(0., 1.)
        

    loss = nce_loss(out)
    loss_log.append(loss.detach().cpu().numpy())

    with torch.no_grad():
        loss.backward()
        opt.step()
        opt.zero_grad()
    
    
plt.plot(range(n_iter), loss_log)   

In [None]:
exampleimgs = [generate_image_from_seed(bend_generator_clip, seed=seeds[i]) for i in range(16)]
image_grid(exampleimgs, 4, 4)

In [None]:
del bendingmod_clip, bend_generator_clip, nce_loss, noise_input#, b_in, b_out
import gc
gc.collect()
torch.cuda.empty_cache()

## Differentiable sorting

In [None]:
# Create new bending module to optimize
# with CLIP loss

numchans = [1024, 1024, 512, 256, 128, 64, 6]
inputsizes = [8, 16, 32, 64, 128, 256, 512]

bending_idx = 3

bendingmod_clip = BendingConvModule(numchans[bending_idx],
                                    act_fn='relu')
bendsorting_clip = BendingDiffSort_XY(numchans[bending_idx],
                                   inputsizes[bending_idx],
                                   perm_rows=True, perm_cols=False)
combined_bendmodule = ConcatenatedModules([bendsorting_clip, bendingmod_clip])

In [None]:
bend_generator_sort = BendedGenerator.from_pretrained("ceyda/butterfly_cropped_uniq1K_512",
                                                 bending_module=combined_bendmodule,
                                                 bending_idx=bending_idx,
                                                 train_bending=True)
bend_generator_sort = bend_generator_sort.to(device)

tgt_text = 'A tree painted by Cezanne'
tgt_text = 'Low-poly rendering of Benjamin Franklin going to Venice'
tgt_text = 'Low-poly rendering of Benjamin Franklin going to Venice'
text_prompt = TextPrompt(tgt_text, device=device)
nce_loss = NCELoss(tgt_text, device=device, temperature=0.01)

In [None]:
torch.cuda.empty_cache()

import random
random.seed(23456)
np.random.seed(54321)
torch.manual_seed(12345)
#torch.use_deterministic_algorithms(True, warn_only=True)

batch_size = 16

n_iter = 1000

div_loss = False
div_weight = 6.
div_loss_clip = False
div_clip_weight = 6.

opt = Adam(combined_bendmodule.parameters(), 1e-4)

loss_log = []

for i in tqdm(range(n_iter)):
    
    noise_input = torch.randn(batch_size, 
                    bend_generator_sort.latent_dim, 
                    device=device)
    
    out, b_in, _ = bend_generator_sort(noise_input, return_inout=True)
    out = out.clamp_(0., 1.)
        
    if div_loss_clip:
        loss, clip_div = text_prompt(out, diversity=True)
        loss += div_clip_weight * clip_div
    else:
        loss = nce_loss(out) #text_prompt(out)
    if div_loss:
        loss += div_weight * compute_diversity_loss(out, b_in)
    
    loss_log.append(loss.detach().cpu().numpy())

    with torch.no_grad():
        loss.backward()
        opt.step()
        opt.zero_grad()
    
    
plt.plot(range(n_iter), loss_log)   

In [None]:
exampleimgs = [generate_image_from_seed(bend_generator_sort, seed=seeds[i]) for i in range(16)]
image_grid(exampleimgs, 4, 4)
ts = tgt_text.replace(" ", "_")
for i, img in enumerate(exampleimgs):
    img.save(f"butterfly_diffsort_clip_{ts}_bendindex_{bending_idx}_{seeds[i]}.pdf")
image_grid(exampleimgs, 4, 4)

In [None]:
exampleimgs = [generate_image_from_seed(bend_generator_sort, seed=seeds[i]) for i in range(16)]
image_grid(exampleimgs, 4, 4)
for i, img in enumerate(exampleimgs):
    img.save(f"butterfly_diffsort_clip_{ts}_bendindex_{bending_idx}_{seeds[i]}.pdf")
image_grid(exampleimgs, 4, 4)

In [None]:
exampleimgs = [generate_image_from_seed(bend_generator_sort, seed=seeds[i]) for i in range(16)]
image_grid(exampleimgs, 4, 4)
for i, img in enumerate(exampleimgs):
    img.save(f"butterfly_diffsort_clip_{ts}_bendindex_{bending_idx}_{seeds[i]}.pdf")
image_grid(exampleimgs, 4, 4)

In [None]:
exampleimgs = [generate_image_from_seed(bend_generator_sort, seed=seeds[i]) for i in range(16)]
image_grid(exampleimgs, 4, 4)

In [None]:
exampleimgs = [generate_image_from_seed(bend_generator_sort, seed=seeds[i]) for i in range(16)]
image_grid(exampleimgs, 4, 4)

In [None]:
exampleimgs = [generate_image_from_seed(bend_generator_sort, seed=seeds[i]) for i in range(16)]
image_grid(exampleimgs, 4, 4)

In [None]:
del combined_bendmodule, bend_generator_sort, text_prompt, nce_loss, noise_input#, b_in, b_out
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
exampleimgs = [generate_image(bend_generator_sort) for _ in range(16)]
image_grid(exampleimgs, 4, 4)

In [None]:
exampleimgs = [generate_image(bend_generator_sort) for _ in range(16)]
image_grid(exampleimgs, 4, 4)

In [None]:
exampleimgs = [generate_image(bend_generator_sort) for _ in range(16)]
image_grid(exampleimgs, 4, 4)

In [None]:
exampleimgs = [generate_image(bend_generator_sort) for _ in range(16)]
image_grid(exampleimgs, 4, 4)

In [None]:
exampleimgs = [generate_image(bend_generator_sort) for _ in range(16)]
image_grid(exampleimgs, 4, 4)

In [None]:
exampleimgs = [generate_image(bend_generator_sort) for _ in range(16)]
image_grid(exampleimgs, 4, 4)

## Control: does diff. sorting do better than random permutation?

In [None]:
# Create new bending module to optimize
# with CLIP loss

numchans = [1024, 1024, 512, 256, 128, 64, 6]
inputsizes = [8, 16, 32, 64, 128, 256, 512]

bending_idx = 1

bendingmod_clip = BendingConvModule(numchans[bending_idx],
                                    act_fn='relu')
perm_h = torch.randperm(inputsizes[bending_idx])

In [None]:
perm_h

In [None]:
bend_generator_sort = BendedGenerator.from_pretrained("ceyda/butterfly_cropped_uniq1K_512",
                                                 bending_module=bendingmod_clip,
                                                 bending_idx=bending_idx,
                                                 train_bending=True)
bend_generator_sort = bend_generator_sort.to(device)

tgt_text = 'Peaches in a greek temple, 8-bit art'
text_prompt = TextPrompt(tgt_text, device=device)
nce_loss = NCELoss(tgt_text, device=device, temperature=0.1)

In [None]:
torch.cuda.empty_cache()

batch_size = 16

n_iter = 1000

div_loss = False
div_weight = 6.
div_loss_clip = False
div_clip_weight = 6.

opt = Adam(bendingmod_clip.parameters(), 1e-4)

loss_log = []

for i in tqdm(range(n_iter)):
    
    noise_input = torch.randn(batch_size, 
                    bend_generator_sort.latent_dim, 
                    device=device)
    
    out = bend_generator_sort(noise_input, perm_h=perm_h)
    out = out.clamp_(0., 1.)
        
    if div_loss_clip:
        loss, clip_div = text_prompt(out, diversity=True)
        loss += div_clip_weight * clip_div
    else:
        loss = nce_loss(out) #text_prompt(out)
    
    loss_log.append(loss.detach().cpu().numpy())

    with torch.no_grad():
        loss.backward()
        opt.step()
        opt.zero_grad()
    
    
plt.plot(range(n_iter), loss_log)   

In [None]:
exampleimgs = [generate_image(bend_generator_sort) for _ in range(16)]
image_grid(exampleimgs, 4, 4)

In [None]:
x, y = torch.meshgrid(torch.arange(32),
                      torch.arange(32),
                      indexing='xy')

In [None]:
plt.matshow(x.numpy())

In [None]:
plt.matshow(y.numpy())

In [None]:
sinx = torch.sin(x)

In [None]:
plt.matshow(sinx.numpy())

In [None]:
siny = torch.cos(y)
plt.matshow(siny.numpy())