# üß¨ Facial AI Platform ‚Äî DECA + FLAME 2023 Reconstruction

This notebook reconstructs a 3D FLAME mesh + photorealistic texture from your photos using **FLAME 2023** (latest model with revised eye region and improved expressions).

**What you need:**
1. 1-3 face photos (front required, left 45¬∞ and right 45¬∞ optional)
2. FLAME 2023 model files (download from https://flame.is.tue.mpg.de/):
   - `generic_model.pkl` ‚Äî Core FLAME 2023 model (103 MB)
   - `FLAME_masks.pkl` ‚Äî Vertex masks (1.1 MB)
   - `FLAME_texture.npz` ‚Äî Texture space (1.2 GB, optional but recommended)
   - `mediapipe_landmark_embedding.npz` ‚Äî MediaPipe mapping (3.1 KB)

**What you get:**
- `face_mesh.obj` ‚Äî FLAME 2023 topology 3D mesh (5,023 vertices)
- `face_texture.png` ‚Äî 1024x1024 photorealistic albedo texture
- `face_normal.png` ‚Äî Normal map for skin detail (pores, wrinkles)
- `face_displacement.png` ‚Äî Displacement map for geometry detail
- `face_params.json` ‚Äî FLAME shape/expression/pose parameters
- `web/` folder ‚Äî Pre-converted files for the browser app (auto-generated)

Upload these files to your Facial AI Platform web app at https://facial-ai-project.vercel.app

## Step 1: Setup Environment & Check GPU

In [None]:
# Check GPU availability
!nvidia-smi
import torch
print(f'\nPyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB')
else:
    print('‚ö†Ô∏è  No GPU detected! Go to Runtime > Change runtime type > GPU')
    print('   DECA reconstruction requires a GPU.')

In [None]:
# Install DECA dependencies
!pip install -q torch torchvision
!pip install -q face-alignment opencv-python-headless scikit-image
!pip install -q pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt241/download.html
!pip install -q chumpy scipy

# Clone DECA
import os
if not os.path.exists('/content/DECA'):
    !git clone https://github.com/yfeng95/DECA.git /content/DECA
    %cd /content/DECA
    !pip install -q -r requirements.txt
else:
    %cd /content/DECA
    print('DECA already cloned')

print('‚úÖ Dependencies installed')

In [None]:
# Download DECA pretrained model
import os
os.makedirs('data', exist_ok=True)

if not os.path.exists('data/deca_model.tar'):
    !gdown --id 1rp8kdyLPvErw2dTmqtjISRVvQLj6Yzje -O data/deca_model.tar
    print('‚úÖ DECA pretrained model downloaded')
else:
    print('‚úÖ DECA pretrained model already exists')

print('\nüìã Next step: Upload your FLAME 2023 model files')

## Step 2: Upload FLAME 2023 Model Files

Upload the files you downloaded from https://flame.is.tue.mpg.de/:

| File | Required | Description |
|------|----------|-------------|
| `generic_model.pkl` | ‚úÖ Yes | Core FLAME 2023 model (103 MB) |
| `FLAME_masks.pkl` | ‚úÖ Yes | Vertex region masks (1.1 MB) |
| `FLAME_texture.npz` | üü° Recommended | Texture space for realistic skin (1.2 GB) |
| `mediapipe_landmark_embedding.npz` | üü° Recommended | MediaPipe landmark mapping (3.1 KB) |

In [None]:
from google.colab import files
import os
import shutil

# Create FLAME data directory
FLAME_DIR = '/content/DECA/data'
os.makedirs(FLAME_DIR, exist_ok=True)

# Define expected files and their destinations
FLAME_FILES = {
    'generic_model.pkl': {'dest': 'generic_model.pkl', 'required': True, 'desc': 'FLAME 2023 model'},
    'FLAME_masks.pkl': {'dest': 'FLAME_masks.pkl', 'required': True, 'desc': 'Vertex masks'},
    'FLAME_texture.npz': {'dest': 'FLAME_texture.npz', 'required': False, 'desc': 'Texture space'},
    'mediapipe_landmark_embedding.npz': {'dest': 'mediapipe_landmark_embedding.npz', 'required': False, 'desc': 'MediaPipe mapping'},
}

# Check which files already exist
missing_required = []
missing_optional = []
for filename, info in FLAME_FILES.items():
    dest_path = os.path.join(FLAME_DIR, info['dest'])
    if os.path.exists(dest_path):
        size = os.path.getsize(dest_path) / (1024 * 1024)
        print(f'  ‚úÖ {info["desc"]}: {filename} ({size:.1f} MB)')
    elif info['required']:
        missing_required.append(filename)
    else:
        missing_optional.append(filename)

if missing_required or missing_optional:
    if missing_required:
        print(f'\n‚ö†Ô∏è  Missing REQUIRED files: {", ".join(missing_required)}')
    if missing_optional:
        print(f'‚ÑπÔ∏è  Missing optional files: {", ".join(missing_optional)}')

    print('\nüì§ Please upload your FLAME files now:')
    uploaded = files.upload()

    for filename, data in uploaded.items():
        # Match uploaded file to known FLAME files
        matched = False
        for known_name, info in FLAME_FILES.items():
            if filename == known_name or filename.lower().replace('-', '_') == known_name.lower():
                dest_path = os.path.join(FLAME_DIR, info['dest'])
                with open(dest_path, 'wb') as f:
                    f.write(data)
                size = len(data) / (1024 * 1024)
                print(f'  ‚úÖ Saved {info["desc"]}: {dest_path} ({size:.1f} MB)')
                matched = True
                break
        if not matched:
            # Save unrecognized files too (might be renamed versions)
            dest_path = os.path.join(FLAME_DIR, filename)
            with open(dest_path, 'wb') as f:
                f.write(data)
            print(f'  üìÅ Saved: {dest_path} ({len(data)/1024/1024:.1f} MB)')
else:
    print('\n‚úÖ All FLAME files already present!')

# Verify required files
flame_model_path = os.path.join(FLAME_DIR, 'generic_model.pkl')
if not os.path.exists(flame_model_path):
    print('\n‚ùå ERROR: generic_model.pkl is required! Please re-run this cell and upload it.')
else:
    # Quick validation
    import pickle
    try:
        with open(flame_model_path, 'rb') as f:
            flame_data = pickle.load(f, encoding='latin1')
        v_count = flame_data['v_template'].shape[0] if hasattr(flame_data['v_template'], 'shape') else len(flame_data['v_template'])
        print(f'\n‚úÖ FLAME 2023 model validated: {v_count} vertices')
        if 'shapedirs' in flame_data:
            shape_dims = flame_data['shapedirs'].shape[-1] if hasattr(flame_data['shapedirs'], 'shape') else '?'
            print(f'   Shape parameters: {shape_dims}')
        if 'exprdirs' in flame_data or 'expressionspace' in flame_data:
            print(f'   Expression parameters: available')
    except Exception as e:
        print(f'‚ö†Ô∏è  Could not validate FLAME model: {e}')

## Step 3: Upload Your Face Photos

Upload 1-3 photos:
- **Required:** Front-facing, neutral expression
- **Optional:** Left 45¬∞, Right 45¬∞

Tips:
- Diffuse, even lighting (no harsh shadows)
- Neutral expression, mouth closed
- Hair pulled back from face
- High resolution

In [None]:
from google.colab import files
import shutil

# Create input directory
INPUT_DIR = '/content/face_input'
os.makedirs(INPUT_DIR, exist_ok=True)

print('üì∏ Upload your face photos (1-3 images):')
uploaded = files.upload()

for filename, data in uploaded.items():
    dest = os.path.join(INPUT_DIR, filename)
    with open(dest, 'wb') as f:
        f.write(data)
    print(f'  ‚úÖ Saved: {filename} ({len(data)/1024:.0f} KB)')

print(f'\nüìÅ {len(uploaded)} photo(s) uploaded to {INPUT_DIR}')

## Step 4: Run DECA Reconstruction with FLAME 2023

In [None]:
import sys
sys.path.insert(0, '/content/DECA')

import cv2
import numpy as np
from PIL import Image
import json
import pickle

# Run DECA reconstruction
OUTPUT_DIR = '/content/face_output'
os.makedirs(OUTPUT_DIR, exist_ok=True)

from decalib.deca import DECA
from decalib.utils.config import cfg as deca_cfg
from decalib.datasets import datasets

# Initialize DECA with texture generation enabled
deca_cfg.model.use_tex = True
deca_cfg.rasterizer_type = 'pytorch3d'
deca = DECA(config=deca_cfg, device='cuda')

print(f'‚úÖ DECA model loaded')
print(f'   FLAME vertices: {deca.flame.v_template.shape[0]}')

# Load vertex masks if available
MASKS_PATH = '/content/DECA/data/FLAME_masks.pkl'
vertex_masks = None
if os.path.exists(MASKS_PATH):
    with open(MASKS_PATH, 'rb') as f:
        vertex_masks = pickle.load(f, encoding='latin1')
    print(f'   Vertex masks: {list(vertex_masks.keys())}')

# Process each photo
input_files = sorted([f for f in os.listdir(INPUT_DIR)
                      if f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))])
