# Localized Semantic Editing of StyleGAN outputs

Introduced in the paper:<br>
> Edo Collins, Raja Bala, Bob Price and Sabine Süsstrunk. _Editing in Style: Uncovering the Local Semantics of GANs_.  IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2020.

This demo illustrates a simple and effective method for making local, semantically-aware edits to a _target_ GAN output image. This is accomplished by borrowing styles from a _source_ image, also a GAN output.

The method requires neither supervision from an external model, nor involves complex spatial morphing operations. Instead, it relies on the emergent disentanglement of semantic objects that is learned by StyleGAN during its training, which we detect using Spherical _k_-means.

The implementation below relies on PyTorch and requires downloading additional parameter files found here: https://drive.google.com/open?id=1GYzEzOCaI8FUS6JHdt6g9UfNTmpO08Tt

In [1]:
%load_ext autoreload
%autoreload 2
import torch
from stylegan2 import Generator                            # StyleGAN model
from stylegan2_output import GANOutputs              # Data structure to hold GAN outputs
import ptutils                                      # Helper tensor functions
import visutils                                     # Visualization functions
from style2_interpolator import StyleInterpolator    # The 'sequential' style-interpolator (Eq. 5)
import cielab                                       # Helper functions for CIELAB color-space

torch.cuda.set_device(1)

Load the appropriate StyleGAN model

In [18]:
dataset_name = 'ffhq'
config = 'E'
channel = 1 if config == 'E' else 2
root_dir = '../karras_ckpt' # See comment above regarding additional files
if dataset_name == 'cat':
    truncation = 0.5
    size = 256
elif dataset_name == 'ffhq':
    truncation = 0.7
    size = 1024


In [19]:
G = Generator(size, 512, 8, channel_multiplier=channel)
G.load_state_dict(torch.load('{}/{}{}.pt'.format(root_dir, dataset_name, config))['g_ema'])
G.eval()
G = G.cuda()

In [20]:
state_dict = torch.load('{}/{}{}.pt'.format(root_dir, dataset_name, config))['g_ema']

style.1.weight
style.1.bias
style.2.weight
style.2.bias
style.3.weight
style.3.bias
style.4.weight
style.4.bias
style.5.weight
style.5.bias
style.6.weight
style.6.bias
style.7.weight
style.7.bias
style.8.weight
style.8.bias
input.input
conv1.conv.weight
conv1.conv.modulation.weight
conv1.conv.modulation.bias
conv1.noise.weight
conv1.activate.bias
to_rgb1.bias
to_rgb1.conv.weight
to_rgb1.conv.modulation.weight
to_rgb1.conv.modulation.bias
convs.0.conv.weight
convs.0.conv.blur.kernel
convs.0.conv.modulation.weight
convs.0.conv.modulation.bias
convs.0.noise.weight
convs.0.activate.bias
convs.1.conv.weight
convs.1.conv.modulation.weight
convs.1.conv.modulation.bias
convs.1.noise.weight
convs.1.activate.bias
convs.2.conv.weight
convs.2.conv.blur.kernel
convs.2.conv.modulation.weight
convs.2.conv.modulation.bias
convs.2.noise.weight
convs.2.activate.bias
convs.3.conv.weight
convs.3.conv.modulation.weight
convs.3.conv.modulation.bias
convs.3.noise.weight
convs.3.activate.bias
convs.4.conv.wei

Load the pre-computed spherical k-means clusters, and provide them to the style interpolator

In [31]:
import pickle 
catalog = pickle.load(open('catalog_ffhqE.pkl', 'rb'))#.format(dataset_name)) # See comment above regarding additional files
# catalog = torch.load('catalogs/stylegan1_FFHQ.pkl')
si_wf = StyleInterpolator(catalog, bias=False)
si_wf._catalog_labels

AttributeError: 'FactorCatalog' object has no attribute 'M'

In [30]:
print(len(catalog.__dict__['M_hoyer']))
print(len(catalog.__dict__['M']))
print(catalog.__dict__['M_hoyer'][0].shape)
print(catalog.__dict__['M'][0].size())

18
18
(512,)
torch.Size([8, 512])


In [8]:
catalog._factorization
# catalog.annotations

MiniBatchSphericalKMeans(batch_size=200, compute_labels=True, init='k-means++',
                         init_size=None, max_iter=100, max_no_improvement=10,
                         n_clusters=16, n_init=3, random_state=10,
                         reassignment_ratio=0.01, tol=0.0, verbose=0)

