# Installations and Imports

In [None]:
!pip install facenet-pytorch lpips torchmetrics --force-reinstall --no-cache-dir

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import pickle
from facenet_pytorch import InceptionResnetV1
from torchvision import transforms
from torchvision import models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import math
from tqdm import tqdm

from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio
import lpips

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
model = InceptionResnetV1(pretrained='vggface2').eval().to(device)
transform = transforms.Compose([
    transforms.Resize((160,160)),
    transforms.ToTensor()
])

# EDA

In [None]:
# load dictionary of {filename: embedding_vector}
with open("embeddings.pkl", "rb") as f:
    embeddings = pickle.load(f)

# access one entry
print(f"Number of samples: {len(embeddings)}")
print(f"File names: {list(embeddings.keys())[:5]}")
vec = embeddings["00001.jpg"]
emb_target = torch.tensor(vec).unsqueeze(0).to(device)  # shape [1,512]
print(f"embedding shape: {emb_target.shape}")

In [None]:
img_00001 = Image.open("00001.jpg").convert("RGB")
x_00001 = transform(img_00001).unsqueeze(0).to(device)
img_00002 = Image.open("00002.jpg").convert("RGB")
x_00002 = transform(img_00002).unsqueeze(0).to(device)

In [None]:
img_00001

In [None]:
img_00002

In [None]:
img_00001.size, x_00001.shape

In [None]:
emb_00001 = model(x_00001*2-1).detach().cpu()
emb_00002 = model(x_00002*2-1).detach().cpu()

In [None]:
random_data = np.random.randint(
    0, 256,
    size=(256, 256, 3),
    dtype=np.uint8
)
random_img = Image.fromarray(random_data, 'RGB')
random_img

In [None]:
random_image_tensor = transform(random_img).unsqueeze(0).to(device)

In [None]:
random_image_emb = model(random_image_tensor*2-1).detach().cpu()

In [None]:
random_image_emb.shape, emb_00001.shape

In [None]:
cos_sim = nn.functional.cosine_similarity(random_image_emb, emb_00001).item()
cos_sim

Because the random image has no relationship with the face embedding, the cosine similarity is close to 0

In [None]:
cos_sim = nn.functional.cosine_similarity(emb_00001, emb_00002).item()
cos_sim

Because the 2 face embeddings are intentionally different, their cosine similarity is closer to -1 which means closer to opposite.

# General Functions

## Initialization Functions

In [None]:
def get_initial_image_tensor_uniform(size=160):
    tensor = torch.rand(1, 3, size, size, device=device)
    tensor.requires_grad_(True)
    return tensor

In [None]:
def get_initial_image_tensor_gaussian(size=160):
    tensor = torch.randn(1, 3, size, size, device=device)
    tensor.requires_grad_(True)
    return tensor

In [None]:
def get_initial_image_tensor_constant_and_gaussian(size=160, color_value=0.0):
    tensor = torch.full((1, 3, size, size), fill_value=color_value, device=device)
    # Add a tiny bit of noise to break symmetry (helps gradients start moving)
    tensor = tensor + (torch.randn_like(tensor) * 0.01)
    tensor.requires_grad_(True)
    return tensor

## Visualization Functions

In [None]:
def display_grid_graphs(metrics_dict, n_cols=2, steps_log=None, log_scale_keys=None, figsize=None):
    """
    Plots multiple graphs in a grid.

    Args:
        metrics_dict (dict): Dictionary where Key is the Title and Value is the list of data.
        n_cols (int): Number of columns in the grid.
        steps_log (list): list of step jumps. If None, include all the steps.
        log_scale_keys (list): List of keys from metrics_dict that should be plotted in log scale.
        figsize (tuple): Optional custom size (width, height). If None, calculates automatically.
    """
    if steps_log is None:
        steps_log = list(range(len(next(iter(metrics_dict.values())))))

    if log_scale_keys is None:
        log_scale_keys = []

    # Calculate Grid Dimensions
    n = len(metrics_dict)
    n_rows = math.ceil(n / n_cols)

    # Auto-calculate figure size if not provided
    if figsize is None:
        figsize = (4 * n_cols, 3 * n_rows)

    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    if n == 1:
        axes = [axes]
    else:
        axes = axes.flatten()

    # Plot Data
    for i, (label, values) in enumerate(metrics_dict.items()):
        ax = axes[i]
        ax.plot(steps_log, values)

        ax.set_title(f"{label} per Step")
        ax.set_xlabel("Step")
        ax.set_ylabel(label)
        ax.grid(True, alpha=0.3)

        if label in log_scale_keys:
            ax.set_yscale('log')

    # Hide empty subplots (if n is not a perfect multiple of n_cols)
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
def save_and_display_image(image, filename):
    final_image = torch.tanh(image.detach().cpu().squeeze(0))
    final_image = (final_image * 0.5) + 0.5

    final_image = transforms.ToPILImage()(final_image)
    final_image.save(filename)
    display(final_image)

