<a href="https://colab.research.google.com/github/Andre6o6/stylegan-editing/blob/master/StyleGAN_encode_images.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!git clone https://github.com/genforce/interfacegan.git

Cloning into 'interfacegan'...
remote: Enumerating objects: 223, done.[K
Receiving objects:   0% (1/223)   Receiving objects:   1% (3/223)   Receiving objects:   2% (5/223)   Receiving objects:   3% (7/223)   Receiving objects:   4% (9/223)   Receiving objects:   5% (12/223)   Receiving objects:   6% (14/223)   Receiving objects:   7% (16/223)   Receiving objects:   8% (18/223)   Receiving objects:   9% (21/223)   Receiving objects:  10% (23/223)   Receiving objects:  11% (25/223)   Receiving objects:  12% (27/223)   Receiving objects:  13% (29/223)   Receiving objects:  14% (32/223)   Receiving objects:  15% (34/223)   Receiving objects:  16% (36/223)   Receiving objects:  17% (38/223)   Receiving objects:  18% (41/223)   Receiving objects:  19% (43/223)   Receiving objects:  20% (45/223)   Receiving objects:  21% (47/223)   Receiving objects:  22% (50/223)   Receiving objects:  23% (52/223)   Receiving objects:  24% (54/223), 92.01 KiB | 168.00 KiB/s   Recei

In [0]:
!gdown https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ
#!mv x.pkl interfacegan/models/pretrain/

Permission denied: https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ
Maybe you need to change permission over 'Anyone with the link'?


In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [0]:
import numpy as np
import cv2
from PIL import Image

def load_image(path):
    image = np.asarray(Image.open(filename))
    image = np.transpose(image, (2,0,1))  #WxHxC to CxWxH
    return image

def save_image(image, save_path):
    image = np.transpose(image, (1,2,0)).astype(np.uint8)
    image = Image.fromarray(image)
    image.save(save_path)

In [0]:
from torchvision.models import vgg16

def denormalize(synthesized_image, min_value=-1, max_value=1):
    #Cast from [-1, 1] to [0, 255]; gradients should be ok ###(?)###
    synthesized_image = 255. * (synthesized_image - min_value) / (max_value - min_value)
    synthesized_image = torch.clamp(synthesized_image + 0.5, min=0, max=255)
    return synthesized_image

class VGGFeatureExtractor(nn.Module):
    def __init__(self, vgg_layer=12):
        super().__init__()
        self.image_size = 256
        self.mean = torch.tensor([0.485, 0.456, 0.406]).to(device).view(-1, 1, 1)
        self.std = torch.tensor([0.229, 0.224, 0.225]).to(device).view(-1, 1, 1)

        self.vgg16 = vgg16(pretrained=True).features[:vgg_layer].to(device).eval()

    def forward(self, image):
        image = image / 255.
        image = F.adaptive_avg_pool2d(image, self.image_size)
        image = (image - self.mean) / self.std
        features = self.vgg16(image)
        return features

class LatentOptimizer(nn.Module):
    def __init__(self, synthesizer, vgg_layer=9):
        super().__init__()
        self.synthesizer = synthesizer.to(device).eval()
        self.feature_extractor = VGGFeatureExtractor(vgg_layer)

    def forward(self, dlatents):
        generated_image = self.synthesizer(dlatents)
        generated_image = denormalize(generated_image)
        features = self.feature_extractor(generated_image)
        return features, generated_image

In [0]:
class LatentLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l1_loss = nn.L1Loss()
        self.log_cosh_loss = LogCoshLoss()
        self.l2_loss = nn.MSELoss()

        self.vgg_loss_coef = 0.4
        self.pixel_loss_coef = 1.5
        self.l1_penalty = 0.3
    
    def forward(
        self, 
        real_features, generated_features,
        real_image=None, generated_image=None, 
        average_dlatents=None, dlatents=None,
    ):           
        loss = 0
        # L1 loss on VGG16 features
        if self.vgg_loss_coef != 0:
            loss += self.vgg_loss_coef * self.l2_loss(real_features, generated_features)

        # + logcosh loss on image pixels
        if real_image is not None and generated_image is not None:
            loss += self.pixel_loss_coef * self.log_cosh_loss(real_image, generated_image)

        # Dlatent Loss - Forces latents to stay near the space the model uses for faces.
        if average_dlatents is not None and dlatents is not None:
            loss += self.l1_penalty * 512 * self.l1_loss(average_dlatents, dlatents)

        return loss

class LogCoshLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, true, pred):
        loss = true - pred
        return torch.mean(torch.log(torch.cosh(loss + 1e-12)))

