# Imports

In [None]:
!git clone https://github.com/NVlabs/stylegan2-ada-pytorch.git

In [None]:
!pip install facenet-pytorch --force-reinstall --no-cache-dir ninja

In [None]:
import sys
sys.path.insert(0, "/content/stylegan2-ada-pytorch")

In [None]:
import os

# Create a folder for models
os.makedirs('models', exist_ok=True)

# Download the stylegan2-ada-pytorch FFHQ model (resolution 1024x1024)
# This is hosted by NVIDIA
!wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl -O models/ffhq.pkl

print("Download complete.")

In [None]:
import torch
import pickle
import copy
import dnnlib
import legacy # From the cloned repo

import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from facenet_pytorch import InceptionResnetV1
from torchvision import transforms
from torchvision import models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial.distance import cosine
from tqdm import tqdm

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
# load dictionary of {filename: embedding_vector}
with open("embeddings.pkl", "rb") as f:
    embeddings = pickle.load(f)

In [None]:
class StyleGANGenerator(torch.nn.Module):
    def __init__(self, network_pkl):
        super(StyleGANGenerator, self).__init__()
        print(f'Loading network from "{network_pkl}"...')

        with dnnlib.util.open_url(network_pkl) as f:
            # Load the network from the pickle file
            self.G = legacy.load_network_pkl(f)['G_ema'].to(device)

        # Lock the weights (we never train the generator itself)
        self.G.eval()
        for param in self.G.parameters():
            param.requires_grad = False

        # Store useful constants
        self.w_dim = self.G.w_dim  # Usually 512
        self.num_ws = self.G.mapping.num_ws # Usually 18 for 1024x1024
        print(f'Loaded network! (w_dim: {self.w_dim}, num_ws: {self.num_ws})')

    def forward(self, w_plus_vector):
        """
        Input: w_plus_vector of shape (Batch, 18, 512)
        Output: Image tensor (Batch, 3, 1024, 1024) in range [-1, 1]
        """
        # synthesis() expects input to be split by layers, but w+ is already shaped correctly
        # noise_mode='const' means we don't add random noise to hair/pores every time (deterministic)
        img = self.G.synthesis(w_plus_vector, noise_mode='const')
        return img

    def get_mean_w(self, n_samples=4096):
        """
        Get the average latent code (W space).
        Optimizing starting from the Mean Face is much faster/easier.
        """
        z = torch.randn(n_samples, self.G.z_dim, device=device)
        w = self.G.mapping(z, None) # Convert z to w
        w_avg = w.mean(0, keepdim=True)

        return w_avg

# Initialize the model
generator = StyleGANGenerator('models/ffhq.pkl')
print("Generator Loaded Successfully!")

In [None]:
# 1. Get the mean latent code
w_mean = generator.get_mean_w()

# 2. Generate the image
with torch.no_grad():
    generated_img_tensor = generator(w_mean)

# 3. Convert from [-1, 1] range to [0, 1] for visualization
# StyleGAN output is (B, 3, H, W)
vis_img = (generated_img_tensor.clamp(-1, 1) + 1) / 2.0
vis_img = vis_img[0].cpu() # Take first item in batch

# 4. Show it
plt.imshow(vis_img.permute(1, 2, 0).numpy())
plt.axis('off')
plt.title("The Average Person (Mean Face)")
plt.show()

In [None]:
model = InceptionResnetV1(pretrained='vggface2').eval().to(device)
transform = transforms.Compose([
    transforms.Resize((160,160)),
    transforms.ToTensor()
])

In [None]:
img_00001 = Image.open("00001.jpg").convert("RGB")
x_00001 = transform(img_00001).unsqueeze(0).to(device)
img_00002 = Image.open("00002.jpg").convert("RGB")
x_00002 = transform(img_00002).unsqueeze(0).to(device)

In [None]:
emb_00001 = model(x_00001*2-1).detach().cpu().numpy()[0]
emb_00002 = model(x_00002*2-1).detach().cpu().numpy()[0]

In [None]:
target_embedding = torch.tensor(emb_00001, dtype=torch.float32).unsqueeze(0).to(device)
target_image = x_00001