## Evaluation Functions

In [None]:
lpips_metric = lpips.LPIPS(net='vgg')
# < 0.25 high similarity
# > 0.7 different images

psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device)
# > 30 dB: High quality (hard to distinguish difference).
# 20-30 dB: Acceptable quality.
# < 20 dB: Poor quality (very noisy).

ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
# 1.0: Identical images.
# > 0.9: Very structurally similar.

In [None]:
def evaluate_and_log(i, iterations, current_img, target_img, history, freq=20):
    """
    Evaluates metrics and updates history lists in-place.

    Args:
        i (int): Current iteration.
        iterations (int): Total iterations.
        current_img (Tensor): The normalized image (output of tanh, [-1, 1]).
        target_img (Tensor): The target image ([0, 1]).
        history (tuple): (lpips_list, psnr_list, ssim_list, steps).
        freq (int): Log frequency.
    """
    if i % freq != 0 and i != iterations - 1:
        return

    lpips_list, psnr_list, ssim_list, steps = history

    with torch.no_grad():
        # Convert [-1, 1] -> [0, 1]
        val_img = (current_img * 0.5) + 0.5
        tgt_img = target_img

        # Clamp to ensure numerical stability (fix float errors like -0.0001 or 1.0001)
        val_img = val_img.clamp(0, 1)
        tgt_img = tgt_img.clamp(0, 1)

        lpips_list.append(get_lpips_dist(val_img, tgt_img).item())
        psnr_list.append(psnr_metric(val_img, tgt_img).item())
        ssim_list.append(ssim_metric(val_img, tgt_img).item())
        steps.append(i)

# Defining the target image

In [None]:
target_embedding = emb_00001.to(device)
target_image = x_00001

# Optimization Attempts

## Attempt 1: uniform initialization

In [None]:
experiment_name = "exp1_uniform_init"
image = get_initial_image_tensor_uniform()
iterations = 200
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
cosine_similarity_list = []

lpips_list = []
psnr_list = []
ssim_list = []
steps = []
history_lists = (lpips_list, psnr_list, ssim_list, steps)

for i in tqdm(range(iterations)):
    optimizer.zero_grad()

    normalized_image = torch.tanh(image)
    current_embedding = model(normalized_image)
    loss = loss_fn(current_embedding, target_embedding)
    cos_sim = nn.functional.cosine_similarity(current_embedding, target_embedding).item()

    evaluate_and_log(i, iterations, normalized_image, target_image, history_lists, freq=20)

    if cos_sim > 0.95:
        print(f" Early stopping at iteration: {i}! Cosine Similarity: {cos_sim:.6f}")
        break

    loss.backward()
    optimizer.step()

    if i == 0 or (i + 1) % int(iterations / 10) == 0:
        print(f" Step [{i+1}/{iterations}], Loss: {loss.item():.6f}, cos sim: {cos_sim:.6f}")

    loss_list.append(loss.item())
    cosine_similarity_list.append(cos_sim)

final_image = ((normalized_image * 0.5) + 0.5).detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_grid_graphs({
    "Loss": loss_list,
    "Cosine Similarity": cosine_similarity_list
}, n_cols=3)

In [None]:
display_grid_graphs({
    "LPIPS": lpips_list,
    "PSNR": psnr_list,
    "SSIM": ssim_list
}, n_cols=3, steps_log=steps)

In [None]:

save_and_display_image(final_image, f"{experiment_name}.png")

In [None]:
print(f"Final cosine similarity: {nn.functional.cosine_similarity(final_embedding, target_embedding).item()}")

## Attempt 2: gaussian initialization

