In [None]:
import os, math, csv, time, cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn.functional as F
from PIL import Image
import numpy as np
import pandas as pd
import scipy.io as sio
from collections import OrderedDict
import gc

device = "cuda" if torch.cuda.is_available() else "cpu"

# Clear GPU memory
gc.collect()
torch.cuda.empty_cache()

In [None]:
# ============================================================================
# CONFIGURATION
# ============================================================================
SELECTED_FOLDERS = ["IBUG", "IBUG_Flip"]
DATA_ROOT = "/kaggle/input/300w-lp/300W_LP"
TRIPLETS_CSV = "/kaggle/input/multi-view-input/triplets_shorter_IBUG.csv"
WEIGHTS_PATH = "/kaggle/input/multi-view-input/resnet50-11ad3fa6.pth"
BATCH_SIZE = 32  # âœ… Increased batch size for better GPU utilization
NUM_EPOCHS = 2
CHECKPOINT_PATH = "resnet50_mvfnet.pth"

MODEL_SHAPE = "/kaggle/input/multi-view-input/Model_Shape.mat"
MODEL_EXP = "/kaggle/input/multi-view-input/Model_Expression.mat"
DATA_FROM_AUTHOR = "/kaggle/input/multi-view-input/sigma_exp.mat"

# Staged training schedule
TRAINING_SCHEDULE = {
    # 'stage1': {'epochs': 1, 'lr': 1e-3, 'shape': 1.0, 'exp': 1.0, 'pose': 0.0, 'landmark': 0.0},
    'stage1': {'epochs': 1, 'lr': 5e-4, 'shape': 1.0, 'exp': 1.0, 'pose': 1.0, 'landmark': 0.0},
    'stage2': {'epochs': 1, 'lr': 1e-4, 'shape': 1.0, 'exp': 1.0, 'pose': 1.0, 'landmark': 0.01}
}

In [None]:
# ============================================================================
# LOAD 3DMM MODELS - MOVE TO GPU IMMEDIATELY
# ============================================================================
model_shape = sio.loadmat(MODEL_SHAPE)
model_exp = sio.loadmat(MODEL_EXP)
data = sio.loadmat(DATA_FROM_AUTHOR)

kpt_index = model_shape["keypoints"].flatten().astype(np.int32) - 1
shape_std = model_shape["sigma"].flatten() # shape std
exp_std = model_exp["sigma_exp"].flatten() # exp std cá»§a Model_Exp.mat

pose_mean = np.array([0, 0, 0, 112, 112, 0, 0]).astype(np.float32)
pose_std = np.array([
    math.pi/2.0, math.pi/2.0, math.pi/2.0,
    56, 56, 1,
    224.0 / (2 * 180000.0)
]).astype(np.float32)

# âœ… Convert to tensors and move to GPU ONCE
w_shape_t = torch.from_numpy(model_shape['w']).float().to(device) #Aid
w_exp_t = torch.from_numpy(model_exp['w_exp']).float().to(device) #Aexp
mu_shape_t = torch.from_numpy(model_shape['mu_shape']).float().to(device) #Smu
sigma_shape_t = torch.from_numpy(model_shape['sigma']).float().to(device) #shape std
exp_std_t = torch.from_numpy(model_exp["sigma_exp"]).float().to(device) # exp std cá»§a Model_Exp.mat
sigma_exp_t = torch.from_numpy(data["sigma_exp"]).float().to(device) # ?? exp std cá»§a sigma_exp.mat
kpt_idx_t = torch.from_numpy(kpt_index).long().to(device)
pose_mean_t = torch.from_numpy(pose_mean).float().to(device)
pose_std_t = torch.from_numpy(pose_std).float().to(device)

print("3DMM models loaded and moved to GPU")

In [None]:
def convert_300wlp_to_bfm_pose(pose_300wlp, origin_w, origin_h, target_size=224):
    """
    Convert 300W-LP Pose_Para to BFM coordinate system
    
    Args:
        pose_300wlp: (7,) array [pitch, yaw, roll, tx, ty, tz, scale]
        pitch, yaw, roll in radians in both system
        tx, ty are in original image coordinates
        origin_w, origin_h: width, height of original image
        target_size: BFM input size (default 224)
    
    Returns:
        pose_bfm: (7,) array in BFM coordinate system
    """
    orig_w = origin_w
    orig_h = origin_h
    
    pitch, yaw, roll = pose_300wlp[0:3]
    tx_orig, ty_orig = pose_300wlp[3:5]
    tz_orig = pose_300wlp[5]
    scale_orig = pose_300wlp[6]

    # Convert rotations to radians
    # pitch = np.deg2rad(pitch)
    # yaw   = np.deg2rad(yaw)
    # roll  = np.deg2rad(roll)
    
    scale_factor_x = target_size / orig_w
    scale_factor_y = target_size / orig_h
    
    tx_bfm = tx_orig * scale_factor_x
    ty_bfm = ty_orig * scale_factor_y
    scale_bfm = scale_orig 
    tz_bfm = 0.0 # weak-perspective
    
    pose_bfm = np.array([
        pitch, yaw, roll,
        tx_bfm, ty_bfm, tz_bfm,
        scale_bfm
    ], dtype=np.float32)
    
    return pose_bfm

In [None]:
# ============================================================================
# LANDMARK LOADING UTILITIES
# ============================================================================
def load_landmarks_from_folder(data_root, folder, mat_filename, origin_w, origin_h, target_size=224):
    """
    Load landmarks from the landmarks folder
    
    Path structure:
    /kaggle/input/300w-lp/300W_LP/landmarks/LFPW/LFPW_image_test_0001_0_pts.mat
    """
    # Convert mat filename to pts filename
    # LFPW_image_test_0001_0.mat -> LFPW_image_test_0001_0_pts.mat
    pts_filename = mat_filename.replace('.mat', '_pts.mat')
    
    landmark_path = os.path.join(data_root, "landmarks", folder, pts_filename)
    
    if os.path.exists(landmark_path):
        try:
            lmk_data = sio.loadmat(landmark_path)
            pts_2d = lmk_data["pts_2d"].astype(np.float32)  # (68, 2)
            
            # Scale to target size
            scale_x = target_size / origin_w
            scale_y = target_size / origin_h
            pts_2d[:, 0] *= scale_x
            pts_2d[:, 1] *= scale_y
            
            return pts_2d, True
        except Exception as e:
            print(f"Warning: Failed to load landmarks from {landmark_path}: {e}")
    
    return None, False

def load_landmarks_from_mat(mat_path, origin_w, origin_h, target_size=224):
    """
    Fallback: Load pt2d from the image's .mat file
    """
    try:
        mat_data = sio.loadmat(mat_path)
        pt2d = mat_data["pt2d"].astype(np.float32).T  # (68, 2)
        
        # Scale to target size
        scale_x = target_size / origin_w
        scale_y = target_size / origin_h
        pt2d[:, 0] *= scale_x
        pt2d[:, 1] *= scale_y
        
        return pt2d, True
    except Exception as e:
        print(f"Warning: Failed to load pt2d from {mat_path}: {e}")
        return None, False

In [None]:
# ============================================================================
# DATASET
# ============================================================================
def get_folder_from_filename(filename):
    for folder in SELECTED_FOLDERS:
        if filename.startswith(folder):
            return folder
    raise ValueError(f"Cannot find folder for {filename}")

class Triplet300wLP(Dataset):
    def __init__(self, triplets_csv, dataset_path, transform=None, target_size=224, use_landmarks_folder=True):
        self.data = pd.read_csv(triplets_csv)
        self.dataset_path = dataset_path
        self.transform = transform
        self.target_size = target_size
        self.use_landmarks_folder = use_landmarks_folder
        
        self.shape_std = shape_std
        self.exp_std = exp_std
        self.pose_mean = pose_mean  
        self.pose_std = pose_std

        # Check landmarks folder
        self.landmarks_path = os.path.join(dataset_path, "landmarks")
        self.has_landmarks_folder = os.path.exists(self.landmarks_path)
        
        if self.use_landmarks_folder and self.has_landmarks_folder:
            print(f"âœ“ Will use landmarks from: {self.landmarks_path}")
        else:
            print(f"âœ“ Will use pt2d from .mat files")
        
        print(f"Dataset initialized with {len(self.data)} triplets")
        
        self.folder_map = {}
        for folder in SELECTED_FOLDERS:
            folder_path = os.path.join(dataset_path, folder)
            if not os.path.exists(folder_path):
                continue
            for f in os.listdir(folder_path):
                if f.endswith(".jpg") or f.endswith(".mat"):
                    self.folder_map[f] = folder
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        imgs = []
        front_label = None
        front_landmarks = None
        
        for view in ["front", "left", "right"]:
            filename_img = row[f'{view}_img']
            filename_mat = row[f'{view}_mat']
            
            folder = get_folder_from_filename(filename_img)
            img_path = os.path.join(self.dataset_path, folder, filename_img)
            mat_path = os.path.join(self.dataset_path, folder, filename_mat)
            
            img = Image.open(img_path).convert("RGB")
            orig_w, orig_h = img.size
            if self.transform:
                img = self.transform(img)
            else:
                img = transforms.ToTensor()(img)
            imgs.append(img)

            if view == "front":
                mat_data = sio.loadmat(mat_path)
                shape_para = mat_data["Shape_Para"].flatten()
                exp_para = mat_data["Exp_Para"].flatten()
                pose_para_orig = mat_data["Pose_Para"].flatten()

                pose_para = convert_300wlp_to_bfm_pose(
                    pose_para_orig, 
                    orig_w, orig_h,
                    self.target_size
                )
                
                # Normalize
                shape_para = shape_para / self.shape_std
                exp_para = exp_para / self.exp_std
                pose_para = (pose_para - self.pose_mean) / self.pose_std
                
                front_label = np.concatenate([shape_para, exp_para, pose_para], axis=0)

                if self.use_landmarks_folder and self.has_landmarks_folder:
                    landmarks, success = load_landmarks_from_folder(
                        self.dataset_path, folder, filename_mat, 
                        orig_w, orig_h, self.target_size
                    )
                    if not success:
                        landmarks, success = load_landmarks_from_mat(
                            mat_path, orig_w, orig_h, self.target_size
                        )
                else:
                    landmarks, success = load_landmarks_from_mat(
                        mat_path, orig_w, orig_h, self.target_size
                    )
                
                if success:
                    front_landmarks = landmarks
                else:
                    front_landmarks = np.zeros((68, 2), dtype=np.float32)
                    print(f"Warning: Could not load landmarks for {filename_mat}")
        
        input_tensor = torch.cat(imgs, dim=0)
        # input_tensor = torch.stack(imgs, dim=0)
        label_tensor = torch.from_numpy(front_label).float()
        landmark_tensor = torch.from_numpy(front_landmarks).float()
        
        return input_tensor, label_tensor, landmark_tensor

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = Triplet300wLP(TRIPLETS_CSV, DATA_ROOT, transform=transform)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, 
                   num_workers=4, pin_memory=True, prefetch_factor=2)
print(f"Dataset created: {len(dataset)} samples")

In [None]:
# TRAIN/TEST SPLIT (80/20)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, 
    [train_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                          num_workers=4, pin_memory=True, prefetch_factor=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                         num_workers=4, pin_memory=True, prefetch_factor=2)

print(f"Dataset created: {len(dataset)} total samples")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Test:  {len(test_dataset)} samples")