In [0]:
dlatent_path = "latents/"
latent_predictor_path = ""
image_path = ""
predict_initial_approximation = False

In [0]:
from interfacegan.models.stylegan_generator import StyleGANGenerator

In [0]:
synthesizer = StyleGANGenerator("stylegan_ffhq").model.synthesis
latent_optimizer = LatentOptimizer(synthesizer, vgg_layer=9)

# This shouldn't be needed if I don't pass them to optimizer(?)
for param in latent_optimizer.parameters():
    param.requires_grad_(False)

reference_image = load_image(image_path)
reference_image = torch.from_numpy(reference_image).unsqueeze(0).to(device)

reference_features = latent_optimizer.feature_extractor(reference_image).detach()
reference_image = reference_image.detach()

In [0]:
if predict_initial_approximation:
    image_to_latent = InitialLatentPredictor().to(device)
    image_to_latent.load_state_dict(torch.load(latent_predictor_path))
    image_to_latent.eval()

    with torch.no_grad():
        initial_latents = image_to_latent(reference_image)
    initial_latents = initial_latents.to(device).requires_grad_(True)
else:
    initial_latents = torch.zeros((1,18,512)).to(device).requires_grad_(True)

In [0]:
criterion = LatentLoss()
optimizer = torch.optim.Adam([initial_latents], lr=0.025)

n_iters = 100
progress_bar = tqdm(range(n_iters))
for step in progress_bar:
    optimizer.zero_grad()

    generated_image_features, _ = latent_optimizer(initial_latents)
    
    loss = criterion(generated_image_features, reference_features)
    loss.backward()
    optimizer.step()
    progress_bar.set_description("{}/{}: Loss = {}".format(step+1, n_iters, loss.item()))

optimized_dlatents = initial_latents.detach().cpu().numpy()
np.save(dlatent_path, optimized_dlatents)

In [0]:
#avg_dlatents = StyleGANGenerator("stylegan_ffhq").model.truncation.w_avg



# Initial approximation prediction

In [0]:
import torch
from torchvision.models import resnet50
from PIL import Image
import numpy as np

