Skip to content

Commit

Permalink
make colab happy
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeqiang-Lai committed Jun 25, 2023
1 parent 93a22d3 commit e053524
Showing 1 changed file with 319 additions and 0 deletions.
319 changes: 319 additions & 0 deletions gradio_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
import os
import gradio as gr
import torch
import numpy as np
import imageio
from PIL import Image
import uuid

from draggan import utils
from draggan.draggan import drag_gan
from draggan import draggan as draggan

device = 'cuda'


SIZE_TO_CLICK_SIZE = {
1024: 8,
512: 5,
256: 2
}

CKPT_SIZE = {
'stylegan2/stylegan2-ffhq-config-f.pkl': 1024,
'stylegan2/stylegan2-cat-config-f.pkl': 256,
'stylegan2/stylegan2-church-config-f.pkl': 256,
'stylegan2/stylegan2-horse-config-f.pkl': 256,
'ada/ffhq.pkl': 1024,
'ada/afhqcat.pkl': 512,
'ada/afhqdog.pkl': 512,
'ada/afhqwild.pkl': 512,
'ada/brecahad.pkl': 512,
'ada/metfaces.pkl': 512,
'human/stylegan_human_v2_512.pkl': 512,
'human/stylegan_human_v2_1024.pkl': 1024,
'self_distill/bicycles_256_pytorch.pkl': 256,
'self_distill/dogs_1024_pytorch.pkl': 1024,
'self_distill/elephants_512_pytorch.pkl': 512,
'self_distill/giraffes_512_pytorch.pkl': 512,
'self_distill/horses_256_pytorch.pkl': 256,
'self_distill/lions_512_pytorch.pkl': 512,
'self_distill/parrots_512_pytorch.pkl': 512,
}

DEFAULT_CKPT = 'ada/afhqcat.pkl'


def to_image(tensor):
tensor = tensor.squeeze(0).permute(1, 2, 0)
arr = tensor.detach().cpu().numpy()
arr = (arr - arr.min()) / (arr.max() - arr.min())
arr = arr * 255
return arr.astype('uint8')


def add_points_to_image(image, points, size=5):
image = utils.draw_handle_target_points(image, points['handle'], points['target'], size)
return image


def on_click(image, target_point, points, size, evt: gr.SelectData):
if target_point:
points['target'].append([evt.index[1], evt.index[0]])
image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
return image, not target_point
points['handle'].append([evt.index[1], evt.index[0]])
image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
return image, not target_point


def on_drag(model, points, max_iters, state, size, mask, lr_box):
if len(points['handle']) == 0:
raise gr.Error('You must select at least one handle point and target point.')
if len(points['handle']) != len(points['target']):
raise gr.Error('You have uncompleted handle points, try to selct a target point or undo the handle point.')
max_iters = int(max_iters)
W = state['W']

handle_points = [torch.tensor(p, device=device).float() for p in points['handle']]
target_points = [torch.tensor(p, device=device).float() for p in points['target']]

if mask.get('mask') is not None:
mask = Image.fromarray(mask['mask']).convert('L')
mask = np.array(mask) == 255

mask = torch.from_numpy(mask).float().to(device)
mask = mask.unsqueeze(0).unsqueeze(0)
else:
mask = None

step = 0
for image, W, handle_points in drag_gan(W, model['G'],
handle_points, target_points, mask,
max_iters=max_iters, lr=lr_box):
points['handle'] = [p.cpu().numpy().astype('int') for p in handle_points]
image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])

state['history'].append(image)
step += 1
yield image, state, step


def on_reset(points, image, state):
return {'target': [], 'handle': []}, state['img'], False


def on_undo(points, image, state, size):
image = state['img']

if len(points['target']) < len(points['handle']):
points['handle'] = points['handle'][:-1]
else:
points['handle'] = points['handle'][:-1]
points['target'] = points['target'][:-1]

image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
return points, image, False


def on_change_model(selected, model):
size = CKPT_SIZE[selected]

G = draggan.load_model(utils.get_path(selected), device=device)
model = {'G': G}
W = draggan.generate_W(
G,
seed=int(1),
device=device,
truncation_psi=0.8,
truncation_cutoff=8,
)
img, _ = draggan.generate_image(W, G, device=device)

state = {
'W': W,
'img': img,
'history': []
}

return model, state, img, img, size


def on_new_image(model, seed):
G = model['G']
W = draggan.generate_W(
G,
seed=int(seed),
device=device,
truncation_psi=0.8,
truncation_cutoff=8,
)
img, _ = draggan.generate_image(W, G, device=device)

state = {
'W': W,
'img': img,
'history': []
}

points = {'target': [], 'handle': []}
target_point = False
return img, img, state, points, target_point


def on_max_iter_change(max_iters):
return gr.update(maximum=max_iters)


def on_save_files(image, state):
os.makedirs('draggan_tmp', exist_ok=True)
image_name = f'draggan_tmp/image_{uuid.uuid4()}.png'
video_name = f'draggan_tmp/video_{uuid.uuid4()}.mp4'
imageio.imsave(image_name, image)
imageio.mimsave(video_name, state['history'])
return [image_name, video_name]


def on_show_save():
return gr.update(visible=True)


