In [None]:
import os, math, csv, time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn.functional as F
from PIL import Image
import numpy as np
import pandas as pd
import scipy.io as sio
from collections import OrderedDict

device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
# PATH
SELECTED_FOLDERS = ["LFPW"] # tạm thời chỉ lấy 1 để tiết kiệm thời gian
DATA_ROOT = "/kaggle/input/300w-lp/300W_LP"
TRIPLETS_CSV = "/kaggle/input/multi-view-input/triplets_shorter_LFPW.csv"
WEIGHTS_PATH = "/kaggle/input/multi-view-input/resnet50-11ad3fa6.pth"
BATCH_SIZE = 16 # lớn hơn Kaggle không chạy được
NUM_EPOCHS = 2
LR = 1e-5
CHECKPOINT_PATH = "resnet50_mvfnet.pth"

In [None]:
def load_params_from_mat(mat_path):
    """
    Return params vector (concatenate of Shape_Para, Exp_Para, Pose_Para)
    """
    mat = sio.loadmat(mat_path)
    
    shape = np.array(mat.get("Shape_Para", np.zeros((1, 0)))).reshape(-1)
    exp = np.array(mat.get("Exp_Para", np.zeros((1, 0)))).reshape(-1)
    pose = np.array(mat.get("Pose_Para", np.zeros((1, 0)))).reshape(-1)

    params = np.concatenate([shape, exp, pose]).astype(np.float32)
    return params

In [None]:
def get_folder_from_filename(filename):
    for folder in SELECTED_FOLDERS:
        if filename.startswith(folder):
            return folder
    raise ValueError(f"Cannot find folder for {filename}")

In [None]:
class Triplet300wLP(Dataset):
    def __init__(self, triplets_csv, dataset_path, transform=None):
        self.data = pd.read_csv(triplets_csv)
        self.dataset_path = dataset_path
        self.transform = transform
        
        # Normalization constants (from your statistics)
        self.shape_std = 66912.0   # From your data check
        self.exp_std = 0.6         # From your data check
        self.pose_std = 107.9      # From your data check
        
        # Create folder mapping
        self.folder_map = {}
        for folder in SELECTED_FOLDERS:
            folder_path = os.path.join(dataset_path, folder)
            if not os.path.exists(folder_path):
                continue
            for f in os.listdir(folder_path):
                if f.endswith('.jpg') or f.endswith('.mat'):
                    self.folder_map[f] = folder
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        imgs = []
        front_label = None
        
        for view in ["front", "left", "right"]:
            filename_img = row[f'{view}_img']
            filename_mat = row[f'{view}_mat']
            
            folder = get_folder_from_filename(filename_img)
            img_path = os.path.join(self.dataset_path, folder, filename_img)
            mat_path = os.path.join(self.dataset_path, folder, filename_mat)
            
            img = Image.open(img_path).convert("RGB")
            if self.transform:
                img = self.transform(img)
            else:
                img = transforms.ToTensor()(img)
            imgs.append(img)
            
            if view == "front":
                mat_data = sio.loadmat(mat_path)
                shape_para = mat_data["Shape_Para"].flatten()
                exp_para = mat_data["Exp_Para"].flatten()
                pose_para = mat_data["Pose_Para"].flatten()
                
                # NORMALIZE
                shape_para = shape_para / self.shape_std
                exp_para = exp_para / self.exp_std
                pose_para = pose_para / self.pose_std
                
                front_label = np.concatenate([shape_para, exp_para, pose_para], axis=0)
        
        input_tensor = torch.cat(imgs, dim=0)
        label_tensor = torch.from_numpy(front_label).float()
        return input_tensor, label_tensor

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# Replace dataset creation
dataset = Triplet300wLP(TRIPLETS_CSV, DATA_ROOT, transform=transform)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

