In [1]:
pip install torch==2.4.1+cu121



In [2]:
pip install torchaudio==2.2.2

Collecting torchaudio==2.2.2
  Downloading torchaudio-2.2.2-cp310-cp310-manylinux1_x86_64.whl.metadata (6.4 kB)
Collecting torch==2.2.2 (from torchaudio==2.2.2)
  Downloading torch-2.2.2-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.2.2->torchaudio==2.2.2)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.2.2->torchaudio==2.2.2)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.2.2->torchaudio==2.2.2)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch==2.2.2->torchaudio==2.2.2)
  Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.2

In [1]:
conda install pytorch3d

Channels:
 - defaults
 - conda-forge
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: done

# All requested packages already installed.


Note: you may need to restart the kernel to use updated packages.


In [3]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from facenet_pytorch import MTCNN
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras,
    PointLights,
    DirectionalLights,
    Materials,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    TexturesVertex
)
from pytorch3d.io import load_obj, save_obj
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import (
    chamfer_distance,
    mesh_edge_loss,
    mesh_laplacian_smoothing,
    mesh_normal_consistency,
)

In [None]:
# 1. Implement a real 3D face reconstruction model

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 4, 2, 1)
        self.conv2 = nn.Conv2d(32, 64, 4, 2, 1)
        self.conv3 = nn.Conv2d(64, 128, 4, 2, 1)
        self.conv4 = nn.Conv2d(128, 256, 4, 2, 1)
        self.fc = nn.Linear(256 * 16 * 16, 1024)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.relu(self.conv4(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class Decoder(nn.Module):
    def __init__(self, num_vertices):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(1024, 1024)
        self.fc2 = nn.Linear(1024, num_vertices * 3)  # 3D coordinates
        self.fc3 = nn.Linear(1024, num_vertices * 3)  # RGB colors

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        vertices = self.fc2(x).view(-1, 3)
        colors = torch.sigmoid(self.fc3(x)).view(-1, 3)
        return vertices, colors

class FaceReconstructionModel(nn.Module):
    def __init__(self, num_vertices):
        super(FaceReconstructionModel, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder(num_vertices)

    def forward(self, x):
        latent = self.encoder(x)
        vertices, colors = self.decoder(latent)
        return vertices, colors

# Face detection for preprocessing
mtcnn = MTCNN(keep_all=True, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

# Custom dataset
class FaceDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Detect face
        boxes, _ = mtcnn.detect(image)
        if boxes is not None:
            box = boxes[0]
            face = image[int(box[1]):int(box[3]), int(box[0]):int(box[2])]
        else:
            face = image  # If no face detected, use the whole image

        face = cv2.resize(face, (256, 256))

        if self.transform:
            face = self.transform(face)

        return face

# Data loading and preprocessing
def load_and_preprocess_images(image_paths):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    dataset = FaceDataset(image_paths, transform=transform)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
    return dataloader

In [None]:
# 2. Define face connectivity
def create_face_mesh(vertices, colors, resolution=100):
    u = np.linspace(0, 1, resolution)
    v = np.linspace(0, 1, resolution)
    u, v = np.meshgrid(u, v)
    u = u.flatten()
    v = v.flatten()

    x = vertices[:, 0].reshape((resolution, resolution))
    y = vertices[:, 1].reshape((resolution, resolution))
    z = vertices[:, 2].reshape((resolution, resolution))

    faces = []
    for i in range(resolution-1):
        for j in range(resolution-1):
            faces.append([i*resolution+j, (i+1)*resolution+j, i*resolution+j+1])
            faces.append([(i+1)*resolution+j, (i+1)*resolution+j+1, i*resolution+j+1])

    return vertices, faces, colors

In [None]:
# 3. Fine-tune rendering parameters
class FaceRenderer:
    def __init__(self, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
        self.device = device

    def setup_renderer(self, image_size=512):
        R, T = look_at_view_transform(2.7, 0, 0)
        cameras = FoVPerspectiveCameras(device=self.device, R=R, T=T)
        raster_settings = RasterizationSettings(
            image_size=image_size,
            blur_radius=0.0,
            faces_per_pixel=1,
        )
        lights = PointLights(device=self.device, location=[[0.0, 0.0, -3.0]])
        materials = Materials(
            device=self.device,
            specular_color=[[0.2, 0.2, 0.2]],
            shininess=32
        )
        renderer = MeshRenderer(
            rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
            shader=SoftPhongShader(device=self.device, cameras=cameras, lights=lights, materials=materials)
        )
        return renderer

    def render_views(self, vertices, faces, colors, num_views=8):
        renderer = self.setup_renderer()
        vertices = vertices.to(self.device)
        faces = torch.tensor(faces, dtype=torch.int64, device=self.device)
        colors = colors.to(self.device)
        textures = TexturesVertex(colors.unsqueeze(0))
        mesh = Meshes(verts=[vertices], faces=[faces], textures=textures)

        images = []
        for i in range(num_views):
            angle = i * (360 / num_views)
            R, T = look_at_view_transform(2.7, 0, angle)
            cameras = FoVPerspectiveCameras(device=self.device, R=R, T=T)
            image = renderer(mesh, cameras=cameras)
            images.append(image[0, ..., :3].cpu().numpy())

        return images

In [None]:
# 4. Implement error handling and validation
def validate_input(image_paths):
    valid_paths = []
    for path in image_paths:
        if not os.path.exists(path):
            print(f"Warning: Image file not found: {path}")
        elif not path.lower().endswith(('.png', '.jpg', '.jpeg')):
            print(f"Warning: Unsupported file format: {path}")
        else:
            valid_paths.append(path)
    return valid_paths

In [None]:
# 5. Optimize for performance
@torch.no_grad()
def reconstruct_face(model, image):
    model.eval()
    vertices, colors = model(image.unsqueeze(0))
    return vertices.squeeze(0), colors.squeeze(0)

In [None]:

# Main pipeline
def face_reconstruction_pipeline(image_paths, model_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Validate input
    valid_image_paths = validate_input(image_paths)
    if not valid_image_paths:
        raise ValueError("No valid image files provided.")

    # Load and preprocess images
    dataloader = load_and_preprocess_images(valid_image_paths)

    # Load or train the model
    num_vertices = 10000  # Adjust based on your needs
    model = FaceReconstructionModel(num_vertices).to(device)

    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path))
        print("Loaded pre-trained model.")
    else:
        print("Pre-trained model not found. Training a new model...")
        # Here you would implement the training loop
        # For brevity, we'll skip the training implementation
        torch.save(model.state_dict(), model_path)
        print("Model trained and saved.")

    # Reconstruct 3D face
    all_vertices = []
    all_colors = []
    for batch in dataloader:
        image = batch.to(device)
        try:
            vertices, colors = reconstruct_face(model, image)
            all_vertices.append(vertices)
            all_colors.append(colors)
        except RuntimeError as e:
            print(f"Error during face reconstruction: {str(e)}")
            continue

    if not all_vertices:
        raise ValueError("Face reconstruction failed for all images.")

    # Average the results if multiple images were provided
    final_vertices = torch.stack(all_vertices).mean(dim=0)
    final_colors = torch.stack(all_colors).mean(dim=0)

    # Create 3D mesh
    vertices, faces, colors = create_face_mesh(final_vertices.cpu().numpy(), final_colors.cpu().numpy())

    # Render views
    renderer = FaceRenderer(device)
    rendered_images = renderer.render_views(torch.tensor(vertices, dtype=torch.float32),
                                            faces,
                                            torch.tensor(colors, dtype=torch.float32))

    return vertices, faces, colors, rendered_images

# Usage
if __name__ == "__main__":
    image_paths = ["path/to/image1.jpg", "path/to/image2.jpg", "path/to/image3.jpg"]
    model_path = "path/to/face_reconstruction_model.pth"

    try:
        vertices, faces, colors, rendered_images = face_reconstruction_pipeline(image_paths, model_path)

        # Save results
        save_obj("reconstructed_face.obj",
                 torch.tensor(vertices, dtype=torch.float32),
                 torch.tensor(faces, dtype=torch.int64),
                 verts_uvs=None,
                 faces_uvs=None,
                 texture_map=torch.tensor(colors, dtype=torch.float32))

        for i, img in enumerate(rendered_images):
            cv2.imwrite(f"rendered_view_{i}.png", cv2.cvtColor((img * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))

        print("Face reconstruction and rendering complete!")
    except Exception as e:
        print(f"An error occurred: {str(e)}")

Using device: cpu
An error occurred: No valid image files provided.
