# Imports

In [None]:
!pip install facenet-pytorch --force-reinstall --no-cache-dir

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import pickle
from facenet_pytorch import InceptionResnetV1
from torchvision import transforms
from torchvision import models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial.distance import cosine
from tqdm import tqdm

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
# load dictionary of {filename: embedding_vector}
with open("embeddings.pkl", "rb") as f:
    embeddings = pickle.load(f)

# EDA

In [None]:
# access one entry
print(len(embeddings))                # number of samples
print(list(embeddings.keys())[:5])    # filenames
vec = embeddings["00001.jpg"]        # numpy array shape (512,)

In [None]:
# convert to torch tensor when used
emb_target = torch.tensor(vec).unsqueeze(0).cuda()  # shape [1,512]
print(emb_target.shape)

In [None]:
model = InceptionResnetV1(pretrained='vggface2').eval().to(device)
transform = transforms.Compose([
    transforms.Resize((160,160)),
    transforms.ToTensor()
])

In [None]:
img_00001 = Image.open("00001.jpg").convert("RGB")
x_00001 = transform(img_00001).unsqueeze(0).to(device)
img_00002 = Image.open("00002.jpg").convert("RGB")
x_00002 = transform(img_00002).unsqueeze(0).to(device)

In [None]:
emb_00001 = model(x_00001*2-1).detach().cpu().numpy()[0]
emb_00002 = model(x_00002*2-1).detach().cpu().numpy()[0]

In [None]:
img_00001

In [None]:
img_00002

In [None]:
img_00001.size, x_00001.shape

In [None]:
random_data = np.random.randint(
    0, 256,
    size=(256, 256, 3),
    dtype=np.uint8
)
random_img = Image.fromarray(random_data, 'RGB')

In [None]:
random_img

In [None]:
random_image_tensor = transform(random_img).unsqueeze(0).to(device)

In [None]:
random_image_emb = model(random_image_tensor*2-1).detach().cpu().numpy()[0]

In [None]:
random_image_emb.shape, emb_00001.shape

In [None]:
cosine_dist = cosine(random_image_emb, emb_00001)
cosine_dist

Because the random image has no relationship with the face embedding, the distance is close to 1

In [None]:
cosine_dist = cosine(emb_00001, emb_00002)
cosine_dist

Because the 2 face embeddings are intentionally different, their distance is closer to 2 which means closer to opposite.

# General functions

In [None]:
def get_initial_image_tensor_uniform(size=160):
    tensor = torch.rand(1, 3, size, size, device=device)
    tensor.requires_grad_(True)
    return tensor

In [None]:
def get_initial_image_tensor_gaussian(size=160):
    tensor = torch.randn(1, 3, size, size, device=device)
    tensor.requires_grad_(True)
    return tensor

In [None]:
def get_initial_image_tensor_constant_and_gaussian(size=160, color_value=0.0):
    tensor = torch.full((1, 3, size, size), fill_value=color_value, device=device)
    # Add a tiny bit of noise to break symmetry (helps gradients start moving)
    tensor = tensor + (torch.randn_like(tensor) * 0.01)
    tensor.requires_grad_(True)
    return tensor

In [None]:
def display_loss_graph(loss_list, log_scale=False):
    plt.plot(loss_list)
    if log_scale:
        plt.yscale('log')
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.title("Loss per Step")
    plt.show()

In [None]:
def save_and_display_image(image, filename):
    final_image = torch.tanh(image.detach().cpu().squeeze(0))
    final_image = (final_image * 0.5) + 0.5

    final_image = transforms.ToPILImage()(final_image)
    final_image.save(filename)
    display(final_image)

In [None]:
target_embedding = torch.tensor(emb_00001, dtype=torch.float32).unsqueeze(0).to(device)

### Attempt 1: uniform initialization

