<a href="https://colab.research.google.com/github/16A0/til/blob/master/interactive_masking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Latent Composition Interactive Masking

Demonstrates using a masked encoder to investigate image priors in GANs.

Related Colab Notebooks:
- [Interactive Composition Demo](https://colab.research.google.com/drive/1j7Bz9vdVnxzOgokawA39hCJZLTmVDq6_?usp=sharing): Interface to compose multiple images using masked encoder.
- [Finetune and Edit](https://colab.research.google.com/drive/1zpD_UYqiGqjzftYxHQPy4sxOQTWV_QY9?usp=sharing): For real images, finetune the encoder towards a specific image for better reconstruction. Further composition can be done in real time.


Source https://github.com/chail/latent-composition

## Download code, models, and set up

In [None]:
! git clone https://github.com/chail/latent-composition.git

Cloning into 'latent-composition'...
remote: Enumerating objects: 209, done.[K
remote: Counting objects: 100% (209/209), done.[K
remote: Compressing objects: 100% (168/168), done.[K
remote: Total 209 (delta 81), reused 134 (delta 34), pack-reused 0[K
Receiving objects: 100% (209/209), 7.01 MiB | 25.01 MiB/s, done.
Resolving deltas: 100% (81/81), done.


In [None]:
import os
os.chdir('latent-composition')

In [None]:
# download pretrained models and unzip them (this may take a few minutes)
print("Downloading pretrained models...")
! gdown --id 1vSEH2XMIG1XzQl3JLZwUKm_kdomelUqm # try download from gdrive first
if not os.path.isfile('pretrained_models.zip'):
  # try download from csail (will be slower than gdrive)
  print("Gdrive download failed, trying backup copy...")
  ! wget http://people.csail.mit.edu/lrchai/projects/latent-composition/pretrained_models.zip

assert(os.path.isfile('pretrained_models.zip')),"pretrained_models.zip not found!"
print("Finished downloading. Unpacking models...")
! unzip pretrained_models.zip
print("Done!")

Downloading pretrained models...
Downloading...
From: https://drive.google.com/uc?id=1vSEH2XMIG1XzQl3JLZwUKm_kdomelUqm
To: /content/latent-composition/pretrained_models.zip
3.20GB [00:49, 65.0MB/s]
Finished downloading. Unpacking models...
Archive:  pretrained_models.zip
   creating: pretrained_models/
   creating: pretrained_models/sgan_encoders/
   creating: pretrained_models/sgan_encoders/horse_RGBM/
  inflating: pretrained_models/sgan_encoders/horse_RGBM/optE.yml  
  inflating: pretrained_models/sgan_encoders/horse_RGBM/model_initial.pth.tar  
  inflating: pretrained_models/sgan_encoders/horse_RGBM/model_final.pth  
   creating: pretrained_models/sgan_encoders/car_RGBM/
  inflating: pretrained_models/sgan_encoders/car_RGBM/model_final.pth  
  inflating: pretrained_models/sgan_encoders/car_RGBM/model_initial.pth.tar  
  inflating: pretrained_models/sgan_encoders/car_RGBM/optE.yml  
   creating: pretrained_models/sgan_encoders/ffhq_reals_RGBM/
  inflating: pretrained_models/sgan_enco

In [None]:
# required for stylegan models
! pip install ninja

Collecting ninja
[?25l  Downloading https://files.pythonhosted.org/packages/1d/de/393468f2a37fc2c1dc3a06afc37775e27fde2d16845424141d4da62c686d/ninja-1.10.0.post2-py3-none-manylinux1_x86_64.whl (107kB)
[K     |███                             | 10kB 17.5MB/s eta 0:00:01[K     |██████                          | 20kB 24.9MB/s eta 0:00:01[K     |█████████▏                      | 30kB 27.3MB/s eta 0:00:01[K     |████████████▏                   | 40kB 18.7MB/s eta 0:00:01[K     |███████████████▎                | 51kB 10.7MB/s eta 0:00:01[K     |██████████████████▎             | 61kB 12.3MB/s eta 0:00:01[K     |█████████████████████▍          | 71kB 10.9MB/s eta 0:00:01[K     |████████████████████████▍       | 81kB 10.8MB/s eta 0:00:01[K     |███████████████████████████▍    | 92kB 11.6MB/s eta 0:00:01[K     |██████████████████████████████▌ | 102kB 10.2MB/s eta 0:00:01[K     |████████████████████████████████| 112kB 10.2MB/s 
[?25hInstalling collected packages: ninja
Succ

In [None]:
import torch
import numpy as np
from utils import show, renormalize, pbar
from utils import util, paintwidget, labwidget, imutil
from networks import networks
from PIL import Image
import os
from torchvision import transforms
import time

In [None]:
 os.environ['TORCH_EXTENSIONS_DIR'] = '/tmp/torch_cpp/' # needed for stylegan to run

In [None]:
assert(torch.cuda.is_available()) # check cuda is available

# Load Networks




In [None]:
#@title Select a model from the dropdown menu

dropdown = 'stylegan car' #@param ["stylegan car", "stylegan church", "stylegan horse", "stylegan ffhq", "proggan celebahq", "proggan church", "proggan livingroom"]

model_type, domain = dropdown.split()
nets = networks.define_nets(model_type, domain)
outdim = nets.setting['outdim']

Using halfsize?: False
Input channels: 4
Using default checkpoint path: pretrained_models/sgan_encoders/car_RGBM/model_final.pth


# Sample an image, and reencode it

In [None]:
#@title Select the input image.

#@markdown Select a random seed to generate an image.
random_seed = 32 #@param {type:"slider", min:0, max:100, step:1}

#@markdown Uncheck this box if you want to use a real image instead.
use_g_sample = True #@param {type:"boolean"}


if use_g_sample:
    # use a gan image as source
    with torch.no_grad():
        source_z = nets.sample_zs(1, seed=random_seed)
        source_im = nets.zs2image(source_z)
    show(['Source Image', renormalize.as_image(source_im[0]).resize((256, 256), Image.LANCZOS)])
else:
    # use a real image as source 
    if domain != "car":
      print("!! WARNING !!: The default image is a car, please use the stylegan_car model or change im_path.")
    im_path = 'img/car0.png' # 'img/car1.png'
    transform = transforms.Compose([
                    transforms.Resize(outdim),
                    transforms.CenterCrop(outdim),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])    
    source_im = transform(Image.open(im_path))[None].cuda()
    show(['Source Image', renormalize.as_image(source_im[0]).resize((256, 256), Image.LANCZOS)])

In [None]:
with torch.no_grad():
    out = nets.invert(source_im)
    show(['Inverted Image', renormalize.as_image(out[0]).resize((256, 256), Image.LANCZOS)])

# Visualize network priors
Drag your mouse on the left panel, and the GAN reconstruction will show in the right panel

In [None]:
src_painter = paintwidget.PaintWidget(oneshot=False, width=256, height=256, 
                                      brushsize=20, save_sequence=False, track_move=True) # , on_move=True)
src_painter.image = renormalize.as_url(source_im[0], size=256)

img_url = renormalize.as_url(torch.zeros(3, 256, 256))
img_html = '<img src="%s"/>'%img_url
output_div = labwidget.Div(img_html)

counter = 0
prev_time = time.time()
update_freq = 0.5 # mouse time intervals;  
# decrease update_freq to reduce lagging, but colab is kind of slow
mask_list = []
reconstruction_list = []

def probe_changed(c):
    global counter
    global prev_time
    counter += 1
    curr_time = time.time()
    if curr_time - prev_time < update_freq:
        return
    prev_time = time.time()
    
    mask_url = src_painter.mask_buffer
    mask =  renormalize.from_url(mask_url, target='pt', size=(outdim, outdim)).cuda()[None] # 1x3xHxW
    with torch.no_grad():
        mask = mask[:, [0], :, :] # 1x1xHxW
        mask_list.append(mask.cpu())
        masked_im = source_im * mask
        regenerated_mask = nets.invert(masked_im, mask)
    img_url = renormalize.as_url(regenerated_mask[0], size=256)
    img_html = '<img src="%s"/>'%img_url
    output_div.innerHTML = img_html
    reconstruction_list.append(renormalize.as_image(regenerated_mask[0]))
    
src_painter.on('mask_buffer', probe_changed)

show.a([src_painter], cols=2)
show.a([output_div], cols=2)

show.flush()

In [None]:
show.a(['Masked Input', renormalize.as_image((mask_list[-1] * source_im.cpu())[0]).resize((256, 256), Image.ANTIALIAS)])
show.a(['Reconstruction', reconstruction_list[-1].resize((256,256), Image.ANTIALIAS)])
show.flush()