In [None]:
class ResNetEncoder(nn.Module):
    def __init__(self, feat_dim=512, num_shape=199, num_exp=29, num_pose=7, weights_path=None):
        super().__init__()

        self.num_shape = num_shape
        self.num_exp = num_exp
        self.num_pose = num_pose
        
        # Load ResNet50
        if weights_path is not None:
            # Tạo model và load state_dict từ file local
            base_model = models.resnet50(weights=None)
            state_dict = torch.load(weights_path)
            base_model.load_state_dict(state_dict)
        else:
            # Không dùng pretrained
            base_model = models.resnet50(weights=None)
        
        # Chỉ giữ convolutional layers
        self.backbone = nn.Sequential(*list(base_model.children())[:-2])
        self.avgpool = nn.AdaptiveAvgPool2d(1)  # output: (batch, 2048, 1, 1)
        self.feat_dim = 2048  # ResNet50 output channel

        # Cải tiến: Learnable fusion weights
        self.w = nn.Parameter(torch.ones(3, self.feat_dim))
        
        # Multi-view fusion for shape and expression
        self.fc_shape = nn.Sequential(
            nn.Linear(self.feat_dim * 3, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, num_shape)
        )
        
        self.fc_exp = nn.Sequential(
            nn.Linear(self.feat_dim * 3, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, num_exp)
        )
        
        # View-specific pose prediction
        self.fc_pose = nn.Sequential(
            nn.Linear(self.feat_dim, 256),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(256, num_pose)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights for newly added layers"""
        for m in [self.fc_shape, self.fc_exp, self.fc_pose]:
            for layer in m.modules():
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_normal_(layer.weight)
                    if layer.bias is not None:
                        nn.init.constant_(layer.bias, 0)
    
    def forward(self, x):
        """
        x: (B, 9, H, W) -> front|left|right concatenated along channel
        output: concatenated [3DMM, poseA, poseB, poseC]
        """

        batch_size = x.size(0)
        
        # Tách 3 view
        front = x[:, 0:3, :, :]
        left = x[:, 3:6, :, :]
        right = x[:, 6:9, :, :]

        # Trích xuất feature backbone
        feat_a = self.backbone(front)
        feat_a = self.avgpool(feat_a).view(batch_size, -1)
        feat_b = self.backbone(left)
        feat_b = self.avgpool(feat_b).view(batch_size, -1)
        feat_c = self.backbone(right)
        feat_c = self.avgpool(feat_c).view(batch_size, -1)

        # Weighted sum fusion
        feat_a = feat_a * self.w[0]   # (B, feat_dim)
        feat_b = feat_b * self.w[1]
        feat_c = feat_c * self.w[2]
    
        # Concatenate weighted features
        feat_fused = torch.cat([feat_a, feat_b, feat_c], dim=1)  # (B, 6144)
        
        # Predict view-invariant parameters
        shape_params = self.fc_shape(feat_fused)  # (B, 199)
        exp_params = self.fc_exp(feat_fused)      # (B, 29)
        
        # Predict view-specific pose parameters (use original unweighted features)
        # Re-extract features without weighting for pose prediction
        feat_a_orig = self.backbone(front)
        feat_a_orig = self.avgpool(feat_a_orig).view(batch_size, -1)
        pose_a = self.fc_pose(feat_a_orig)
        
        feat_b_orig = self.backbone(left)
        feat_b_orig = self.avgpool(feat_b_orig).view(batch_size, -1)
        pose_b = self.fc_pose(feat_b_orig)
        
        feat_c_orig = self.backbone(right)
        feat_c_orig = self.avgpool(feat_c_orig).view(batch_size, -1)
        pose_c = self.fc_pose(feat_c_orig)
        
        # Concatenate all outputs
        output = torch.cat([shape_params, exp_params, pose_a, pose_b, pose_c], dim=1)
        
        return output

In [None]:
model = ResNetEncoder(weights_path=WEIGHTS_PATH).to(device)
model

In [None]:
def supervised_loss(pred, target, lambda_shape=1.0, lambda_exp=1.0, lambda_pose=1.0):
    """
    Compute supervised loss for MVF-Net on 300W-LP dataset
    
    Args:
        pred: (B, 249) - Model predictions
              [Shape(199) | Exp(29) | PoseFront(7) | PoseLeft(7) | PoseRight(7)]
              Indices: [0:199 | 199:228 | 228:235 | 235:242 | 242:249]
              
        target: (B, 235) - Ground truth from .mat files
                [Shape(199) | Exp(29) | PoseFront(7)]
                Indices: [0:199 | 199:228 | 228:235]
        
        lambda_shape: Weight for shape loss
        lambda_exp: Weight for expression loss  
        lambda_pose: Weight for pose loss
    
    Returns:
        total_loss: Weighted sum of all losses
        loss_dict: Dictionary with individual loss values for logging
    """
    
    pred_shape = pred[:, :199]        # Shape parameters (0:199)
    pred_exp = pred[:, 199:228]       # Expression parameters (199:228)
    pred_pose_front = pred[:, 228:235]  # Front view pose (228:235)
    

    target_shape = target[:, :199]      # Shape parameters (0:199)
    target_exp = target[:, 199:228]     # Expression parameters (199:228)
    target_pose = target[:, 228:235]    # Front view pose (228:235)
    
    # Compute losses
    loss_shape = F.mse_loss(pred_shape, target_shape)
    loss_exp = F.mse_loss(pred_exp, target_exp)
    loss_pose = F.mse_loss(pred_pose_front, target_pose)
    
    # Weighted total loss
    total_loss = (lambda_shape * loss_shape + 
                  lambda_exp * loss_exp + 
                  lambda_pose * loss_pose)
    
    # Loss dictionary for logging
    loss_dict = {
        'total': total_loss.item(),
        'shape': loss_shape.item(),
        'exp': loss_exp.item(),
        'pose': loss_pose.item()
    }
    
    return total_loss, loss_dict

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

In [None]:
def train_epoch(model, dataloader, optimizer, device, epoch_num):
    """
    Complete training loop for one epoch
    """
    model.train()
    
    # Track losses
    total_samples = 0
    epoch_losses = {
        'total': 0.0,
        'shape': 0.0,
        'exp': 0.0,
        'pose': 0.0
    }
    
    for batch_idx, (imgs, targets) in enumerate(dataloader):
        imgs = imgs.to(device)      # (B, 9, H, W)
        targets = targets.to(device)  # (B, 235)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(imgs)  # (B, 249)
        
        # Compute loss
        loss, loss_dict = supervised_loss(outputs, targets)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Accumulate losses
        batch_size = imgs.size(0)
        total_samples += batch_size
        for key in epoch_losses:
            epoch_losses[key] += loss_dict[key] * batch_size
        
        # Print progress every 100 batches
        if (batch_idx + 1) % 100 == 0:
            print(f"  Batch {batch_idx+1}/{len(dataloader)} | "
                  f"Loss: {loss_dict['total']:.6f} "
                  f"(shape: {loss_dict['shape']:.6f}, "
                  f"exp: {loss_dict['exp']:.6f}, "
                  f"pose: {loss_dict['pose']:.6f})")
    
    # Average losses over epoch
    for key in epoch_losses:
        epoch_losses[key] /= total_samples
    
    return epoch_losses

In [None]:
def check_data_statistics(dataloader):
    """
    Check if your data has reasonable values
    """
    print("\n" + "="*60)
    print("CHECKING DATA STATISTICS")
    print("="*60)
    
    # Get one batch
    imgs, targets = next(iter(dataloader))
    
    print(f"\nInput Images:")
    print(f"  Shape: {imgs.shape}")
    print(f"  Min: {imgs.min().item():.4f}")
    print(f"  Max: {imgs.max().item():.4f}")
    print(f"  Mean: {imgs.mean().item():.4f}")
    print(f"  Std: {imgs.std().item():.4f}")
    
    print(f"\nTarget Parameters:")
    print(f"  Shape: {targets.shape}")
    print(f"  Min: {targets.min().item():.4f}")
    print(f"  Max: {targets.max().item():.4f}")
    print(f"  Mean: {targets.mean().item():.4f}")
    print(f"  Std: {targets.std().item():.4f}")
    
    # Check individual components
    shape_params = targets[:, :199]
    exp_params = targets[:, 199:228]
    pose_params = targets[:, 228:235]
    
    print(f"\n  Shape params - Mean: {shape_params.mean().item():.4f}, Std: {shape_params.std().item():.4f}")
    print(f"  Exp params   - Mean: {exp_params.mean().item():.4f}, Std: {exp_params.std().item():.4f}")
    print(f"  Pose params  - Mean: {pose_params.mean().item():.4f}, Std: {pose_params.std().item():.4f}")
    
    # WARNING: If any std is > 100, you might need to normalize targets!
    if shape_params.std().item() > 100:
        print("\n⚠️  WARNING: Shape parameters have very large values!")
        print("   Consider normalizing target parameters.")
    
    print("="*60 + "\n")

check_data_statistics(loader)

In [None]:
# Track best model
best_loss = float('inf')
best_epoch = 0

for epoch in range(NUM_EPOCHS):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    print('='*60)
        
    t0 = time.time()
    epoch_losses = train_epoch(model, loader, optimizer, device, epoch+1)
    t1 = time.time()
        
    # Print epoch summary
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Total Loss:      {epoch_losses['total']:.6f}")
    print(f"  Shape Loss:      {epoch_losses['shape']:.6f}")
    print(f"  Expression Loss: {epoch_losses['exp']:.6f}")
    print(f"  Pose Loss:       {epoch_losses['pose']:.6f}")
    print(f"  Time: {t1-t0:.1f}s")
        
    # Save checkpoint if this is the best model so far
    current_loss = epoch_losses['total']
    if current_loss < best_loss:
        best_loss = current_loss
        best_epoch = epoch + 1
            
        torch.save({
            "epoch": epoch+1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": current_loss,
            "loss_details": epoch_losses,
        }, CHECKPOINT_PATH)
            
        print(f"  ✓ NEW BEST MODEL! Loss: {current_loss:.6f} → Saved to {CHECKPOINT_PATH}")
    else:
        print(f"  Best loss so far: {best_loss:.6f} (Epoch {best_epoch})")
    
print("\n" + "="*60)
print("✓ Training Complete!")
print(f"✓ Best model from Epoch {best_epoch} with loss {best_loss:.6f}")
print(f"✓ Saved at: {CHECKPOINT_PATH}")
print("="*60)