Generate some examples

In [23]:
if dataset_name == 'ffhq':
    gs = GANOutputs.from_seed(5, 2001)
elif dataset_name == 'cat':
    gs = GANOutputs.from_seed((0,33,3,19,34), 6813)
batch = 5

In [24]:
gs.z.size()

torch.Size([5, 512])

In [25]:
truncation_mean=4096
with torch.no_grad():
    mean_latent = G.mean_latent(truncation_mean)
    rgb, gs.ys, _ = G(gs.z.cuda(), return_latents=True)#, truncation=1, truncation_latent=mean_latent)
    rgb = (rgb.clamp(-1, 1) + 1) / 2
    rgb = rgb.cpu()
    gs.rgb = ptutils.MultiResolutionStore(rgb)
    
    
    gs1 = gs[:1]
    gs2 = gs[1:]

res=256
i, n = 0,4
print(gs1.rgb.get(res)[i:i+n].size())
visutils.show(gs1.rgb.get(res)[i:i+n].permute(0,2,3,1).cpu(), title='Target')
visutils.show(gs2.rgb.get(res)[i:i+n].permute(0,2,3,1).cpu(), title='References')

AttributeError: 'GANOutputs' object has no attribute 'rgb'

Transfer object styles from refernces to target 

In [11]:
part_gs = {}
print(len(gs.ys))
def get_epsilons(epsilon, low_res_epsilon=0):
    epsilons = [epsilon]*len(gs.ys)
    for i in range(4): epsilons[i] = low_res_epsilon
    return epsilons

if dataset_name == 'ffhq':
    parts_thresholds = {
        'eyes': (0.1, get_epsilons(50, 5)),
        'nose': (0.1, get_epsilons(30, 5)),
        'mouth': (0.1, get_epsilons(50, 5)),
    }

elif dataset_name == 'bedrooms':
    parts_thresholds = {
        'bed': (0.01, get_epsilons(120)),
        'pillow': (0.05, get_epsilons(100)),
        'window': (0.05, get_epsilons(100)),
    }

for label, (rho, epsilon) in parts_thresholds.items():
        key = (label)
        part_gs[key]  = GANOutputs()
#         print(gs1.ys[0].size(), gs2.ys[0].size())
        part_gs[key].ys = si_wf.interpolate_ys(gs1.ys, gs2.ys, label, rho, epsilon)
        with torch.no_grad():
                rgb = G.ys_to_rgb(part_gs[key].ys)
                rgb = (rgb.clamp(-1, 1) + 1) / 2
                rgb = rgb.cpu()
                part_gs[key].rgb = ptutils.MultiResolutionStore(rgb)

26
torch.Size([1, 1, 512, 1, 1]) torch.Size([4, 1, 512, 1, 1]) torch.Size([512])
torch.Size([1, 1, 512, 1, 1]) torch.Size([4, 1, 512, 1, 1]) torch.Size([512])
torch.Size([1, 1, 512, 1, 1]) torch.Size([4, 1, 512, 1, 1]) torch.Size([512])
torch.Size([1, 1, 512, 1, 1]) torch.Size([4, 1, 512, 1, 1]) torch.Size([512])
torch.Size([1, 1, 512, 1, 1]) torch.Size([4, 1, 512, 1, 1]) torch.Size([512])
torch.Size([1, 1, 512, 1, 1]) torch.Size([4, 1, 512, 1, 1]) torch.Size([512])
torch.Size([1, 1, 512, 1, 1]) torch.Size([4, 1, 512, 1, 1]) torch.Size([512])
torch.Size([1, 1, 512, 1, 1]) torch.Size([4, 1, 512, 1, 1]) torch.Size([512])
torch.Size([1, 1, 512, 1, 1]) torch.Size([4, 1, 512, 1, 1]) torch.Size([256])


RuntimeError: The size of tensor a (256) must match the size of tensor b (512) at non-singleton dimension 1

View the results

In [None]:
res = 256
visutils.part_grid(gs1.rgb.get(res), gs2.rgb.get(res), {k: v.rgb.get(res) for k,v in part_gs.items()});

View the MSE in CIELAB color-space, between the edited output and the target image

In [None]:
res = 256
normalize = lambda x: x/x.max()
visutils.part_grid(gs1.rgb.get(res), gs2.rgb.get(res),
                 {k: normalize(cielab.squared_error(v.rgb.get(res), gs1.rgb.get(res))) for k,v in part_gs.items()});