In [None]:
image = get_initial_image_tensor_uniform()
num_steps = 200
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
for i in tqdm(range(num_steps)):
    optimizer.zero_grad()

    normalized_image = torch.tanh(image * 2 - 1)
    current_embedding = model(normalized_image)
    loss = loss_fn(current_embedding, target_embedding)

    loss.backward()
    optimizer.step()

    if (i + 1) % 100 == 0:
        print(f"Step [{i+1}/{num_steps}], Loss: {loss.item():.6f}")

    loss_list.append(loss.item())

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_loss_graph(loss_list=loss_list)

In [None]:
save_and_display_image(final_image, "exp1_uniform_init.png")

In [None]:
cosine_dist = cosine(emb_00001, final_embedding)
cosine_dist

### Attempt 2: gaussian initialization

In [None]:
image = get_initial_image_tensor_gaussian()
num_steps = 200
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
for i in tqdm(range(num_steps)):
    optimizer.zero_grad()

    normalized_image = torch.tanh(image * 2 - 1)
    current_embedding = model(normalized_image)
    loss = loss_fn(current_embedding, target_embedding)

    loss.backward()
    optimizer.step()

    if (i + 1) % 100 == 0:
        print(f"Step [{i+1}/{num_steps}], Loss: {loss.item():.6f}")

    loss_list.append(loss.item())

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_loss_graph(loss_list=loss_list)

In [None]:
save_and_display_image(final_image, "exp2_gaussian_init.png")

In [None]:
cosine_dist = cosine(emb_00001, final_embedding)
cosine_dist

### Attempt 3: constant initialization and a bit of gaussian noise

In [None]:
image = get_initial_image_tensor_constant_and_gaussian()
num_steps = 200
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
for i in tqdm(range(num_steps)):
    optimizer.zero_grad()

    normalized_image_input = torch.tanh(image * 2 - 1)
    current_embedding = model(normalized_image_input)
    loss = loss_fn(current_embedding, target_embedding)
    loss.backward()
    optimizer.step()

    if (i + 1) % 100 == 0:
        print(f"Step [{i+1}/{num_steps}], Loss: {loss.item():.6f}")

    loss_list.append(loss.item())

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_loss_graph(loss_list=loss_list)

In [None]:
save_and_display_image(final_image, "exp3_small_gaussian_init.png")

In [None]:
cosine_dist = cosine(emb_00001, final_embedding)
cosine_dist

### Attempt 4: Adding Total Variation to the loss function

In [None]:
def get_tv_loss(img_tensor):
    """
    Computes Total Variation Loss.
    Expected input shape: (Batch, Channels, Height, Width)
    """
    # Calculate horizontal differences (between columns)
    # Select all columns except the last one, minus all columns except the first one
    diff_h = img_tensor[:, :, :, :-1] - img_tensor[:, :, :, 1:]

    # Calculate vertical differences (between rows)
    # Select all rows except the last one, minus all rows except the first one
    diff_w = img_tensor[:, :, :-1, :] - img_tensor[:, :, 1:, :]

    # Sum the absolute differences
    tv_loss = torch.sum(torch.abs(diff_h)) + torch.sum(torch.abs(diff_w))

    return tv_loss

In [None]:
image = get_initial_image_tensor_constant_and_gaussian()
num_steps = 1000
tv_weight = 4e-7
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []

for i in tqdm(range(num_steps)):
    # Clear old gradients
    optimizer.zero_grad()

    normalized_image_input = torch.tanh(image * 2 - 1)
    current_embedding = model(normalized_image_input)

    loss_mse = loss_fn(current_embedding, target_embedding)
    loss_tv = get_tv_loss(image)
    total_loss = loss_mse + (tv_weight * loss_tv)

    total_loss.backward()

    optimizer.step()

    if (i + 1) % 100 == 0:
        print(f"Step [{i+1}/{num_steps}], Loss: {total_loss.item():.6f}")

    loss_list.append(total_loss.item())

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_loss_graph(loss_list=loss_list)

In [None]:
save_and_display_image(final_image, "exp4_small_gausian_init_and_tv_weight_4e-07.png")

In [None]:
cosine_dist = cosine(emb_00001, final_embedding)
cosine_dist

### Attempt 5: gaussian initialization with tv loss

In [None]:
image = get_initial_image_tensor_gaussian()
num_steps = 1000
tv_weight = 1e-6
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []

for i in tqdm(range(num_steps)):
    # Clear old gradients
    optimizer.zero_grad()

    normalized_image_input = torch.tanh(image * 2 - 1)
    current_embedding = model(normalized_image_input)

    loss_mse = loss_fn(current_embedding, target_embedding)
    loss_tv = get_tv_loss(image)
    total_loss = loss_mse + (tv_weight * loss_tv)

    total_loss.backward()

    optimizer.step()

    if (i + 1) % 100 == 0:
        print(f"Step [{i+1}/{num_steps}], Loss: {total_loss.item():.6f}")

    loss_list.append(total_loss.item())

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_loss_graph(loss_list=loss_list)

In [None]:
save_and_display_image(final_image, "exp5_gaussian_init_and_tv_weight_1e-06.png")

In [None]:
cosine_dist = cosine(emb_00001, final_embedding)
cosine_dist

### Attempt 6: man face initialization

In [None]:
image_pil = Image.open("man_face.png").convert("RGB")
image = transform(image_pil).unsqueeze(0).to(device).requires_grad_(True)

In [None]:
image_pil

In [None]:
target_embedding = torch.tensor(emb_00001, dtype=torch.float32).unsqueeze(0).to(device)

In [None]:
num_steps = 1000
tv_weight = 1e-7
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []

for i in tqdm(range(num_steps)):
    # Clear old gradients
    optimizer.zero_grad()

    normalized_image_input = torch.tanh(image * 2 - 1)
    current_embedding = model(normalized_image_input)

    loss_mse = loss_fn(current_embedding, target_embedding)
    loss_tv = get_tv_loss(image)
    total_loss = loss_mse + (tv_weight * loss_tv)

    total_loss.backward()

    optimizer.step()

    if (i + 1) % 100 == 0:
        print(f"Step [{i+1}/{num_steps}], Loss: {total_loss.item():.6f}")

    loss_list.append(total_loss.item())

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_loss_graph(loss_list=loss_list)

In [None]:
save_and_display_image(final_image, "exp6_man_face_init_and_tv_weight_1e-07_a.png")

In [None]:
cosine_dist = cosine(emb_00001, final_embedding)
cosine_dist

### Attempt 7: adding image jittering

In [None]:
class RandomJitter(nn.Module):
    def __init__(self, lim=30):
        super(RandomJitter, self).__init__()
        self.lim = lim

    def forward(self, img_tensor):
        B, C, H, W = img_tensor.shape
        padded = nn.functional.pad(img_tensor, (self.lim, self.lim, self.lim, self.lim), mode='reflect')

        sx = torch.randint(0, 2 * self.lim, (1,)).item()
        sy = torch.randint(0, 2 * self.lim, (1,)).item()

        jittered = padded[:, :, sy:sy+H, sx:sx+W]

        return jittered

In [None]:
image = get_initial_image_tensor_constant_and_gaussian()
num_steps = 1000
tv_weight = 4e-7
jitter = RandomJitter(lim=30)
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []

for i in tqdm(range(num_steps)):
    # Clear old gradients
    optimizer.zero_grad()

    normalized_image_input = torch.tanh(image * 2 - 1)
    jittered_image = jitter(normalized_image_input)
    current_embedding = model(jittered_image)

    loss_mse = loss_fn(current_embedding, target_embedding)
    loss_tv = get_tv_loss(image)
    total_loss = loss_mse + (tv_weight * loss_tv)

    total_loss.backward()

    optimizer.step()

    if (i + 1) % 100 == 0:
        print(f"Step [{i+1}/{num_steps}], Loss: {total_loss.item():.6f}")

    loss_list.append(total_loss.item())

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_loss_graph(loss_list=loss_list)

In [None]:
save_and_display_image(final_image, "exp7_tv_weight_1e-07_and_jittering.png")

In [None]:
cosine_dist = cosine(emb_00001, final_embedding)
cosine_dist

### Attempt 8: Decaying tv weight

In [None]:
image = get_initial_image_tensor_constant_and_gaussian()
num_steps = 1000
tv_weight_per_step = [1e-6] * 250 + [1e-7] * 250 + [1e-8] * 250 + [0] * 250
jitter = RandomJitter(lim=30)
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []

for i in tqdm(range(num_steps)):
    # Clear old gradients
    optimizer.zero_grad()

    normalized_image_input = torch.tanh(image * 2 - 1)
    jittered_image = jitter(normalized_image_input)
    current_embedding = model(jittered_image)

    loss_mse = loss_fn(current_embedding, target_embedding)
    loss_tv = get_tv_loss(image)
    tv_weight = tv_weight_per_step[i]
    total_loss = loss_mse + (tv_weight * loss_tv)

    total_loss.backward()

    optimizer.step()

    if (i + 1) % 100 == 0:
        print(f"Step [{i+1}/{num_steps}], Loss: {total_loss.item():.6f}")

    loss_list.append(total_loss.item())

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_loss_graph(loss_list=loss_list)

In [None]:
save_and_display_image(final_image, "exp8_decaying_tv_weight_1e-06_to_0_and_jittering.png")

In [None]:
cosine_dist = cosine(emb_00001, final_embedding)
cosine_dist

### Attempt 9: Adding Perceptual Loss