In [None]:
experiment_name = "exp2_gaussian_init"
image = get_initial_image_tensor_gaussian()
iterations = 200
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
cosine_similarity_list = []

lpips_list = []
psnr_list = []
ssim_list = []
steps = []
history_lists = (lpips_list, psnr_list, ssim_list, steps)

for i in tqdm(range(iterations)):
    optimizer.zero_grad()

    normalized_image = torch.tanh(image)
    current_embedding = model(normalized_image)
    loss = loss_fn(current_embedding, target_embedding)
    cos_sim = nn.functional.cosine_similarity(current_embedding, target_embedding).item()

    evaluate_and_log(i, iterations, normalized_image, target_image, history_lists, freq=20)

    if cos_sim > 0.95:
        print(f" Early stopping at iteration: {i}! Cosine Similarity: {cos_sim:.6f}")
        break

    loss.backward()
    optimizer.step()

    if i == 0 or (i + 1) % int(iterations / 10) == 0:
        print(f" Step [{i+1}/{iterations}], Loss: {loss.item():.6f}, cos sim: {cos_sim:.6f}")

    loss_list.append(loss.item())
    cosine_similarity_list.append(cos_sim)

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_grid_graphs({
    "Loss": loss_list,
    "Cosine Similarity": cosine_similarity_list
}, n_cols=3)

In [None]:
display_grid_graphs({
    "LPIPS": lpips_list,
    "PSNR": psnr_list,
    "SSIM": ssim_list
}, n_cols=3, steps_log=steps)

In [None]:
save_and_display_image(final_image, f"{experiment_name}.png")

## Attempt 3: constant initialization and a bit of gaussian noise

In [None]:
experiment_name = "exp3_small_gaussian_init"
image = get_initial_image_tensor_constant_and_gaussian()
iterations = 200
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
cosine_similarity_list = []

lpips_list = []
psnr_list = []
ssim_list = []
steps = []
history_lists = (lpips_list, psnr_list, ssim_list, steps)

for i in tqdm(range(iterations)):
    optimizer.zero_grad()

    normalized_image = torch.tanh(image)
    current_embedding = model(normalized_image)
    loss = loss_fn(current_embedding, target_embedding)
    cos_sim = nn.functional.cosine_similarity(current_embedding, target_embedding).item()

    evaluate_and_log(i, iterations, normalized_image, target_image, history_lists, freq=20)

    if cos_sim > 0.95:
        print(f" Early stopping at iteration: {i}! Cosine Similarity: {cos_sim:.6f}")
        break

    loss.backward()
    optimizer.step()

    if i == 0 or (i + 1) % int(iterations / 10) == 0:
        print(f" Step [{i+1}/{iterations}], Loss: {loss.item():.6f}, cos sim: {cos_sim:.6f}")

    loss_list.append(loss.item())
    cosine_similarity_list.append(cos_sim)

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_grid_graphs({
    "Loss": loss_list,
    "Cosine Similarity": cosine_similarity_list
}, n_cols=3)

In [None]:
display_grid_graphs({
    "LPIPS": lpips_list,
    "PSNR": psnr_list,
    "SSIM": ssim_list
}, n_cols=3, steps_log=steps)

In [None]:
save_and_display_image(final_image, f"{experiment_name}.png")

## Attempt 4: adding Total Variation to the loss function

In [None]:
def get_tv_loss(img_tensor):
    """
    Computes Total Variation Loss.
    Expected input shape: (Batch, Channels, Height, Width)
    """
    # Calculate horizontal differences (between columns)
    # Select all columns except the last one, minus all columns except the first one
    diff_h = img_tensor[:, :, :, :-1] - img_tensor[:, :, :, 1:]

    # Calculate vertical differences (between rows)
    # Select all rows except the last one, minus all rows except the first one
    diff_w = img_tensor[:, :, :-1, :] - img_tensor[:, :, 1:, :]

    # Sum the absolute differences
    tv_loss = torch.sum(torch.abs(diff_h)) + torch.sum(torch.abs(diff_w))

    return tv_loss

In [None]:
experiment_name = "exp4_small_gausian_init_and_tv_weight_4e-07"
image = get_initial_image_tensor_constant_and_gaussian()
iterations = 1000
tv_weight = 4e-7
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
cosine_similarity_list = []

lpips_list = []
psnr_list = []
ssim_list = []
steps = []
history_lists = (lpips_list, psnr_list, ssim_list, steps)

for i in tqdm(range(iterations)):
    optimizer.zero_grad()

    normalized_image = torch.tanh(image)
    current_embedding = model(normalized_image)

    loss_mse = loss_fn(current_embedding, target_embedding)
    loss_tv = get_tv_loss(image)
    loss = loss_mse + (tv_weight * loss_tv)
    cos_sim = nn.functional.cosine_similarity(current_embedding, target_embedding).item()

    evaluate_and_log(i, iterations, normalized_image, target_image, history_lists, freq=20)

    if cos_sim > 0.95:
        print(f" Early stopping at iteration: {i}! Cosine Similarity: {cos_sim:.6f}")
        break

    loss.backward()
    optimizer.step()

    if i == 0 or (i + 1) % int(iterations / 10) == 0:
        print(f" Step [{i+1}/{iterations}], Loss: {loss.item():.6f}, cos sim: {cos_sim:.6f}")

    loss_list.append(loss.item())
    cosine_similarity_list.append(cos_sim)

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_grid_graphs({
    "Loss": loss_list,
    "Cosine Similarity": cosine_similarity_list
}, n_cols=3)

In [None]:
display_grid_graphs({
    "LPIPS": lpips_list,
    "PSNR": psnr_list,
    "SSIM": ssim_list
}, n_cols=3, steps_log=steps)

In [None]:
save_and_display_image(final_image, f"{experiment_name}.png")

## Attempt 5: gaussian initialization with tv loss

In [None]:
experiment_name = "exp5_gaussian_init_and_tv_weight_1e-06"
image = get_initial_image_tensor_gaussian()
iterations = 1000
tv_weight = 1e-6
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
cosine_similarity_list = []

lpips_list = []
psnr_list = []
ssim_list = []
steps = []
history_lists = (lpips_list, psnr_list, ssim_list, steps)

for i in tqdm(range(iterations)):
    optimizer.zero_grad()

    normalized_image = torch.tanh(image)
    current_embedding = model(normalized_image)

    loss_mse = loss_fn(current_embedding, target_embedding)
    loss_tv = get_tv_loss(image)
    loss = loss_mse + (tv_weight * loss_tv)
    cos_sim = nn.functional.cosine_similarity(current_embedding, target_embedding).item()

    evaluate_and_log(i, iterations, normalized_image, target_image, history_lists, freq=20)

    if cos_sim > 0.95:
        print(f" Early stopping at iteration: {i}! Cosine Similarity: {cos_sim:.6f}")
        break

    loss.backward()
    optimizer.step()

    if i == 0 or (i + 1) % int(iterations / 10) == 0:
        print(f" Step [{i+1}/{iterations}], Loss: {loss.item():.6f}, cos sim: {cos_sim:.6f}")

    loss_list.append(loss.item())
    cosine_similarity_list.append(cos_sim)

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_grid_graphs({
    "Loss": loss_list,
    "Cosine Similarity": cosine_similarity_list
}, n_cols=3)

In [None]:
display_grid_graphs({
    "LPIPS": lpips_list,
    "PSNR": psnr_list,
    "SSIM": ssim_list
}, n_cols=3, steps_log=steps)

In [None]:
save_and_display_image(final_image, f"{experiment_name}.png")

## Attempt 6: man face initialization

In [None]:
image_pil = Image.open("man_face.png").convert("RGB")
image = transform(image_pil).unsqueeze(0).to(device).requires_grad_(True)

In [None]:
image_pil

In [None]:
experiment_name = "exp6_man_face_init_and_tv_weight_1e-07_a"
iterations = 1000
tv_weight = 1e-7
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
cosine_similarity_list = []

lpips_list = []
psnr_list = []
ssim_list = []
steps = []
history_lists = (lpips_list, psnr_list, ssim_list, steps)

for i in tqdm(range(iterations)):
    optimizer.zero_grad()

    normalized_image = torch.tanh(image)
    current_embedding = model(normalized_image)

    loss_mse = loss_fn(current_embedding, target_embedding)
    loss_tv = get_tv_loss(image)
    loss = loss_mse + (tv_weight * loss_tv)
    cos_sim = nn.functional.cosine_similarity(current_embedding, target_embedding).item()

    evaluate_and_log(i, iterations, normalized_image, target_image, history_lists, freq=20)

    if cos_sim > 0.95:
        print(f" Early stopping at iteration: {i}! Cosine Similarity: {cos_sim:.6f}")
        break

    loss.backward()
    optimizer.step()

    if i == 0 or (i + 1) % int(iterations / 10) == 0:
        print(f" Step [{i+1}/{iterations}], Loss: {loss.item():.6f}, cos sim: {cos_sim:.6f}")

    loss_list.append(loss.item())
    cosine_similarity_list.append(cos_sim)

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_grid_graphs({
    "Loss": loss_list,
    "Cosine Similarity": cosine_similarity_list
}, n_cols=3)

In [None]:
display_grid_graphs({
    "LPIPS": lpips_list,
    "PSNR": psnr_list,
    "SSIM": ssim_list
}, n_cols=3, steps_log=steps)

In [None]:
save_and_display_image(final_image, f"{experiment_name}.png")

## Attempt 7: adding image jittering

In [None]:
class RandomJitter(nn.Module):
    def __init__(self, lim=30):
        super(RandomJitter, self).__init__()
        self.lim = lim

    def forward(self, img_tensor):
        B, C, H, W = img_tensor.shape
        padded = nn.functional.pad(img_tensor, (self.lim, self.lim, self.lim, self.lim), mode='reflect')

        sx = torch.randint(0, 2 * self.lim, (1,)).item()
        sy = torch.randint(0, 2 * self.lim, (1,)).item()

        jittered = padded[:, :, sy:sy+H, sx:sx+W]

        return jittered

In [None]:
experiment_name = "exp7_tv_weight_1e-07_and_jittering"
image = get_initial_image_tensor_constant_and_gaussian()
iterations = 1000
tv_weight = 4e-7
jitter = RandomJitter(lim=30)
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
cosine_similarity_list = []

lpips_list = []
psnr_list = []
ssim_list = []
steps = []
history_lists = (lpips_list, psnr_list, ssim_list, steps)

for i in tqdm(range(iterations)):
    optimizer.zero_grad()

    normalized_image = torch.tanh(image)
    jittered_image = jitter(normalized_image)
    current_embedding = model(jittered_image)

    loss_mse = loss_fn(current_embedding, target_embedding)
    loss_tv = get_tv_loss(image)
    loss = loss_mse + (tv_weight * loss_tv)
    cos_sim = nn.functional.cosine_similarity(current_embedding, target_embedding).item()

    evaluate_and_log(i, iterations, normalized_image, target_image, history_lists, freq=20)

    if cos_sim > 0.95:
        print(f" Early stopping at iteration: {i}! Cosine Similarity: {cos_sim:.6f}")
        break

    loss.backward()
    optimizer.step()

    if i == 0 or (i + 1) % int(iterations / 10) == 0:
        print(f" Step [{i+1}/{iterations}], Loss: {loss.item():.6f}, cos sim: {cos_sim:.6f}")

    loss_list.append(loss.item())
    cosine_similarity_list.append(cos_sim)

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_grid_graphs({
    "Loss": loss_list,
    "Cosine Similarity": cosine_similarity_list
}, n_cols=3)

In [None]:
display_grid_graphs({
    "LPIPS": lpips_list,
    "PSNR": psnr_list,
    "SSIM": ssim_list
}, n_cols=3, steps_log=steps)

In [None]:
save_and_display_image(final_image, f"{experiment_name}.png")

## Attempt 8: Decaying tv weight

In [None]:
experiment_name = "exp8_decaying_tv_weight_1e-06_to_0_and_jittering"
image = get_initial_image_tensor_constant_and_gaussian()
iterations = 1000
tv_weight_per_step = [1e-6] * 250 + [1e-7] * 250 + [1e-8] * 250 + [0] * 250
jitter = RandomJitter(lim=30)
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
cosine_similarity_list = []

lpips_list = []
psnr_list = []
ssim_list = []
steps = []
history_lists = (lpips_list, psnr_list, ssim_list, steps)

