Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Project real photo,there are some strange circle in image. #35

Open
hgh086 opened this issue Oct 19, 2021 · 1 comment
Open

Project real photo,there are some strange circle in image. #35

hgh086 opened this issue Oct 19, 2021 · 1 comment

Comments

@hgh086
Copy link

hgh086 commented Oct 19, 2021

Project real photo,there are some strange circle in image. like below:
proj
proj2
What cause this? How to avoid?
Thanks

@hgh086
Copy link
Author

hgh086 commented Oct 19, 2021

Code is here.

"""Project given image to the latent space of pretrained network pickle."""

import copy
import os
from time import perf_counter

import click
import imageio
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F

import dnnlib
import legacy
from metrics import metric_utils

def project(
G,
target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
*,
num_steps = 1000,
w_avg_samples = 10000,
initial_learning_rate = 0.1, # org 0.1
initial_noise_factor = 0.05,
lr_rampdown_length = 0.25,
lr_rampup_length = 0.05,
noise_ramp_length = 0.75,
regularize_noise_weight = 1e5,
verbose = False,
device: torch.device,
label:torch.Tensor,
noise_mode = "const"
):
assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)

def logprint(*args):
    if verbose:
        print(*args)

G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore

# Compute w stats.
logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None)  # [N, L, C]
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)       # [N, 1, C]
w_avg = np.mean(w_samples, axis=0, keepdims=True)      # [1, 1, C]
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5

# Load VGG16 feature detector.
#url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
#with dnnlib.util.open_url(url) as f:
#    vgg16 = torch.jit.load(f).eval().to(device)
vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
vgg16 = metric_utils.get_feature_detector(vgg16_url, device=device)

# Features for target image.
target_images = target.unsqueeze(0).to(device).to(torch.float32)
if target_images.shape[2] > 256:
    target_images = F.interpolate(target_images, size=(256, 256), mode='area')
target_features = vgg16(target_images, resize_images=False, return_lpips=True)

w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
optimizer = torch.optim.Adam([w_opt], betas=(0.9, 0.999), lr=initial_learning_rate)

for step in range(num_steps):
    # Learning rate schedule.
    t = step / num_steps
    lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
    lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
    lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
    lr = initial_learning_rate * lr_ramp
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # Synth images from opt_w.
    synth_images = G(w_opt[0], label, noise_mode=noise_mode)

    # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
    synth_images = (synth_images + 1) * (255/2)
    if synth_images.shape[2] > 256:
        synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')

    # Features for synth images.
    synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
    dist = (target_features - synth_features).square().sum()

    loss = dist

    # Step
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')

    # Save projected W for each optimization step.
    w_out[step] = w_opt.detach()[0]

return w_out

#----------------------------------------------------------------------------

@click.command()
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--target', 'target_fname', help='Target image file to project to', required=True, metavar='FILE')
@click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True)
@click.option('--seed', help='Random seed', type=int, default=303, show_default=True)
@click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True)
@click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR')
def run_projection(
network_pkl: str,
target_fname: str,
outdir: str,
save_video: bool,
seed: int,
num_steps: int
):
"""Project given image to the latent space of pretrained network pickle.

Examples:

\b
python projector.py --outdir=out --target=~/mytargetimg.png \\
    --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
"""
np.random.seed(seed)
torch.manual_seed(seed)

# Load networks.
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as fp:
    G = legacy.load_network_pkl(fp)['G_ema'].to(device) # type: ignore

# Load target image.
target_pil = PIL.Image.open(target_fname).convert('RGB')
w, h = target_pil.size
s = min(w, h)
target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
target_uint8 = np.array(target_pil, dtype=np.uint8)

# label
label = torch.zeros([1, G.c_dim], device=device)

# noise_mode: "random" "none" "const"
noise_mode = "const"

# Optimize projection.
start_time = perf_counter()
projected_w_steps = project(
    G,
    target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
    num_steps=num_steps,
    device=device,
    verbose=True,
    label = label,
    noise_mode = noise_mode
)
print (f'Elapsed: {(perf_counter()-start_time):.1f} s')

# Render debug output: optional video and projected image and W vector.
os.makedirs(outdir, exist_ok=True)
if save_video:
    video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
    print (f'Saving optimization progress video "{outdir}/proj.mp4"')
    for projected_w in projected_w_steps:
        synth_image = G(projected_w, label, truncation_psi=1, noise_mode=noise_mode)
        synth_image = (synth_image + 1) * (255/2)
        synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
        video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
    video.close()

# Save final projected frame and W vector.
target_pil.save(f'{outdir}/target.png')
projected_w = projected_w_steps[-1]
synth_image = G(projected_w, label, noise_mode=noise_mode)
synth_image = (synth_image + 1) * (255/2)
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())

#----------------------------------------------------------------------------

if name == "main":
run_projection() # pylint: disable=no-value-for-parameter

#----------------------------------------------------------------------------

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant