# Imports

In [6]:
!brew link --overwrite cmake

Linking /usr/local/Cellar/cmake/3.31.0... 15 symlinks created.


In [7]:
%pip install dlib

Collecting dlib
  Using cached dlib-19.24.6.tar.gz (3.4 MB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: dlib
  Building wheel for dlib (setup.py) ... [?25ldone
[?25h  Created wheel for dlib: filename=dlib-19.24.6-cp311-cp311-macosx_15_0_x86_64.whl size=3486359 sha256=cc5e1573f2b4dcbca26c43997a4759773e645525b8b5a5bf4f077189c456bdd1
  Stored in directory: /Users/adriandaschlein/Library/Caches/pip/wheels/fe/c7/1f/c778b9f7cc6d8d0da4f6697f619f9eb5a49d54d2a2c8267f3c
Successfully built dlib
Installing collected packages: dlib
Successfully installed dlib-19.24.6

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import dlib
import cv2

# Data Pre-Processing
encapsulated for now

!FIX remember to implement general solutions alignment helper
## Steps
- Alignment

In [9]:
def align_face(image_path, ref_landmarks):
    detector = dlib.get_frontal_face_detector()
    predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
    
    img = cv2.imread(image_path)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    faces = detector(gray)
    
    if len(faces) == 0:
        raise ValueError("No face detected!")
    
    for face in faces:
        landmarks = predictor(gray, face)
        aligned_face = cv2.warpAffine(
            img, 
            cv2.getAffineTransform(
                np.float32([landmarks.part(i).x, landmarks.part(i).y] for i in [36, 45, 33]),
                ref_landmarks
            ), 
            (128, 128)
        )
    return aligned_face


# Transformation Network
2 Steps no? 
- Lighting loss 128x128
- Content loss | Style loss

In [53]:
class TransformationNetwork(nn.Module):
    def __init__(self):
        super(TransformationNetwork, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=9, stride=1, padding=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        
        self.residual = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),  # Replace in-place ReLU
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),  # Replace in-place ReLU
        )
        
        self.upsample1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.upsample2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.output = nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4)
    
    def forward(self, x):
        x = self.conv1(x)
        x = nn.ReLU(inplace=False)(x)
        x = self.conv2(x)
        x = nn.ReLU(inplace=False)(x)
        x = self.conv3(x)
        for _ in range(5):  # 5 residual blocks
            x = x + self.residual(x)
        x = nn.ReLU(inplace=False)(self.upsample1(x))
        x = nn.ReLU(inplace=False)(self.upsample2(x))
        return self.output(x)


# Loss functions

In [52]:
class VGGFeatures(nn.Module):
    def __init__(self, layer_ids):
        super(VGGFeatures, self).__init__()
        vgg = models.vgg19(pretrained=True).features
        self.layers = nn.ModuleList([vgg[i] for i in layer_ids])
        for i, layer in enumerate(self.layers):
            if isinstance(layer, nn.ReLU):
                self.layers[i] = nn.ReLU(inplace=False)  # Replace in-place ReLU

    def forward(self, x):
        features = []
        for layer in self.layers:
            x = layer(x)
            features.append(x)
        return features


def content_loss(generated, content):
    return nn.MSELoss()(generated, content)

def style_loss(generated, styles):
    loss = 0
    for g, s in zip(generated, styles):
        # Ensure tensors are correctly handled
        if isinstance(s, list):
            s = torch.stack(s)  # Convert list of tensors into a single tensor

        # Repeat style features to match the batch size of generated features
        if s.size(0) != g.size(0):
            repeat_times = g.size(0) // s.size(0)  # Calculate how many times to repeat
            remainder = g.size(0) % s.size(0)  # Handle remainder if not divisible
            s = torch.cat([s.repeat(repeat_times, 1, 1, 1), s[:remainder]], dim=0)
        
        # Reshape tensors to [batch_size, num_channels, width*height]
        g = g.view(g.size(0), g.size(1), -1)
        s = s.view(s.size(0), s.size(1), -1)
        
        # Compute Gram matrices
        G = torch.bmm(g, g.transpose(1, 2))  # Compute Gram matrix for generated features
        S = torch.bmm(s, s.transpose(1, 2))  # Compute Gram matrix for style features
        print(f"G shape: {G.shape}, S shape: {S.shape}")  # Debugging shapes
        
        # Compute MSE loss
        loss += nn.MSELoss()(G, S)
        
    return loss







# Dataloader

In [54]:
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

class CelebA256Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith('.png')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        content_image = Image.open(img_name).convert("RGB")
        
        # Example of selecting a random style image
        style_idx = (idx + 1) % len(self.image_files)  # Just for example
        style_name = os.path.join(self.root_dir, self.image_files[style_idx])
        style_image = Image.open(style_name).convert("RGB")
        
        if self.transform:
            content_image = self.transform(content_image)
            style_image = self.transform(style_image)
        
        return content_image, style_image


# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

# Initialize dataset and dataloader
dataset = CelebA256Dataset(root_dir='../data/celebAhq256small', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)


In [55]:
# data loader test
print(dataloader)


<torch.utils.data.dataloader.DataLoader object at 0x1519b1650>


# Train

In [56]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model and optimizer
model = TransformationNetwork().to(device)
vgg = VGGFeatures(layer_ids=[3, 8, 15]).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 3

torch.autograd.set_detect_anomaly(True)

# Training loop
for epoch in range(num_epochs):
    for content_img, style_imgs in dataloader:
        content_img = content_img.to(device)
        style_imgs = style_imgs.to(device)
        
        optimizer.zero_grad()
        
        # Generate output and compute features
        output = model(content_img)
        content_features = vgg(content_img)
        output_features = vgg(output)
        
        # Extract style features for all style images
        style_features = [vgg(style_img) for style_img in style_imgs]
        
        # Compute losses
        c_loss = content_loss(output_features[2], content_features[2])
        s_loss = style_loss(output_features, style_features)
        
        loss = c_loss + 10 * s_loss
        loss.backward()
        optimizer.step()
        
    print(f"Epoch {epoch}, Loss: {loss.item()}")




G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 3, 3]), S shape: torch.Size([32, 3, 3])
G shape: torch.Size([32, 

KeyboardInterrupt: 

# Blend
thought like that?

In [None]:
def blend_images(content_image, generated_face, mask):
    return cv2.seamlessClone(generated_face, content_image, mask, (128, 128), cv2.NORMAL_CLONE)