for i in tqdm(range(iterations)):
    optimizer.zero_grad()

    normalized_image = torch.tanh(image)
    jittered_image = jitter(normalized_image)
    current_embedding = model(jittered_image)

    loss_mse = loss_fn(current_embedding, target_embedding)
    loss_tv = get_tv_loss(image)
    tv_weight = tv_weight_per_step[i]
    loss = loss_mse + (tv_weight * loss_tv)
    cos_sim = nn.functional.cosine_similarity(current_embedding, target_embedding).item()

    evaluate_and_log(i, iterations, normalized_image, target_image, history_lists, freq=20)

    if cos_sim > 0.95:
        print(f" Early stopping at iteration: {i}! Cosine Similarity: {cos_sim:.6f}")
        break

    loss.backward()
    optimizer.step()

    if i == 0 or (i + 1) % int(iterations / 10) == 0:
        print(f" Step [{i+1}/{iterations}], Loss: {loss.item():.6f}, cos sim: {cos_sim:.6f}")

    loss_list.append(loss.item())
    cosine_similarity_list.append(cos_sim)

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_grid_graphs({
    "Loss": loss_list,
    "Cosine Similarity": cosine_similarity_list
}, n_cols=3)

In [None]:
display_grid_graphs({
    "LPIPS": lpips_list,
    "PSNR": psnr_list,
    "SSIM": ssim_list
}, n_cols=3, steps_log=steps)

In [None]:
save_and_display_image(final_image, f"{experiment_name}.png")

## Attempt 9: adding perceptual loss

In [None]:
class VGGPerceptualDist(nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualDist, self).__init__()

        # Load VGG16
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features

        # Slicing up to layer 16 (ReLU3_3) is standard.
        self.blocks = nn.Sequential(*list(vgg.children())[:16]).eval()

        # Freeze the model weights
        for param in self.blocks.parameters():
            param.requires_grad = False

        # VGG specific normalization
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(device)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(device)
        self.resize = resize

    def forward(self, generated_img, target_img):
        # Assuming the images are in [0, 1] range:
        gen_norm = (generated_img - self.mean) / self.std
        target_norm = (target_img - self.mean) / self.std

        # Extract features
        gen_features = self.blocks(gen_norm)
        target_features = self.blocks(target_norm)

        # Calculate L2 loss between the feature maps
        loss = torch.nn.functional.mse_loss(gen_features, target_features)
        return loss

get_perc_dist = VGGPerceptualDist().to(device)

In [None]:
man_face_pil = Image.open("man_face.png").convert("RGB")
man_face = transform(image_pil).unsqueeze(0).to(device)

In [None]:
experiment_name = "exp9_tv_weight_4e-07_and_jittering_and_perceptual_with_man_face"
image = get_initial_image_tensor_constant_and_gaussian()
iterations = 1000
tv_weight = 4e-7
jitter = RandomJitter(lim=30)
perceptual_target_image = man_face
perceptual_weight = 1e-3
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
cosine_similarity_list = []

lpips_list = []
psnr_list = []
ssim_list = []
steps = []
history_lists = (lpips_list, psnr_list, ssim_list, steps)

for i in tqdm(range(iterations)):
    optimizer.zero_grad()

    normalized_image = torch.tanh(image)
    jittered_image = jitter(normalized_image)
    current_embedding = model(jittered_image)

    loss_mse = loss_fn(current_embedding, target_embedding)
    loss_tv = get_tv_loss(image)
    loss_perceptual = get_perc_dist(image, perceptual_target_image)
    loss = loss_mse + (tv_weight * loss_tv) + (perceptual_weight * loss_perceptual)
    cos_sim = nn.functional.cosine_similarity(current_embedding, target_embedding).item()

    evaluate_and_log(i, iterations, normalized_image, target_image, history_lists, freq=20)

    if cos_sim > 0.95:
        print(f" Early stopping at iteration: {i}! Cosine Similarity: {cos_sim:.6f}")
        break

    loss.backward()
    optimizer.step()

    if i == 0 or (i + 1) % int(iterations / 10) == 0:
        print(f" Step [{i+1}/{iterations}], Loss: {loss.item():.6f}, cos sim: {cos_sim:.6f}")

    loss_list.append(loss.item())
    cosine_similarity_list.append(cos_sim)

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_grid_graphs({
    "Loss": loss_list,
    "Cosine Similarity": cosine_similarity_list
}, n_cols=3)