def on_image_change(model, image_size, image):
image = Image.fromarray(image)
result = inverse_image(
model.g_ema,
image,
image_size=image_size
)
result['history'] = []
image = to_image(result['sample'])
points = {'target': [], 'handle': []}
target_point = False
return image, image, result, points, target_point


def on_mask_change(mask):
return mask['image']


def on_select_mask_tab(state):
img = to_image(state['sample'])
return img


def main():
torch.cuda.manual_seed(25)

with gr.Blocks() as demo:
gr.Markdown(
"""
# DragGAN
Unofficial implementation of [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/)
[Our Implementation](https://github.com/Zeqiang-Lai/DragGAN) | [Official Implementation](https://github.com/XingangPan/DragGAN) (Not released yet)
## Tutorial
1. (Opklional) Draw a mask indicate the movable region.
2. Setup a least one pair of handle point and target point.
3. Click "Drag it".
## Hints
- Handle points (Blue): the point you want to drag.
- Target points (Red): the destination you want to drag towards to.
## Primary Support of Custom Image.
- We now support dragging user uploaded image by GAN inversion.
- **Please upload your image at `Setup Handle Points` pannel.** Upload it from `Draw a Mask` would cause errors for now.
- Due to the limitation of GAN inversion,
- You might wait roughly 1 minute to see the GAN version of the uploaded image.
- The shown image might be slightly difference from the uploaded one.
- It could also fail to invert the uploaded image and generate very poor results.
- Idealy, you should choose the closest model of the uploaded image. For example, choose `stylegan2-ffhq-config-f.pkl` for human face. `stylegan2-cat-config-f.pkl` for cat.
> Please fire an issue if you have encounted any problem. Also don't forgot to give a star to the [Official Repo](https://github.com/XingangPan/DragGAN), [our project](https://github.com/Zeqiang-Lai/DragGAN) could not exist without it.
""",
)
G = draggan.load_model(utils.get_path(DEFAULT_CKPT), device=device)
model = gr.State({'G': G})
W = draggan.generate_W(
G,
seed=int(1),
device=device,
truncation_psi=0.8,
truncation_cutoff=8,
)
img, F0 = draggan.generate_image(W, G, device=device)

state = gr.State({
'W': W,
'img': img,
'history': []
})
points = gr.State({'target': [], 'handle': []})
size = gr.State(CKPT_SIZE[DEFAULT_CKPT])
target_point = gr.State(False)

with gr.Row():
with gr.Column(scale=0.3):
with gr.Accordion("Model"):
model_dropdown = gr.Dropdown(choices=list(CKPT_SIZE.keys()), value=DEFAULT_CKPT,
label='StyleGAN2 model')
seed = gr.Number(value=1, label='Seed', precision=0)
new_btn = gr.Button('New Image')
with gr.Accordion('Drag'):
with gr.Row():
lr_box = gr.Number(value=2e-3, label='Learning Rate')
max_iters = gr.Slider(1, 500, 20, step=1, label='Max Iterations')

with gr.Row():
with gr.Column(min_width=100):
reset_btn = gr.Button('Reset All')
with gr.Column(min_width=100):
undo_btn = gr.Button('Undo Last')
with gr.Row():
btn = gr.Button('Drag it', variant='primary')

with gr.Accordion('Save', visible=False) as save_panel:
files = gr.Files(value=[])

progress = gr.Slider(value=0, maximum=20, label='Progress', interactive=False)

with gr.Column():
with gr.Tabs():
with gr.Tab('Setup Handle Points', id='input'):
image = gr.Image(img).style(height=512, width=512)
with gr.Tab('Draw a Mask', id='mask') as masktab:
mask = gr.ImageMask(img, label='Mask').style(height=512, width=512)

image.select(on_click, [image, target_point, points, size], [image, target_point])
image.upload(on_image_change, [model, size, image], [image, mask, state, points, target_point])
mask.upload(on_mask_change, [mask], [image])
btn.click(on_drag, inputs=[model, points, max_iters, state, size, mask, lr_box], outputs=[image, state, progress]).then(
on_show_save, outputs=save_panel).then(
on_save_files, inputs=[image, state], outputs=[files]
)
reset_btn.click(on_reset, inputs=[points, image, state], outputs=[points, image, target_point])
undo_btn.click(on_undo, inputs=[points, image, state, size], outputs=[points, image, target_point])
model_dropdown.change(on_change_model, inputs=[model_dropdown, model], outputs=[model, state, image, mask, size])
new_btn.click(on_new_image, inputs=[model, seed], outputs=[image, mask, state, points, target_point])
max_iters.change(on_max_iter_change, inputs=max_iters, outputs=progress)
masktab.select(lambda: gr.update(value=None), outputs=[mask]).then(on_select_mask_tab, inputs=[state], outputs=[mask])
return demo


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--device', default='cuda')
parser.add_argument('--share', action='store_true')
parser.add_argument('-p', '--port', default=None)
parser.add_argument('--ip', default=None)
args = parser.parse_args()
device = args.device
demo = main()
print('Successfully loaded, starting gradio demo')
demo.queue(concurrency_count=1, max_size=20).launch(share=args.share, server_name=args.ip, server_port=args.port)

0 comments on commit e053524

Please sign in to comment.