In [261]:
import cv2
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

from dataset import FacialLandmarkDataset

In [262]:
basic_augmentations = A.Compose([
    A.Resize(224, 224),
    ToTensorV2(),
], keypoint_params=A.KeypointParams(format='xy'))

In [263]:
dataset = FacialLandmarkDataset(root_dir='archive/ibug_300W_large_face_landmark_dataset/afw', transform=basic_augmentations)

In [264]:
# Split the dataset (80% training, 20% testing)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create data loaders for both sets
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
dataset.__len__()

In [None]:
img, landmarks = dataset[10]

if isinstance(img, torch.Tensor):
    img = img.permute(1, 2, 0).numpy()  # Convert from (C, H, W) to (H, W, C)

print(len(landmarks))

# Convert to NumPy array for plotting
img_np = np.array(img)

# Plot image
plt.imshow(img_np)
plt.scatter(landmarks[:, 0], landmarks[:, 1], c='r', marker='x')

In [267]:
def check_validity(tensor, name="tensor"):
    if not torch.isfinite(tensor).all():
        print(f"{name} contains NaN or Inf values")
        print("Min:", tensor.min().item(), "Max:", tensor.max().item())
    #else:
        #print(f"{name} is valid")

In [268]:
for idx in range(dataset.__len__()):
    img, landmarks = dataset[idx]
    landmarks_np = landmarks.numpy()
    
    if landmarks_np.shape[0] != 68:
        print(len(landmarks))
        print(f"Face at index {idx} does not have exactly 68 landmarks. Found: {landmarks_np.shape[0]}")