In [None]:
display_grid_graphs({
    "LPIPS": lpips_list,
    "PSNR": psnr_list,
    "SSIM": ssim_list
}, n_cols=3, steps_log=steps)

In [None]:
save_and_display_image(final_image, f"{experiment_name}.png")

## Attempt 10: adding perceptual style loss

In [None]:
class VGGStyleLoss(nn.Module):
    def __init__(self):
        super(VGGStyleLoss, self).__init__()
        # Load VGG
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features

        # List of sub-models ending at different depths
        self.slice1 = nn.Sequential(*list(vgg.children())[:2])   # Low level (colors)
        self.slice2 = nn.Sequential(*list(vgg.children())[:7])   # Textures
        self.slice3 = nn.Sequential(*list(vgg.children())[:12])  # Shapes
        self.slice4 = nn.Sequential(*list(vgg.children())[:21])  # Deep features

        # Freeze model
        for p in self.parameters():
            p.requires_grad = False

        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(device)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(device)

    @ staticmethod
    def gram_matrix(input_tensor):
        batch, channels, height, width = input_tensor.size()

        # Reshape so we can multiply features
        features = input_tensor.view(batch * channels, height * width)

        # Calculate the dot product
        G = torch.mm(features, features.t())

        # Normalize by the number of elements to keep values small
        return G.div(batch * channels * height * width)

    def forward(self, generated_img, guide_img):
        gen = (generated_img - self.mean) / self.std
        guide = (guide_img - self.mean) / self.std

        loss = 0
        # Pass through each slice
        for slice_net in [self.slice1, self.slice2, self.slice3, self.slice4]:
            gen_feat = slice_net(gen)
            guide_feat = slice_net(guide)

            # Compare Gram Matrices
            gen_gram = VGGStyleLoss.gram_matrix(gen_feat)
            guide_gram = VGGStyleLoss.gram_matrix(guide_feat)

            loss += torch.nn.functional.mse_loss(gen_gram, guide_gram)

        return loss

In [None]:
man_face_pil = Image.open("man_face.png").convert("RGB")
man_face = transform(image_pil).unsqueeze(0).to(device)

In [None]:
img_00002

In [None]:
experiment_name = "exp10_tv_weight_4e-07_and_jittering_and_style_with_man_face"
image = x_00002.requires_grad_(True)#get_initial_image_tensor_constant_and_gaussian()
iterations = 1000
tv_weight = 4e-7
jitter = RandomJitter(lim=30)
perceptual_criterion = VGGPerceptualLoss().to(device)
style_criterion = VGGStyleLoss().to(device)
perceptual_target_image = man_face
perceptual_weight = 1e2
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
cosine_similarity_list = []

lpips_list = []
psnr_list = []
ssim_list = []
steps = []
history_lists = (lpips_list, psnr_list, ssim_list, steps)

for i in tqdm(range(iterations)):
    optimizer.zero_grad()

    normalized_image = torch.tanh(image)
    current_embedding = model(normalized_image)

    loss_mse = loss_fn(current_embedding, target_embedding)
    loss_tv = get_tv_loss(image)
    loss_style = style_criterion(image, perceptual_target_image)
    loss = loss_mse + (tv_weight * loss_tv) + (perceptual_weight * loss_style)
    cos_sim = nn.functional.cosine_similarity(current_embedding, target_embedding).item()

    evaluate_and_log(i, iterations, normalized_image, target_image, history_lists, freq=20)

    if cos_sim > 0.95:
        print(f" Early stopping at iteration: {i}! Cosine Similarity: {cos_sim:.6f}")
        break

    loss.backward()
    optimizer.step()

    if i == 0 or (i + 1) % int(iterations / 10) == 0:
        print(f" Step [{i+1}/{iterations}], Loss: {loss.item():.6f}, cos sim: {cos_sim:.6f}")

    loss_list.append(loss.item())
    cosine_similarity_list.append(cos_sim)

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_grid_graphs({
    "Loss": loss_list,
    "Cosine Similarity": cosine_similarity_list
}, n_cols=3)

In [None]:
display_grid_graphs({
    "LPIPS": lpips_list,
    "PSNR": psnr_list,
    "SSIM": ssim_list
}, n_cols=3, steps_log=steps)

In [None]:
save_and_display_image(final_image, f"{experiment_name}.png")