In [None]:
class VGGPerceptualLoss(nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__()

        # Load VGG16
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features

        # Slicing up to layer 16 (ReLU3_3) is standard.
        self.blocks = nn.Sequential(*list(vgg.children())[:16]).eval()

        # Freeze the model weights
        for param in self.blocks.parameters():
            param.requires_grad = False

        # VGG specific normalization
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(device)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(device)
        self.resize = resize

    def forward(self, generated_img, target_img):
        # Assuming the images are in [0, 1] range:
        gen_norm = (generated_img - self.mean) / self.std
        target_norm = (target_img - self.mean) / self.std

        # Extract features
        gen_features = self.blocks(gen_norm)
        target_features = self.blocks(target_norm)

        # Calculate L2 loss between the feature maps
        loss = torch.nn.functional.mse_loss(gen_features, target_features)
        return loss

In [None]:
man_face_pil = Image.open("man_face.png").convert("RGB")
man_face = transform(image_pil).unsqueeze(0).to(device)

In [None]:
image = get_initial_image_tensor_constant_and_gaussian()
num_steps = 1000
tv_weight = 4e-7
jitter = RandomJitter(lim=30)
perceptual_criterion = VGGPerceptualLoss().to(device)
perceptual_target_image = man_face
perceptual_weight = 1e-3
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
perceptual_list = []

for i in tqdm(range(num_steps)):
    # Clear old gradients
    optimizer.zero_grad()

    normalized_image_input = torch.tanh(image * 2 - 1)
    jittered_image = jitter(normalized_image_input)
    current_embedding = model(jittered_image)

    loss_mse = loss_fn(current_embedding, target_embedding)
    loss_tv = get_tv_loss(image)
    loss_perceptual = perceptual_criterion(image, perceptual_target_image)
    total_loss = loss_mse + (tv_weight * loss_tv) + (perceptual_weight * loss_perceptual)

    perceptual_with_real = perceptual_criterion(image, x_00001)

    total_loss.backward()

    optimizer.step()

    if (i + 1) % 100 == 0:
        print(f"Step [{i+1}/{num_steps}], Loss: {total_loss.item():.6f}, Perc: {perceptual_with_real.item():.6f}")

    loss_list.append(total_loss.item())
    perceptual_list.append(perceptual_with_real.item())

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_loss_graph(loss_list=loss_list)

In [None]:
display_loss_graph(loss_list=perceptual_list)

In [None]:
save_and_display_image(final_image, "exp9_tv_weight_4e-07_and_jittering_and_perceptual_with_man_face.png")

In [None]:
cosine_dist = cosine(emb_00001, final_embedding)
cosine_dist

### Attempt 10: Adding Perceptual Style Loss

In [None]:
class VGGStyleLoss(nn.Module):
    def __init__(self):
        super(VGGStyleLoss, self).__init__()
        # Load VGG
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features

        # List of sub-models ending at different depths
        self.slice1 = nn.Sequential(*list(vgg.children())[:2])   # Low level (colors)
        self.slice2 = nn.Sequential(*list(vgg.children())[:7])   # Textures
        self.slice3 = nn.Sequential(*list(vgg.children())[:12])  # Shapes
        self.slice4 = nn.Sequential(*list(vgg.children())[:21])  # Deep features

        # Freeze model
        for p in self.parameters():
            p.requires_grad = False

        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(device)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(device)

    @ staticmethod
    def gram_matrix(input_tensor):
        batch, channels, height, width = input_tensor.size()

        # Reshape so we can multiply features
        features = input_tensor.view(batch * channels, height * width)

        # Calculate the dot product
        G = torch.mm(features, features.t())

        # Normalize by the number of elements to keep values small
        return G.div(batch * channels * height * width)

    def forward(self, generated_img, guide_img):
        gen = (generated_img - self.mean) / self.std
        guide = (guide_img - self.mean) / self.std

        loss = 0
        # Pass through each slice
        for slice_net in [self.slice1, self.slice2, self.slice3, self.slice4]:
            gen_feat = slice_net(gen)
            guide_feat = slice_net(guide)

            # Compare Gram Matrices
            gen_gram = VGGStyleLoss.gram_matrix(gen_feat)
            guide_gram = VGGStyleLoss.gram_matrix(guide_feat)

            loss += torch.nn.functional.mse_loss(gen_gram, guide_gram)

        return loss

In [None]:
man_face_pil = Image.open("man_face.png").convert("RGB")
man_face = transform(image_pil).unsqueeze(0).to(device)

In [None]:
img_00002

In [None]:
image = x_00002.requires_grad_(True)#get_initial_image_tensor_constant_and_gaussian()
num_steps = 1000
tv_weight = 4e-7
jitter = RandomJitter(lim=30)
perceptual_criterion = VGGPerceptualLoss().to(device)
style_criterion = VGGStyleLoss().to(device)
perceptual_target_image = man_face
perceptual_weight = 1e2
optimizer = optim.Adam([image], lr=0.01)
loss_fn = nn.MSELoss()

In [None]:
loss_list = []
perceptual_list = []

for i in tqdm(range(num_steps)):
    # Clear old gradients
    optimizer.zero_grad()

    normalized_image_input = torch.tanh(image * 2 - 1)
    jittered_image = normalized_image_input#jitter(normalized_image_input)
    current_embedding = model(jittered_image)

    loss_mse = loss_fn(current_embedding, target_embedding)
    loss_tv = get_tv_loss(image)
    loss_style = style_criterion(image, perceptual_target_image)
    total_loss = loss_mse + (tv_weight * loss_tv) + (perceptual_weight * loss_style)
    if i == 0:
        # print the components of total_loss in one line, including the components' weights:
        print(

    perceptual_with_real = perceptual_criterion(image, x_00001)

    total_loss.backward()

    optimizer.step()

    if (i + 1) % 100 == 0:
        print(f"Step [{i+1}/{num_steps}], Loss: {total_loss.item():.6f}, Perc: {perceptual_with_real.item():.6f}")

    loss_list.append(total_loss.item())
    perceptual_list.append(perceptual_with_real.item())

final_image = image.detach().cpu().squeeze(0)
final_embedding = current_embedding.detach().cpu().squeeze(0)

In [None]:
display_loss_graph(loss_list=loss_list)

In [None]:
display_loss_graph(loss_list=perceptual_list)

In [None]:
save_and_display_image(final_image, "exp10_tv_weight_4e-07_and_jittering_and_style_with_man_face.png")

In [None]:
cosine_dist = cosine(emb_00001, final_embedding)
cosine_dist