class InitialLatentPredictor(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.activation = torch.nn.ELU()

        # 3, 256, 256 ->
        self.resnet = list(resnet50(pretrained=True).children())[:-2]
        self.resnet = torch.nn.Sequential(*self.resnet)
        # -> 2048, 8, 8
        self.conv2d = torch.nn.Conv2d(2048, 256, kernel_size=1)
        self.flatten = torch.nn.Flatten()
        self.dense1 = torch.nn.Linear(256*8*8, 512)
        self.dense2 = torch.nn.Linear(512, (18 * 512))

    def forward(self, image):
        x = self.resnet(image)
        x = self.conv2d(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = x.view((-1, 18, 512))
        return x

class ImageLatentDataset(torch.utils.data.Dataset):
    def __init__(self, filenames, dlatents, transforms = None):
        self.filenames = filenames
        self.dlatents = dlatents
        self.transforms = transforms

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, index):
        filename = self.filenames[index]
        dlatent = self.dlatents[index]

        image = Image.open(filename)
        if self.transforms:
            image = self.transforms(image)

        return image, dlatent

In [0]:
from InterFaceGAN.models.stylegan_generator import StyleGANGenerator
from torchvision import transforms
import matplotlib.pyplot as plt
import torch
from glob import glob
from tqdm import tqdm_notebook as tqdm
import numpy as np

In [0]:
augments = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

image_size = 256

directory = "../StyleGan/InterFaceGAN/test15/"
filenames = sorted(glob(directory + "*.jpg"))

train_filenames = filenames[0:48000]
validation_filenames = filenames[48000:]

dlatents = np.load(directory + "wp.npy")

train_dlatents = dlatents[0:48000]
validation_dlatents = dlatents[48000:]

train_dataset = ImageLatentDataset(train_filenames, train_dlatents, transforms=augments)
validation_dataset = ImageLatentDataset(validation_filenames, validation_dlatents, transforms=augments)

train_generator = torch.utils.data.DataLoader(train_dataset, batch_size=32)
validation_generator = torch.utils.data.DataLoader(validation_dataset, batch_size=32)

In [0]:
image_to_latent = InitialLatentPredictor(image_size).cuda()
optimizer = torch.optim.Adam(image_to_latent.parameters())
criterion = LogCoshLoss()

In [0]:
epochs = 20
validation_loss = 0.0

progress_bar = tqdm(range(epochs))
for epoch in progress_bar:    
    running_loss = 0.0
    
    image_to_latent.train()
    for i, (images, latents) in enumerate(train_generator, 1):
        optimizer.zero_grad()

        images, latents = images.cuda(), latents.cuda()
        pred_latents = image_to_latent(images)
        loss = criterion(pred_latents, latents)
        loss.backward()
        
        optimizer.step()
        
        running_loss += loss.item()
        progress_bar.set_description("Step: {0}, Loss: {1:4f}, Validation Loss: {2:4f}".format(i, running_loss / i, validation_loss))
    
    validation_loss = 0.0
    
    image_to_latent.eval()
    for i, (images, latents) in enumerate(validation_generator, 1):
        with torch.no_grad():
            images, latents = images.cuda(), latents.cuda()
            pred_latents = image_to_latent(images)
            loss =  criterion(pred_latents, latents)
            
            validation_loss += loss.item()
    
    validation_loss /= i
    progress_bar.set_description("Step: {0}, Loss: {1:4f}, Validation Loss: {2:4f}".format(i, running_loss / i, validation_loss))

In [0]:
torch.save(image_to_latent.state_dict(), "./image_to_latent.pt")

In [0]:
image_to_latent = InitialLatentPredictor(image_size).cuda()
image_to_latent.load_state_dict(torch.load("image_to_latent.pt"))
image_to_latent.eval()
print()

In [0]:
def normalized_to_normal_image(image):
    mean=torch.tensor([0.485, 0.456, 0.406]).view(-1,1,1).float()
    std=torch.tensor([0.229, 0.224, 0.225]).view(-1,1,1).float()
    
    image = image.detach().cpu()
    
    image *= std
    image += mean
    image *= 255
    
    image = image.numpy()[0]
    image = np.transpose(image, (1,2,0))
    return image.astype(np.uint8)


num_test_images = 5
images = [validation_dataset[i][0].unsqueeze(0).cuda() for i in range(num_test_images)]
normal_images = list(map(normalized_to_normal_image, images))

pred_dlatents = map(image_to_latent, images)

synthesizer = StyleGANGenerator("stylegan_ffhq").model.synthesis
post_process = lambda image: denormalize(image).detach().cpu().numpy().astype(np.uint8)[0]

pred_images = map(synthesizer, pred_dlatents)
pred_images = map(post_process, pred_images)
pred_images = list(map(lambda image: np.transpose(image, (1,2,0)), pred_images))

In [0]:
figure = plt.figure(figsize=(25,10))
columns = len(normal_images)
rows = 2

axis = []

for i in range(columns):
    axis.append(figure.add_subplot(rows, columns, i + 1))
    axis[-1].set_title("Reference Image")
    plt.imshow(normal_images[i])

for i in range(columns, columns*rows):
    axis.append(figure.add_subplot(rows, columns, i + 1))
    axis[-1].set_title("Generated With Predicted Latents")
    plt.imshow(pred_images[i - columns])

plt.show()