print(f'\nProcessing {len(input_files)} image(s)...')

all_params = []
all_codedicts = []

for i, filename in enumerate(input_files):
    img_path = os.path.join(INPUT_DIR, filename)
    print(f'\n--- Processing {filename} ({i+1}/{len(input_files)}) ---')

    # Load and preprocess
    testdata = datasets.TestData(img_path, iscrop=True, face_detector='fan', sample_step=1)
    if len(testdata) == 0:
        print(f'  ‚ö†Ô∏è No face detected in {filename}, skipping')
        continue

    images = testdata[0]['image'].unsqueeze(0).to('cuda')

    with torch.no_grad():
        codedict = deca.encode(images)
        opdict, visdict = deca.decode(codedict)

    all_codedicts.append(codedict)

    # Extract parameters
    params = {
        'shape': codedict['shape'].cpu().numpy().tolist()[0],
        'exp': codedict['exp'].cpu().numpy().tolist()[0],
        'pose': codedict['pose'].cpu().numpy().tolist()[0],
        'cam': codedict['cam'].cpu().numpy().tolist()[0],
        'light': codedict['light'].cpu().numpy().tolist()[0] if 'light' in codedict else None,
        'tex': codedict['tex'].cpu().numpy().tolist()[0] if 'tex' in codedict else None,
        'detail': codedict['detail'].cpu().numpy().tolist()[0] if 'detail' in codedict else None,
        'source_image': filename
    }
    all_params.append(params)

    # Get mesh info
    vertices = opdict['verts'].cpu().numpy()[0]
    faces = deca.flame.faces_tensor.cpu().numpy()

    print(f'  Mesh: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
    print(f'  Shape params: {len(params["shape"])} dims')
    print(f'  Expression params: {len(params["exp"])} dims')
    print(f'  Pose params: {len(params["pose"])} dims (jaw + global rotation)')

    # Show key shape values
    shape_arr = np.array(params['shape'])
    top_indices = np.argsort(np.abs(shape_arr))[-5:][::-1]
    print(f'  Top shape components: {[(int(idx), f"{shape_arr[idx]:.2f}") for idx in top_indices]}')

print(f'\n‚úÖ Reconstruction complete for {len(all_params)} image(s)')

## Step 5: Export Mesh, Textures & Parameters

In [None]:
# =====================================================================
# EXPORT: Mesh + Textures + Parameters + Web-ready files
# =====================================================================

primary_idx = 0  # Use first (front) image as primary

# Re-run decode for primary image to get full mesh data
testdata = datasets.TestData(
    os.path.join(INPUT_DIR, input_files[primary_idx]),
    iscrop=True, face_detector='fan', sample_step=1
)
images = testdata[0]['image'].unsqueeze(0).to('cuda')

with torch.no_grad():
    codedict = deca.encode(images)
    opdict, visdict = deca.decode(codedict)

vertices = opdict['verts'].cpu().numpy()[0]
faces = deca.flame.faces_tensor.cpu().numpy()

# Get UV coordinates from FLAME
try:
    uvs = deca.flame.vt.cpu().numpy() if hasattr(deca.flame, 'vt') else None
    uv_faces = deca.flame.ft.cpu().numpy() if hasattr(deca.flame, 'ft') else None
except:
    uvs = None
    uv_faces = None

# =====================================================================
# 1. Export OBJ mesh
# =====================================================================
obj_path = os.path.join(OUTPUT_DIR, 'face_mesh.obj')
mtl_path = os.path.join(OUTPUT_DIR, 'face_mesh.mtl')

with open(obj_path, 'w') as f:
    f.write('# DECA + FLAME 2023 Reconstruction\n')
    f.write(f'# Vertices: {vertices.shape[0]}\n')
    f.write(f'# Faces: {faces.shape[0]}\n')
    f.write(f'# Generated by Facial AI Platform\n')
    f.write(f'mtllib face_mesh.mtl\n')
    f.write(f'usemtl face_material\n\n')

    for v in vertices:
        f.write(f'v {v[0]:.6f} {v[1]:.6f} {v[2]:.6f}\n')

    if uvs is not None:
        for uv in uvs:
            f.write(f'vt {uv[0]:.6f} {uv[1]:.6f}\n')

    # Compute and write vertex normals
    from pytorch3d.structures import Meshes
    mesh_p3d = Meshes(verts=[torch.tensor(vertices).float()],
                      faces=[torch.tensor(faces).long()])
    vnormals = mesh_p3d.verts_normals_packed().numpy()
    for vn in vnormals:
        f.write(f'vn {vn[0]:.6f} {vn[1]:.6f} {vn[2]:.6f}\n')

    for fi, face in enumerate(faces):
        if uvs is not None and uv_faces is not None and fi < len(uv_faces):
            uv_face = uv_faces[fi]
            f.write(f'f {face[0]+1}/{uv_face[0]+1}/{face[0]+1} '
                    f'{face[1]+1}/{uv_face[1]+1}/{face[1]+1} '
                    f'{face[2]+1}/{uv_face[2]+1}/{face[2]+1}\n')
        else:
            f.write(f'f {face[0]+1}//{face[0]+1} {face[1]+1}//{face[1]+1} {face[2]+1}//{face[2]+1}\n')

with open(mtl_path, 'w') as f:
    f.write('newmtl face_material\n')
    f.write('Ka 0.2 0.2 0.2\n')
    f.write('Kd 0.8 0.8 0.8\n')
    f.write('Ks 0.1 0.1 0.1\n')
    f.write('Ns 20.0\n')
    f.write('map_Kd face_texture.png\n')
    f.write('bump face_normal.png\n')
    f.write('disp face_displacement.png\n')

print(f'‚úÖ Mesh: {obj_path} ({vertices.shape[0]} verts, {faces.shape[0]} faces)')

# =====================================================================
# 2. Export albedo texture (1024x1024)
# =====================================================================
tex_path = os.path.join(OUTPUT_DIR, 'face_texture.png')
TEX_SIZE = 1024

if 'uv_texture_gt' in visdict:
    texture = visdict['uv_texture_gt'][0].cpu().numpy()
    texture = (texture.transpose(1, 2, 0) * 255).astype(np.uint8)
    texture = cv2.resize(texture, (TEX_SIZE, TEX_SIZE), interpolation=cv2.INTER_LANCZOS4)
    Image.fromarray(texture).save(tex_path, quality=95)
    print(f'‚úÖ Texture: {tex_path} ({TEX_SIZE}x{TEX_SIZE})')
elif opdict.get('albedo') is not None:
    albedo = opdict['albedo'][0].cpu().numpy()
    albedo = (albedo.transpose(1, 2, 0) * 255).astype(np.uint8)
    albedo = cv2.resize(albedo, (TEX_SIZE, TEX_SIZE), interpolation=cv2.INTER_LANCZOS4)
    Image.fromarray(albedo).save(tex_path, quality=95)
    print(f'‚úÖ Albedo texture: {tex_path} ({TEX_SIZE}x{TEX_SIZE})')
else:
    # Generate texture from FLAME texture space if available
    TEXTURE_SPACE_PATH = '/content/DECA/data/FLAME_texture.npz'
    if os.path.exists(TEXTURE_SPACE_PATH) and 'tex' in codedict:
        print('  Generating texture from FLAME texture space...')
        tex_space = np.load(TEXTURE_SPACE_PATH)
        if 'mean' in tex_space and 'tex_dir' in tex_space:
            tex_params = codedict['tex'].cpu().numpy()[0]
            mean_tex = tex_space['mean']
            tex_dir = tex_space['tex_dir']
            n_components = min(len(tex_params), tex_dir.shape[-1])
            texture_flat = mean_tex + tex_dir[:, :, :n_components].dot(tex_params[:n_components])
            texture = np.clip(texture_flat.reshape(TEX_SIZE, TEX_SIZE, 3), 0, 255).astype(np.uint8)
            Image.fromarray(texture).save(tex_path, quality=95)
            print(f'‚úÖ Texture from FLAME space: {tex_path}')
        else:
            print('  ‚ö†Ô∏è Unexpected texture space format, using photo fallback')
    else:
        # Fallback: project input photo
        src_img = cv2.imread(os.path.join(INPUT_DIR, input_files[primary_idx]))
        src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
        src_img = cv2.resize(src_img, (TEX_SIZE, TEX_SIZE), interpolation=cv2.INTER_LANCZOS4)
        Image.fromarray(src_img).save(tex_path, quality=95)
        print(f'‚úÖ Fallback texture from photo: {tex_path}')

# =====================================================================
# 3. Export normal map
# =====================================================================
normal_path = os.path.join(OUTPUT_DIR, 'face_normal.png')

if 'normal_images' in visdict:
    normal_img = visdict['normal_images'][0].cpu().numpy()
    normal_img = ((normal_img.transpose(1, 2, 0) + 1) * 0.5 * 255).astype(np.uint8)
    normal_img = cv2.resize(normal_img, (TEX_SIZE, TEX_SIZE), interpolation=cv2.INTER_LANCZOS4)
    Image.fromarray(normal_img).save(normal_path)
    print(f'‚úÖ Normal map: {normal_path}')
else:
    # Generate default tangent-space normal map
    normal_img = np.full((TEX_SIZE, TEX_SIZE, 3), 128, dtype=np.uint8)
    normal_img[:, :, 2] = 255  # Z-up default
    Image.fromarray(normal_img).save(normal_path)
    print(f'‚úÖ Default normal map: {normal_path}')

# =====================================================================
# 4. Export displacement map (if detail code available)
# =====================================================================
disp_path = os.path.join(OUTPUT_DIR, 'face_displacement.png')

if 'displacement_map' in opdict:
    disp = opdict['displacement_map'][0, 0].cpu().numpy()
    disp_normalized = ((disp - disp.min()) / (disp.max() - disp.min() + 1e-8) * 255).astype(np.uint8)
    disp_normalized = cv2.resize(disp_normalized, (TEX_SIZE, TEX_SIZE), interpolation=cv2.INTER_LANCZOS4)
    Image.fromarray(disp_normalized).save(disp_path)
    print(f'‚úÖ Displacement map: {disp_path}')
elif 'uv_detail_normals' in visdict:
    detail = visdict['uv_detail_normals'][0].cpu().numpy()
    detail = ((detail.transpose(1, 2, 0) + 1) * 0.5 * 255).astype(np.uint8)
    detail = cv2.resize(detail, (TEX_SIZE, TEX_SIZE), interpolation=cv2.INTER_LANCZOS4)
    Image.fromarray(detail).save(disp_path)
    print(f'‚úÖ Detail normals as displacement: {disp_path}')
else:
    # Flat displacement
    Image.fromarray(np.full((TEX_SIZE, TEX_SIZE), 128, dtype=np.uint8)).save(disp_path)
    print(f'‚úÖ Flat displacement map: {disp_path}')

# =====================================================================
# 5. Export FLAME parameters JSON
# =====================================================================
params_path = os.path.join(OUTPUT_DIR, 'face_params.json')

export_data = {
    'flame_version': 'FLAME 2023',
    'reconstruction_method': 'DECA',
    'vertex_count': int(vertices.shape[0]),
    'face_count': int(faces.shape[0]),
    'shape_params': all_params[primary_idx]['shape'],
    'expression_params': all_params[primary_idx]['exp'],
    'pose_params': all_params[primary_idx]['pose'],
    'camera_params': all_params[primary_idx]['cam'],
    'lighting_params': all_params[primary_idx]['light'],
    'texture_params': all_params[primary_idx]['tex'],
    'detail_params': all_params[primary_idx].get('detail'),
    'source_images': [p['source_image'] for p in all_params],
    'all_reconstructions': [{
        'source': p['source_image'],
        'shape': p['shape'],
        'exp': p['exp'],
        'pose': p['pose'],
    } for p in all_params]
}

with open(params_path, 'w') as f:
    json.dump(export_data, f, indent=2)

print(f'‚úÖ Parameters: {params_path}')

# =====================================================================
# Summary
# =====================================================================
print(f'\n{"="*60}')
print(f'üìÅ All exports in {OUTPUT_DIR}:')
for f_name in sorted(os.listdir(OUTPUT_DIR)):
    size = os.path.getsize(os.path.join(OUTPUT_DIR, f_name))
    unit = 'KB' if size < 1024*1024 else 'MB'
    size_val = size/1024 if unit == 'KB' else size/(1024*1024)
    print(f'  {f_name:30s} {size_val:8.1f} {unit}')
print(f'{"="*60}')

## Step 5.5: Generate Web-Ready Files

Converts FLAME data to JSON + binary format that the browser app can load directly via `FlameMeshGenerator.loadFLAME()`.

In [None]:
# =====================================================================
# GENERATE WEB-READY FILES for the browser app
# =====================================================================

import struct
import json
import numpy as np
import pickle

WEB_DIR = os.path.join(OUTPUT_DIR, 'web')
os.makedirs(WEB_DIR, exist_ok=True)

# Load FLAME model data
FLAME_PATH = '/content/DECA/data/generic_model.pkl'
with open(FLAME_PATH, 'rb') as f:
    flame = pickle.load(f, encoding='latin1')

# Convert chumpy arrays to numpy
def to_np(x):
    return np.array(x) if hasattr(x, '__array__') else x

v_template = to_np(flame['v_template']).astype(np.float32)
shapedirs = to_np(flame['shapedirs']).astype(np.float32)
faces_arr = to_np(flame['f']).astype(np.int32)

# Expression dirs (may be in different keys depending on FLAME version)
exprdirs = None
for key in ['exprdirs', 'expressionspace', 'expression_dirs']:
    if key in flame:
        exprdirs = to_np(flame[key]).astype(np.float32)
        break

V = v_template.shape[0]  # 5023
F = faces_arr.shape[0]    # ~9976
SHAPE_COMPONENTS = min(50, shapedirs.shape[2])
EXPR_COMPONENTS = min(50, exprdirs.shape[2]) if exprdirs is not None else 0

print(f'FLAME model loaded:')
print(f'  Vertices: {V}')
print(f'  Faces: {F}')
print(f'  Shape components: {shapedirs.shape[2]} (exporting {SHAPE_COMPONENTS})')
print(f'  Expression components: {EXPR_COMPONENTS}')

# --- 1. Template vertices (Float32) ---
v_template.tofile(os.path.join(WEB_DIR, 'flame_template_vertices.bin'))

# --- 2. Shape basis (Float32): reshape to (V*3, N) then flatten ---
shape_basis = shapedirs[:, :, :SHAPE_COMPONENTS].reshape(-1).astype(np.float32)
shape_basis.tofile(os.path.join(WEB_DIR, 'flame_shape_basis.bin'))

# --- 3. Expression basis (Float32) ---
if exprdirs is not None:
    expr_basis = exprdirs[:, :, :EXPR_COMPONENTS].reshape(-1).astype(np.float32)
    expr_basis.tofile(os.path.join(WEB_DIR, 'flame_expression_basis.bin'))

# --- 4. Face indices (Uint32) ---
faces_arr.astype(np.uint32).tofile(os.path.join(WEB_DIR, 'flame_faces.bin'))

# --- 5. UV coordinates ---
if uvs is not None:
    uvs.astype(np.float32).tofile(os.path.join(WEB_DIR, 'flame_uv.bin'))

# --- 6. Region mapping (52 clinical zones from vertex masks + position-based subdivision) ---
region_map = {}

# Load vertex masks
MASKS_PATH = '/content/DECA/data/FLAME_masks.pkl'
if os.path.exists(MASKS_PATH):
    with open(MASKS_PATH, 'rb') as f:
        masks = pickle.load(f, encoding='latin1')

    # FLAME masks typically contain: face, left_eyeball, right_eyeball, nose,
    # right_eye_region, forehead, lips, right_ear, left_ear, left_eye_region,
    # neck, scalp, boundary
    flame_masks = {}
    for key, val in masks.items():
        flame_masks[key] = to_np(val).astype(int).tolist()
        print(f'  Mask "{key}": {len(flame_masks[key])} vertices')

    # Helper: filter vertices by position
    def filter_by_pos(indices, x_min=-999, x_max=999, y_min=-999, y_max=999, z_min=-999, z_max=999):
        result = []
        for idx in indices:
            x, y, z = v_template[idx]
            if x_min <= x <= x_max and y_min <= y <= y_max and z_min <= z <= z_max:
                result.append(int(idx))
        return result

    # Compute position statistics for subdivision thresholds
    all_face_indices = flame_masks.get('face', list(range(V)))
    face_verts = v_template[all_face_indices]
    y_min_f, y_max_f = face_verts[:, 1].min(), face_verts[:, 1].max()
    y_mid = (y_min_f + y_max_f) / 2
    x_center = face_verts[:, 0].mean()

    # full_face = all face mask vertices
    region_map['full_face'] = flame_masks.get('face', list(range(V)))

    # --- Forehead subdivisions ---
    fh = flame_masks.get('forehead', [])
    region_map['forehead'] = fh
    region_map['forehead_left'] = filter_by_pos(fh, x_min=0.005)
    region_map['forehead_right'] = filter_by_pos(fh, x_max=-0.005)
    region_map['forehead_center'] = filter_by_pos(fh, x_min=-0.015, x_max=0.015)

    # Brows: top portion of eye regions
    left_eye = flame_masks.get('left_eye_region', [])
    right_eye = flame_masks.get('right_eye_region', [])

    if left_eye:
        le_verts = v_template[left_eye]
        le_y_mid = le_verts[:, 1].mean()
        region_map['brow_left'] = filter_by_pos(left_eye, y_min=le_y_mid)
        region_map['brow_inner_left'] = filter_by_pos(region_map['brow_left'], x_max=0.02)
        region_map['eye_left_upper'] = filter_by_pos(left_eye, y_min=le_y_mid - 0.003, y_max=le_y_mid + 0.01)
        region_map['eye_left_lower'] = filter_by_pos(left_eye, y_max=le_y_mid - 0.003)
        region_map['eye_left_corner_inner'] = filter_by_pos(left_eye, x_max=0.01, y_min=le_y_mid-0.005, y_max=le_y_mid+0.005)
        region_map['eye_left_corner_outer'] = filter_by_pos(left_eye, x_min=0.03, y_min=le_y_mid-0.005, y_max=le_y_mid+0.005)
        region_map['under_eye_left'] = filter_by_pos(left_eye, y_max=le_y_mid - 0.008)
        region_map['tear_trough_left'] = filter_by_pos(region_map.get('under_eye_left', []), x_max=0.02)

    if right_eye:
        re_verts = v_template[right_eye]
        re_y_mid = re_verts[:, 1].mean()
        region_map['brow_right'] = filter_by_pos(right_eye, y_min=re_y_mid)
        region_map['brow_inner_right'] = filter_by_pos(region_map['brow_right'], x_min=-0.02)
        region_map['eye_right_upper'] = filter_by_pos(right_eye, y_min=re_y_mid - 0.003, y_max=re_y_mid + 0.01)
        region_map['eye_right_lower'] = filter_by_pos(right_eye, y_max=re_y_mid - 0.003)
        region_map['eye_right_corner_inner'] = filter_by_pos(right_eye, x_min=-0.01, y_min=re_y_mid-0.005, y_max=re_y_mid+0.005)
        region_map['eye_right_corner_outer'] = filter_by_pos(right_eye, x_max=-0.03, y_min=re_y_mid-0.005, y_max=re_y_mid+0.005)
        region_map['under_eye_right'] = filter_by_pos(right_eye, y_max=re_y_mid - 0.008)
        region_map['tear_trough_right'] = filter_by_pos(region_map.get('under_eye_right', []), x_min=-0.02)

    # --- Nose subdivisions ---
    nose = flame_masks.get('nose', [])
    region_map['nose_bridge'] = filter_by_pos(nose, x_min=-0.01, x_max=0.01, y_min=0)
    region_map['nose_bridge_upper'] = filter_by_pos(region_map['nose_bridge'], y_min=0.01)
    region_map['nose_bridge_lower'] = filter_by_pos(region_map['nose_bridge'], y_max=0.01)
    region_map['nose_tip'] = filter_by_pos(nose, y_max=0.005, z_min=0.03)
    region_map['nose_tip_left'] = filter_by_pos(region_map['nose_tip'], x_min=0.002)
    region_map['nose_tip_right'] = filter_by_pos(region_map['nose_tip'], x_max=-0.002)
    region_map['nostril_left'] = filter_by_pos(nose, x_min=0.01, y_max=0)
    region_map['nostril_right'] = filter_by_pos(nose, x_max=-0.01, y_max=0)
    region_map['nose_dorsum'] = filter_by_pos(nose, x_min=-0.012, x_max=0.012)

    # --- Lips subdivisions ---
    lips = flame_masks.get('lips', [])
    if lips:
        lip_verts = v_template[lips]
        lip_y_mid = lip_verts[:, 1].mean()
        region_map['lip_upper'] = filter_by_pos(lips, y_min=lip_y_mid)
        region_map['lip_upper_left'] = filter_by_pos(region_map['lip_upper'], x_min=0.005)
        region_map['lip_upper_right'] = filter_by_pos(region_map['lip_upper'], x_max=-0.005)
        region_map['lip_upper_center'] = filter_by_pos(region_map['lip_upper'], x_min=-0.008, x_max=0.008)
        region_map['lip_lower'] = filter_by_pos(lips, y_max=lip_y_mid)
        region_map['lip_lower_left'] = filter_by_pos(region_map['lip_lower'], x_min=0.005)
        region_map['lip_lower_right'] = filter_by_pos(region_map['lip_lower'], x_max=-0.005)
        region_map['lip_lower_center'] = filter_by_pos(region_map['lip_lower'], x_min=-0.008, x_max=0.008)
        region_map['lip_corner_left'] = filter_by_pos(lips, x_min=0.02, y_min=lip_y_mid-0.003, y_max=lip_y_mid+0.003)
        region_map['lip_corner_right'] = filter_by_pos(lips, x_max=-0.02, y_min=lip_y_mid-0.003, y_max=lip_y_mid+0.003)

    # --- Cheeks, Jaw, Chin (from face mask, excluding other regions) ---
    specific_indices = set()
    for key in ['forehead', 'left_eye_region', 'right_eye_region', 'nose', 'lips',
                'left_ear', 'right_ear', 'left_eyeball', 'right_eyeball', 'neck', 'scalp']:
        if key in flame_masks:
            specific_indices.update(flame_masks[key])

    remaining_face = [i for i in all_face_indices if i not in specific_indices]

    # Cheeks: lateral, mid-face height
    region_map['cheek_left'] = filter_by_pos(remaining_face, x_min=0.02, y_min=-0.03, y_max=0.02)
    region_map['cheek_right'] = filter_by_pos(remaining_face, x_max=-0.02, y_min=-0.03, y_max=0.02)
    region_map['cheekbone_left'] = filter_by_pos(region_map['cheek_left'], y_min=0)
    region_map['cheekbone_right'] = filter_by_pos(region_map['cheek_right'], y_min=0)
    region_map['cheek_hollow_left'] = filter_by_pos(region_map['cheek_left'], y_max=0)
    region_map['cheek_hollow_right'] = filter_by_pos(region_map['cheek_right'], y_max=0)

    # Nasolabial
    region_map['nasolabial_left'] = filter_by_pos(remaining_face, x_min=0.01, x_max=0.025, y_min=-0.04, y_max=0)
    region_map['nasolabial_right'] = filter_by_pos(remaining_face, x_max=-0.01, x_min=-0.025, y_min=-0.04, y_max=0)

    # Chin
    chin_verts = filter_by_pos(remaining_face, y_max=-0.04, x_min=-0.03, x_max=0.03)
    region_map['chin'] = chin_verts
    region_map['chin_center'] = filter_by_pos(chin_verts, x_min=-0.01, x_max=0.01)
    region_map['chin_left'] = filter_by_pos(chin_verts, x_min=0.01)
    region_map['chin_right'] = filter_by_pos(chin_verts, x_max=-0.01)

    # Jaw
    jaw_verts = filter_by_pos(remaining_face, y_max=-0.02, x_min=0.03)
    region_map['jaw_left'] = jaw_verts
    region_map['jawline_left'] = filter_by_pos(jaw_verts, y_max=-0.04)
    jaw_verts_r = filter_by_pos(remaining_face, y_max=-0.02, x_max=-0.03)
    region_map['jaw_right'] = jaw_verts_r
    region_map['jawline_right'] = filter_by_pos(jaw_verts_r, y_max=-0.04)

    # Temples: above cheeks, lateral to forehead
    region_map['temple_left'] = filter_by_pos(remaining_face, x_min=0.04, y_min=0.01)
    region_map['temple_right'] = filter_by_pos(remaining_face, x_max=-0.04, y_min=0.01)

    # Ears
    region_map['ear_left'] = flame_masks.get('left_ear', [])
    region_map['ear_right'] = flame_masks.get('right_ear', [])

    # Neck
    region_map['neck'] = flame_masks.get('neck', [])

else:
    print('‚ö†Ô∏è  No vertex masks file found, region mapping will be position-based only')
    # Fallback: purely position-based classification (less accurate)
    for i in range(V):
        x, y, z = v_template[i]
        region_map.setdefault('full_face', []).append(i)

# Save regions JSON
with open(os.path.join(WEB_DIR, 'flame_regions.json'), 'w') as f:
    json.dump(region_map, f)

# --- 7. MediaPipe mapping ---
MP_PATH = '/content/DECA/data/mediapipe_landmark_embedding.npz'
if os.path.exists(MP_PATH):
    mp_data = np.load(MP_PATH, allow_pickle=True)
    mp_mapping = {}
    if 'lmk_face_idx' in mp_data:
        mp_mapping['lmk_face_idx'] = mp_data['lmk_face_idx'].tolist()
    if 'lmk_b_coords' in mp_data:
        mp_mapping['lmk_b_coords'] = mp_data['lmk_b_coords'].tolist()
    # Convert to closest vertex indices for simpler browser use
    if 'lmk_face_idx' in mp_data and 'lmk_b_coords' in mp_data:
        closest_vertices = []
        for face_idx, bary in zip(mp_data['lmk_face_idx'], mp_data['lmk_b_coords']):
            face = faces_arr[int(face_idx)]
            # Pick the vertex with highest barycentric weight
            max_bary_idx = np.argmax(bary)
            closest_vertices.append(int(face[max_bary_idx]))
        mp_mapping['closest_vertices'] = closest_vertices
        mp_mapping['landmark_count'] = len(closest_vertices)

    with open(os.path.join(WEB_DIR, 'flame_mediapipe_mapping.json'), 'w') as f:
        json.dump(mp_mapping, f)
    print(f'‚úÖ MediaPipe mapping: {len(mp_mapping.get("closest_vertices", []))} landmarks')

# --- 8. Template metadata JSON ---
template_meta = {
    'vertexCount': int(V),
    'faceCount': int(F),
    'shapeComponents': int(SHAPE_COMPONENTS),
    'expressionComponents': int(EXPR_COMPONENTS),
    'flameVersion': 'FLAME 2023',
    'hasUV': uvs is not None,
    'hasMediaPipe': os.path.exists(MP_PATH),
    'hasTexureSpace': os.path.exists('/content/DECA/data/FLAME_texture.npz'),
    'regionCount': len(region_map),
    'files': {
        'vertices': 'flame_template_vertices.bin',
        'shapeBasis': 'flame_shape_basis.bin',
        'expressionBasis': 'flame_expression_basis.bin' if exprdirs is not None else None,
        'faces': 'flame_faces.bin',
        'uv': 'flame_uv.bin' if uvs is not None else None,
        'regions': 'flame_regions.json',
        'mediapipe': 'flame_mediapipe_mapping.json' if os.path.exists(MP_PATH) else None,
    },
    # Include reconstruction-specific params for this face
    'reconstruction': {
        'shapeParams': all_params[primary_idx]['shape'],
        'expressionParams': all_params[primary_idx]['exp'],
        'poseParams': all_params[primary_idx]['pose'],
    }
}

with open(os.path.join(WEB_DIR, 'flame_template.json'), 'w') as f:
    json.dump(template_meta, f, indent=2)

# Summary
print(f'\n{"="*60}')
print(f'üìÅ Web-ready files in {WEB_DIR}:')
total_size = 0
for f_name in sorted(os.listdir(WEB_DIR)):
    size = os.path.getsize(os.path.join(WEB_DIR, f_name))
    total_size += size
    unit = 'KB' if size < 1024*1024 else 'MB'
    size_val = size/1024 if unit == 'KB' else size/(1024*1024)
    print(f'  {f_name:40s} {size_val:8.1f} {unit}')
print(f'  {"Total:":40s} {total_size/1024/1024:8.1f} MB')
print(f'{"="*60}')
print(f'\n‚úÖ {len(region_map)} clinical regions mapped to FLAME topology')
for name, indices in sorted(region_map.items()):
    if indices:
        print(f'  {name:30s}: {len(indices):5d} vertices')

## Step 6: Preview Reconstruction

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

fig = plt.figure(figsize=(20, 5))

# 1. Source photo
ax1 = fig.add_subplot(141)
src = Image.open(os.path.join(INPUT_DIR, input_files[primary_idx]))
ax1.imshow(src)
ax1.set_title('Source Photo', fontsize=12, fontweight='bold')
ax1.axis('off')

# 2. 3D mesh (front view)
ax2 = fig.add_subplot(142, projection='3d')
ax2.plot_trisurf(vertices[:, 0], vertices[:, 1], vertices[:, 2],
                 triangles=faces, color='#e8b89d', edgecolor='gray',
                 linewidth=0.05, alpha=0.9)
ax2.set_title('3D Mesh (Front)', fontsize=12, fontweight='bold')
ax2.view_init(elev=0, azim=0)
ax2.axis('off')
ax2.set_box_aspect([1, 1.2, 1])

# 3. Texture map
ax3 = fig.add_subplot(143)
tex = Image.open(tex_path)
ax3.imshow(tex)
ax3.set_title('Albedo Texture', fontsize=12, fontweight='bold')
ax3.axis('off')

# 4. 3D mesh with region coloring
ax4 = fig.add_subplot(144, projection='3d')
# Color vertices by region
vertex_colors = np.ones((V, 3)) * 0.85  # default gray

color_palette = {
    'forehead': [0.4, 0.6, 1.0],
    'nose': [1.0, 0.7, 0.3],
    'lips': [1.0, 0.3, 0.4],
    'cheek_left': [0.5, 0.9, 0.5],
    'cheek_right': [0.5, 0.9, 0.5],
    'chin': [0.8, 0.5, 0.9],
    'jaw_left': [0.9, 0.7, 0.5],
    'jaw_right': [0.9, 0.7, 0.5],
    'neck': [0.6, 0.6, 0.6],
}

for region_name, color in color_palette.items():
    if region_name in region_map:
        for idx in region_map[region_name]:
            if idx < V:
                vertex_colors[idx] = color

# Plot faces with vertex colors
face_colors = np.mean(vertex_colors[faces], axis=1)
ax4.plot_trisurf(vertices[:, 0], vertices[:, 1], vertices[:, 2],
                 triangles=faces, edgecolor='gray', linewidth=0.02, alpha=0.9)
ax4.set_title(f'Region Map ({len(region_map)} zones)', fontsize=12, fontweight='bold')
ax4.view_init(elev=0, azim=0)
ax4.axis('off')
ax4.set_box_aspect([1, 1.2, 1])

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'preview.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f'‚úÖ Preview saved to {OUTPUT_DIR}/preview.png')
print(f'\nReconstruction stats:')
print(f'  Vertices: {vertices.shape[0]:,}')
print(f'  Faces: {faces.shape[0]:,}')
print(f'  Clinical regions: {len(region_map)}')
print(f'  Shape params: {len(all_params[primary_idx]["shape"])}')
print(f'  Expression params: {len(all_params[primary_idx]["exp"])}')

## Step 7: Download Results

Downloads a ZIP containing:
- **Traditional files**: OBJ mesh, textures, normal map, displacement map, parameters JSON
- **Web-ready files**: Binary + JSON files for direct browser loading via `FlameMeshGenerator.loadFLAME()`

Upload the `web/` folder contents to your project's `public/models/flame/web/` directory, and the textures to `public/models/`.

In [None]:
import zipfile

zip_path = '/content/facial_reconstruction_flame2023.zip'

with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
    # Add main output files
    for f_name in os.listdir(OUTPUT_DIR):
        f_path = os.path.join(OUTPUT_DIR, f_name)
        if os.path.isfile(f_path):
            zf.write(f_path, f_name)

    # Add web-ready files in web/ subdirectory
    web_dir = os.path.join(OUTPUT_DIR, 'web')
    if os.path.exists(web_dir):
        for f_name in os.listdir(web_dir):
            f_path = os.path.join(web_dir, f_name)
            if os.path.isfile(f_path):
                zf.write(f_path, f'web/{f_name}')

zip_size = os.path.getsize(zip_path) / (1024 * 1024)
print(f'üì¶ Created: facial_reconstruction_flame2023.zip ({zip_size:.1f} MB)')
print(f'\nContents:')
with zipfile.ZipFile(zip_path, 'r') as zf:
    for info in zf.infolist():
        size_kb = info.file_size / 1024
        unit = 'KB' if size_kb < 1024 else 'MB'
        size_val = size_kb if unit == 'KB' else size_kb / 1024
        print(f'  {info.filename:40s} {size_val:8.1f} {unit}')

print(f'\n‚¨áÔ∏è  Downloading...')
files.download(zip_path)

print('\n' + '='*60)
print('‚úÖ Done! Next steps:')
print('='*60)
print()
print('1. Unzip the downloaded file')
print('2. Copy web/ folder contents to your project:')
print('   ‚Üí facial-ai-project/public/models/flame/web/')
print('3. Copy textures to:')
print('   ‚Üí facial-ai-project/public/models/')
print('4. The app will auto-detect FLAME data via FlameMeshGenerator.loadFLAME()')
print()
print('üåê Your app: https://facial-ai-project.vercel.app')
print('üì¶ GitHub:   https://github.com/OfirVento/facial-ai-project')