In [11]:
import os
import cv2
import torch
import numpy as np
from tqdm import tqdm
from model import Dynamic_sparse_alignment_network

from Config.default import _C as cfg

In [12]:

# Preprocessing the image
def preprocess_image(image):
    resized_image = cv2.resize(image, (256, 256))
    normalized_image = resized_image / 255.0
    tensor_image = torch.tensor(normalized_image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
    return tensor_image

# Extract 68 landmarks using the provided mapping
def filter_68_landmarks(landmarks):
    mapping_indices = [
        0, 2, 5, 4, 7, 9, 12, 14, 16, 18, 20, 22, 25, 27, 29, 31, 32,
        33, 34, 35, 36, 37, 42, 43, 44, 45, 46, 51, 52, 53, 54, 55,
        56, 57, 58, 59, 60, 61, 63, 64, 65, 67, 68, 69, 71, 72, 73,
        75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
        90, 91, 92, 93, 94, 95
    ]
    final_landmarks = landmarks[-1]  # Select the last stage
    return final_landmarks[mapping_indices, :]

# Plot the 68 landmarks on the image
def draw_landmarks(image, landmarks):
    h, w, _ = image.shape
    diagonal = (w ** 2 + h ** 2) ** 0.5
    radius = max(8, min(int(diagonal * 0.01), 16))  # Radius proportional to image diagonal
    thickness = -1  # Filled circle

    for (x, y) in landmarks:
        x = int(x * w)
        y = int(y * h)
        cv2.circle(image, (x, y), radius, (0, 255, 0), thickness)
    return image

# Save landmarks to a .txt file
def save_landmarks_to_txt(landmarks, output_txt_path, image_width, image_height):
    landmarks[:, 0] *= image_width
    landmarks[:, 1] *= image_height
    landmarks = landmarks.flatten().astype(int)
    with open(output_txt_path, 'w') as f:
        f.write(" ".join(map(str, landmarks)))  # Save in a single line

# Process a single image and save results
def process_image(image_path, output_image_dir, output_label_dir, model, device):
    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError(f"Could not load image: {image_path}")

    h, w, _ = image.shape
    input_tensor = preprocess_image(image).to(device)

    with torch.inference_mode():
        output_list, _, _, _ = model(input_tensor)
        landmarks = output_list[-1].squeeze().cpu().numpy()
        landmarks_68 = filter_68_landmarks(landmarks)

    # Save the plotted image
    output_image_path = os.path.join(output_image_dir, os.path.basename(image_path))
    plotted_image = draw_landmarks(image.copy(), landmarks_68)
    cv2.imwrite(output_image_path, plotted_image)

    # Save the landmarks in .txt format
    output_label_path = os.path.join(output_label_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}.txt")
    save_landmarks_to_txt(landmarks_68, output_label_path, w, h)

    return True

# Process a directory of images
def process_directory(input_dir, output_dir, model, device):
    # Create output directories
    image_output_dir = os.path.join(output_dir, 'images')
    label_output_dir = os.path.join(output_dir, 'labels')
    os.makedirs(image_output_dir, exist_ok=True)
    os.makedirs(label_output_dir, exist_ok=True)

    images = [f for f in os.listdir(input_dir) if f.lower().endswith(('.jpg', '.png'))]
    processed_count = 0
    failed_images = []

    print("\nProcessing images...")
    for image_name in tqdm(images, desc="Processing"):
        image_path = os.path.join(input_dir, image_name)
        try:
            process_image(image_path, image_output_dir, label_output_dir, model, device)
            processed_count += 1
        except Exception as e:
            print(f"Failed to process {image_name}: {e}")
            failed_images.append(image_name)

    # Summary
    print("\n=== Processing Summary ===")
    print(f"Total images processed: {processed_count}")
    print(f"Total images failed: {len(failed_images)}")
    if failed_images:
        print("Failed images:")
        for img in failed_images:
            print(f"  {img}")


In [13]:
input_directory = '/home/jocareher/Documents/baby_face_72/images'
output_directory = '/home/jocareher/Documents/results_dslpt'

# Load model
model_path = '/home/jocareher/Downloads/DSLPT_WFLW_6_layers.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Dynamic_sparse_alignment_network(num_point=98, d_model=256, trainable=False,
                                            return_interm_layers=False, nhead=8,
                                            feedforward_dim=1024, initial_path='/home/jocareher/Downloads/DSLPT/Config/init_98.npz',
                                            cfg=cfg)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

# Process directory
process_directory(input_directory, output_directory, model, device)





  model.load_state_dict(torch.load(model_path, map_location=device))



Processing images...


Processing: 100%|██████████| 311/311 [02:27<00:00,  2.11it/s]


=== Processing Summary ===
Total images processed: 311
Total images failed: 0