In [None]:
class VGGPerceptualLoss(nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__()

        # Load VGG16
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features

        # Slicing up to layer 16 (ReLU3_3) is standard.
        self.blocks = nn.Sequential(*list(vgg.children())[:16]).eval()

        # Freeze the model weights
        for param in self.blocks.parameters():
            param.requires_grad = False

        # VGG specific normalization
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(device)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(device)
        self.resize = resize

    def forward(self, generated_img, target_img):
        # Assuming the images are in [0, 1] range:
        gen_norm = (generated_img - self.mean) / self.std
        target_norm = (target_img - self.mean) / self.std

        # Extract features
        gen_features = self.blocks(gen_norm)
        target_features = self.blocks(target_norm)

        # Calculate L2 loss between the feature maps
        loss = torch.nn.functional.mse_loss(gen_features, target_features)
        return loss

perceptual_criterion = VGGPerceptualLoss().to(device)

In [None]:
def display_loss_graph(loss_list, log_scale=False):
    plt.plot(loss_list)
    if log_scale:
        plt.yscale('log')
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.title("Loss per Step")
    plt.show()

In [None]:
def save_and_display_image(image, filename):
    final_image = torch.tanh(image.detach().cpu().squeeze(0))
    final_image = (final_image * 0.5) + 0.5

    final_image = transforms.ToPILImage()(final_image)
    final_image.save(filename)
    display(final_image)

### Setup

In [None]:
iterations = 1000

# A. Initialize the Latent Code (The "Input" we optimize)
# We start with the Mean W because it's the most stable starting point.
w_avg = generator.get_mean_w() # Shape: (1, 18, 512)

# Make a copy that requires gradients
latent_code = w_avg.clone().detach().to(device)
latent_code.requires_grad = True

# B. The Optimizer
# We optimize the latent code, NOT the image.
# Note: Learning rate for W space is usually higher (0.01 to 0.1) than pixel optimization.
optimizer = optim.Adam([latent_code], lr=0.05)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.5, total_iters=iterations)

# C. Define Losses
mse_loss = torch.nn.MSELoss()
cosine_loss = torch.nn.CosineEmbeddingLoss()
# Note: For Cosine loss, we need a target label '1' (meaning "similar")
target_label = torch.tensor([1]).to(device)

reg_loss_weight = 0.001

### Attack loop

In [None]:
loss_list = []
perceptual_list = []
min_perc = np.inf
best_latent_code = None

for i in tqdm(range(iterations)):
    optimizer.zero_grad()

    # Generate the image
    generated_image_1024 = generator(latent_code)

    # Resize for FaceNet (160x160)
    generated_image_160 = F.interpolate(generated_image_1024, size=(160, 160), mode='bilinear', align_corners=False)

    # Get Embedding
    current_embedding = model(generated_image_160)

    # Calculate loss
    loss_mse = mse_loss(current_embedding, target_embedding)

    # Optional:
    # Penalize if the code gets too far from the average face (prevents "weird" artifacts)
    loss_reg = torch.mean((latent_code - w_avg) ** 2)

    # Total Loss
    total_loss = loss_mse + (reg_loss_weight * loss_reg)
    # if i < 10:
    #     print(f"loss_mse: {loss_mse:.6f}, weighted loss_reg: {(reg_loss_weight * loss_reg):.6f}")

    perceptual_input = (generated_image_160 * 0.5) + 0.5
    perceptual_with_real = perceptual_criterion(perceptual_input, target_image)

    total_loss.backward()
    optimizer.step()
    scheduler.step()

    if i == 0 or (i + 1) % 100 == 0:
        print(f"Step [{i+1}/{iterations}], Loss: {total_loss.item():.6f}, Perc: {perceptual_with_real.item():.6f}")
        # Optional: Display image periodically
        # We take the 1024 version to see the full quality result
        # viz_img = (generated_image_1024.detach().clamp(-1, 1) + 1) / 2.0
        # display(transforms.ToPILImage()(viz_img[0].cpu()))

    loss_list.append(total_loss.item())
    perceptual_list.append(perceptual_with_real.item())

    if perceptual_with_real.item() < min_perc:
        min_perc = perceptual_with_real.item()
        best_latent_code = copy.deepcopy(latent_code)

final_image = generated_image_160.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)
print("Inversion Complete.")

In [None]:
with torch.no_grad():
    final_high_res = generator(latent_code)
    # Convert to 0-1 for display
    final_img = (final_high_res.clamp(-1, 1) + 1) / 2.0

    print(perceptual_list[-1])

    # Save or Show
    plt.imshow(final_img[0].permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.show()

In [None]:
with torch.no_grad():
    final_high_res = generator(best_latent_code)
    # Convert to 0-1 for display
    final_img = (final_high_res.clamp(-1, 1) + 1) / 2.0

    print(min_perc)

    # Save or Show
    plt.imshow(final_img[0].permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.show()

In [None]:
display_loss_graph(loss_list=loss_list)

In [None]:
display_loss_graph(loss_list=perceptual_list)

In [None]:
save_and_display_image(final_image, "exp11_gmi.png")

In [None]:
cosine_dist = cosine(emb_00001, final_embedding)
cosine_dist

In [None]:
img_00001