In [None]:
# ============================================================================
# MODEL - OPTIMIZED FOR SPEED
# ============================================================================
class ResNetEncoder(nn.Module):
    def __init__(self, feat_dim=512, num_shape=199, num_exp=29, num_pose=7, weights_path=None):
        super().__init__()

        self.num_shape = num_shape
        self.num_exp = num_exp
        self.num_pose = num_pose
        
        # Load ResNet50
        if weights_path is not None:
            base_model = models.resnet50(weights=None)
            state_dict = torch.load(weights_path)
            base_model.load_state_dict(state_dict)
        else:
            base_model = models.resnet50(weights=None)
        
        self.backbone = nn.Sequential(*list(base_model.children())[:-2])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.feat_dim = 2048

        # Learnable fusion weights
        self.w = nn.Parameter(torch.ones(3, self.feat_dim) / 3.0)
        
        self.fc_shape = nn.Sequential(
            nn.Linear(self.feat_dim * 3, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, num_shape)
        )
        
        self.fc_exp = nn.Sequential(
            nn.Linear(self.feat_dim * 3, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, num_exp)
        )
        
        self.fc_pose = nn.Sequential(
            nn.Linear(self.feat_dim, 256),
            nn.ReLU(True),
            nn.Dropout(0.3),
            nn.Linear(256, num_pose)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for m in [self.fc_shape, self.fc_exp]:
            for layer in m.modules():
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_normal_(layer.weight)
                    if layer.bias is not None:
                        nn.init.constant_(layer.bias, 0)
        
        for layer in self.fc_pose.modules():
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight, gain=0.01)
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, 0)
    
    def forward(self, x):
        """
        âœ… OPTIMIZED: Process all 3 views in a single batch pass
        """
        B = x.size(0)
        
        # âœ… Stack all 3 views into one big batch: (B*3, 3, H, W)
        front = x[:, 0:3, :, :]
        left = x[:, 3:6, :, :]
        right = x[:, 6:9, :, :]
        all_views = torch.cat([front, left, right], dim=0)  # (B*3, 3, 224, 224)
        
        # âœ… ONE backbone call for all views!
        all_feats = self.backbone(all_views)  # (B*3, 2048, 7, 7)
        all_feats = self.avgpool(all_feats).view(B*3, -1)  # (B*3, 2048)
        
        # Split back into 3 views
        feat_a = all_feats[:B]
        feat_b = all_feats[B:2*B]
        feat_c = all_feats[2*B:]    # (B, 2048)

        # Weighted fusion
        feat_a_w = feat_a * self.w[0]
        feat_b_w = feat_b * self.w[1]
        feat_c_w = feat_c * self.w[2]
        feat_fused = torch.cat([feat_a_w, feat_b_w, feat_c_w], dim=1)
        
        # Predict shape & expression
        # shape_params = self.fc_shape(feat_fused)
        shape_params = torch.tanh(self.fc_shape(feat_fused)) * 3.0  # Constrained to [-3, +3]
        exp_params = self.fc_exp(feat_fused)
        
        # Predict pose (use original features)
        pose_a = self.fc_pose(feat_a)
        pose_b = self.fc_pose(feat_b)
        pose_c = self.fc_pose(feat_c)
        
        # Constrain pose
        # pose_a = torch.tanh(pose_a) * 3.0
        # pose_b = torch.tanh(pose_b) * 2.0
        # pose_c = torch.tanh(pose_c) * 2.0
        
        output = torch.cat([shape_params, exp_params, pose_a], dim=1)
        return output

In [None]:
# ============================================================================
# LANDMARK PROJECTION - OPTIMIZED
# ============================================================================
def angle_to_rotation_batch(angles):
    """angles: (B, 3) = [pitch (x), yaw (y), roll (z)], return: (B, 3, 3)"""
    B = angles.shape[0]
    device = angles.device
    
    # 1. Apply empirical sign flip for 300W-LP compatibility
    # Keep Pitch and Roll as they are, flip Yaw
    phi = -angles[:, 0]  # Pitch (X)
    gamma = -angles[:, 1] # Yaw (Y) - CRITICAL FIX
    theta = angles[:, 2] # Roll (Z)

    cp, sp = torch.cos(phi), torch.sin(phi)
    cg, sg = torch.cos(gamma), torch.sin(gamma)
    ct, st = torch.cos(theta), torch.sin(theta)

    # Batched R_x matrix (Rotation around X - Pitch)
    R_x = torch.zeros(B, 3, 3, device=device)
    R_x[:, 0, 0] = 1.0
    R_x[:, 1, 1] = cp
    R_x[:, 1, 2] = sp
    R_x[:, 2, 1] = -sp
    R_x[:, 2, 2] = cp

    # Batched R_y matrix (Rotation around Y - Yaw)
    R_y = torch.zeros(B, 3, 3, device=device)
    R_y[:, 1, 1] = 1.0
    R_y[:, 0, 0] = cg
    R_y[:, 0, 2] = -sg
    R_y[:, 2, 0] = sg
    R_y[:, 2, 2] = cg

    # Batched R_z matrix (Rotation around Z - Roll)
    R_z = torch.zeros(B, 3, 3, device=device)
    R_z[:, 2, 2] = 1.0
    R_z[:, 0, 0] = ct
    R_z[:, 0, 1] = st
    R_z[:, 1, 0] = -st
    R_z[:, 1, 1] = ct

    # Composition: R = R_x @ R_y @ R_z (sticking to the author's extrinsic order)
    R = torch.bmm(R_x, torch.bmm(R_y, R_z))
    
    return R

def decode_params(params):
    """
    Decode parameters: (B, 235) -> 3D shape
    params: (B, 235) where:
      - [0:199] = shape parameters
      - [199:228] = expression parameters  
      - [228:235] = pose parameters
    return: shape_3d, R, tx, ty, s
    """
    B = params.shape[0]
    
    shape_norm = params[:, :199]      # (B, 199)
    exp_norm   = params[:, 199:228]   # (B, 29)
    pose_norm  = params[:, 228:235]   # (B, 7)

    # DE-NORMALIZE
    sigma_shape_expanded = sigma_shape_t.squeeze()[None, :].expand(B, -1)  # (B, 199)
    sigma_exp_expanded = exp_std_t.squeeze()[None, :].expand(B, -1)      # (B, 29)
    pose_mean_expanded = pose_mean_t[None, :].expand(B, -1)                # (B, 7)
    pose_std_expanded = pose_std_t[None, :].expand(B, -1)                  # (B, 7)
    
    alpha = shape_norm * sigma_shape_expanded
    # beta  = exp_norm / (1000.0 * sigma_exp_expanded)
    beta  = exp_norm * sigma_exp_expanded
    print(alpha.max(), alpha.min(), beta.max(), beta.min())

    pose = pose_norm * pose_std_expanded + pose_mean_expanded
    
    tx, ty, tz, s = pose[:,3], pose[:,4], pose[:,5], pose[:,6]

    R = angle_to_rotation_batch(pose[:, :3])

    # 3D FACE SHAPE
    mu_expanded = mu_shape_t.squeeze()[None, None, :].expand(B, -1, -1)  # (B, 1, N*3)
    s_comp = torch.matmul(alpha, w_shape_t.transpose(0,1))               # (B, N*3)
    e_comp = torch.matmul(beta,  w_exp_t.transpose(0,1))                 # (B, N*3)
    e_comp = e_comp / 1000.0
    
    shape_3d = mu_expanded.squeeze(1) + s_comp + e_comp  # (B, N*3)
    shape_3d = shape_3d.view(B, -1, 3)  # (B, N, 3)
    # thá»­ rescale láº¡i
    # shape_3d = shape_3d / 1000.0

    print("shape_3d max/min:", shape_3d.max(), shape_3d.min())
    print("R nan?", torch.isnan(R).any())
    print("s range:", s.min(), s.max())
    print("tx range:", tx.min(), tx.max())
    print("ty range:", ty.min(), ty.max())

    return shape_3d, R, tx, ty, s


def project_landmarks_batched(params):
    """Project landmarks from 3D face model"""
    try:
        shape_3d, R, tx, ty, s = decode_params(params[:, :235])

        B = shape_3d.shape[0]
        
        # Extract landmark indices
        kpts3d = shape_3d[:, kpt_idx_t, :]  # (B, 68, 3)
        R2 = R[:, :2, :]  # (B, 2, 3)

        # Project: 2D = s * (R[0:2] @ xyz) + t
        proj = torch.bmm(kpts3d, R2.transpose(1,2))  # (B, 68, 2)
        proj = proj * s[:,None,None]  # Scale
        proj[:,:,0] += tx[:,None]  # Translate X
        proj[:,:,1] -= ty[:,None]  # Translate Y

        # Flip Y axis
        proj[:,:,1] = 224 - proj[:,:,1]

        return proj
    except Exception as e:
        raise RuntimeError(f"Landmark projection failed: {e}")

In [None]:
# ============================================================================
# LOSS FUNCTION
# ============================================================================
def compute_loss(pred, target, landmarks_gt, lambda_dict):
    """Compute loss with fixed landmark projection"""
    l_s = F.mse_loss(pred[:,:199], target[:,:199])
    l_e = F.mse_loss(pred[:,199:228], target[:,199:228])
    l_p = F.mse_loss(pred[:,228:235], target[:,228:235])

    total = (
        lambda_dict['shape'] * l_s +
        lambda_dict['exp']   * l_e +
        lambda_dict['pose']  * l_p
    )

    l_lmk = torch.tensor(0.0, device=pred.device)
    
    if lambda_dict['landmark'] > 0:
        try:
            lmk_pred = project_landmarks_batched(pred)  # (B, 68, 2)
            print(lmk_pred.max(), lmk_pred.min(), torch.isnan(lmk_pred).any())
            
            # Ensure shapes match
            assert lmk_pred.shape == landmarks_gt.shape, \
                f"Shape mismatch: pred {lmk_pred.shape} vs gt {landmarks_gt.shape}"
            
            # Compute Euclidean distance
            diff = lmk_pred - landmarks_gt  # (B, 68, 2)
            dist = torch.sqrt((diff ** 2).sum(dim=2) + 1e-8)  # (B, 68)
            l_lmk = dist.mean()
            
            total += lambda_dict['landmark'] * l_lmk
        except Exception as e:
            print(f"Landmark loss error: {str(e)}")
            l_lmk = torch.tensor(0.0, device=pred.device)

    return total, {
        "total": total.item(),
        "shape": l_s.item(),
        "exp": l_e.item(),
        "pose": l_p.item(),
        "landmark": l_lmk.item() if not isinstance(l_lmk, float) else l_lmk,
    }