In [269]:
def fiducial_focus_augmentation(image_tensor, landmarks, n, epoch, total_epochs):
    augmented_images = []
    patch_size = max(1, n - (epoch * n // total_epochs))  # Gradually reduce patch size

    for i in range(image_tensor.size(0)):  # Iterate over batch
        image = image_tensor[i].permute(1, 2, 0).cpu().numpy()  # Convert tensor to NumPy
        augmented_image = image.copy()
        image_landmarks = landmarks[i]

        for (x, y) in image_landmarks:
            x, y = int(x), int(y)
            top_left = (max(0, x - patch_size // 2), max(0, y - patch_size // 2))
            bottom_right = (min(image.shape[1], x + patch_size // 2), min(image.shape[0], y + patch_size // 2))

            # Apply black patch
            augmented_image[top_left[1]:bottom_right[1], top_left[0]:bottom_right[0]] = (0, 0, 0)

        augmented_images.append(torch.from_numpy(augmented_image).permute(2, 0, 1))

    return torch.stack(augmented_images)

In [270]:
# BlurPool Layer for anti-aliasing
class BlurPool(nn.Module):
    def __init__(self, channels, stride=1):
        super(BlurPool, self).__init__()
        self.stride = stride
        self.pad = nn.ReflectionPad2d(1)
        self.blur_kernel = torch.tensor([[1., 2., 1.],
                                         [2., 4., 2.],
                                         [1., 2., 1.]])
        self.blur_kernel = self.blur_kernel / self.blur_kernel.sum()
        self.blur_kernel = self.blur_kernel.expand(channels, 1, 3, 3)

    def forward(self, x):
        x = self.pad(x)
        return F.conv2d(x, self.blur_kernel.to(x.device), stride=self.stride, groups=x.shape[1])

In [271]:
# Implement FF-Parser Layer to filter high-frequency noise
class FFParser(nn.Module):
    def __init__(self):
        super(FFParser, self).__init__()

    def forward(self, x):
        # Apply FFT to transform to frequency domain
        x_freq = torch.fft.fft2(x, dim=(-2, -1))
        # Apply a low-pass filter (keeping low-frequency components)
        x_filtered = torch.fft.ifft2(x_freq, dim=(-2, -1)).real
        return x_filtered

In [272]:
# Hourglass module with anti-aliasing and noise reduction
class HourglassModule(nn.Module):
    def __init__(self):
        super(HourglassModule, self).__init__()
        self.conv1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.anti_alias = BlurPool(128, stride=2)
        self.conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.ff_parser = FFParser()  # Add FF-Parser layer for noise reduction

    def forward(self, x):
        x = self.conv1(x)
        x = self.anti_alias(x)
        x = self.conv2(x)
        x = self.ff_parser(x)  # Apply FF-Parser to filter out noise
        return x

In [273]:
# Facial Landmark Model
class FacialLandmarkModel(nn.Module):
    def __init__(self, num_landmarks):
        super(FacialLandmarkModel, self).__init__()
        self.vit_backbone = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
        self.vit_backbone.head = nn.Identity()  # Remove classification head
        self.reshape = nn.Linear(768, 256 * 14 * 14)
        self.hourglass = HourglassModule()
        self.fc = nn.Conv2d(128, num_landmarks, kernel_size=1)  # Predict heatmaps

    def forward(self, x):
        x = torch.clamp(x, min=0.0, max=1.0)  # Clamp input values
        features = self.vit_backbone(x)
        features = self.reshape(features).view(-1, 256, 14, 14)
        features = self.hourglass(features)
        heatmaps = self.fc(features)
        return heatmaps

In [274]:
def local_soft_argmax(heatmap, temperature=1.0):
    softmax = F.softmax(heatmap.view(-1) / temperature, dim=0)  # Flatten heatmap and apply softmax
    height, width = heatmap.shape

    # Create a meshgrid of coordinates corresponding to the heatmap size
    y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width), indexing='ij')

    # Normalize the softmax-weighted sum to get coordinates in the heatmap space
    x_coord = (x_grid.float() * softmax.view(height, width)).sum()
    y_coord = (y_grid.float() * softmax.view(height, width)).sum()

    return x_coord, y_coord


In [275]:
def extract_landmark_coords_from_heatmaps(heatmaps, image_size=(224, 224)):
    batch_size, num_landmarks, height, width = heatmaps.size()
    coords = torch.zeros(batch_size, num_landmarks, 2, dtype=torch.float32)
    
    for b in range(batch_size):
        for l in range(num_landmarks):
            heatmap = heatmaps[b, l]
            if heatmap.sum() > 0:
                heatmap = heatmap / heatmap.sum()  # Normalize the heatmap
            else:
                heatmap = heatmap

            # Extract the coordinates from the heatmap using local_soft_argmax
            x, y = local_soft_argmax(heatmap)

            # Scale coordinates back to the original image resolution
            coords[b, l, 0] = torch.clamp(x * image_size[1] / width, min=0, max=image_size[1] - 1)
            coords[b, l, 1] = torch.clamp(y * image_size[0] / height, min=0, max=image_size[0] - 1)

    
    return coords


In [276]:
# DCCA Loss Function for consistency
class DCCALoss(nn.Module):
    def __init__(self):
        super(DCCALoss, self).__init__()

    def forward(self, H1, H2):
        H1 = H1 - H1.mean(dim=0, keepdim=True)
        H2 = H2 - H2.mean(dim=0, keepdim=True)
        H1H2_cov = (H1.T @ H2) / (H1.size(0) - 1)
        H1_var = (H1.T @ H1) / (H1.size(0) - 1)
        H2_var = (H2.T @ H2) / (H2.size(0) - 1)
        H1_var_sqrt = torch.sqrt(torch.trace(H1_var))
        H2_var_sqrt = torch.sqrt(torch.trace(H2_var))
        corr = torch.trace(H1H2_cov) / (H1_var_sqrt * H2_var_sqrt)
        return -corr

In [277]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FacialLandmarkModel(num_landmarks=68).to(device)

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
dcca_loss = DCCALoss().to(device)

In [None]:
num_epochs = 250  # Increase the number of epochs

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for batch in train_loader:
        images, landmarks = batch
        images, landmarks = images.to(device), landmarks.to(device)

        # Apply fiducial focus augmentation
        fiducial_focus_images = fiducial_focus_augmentation(images, landmarks, 8, epoch, num_epochs)

        # Forward pass for original and transformed images
        outputs_original = model(images)
        outputs_transformed = model(fiducial_focus_images)

        # Extract coordinates from heatmaps
        coords_original = extract_landmark_coords_from_heatmaps(outputs_original)
        coords_transformed = extract_landmark_coords_from_heatmaps(outputs_transformed)

        # Flatten the coordinates for DCCA loss
        coords_original = coords_original.view(coords_original.size(0), -1)
        coords_transformed = coords_transformed.view(coords_transformed.size(0), -1)

        # DCCA Loss (consistency loss between original and transformed outputs)
        loss_dcca = dcca_loss(coords_original, coords_transformed)

        # Calculate the MSE loss with ground truth landmarks
        ground_truth_loss = F.mse_loss(coords_original.view(-1, 68, 2), landmarks)

        # Total loss: weighted sum of DCCA and ground truth MSE losses
        loss = 0.5 * loss_dcca + 0.5 * ground_truth_loss

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

In [None]:
model_save_path = 'best_model.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

In [256]:
# NME calculation
def calculate_nme(preds, labels, normalization_factor):
    """
    Calculate Normalized Mean Error (NME)
    preds: Tensor of shape (batch_size, num_landmarks, 2) - predicted landmarks
    labels: Tensor of shape (batch_size, num_landmarks, 2) - ground truth landmarks
    normalization_factor: Tensor or value representing the normalization factor (e.g., inter-ocular distance)
    """
    batch_size, num_landmarks, _ = preds.shape
    error = torch.norm(preds - labels, dim=-1)  # Euclidean distance for each landmark
    nme = error.mean(dim=1) / normalization_factor  # Mean error per image, normalized
    return nme.mean()  # Return the average NME across the batch

In [257]:
# Evaluation function with MSE and NME
def evaluate_model(model, test_loader, normalization_factors):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0.0
    total_nme = 0.0
    criterion = nn.MSELoss()  # Using MSE loss for evaluation

    with torch.no_grad():  # No need to compute gradients during evaluation
        for batch_idx, (images, landmarks) in enumerate(test_loader):
            images, landmarks = images.to(device), landmarks.to(device)
            outputs = model(images)
            coords_pred = extract_landmark_coords_from_heatmaps(outputs)

            # Calculate MSE Loss
            loss = criterion(coords_pred, landmarks)
            total_loss += loss.item()

            # Calculate NME
            normalization_factor = normalization_factors[batch_idx]
            nme_batch = calculate_nme(coords_pred, landmarks, normalization_factor)
            total_nme += nme_batch.item()

    avg_test_loss = total_loss / len(test_loader)
    avg_nme = total_nme / len(test_loader)

    print(f"Test Loss (MSE): {avg_test_loss:.4f}")
    print(f"Test NME: {avg_nme:.4f}")
    return avg_test_loss, avg_nme

In [258]:
# Plot predictions function with better visualization
def plot_predictions(model, test_loader, image_size=(224, 224), num_images=5):
    model.eval()
    images_shown = 0
    with torch.no_grad():
        for images, landmarks in test_loader:
            images = images.to(device)
            outputs = model(images)
            coords_pred = extract_landmark_coords_from_heatmaps(outputs, image_size).cpu().numpy()
            images_np = images.permute(0, 2, 3, 1).cpu().numpy()  # Convert to HWC format for plotting

            # Plot a few images with their predicted landmarks
            for i in range(len(images)):
                plt.imshow(images_np[i])
                plt.scatter(coords_pred[i][:, 0], coords_pred[i][:, 1], c='r', marker='x', label='Predicted')
                plt.scatter(landmarks[i][:, 0], landmarks[i][:, 1], c='g', marker='o', label='Ground Truth')
                plt.title(f"Predicted vs Ground Truth Landmarks")
                plt.legend()
                plt.show()

                images_shown += 1
                if images_shown >= num_images:
                    return  # Stop after showing the specified number of images


In [259]:
# Sample function to get normalization factors (e.g., inter-ocular distance)
def get_normalization_factors(landmarks_batch):
    """
    Compute the inter-ocular distance or any other normalization factor
    landmarks_batch: Tensor of shape (batch_size, num_landmarks, 2)
    Return: Tensor of shape (batch_size,) representing normalization factor for each image
    """
    left_eye = landmarks_batch[:, [36], :]  # Coordinates for left eye
    right_eye = landmarks_batch[:, [45], :]  # Coordinates for right eye
    normalization_factors = torch.norm(left_eye - right_eye, dim=-1)  # Inter-ocular distance
    return normalization_factors

In [None]:
# Before calling evaluation functions, ensure normalization factors are available for each batch
normalization_factors = []

for images, landmarks in test_loader:
    norm_factors = get_normalization_factors(landmarks)
    normalization_factors.append(norm_factors)

# Evaluate and plot predictions after training
avg_test_loss, avg_nme = evaluate_model(model, test_loader, normalization_factors)
plot_predictions(model, test_loader, num_images=5)