In [None]:
import torch
import scipy.io as sio

# Load one sample
sample_mat_path = "/kaggle/input/300w-lp/300W_LP/IBUG/IBUG_image_019_1_0.mat"  # replace with your .mat file
pts_2d_mat_path = "/kaggle/input/300w-lp/300W_LP/landmarks/IBUG/IBUG_image_019_1_0_pts.mat"  # if using separate landmarks folder

mat_data = sio.loadmat(sample_mat_path)
shape_para = mat_data["Shape_Para"].flatten()      # (199,)
exp_para   = mat_data["Exp_Para"].flatten()        # (29,)
pose_para  = mat_data["Pose_Para"].flatten()       # (7,)
pose_para = convert_300wlp_to_bfm_pose(
                    pose_para, 
                    450, 450,
                    224
                )

# Convert to torch and normalize
shape_tensor = torch.from_numpy(shape_para / shape_std).float().unsqueeze(0).to(device)
exp_tensor   = torch.from_numpy(exp_para / exp_std).float().unsqueeze(0).to(device)
pose_tensor  = torch.from_numpy((pose_para - pose_mean) / pose_std).float().unsqueeze(0).to(device)

# Concatenate to full param vector
params = torch.cat([shape_tensor, exp_tensor, pose_tensor], dim=1)  # (1, 235)

# Project landmarks
proj_landmarks = project_landmarks_batched(params)  # (1, 68, 2)

# Load ground-truth 2D landmarks
pts_2d = torch.from_numpy(sio.loadmat(pts_2d_mat_path)['pts_2d'].astype(np.float32)).unsqueeze(0).to(device)
pts_2d[:, 0] *= 224/450
pts_2d[:, 1] *= 224/450

# Compare
print("Projected landmarks shape:", proj_landmarks.shape)
print("First 5 projected landmarks:\n", proj_landmarks[0, :5])
print("First 5 ground-truth landmarks:\n", pts_2d[0, :5])

# Optional: L2 error
l2_error = torch.norm(proj_landmarks - pts_2d, dim=2).mean()
print("Mean L2 error between projected and GT landmarks:", l2_error.item())

In [None]:
# ============================================================================
# TRAINING
# ============================================================================
model = ResNetEncoder(weights_path=WEIGHTS_PATH).to(device)

# âœ… Use mixed precision training for speed
scaler = torch.amp.GradScaler('cuda')

def get_stage_config(epoch):
    if epoch < 1:
        return TRAINING_SCHEDULE['stage1']
    #elif epoch < 2:
        #return TRAINING_SCHEDULE['stage2']
    else:
        return TRAINING_SCHEDULE['stage2']

def train_epoch(model, dataloader, optimizer, lambda_dict, use_amp=True):
    model.train()
    
    total_samples = 0
    epoch_losses = {
        "total": 0.0,
        "shape": 0.0,
        "exp": 0.0,
        "pose": 0.0,
        "landmark": 0.0
    }
    
    for batch_idx, (imgs, targets, landmarks_gt) in enumerate(dataloader):
        imgs = imgs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        landmarks_gt = landmarks_gt.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        
        # âœ… Mixed precision training
        if use_amp:
            with torch.amp.autocast('cuda'):
                outputs = model(imgs)
                loss, loss_dict = compute_loss(outputs, targets, landmarks_gt, lambda_dict)
            
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"WARNING: NaN/Inf loss at batch {batch_idx}, skipping")
                continue
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(imgs)
            loss, loss_dict = compute_loss(outputs, targets, landmarks_gt, lambda_dict)
            
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"WARNING: NaN/Inf loss at batch {batch_idx}, skipping")
                continue
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        
        batch_size = imgs.size(0)
        total_samples += batch_size
        for key in epoch_losses:
            epoch_losses[key] += loss_dict[key] * batch_size
        
        if (batch_idx + 1) % 100 == 0:
            print(f"  Batch {batch_idx+1}/{len(dataloader)} | "
                  f"Loss: {loss_dict['total']:.4f} "
                  f"(S:{loss_dict['shape']:.4f} "
                  f"E:{loss_dict['exp']:.4f} "
                  f"P:{loss_dict['pose']:.4f} "
                  f"L:{loss_dict['landmark']:.2f})")
    
    for key in epoch_losses:
        epoch_losses[key] /= max(total_samples, 1)
    
    return epoch_losses

def eval_epoch(model, dataloader, lambda_dict):
    """Evaluate on test set"""
    model.eval()
    
    total_samples = 0
    epoch_losses = {
        "total": 0.0,
        "shape": 0.0,
        "exp": 0.0,
        "pose": 0.0,
        "landmark": 0.0
    }
    
    with torch.no_grad():
        for imgs, targets, landmarks_gt in dataloader:
            imgs = imgs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            landmarks_gt = landmarks_gt.to(device, non_blocking=True)
            
            outputs = model(imgs)
            loss, loss_dict = compute_loss(outputs, targets, landmarks_gt, lambda_dict)
            
            if torch.isnan(loss) or torch.isinf(loss):
                continue
            
            batch_size = imgs.size(0)
            total_samples += batch_size
            for key in epoch_losses:
                epoch_losses[key] += loss_dict[key] * batch_size
    
    for key in epoch_losses:
        epoch_losses[key] /= max(total_samples, 1)
    
    return epoch_losses

# Training loop
print("\n" + "="*60)
print("STARTING STAGED TRAINING")
print("="*60)

best_loss = float('inf')
best_epoch = 0
optimizer = None
current_lr = None

for epoch in range(NUM_EPOCHS):
    stage_config = get_stage_config(epoch)
    
    if current_lr != stage_config['lr']:
        current_lr = stage_config['lr']
        optimizer = torch.optim.Adam(model.parameters(), lr=current_lr)
        print(f"\nðŸ”„ Learning rate changed to {current_lr}")
    
    lambda_dict = {
        'shape': stage_config['shape'],
        'exp': stage_config['exp'],
        'pose': stage_config['pose'],
        'landmark': stage_config['landmark']
    }
    
    print(f"\n{'='*60}")
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    if epoch < 1:
        print("ðŸ”´ STAGE 1: Learning Shape & Expression only")
    # elif epoch < 2:
        # print("ðŸŸ¡ STAGE 2: Adding Pose")
    else:
        print("ðŸŸ¢ STAGE 3: Full training with Landmarks")
    print(f"Lambda: S={lambda_dict['shape']:.3f} E={lambda_dict['exp']:.3f} "
          f"P={lambda_dict['pose']:.4f} L={lambda_dict['landmark']:.4f}")
    print('='*60)
        
    t0 = time.time()
    epoch_losses = train_epoch(model, train_loader, optimizer, lambda_dict, use_amp=True)
    test_losses = eval_epoch(model, test_loader, lambda_dict)
    t1 = time.time()
        
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  TRAIN - Total: {epoch_losses['total']:.6f}, Shape: {epoch_losses['shape']:.6f}, Exp: {epoch_losses['exp']:.6f}, Pose: {epoch_losses['pose']:.6f}, Landmark: {epoch_losses['landmark']:.4f}")
    print(f"  TEST  - Total: {test_losses['total']:.6f}, Shape: {test_losses['shape']:.6f}, Exp: {test_losses['exp']:.6f}, Pose: {test_losses['pose']:.6f}, Landmark: {test_losses['landmark']:.4f}")
    print(f"  Time: {t1-t0:.1f}s")
        
    current_loss = test_losses['total']
    if current_loss < best_loss:
        best_loss = current_loss
        best_epoch = epoch + 1
            
        torch.save({
            "epoch": epoch+1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "train_loss": epoch_losses['total'],
            "test_loss": test_losses['total'],
            "train_losses": epoch_losses,
            "test_losses": test_losses,
        }, CHECKPOINT_PATH)
            
        print(f"  âœ“ NEW BEST MODEL! Test Loss: {current_loss:.6f}")
    else:
        print(f"  Best test loss: {best_loss:.6f} (Epoch {best_epoch})")

print("\n" + "="*60)
print("âœ“ Training Complete!")
print(f"âœ“ Best model from Epoch {best_epoch} with loss {best_loss:.6f}")
print("="*60)