In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
import os
import random
import glob

# ==========================================
# 1. DATASET: S·ª¨A L·∫†I DEPTH TH√ÄNH 1 K√äNH (GRAYSCALE)
# ==========================================

class RGBDTripletDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.rgb_color_aug = transforms.ColorJitter(brightness=0.2, contrast=0.2)
        
        # Normalize
        self.norm_rgb = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.norm_depth = transforms.Normalize(mean=[0.5], std=[0.5])
        self.people_data = {}
        self.people_names = []
        
        # C·∫•u h√¨nh t√™n folder ch√≠nh x√°c c·ªßa b·∫°n
        RGB_FOLDER_NAME = "dataset_face"   # Folder ch·ª©a ·∫£nh m√†u
        DEPTH_FOLDER_NAME = "dataset_depth" # Folder ch·ª©a ·∫£nh depth
        
        print(f"üîÑ ƒêang qu√©t d·ªØ li·ªáu... (T√¨m folder '{RGB_FOLDER_NAME}' v√† '{DEPTH_FOLDER_NAME}')")
        
        if not os.path.exists(root_dir):
             print(f"‚ùå Error: Kh√¥ng t√¨m th·∫•y ƒë∆∞·ªùng d·∫´n g·ªëc {root_dir}")
             return

        # Duy·ªát qua t·ª´ng ng∆∞·ªùi (v√≠ d·ª•: Longvu)
        for person_name in os.listdir(root_dir):
            person_path = os.path.join(root_dir, person_name)
            if not os.path.isdir(person_path): continue
            
            # T·∫°o ƒë∆∞·ªùng d·∫´n t·ªõi folder con
            rgb_dir = os.path.join(person_path, RGB_FOLDER_NAME)
            depth_dir = os.path.join(person_path, DEPTH_FOLDER_NAME)
            
            # Ki·ªÉm tra c·∫£ 2 folder c√≥ t·ªìn t·∫°i kh√¥ng
            if os.path.isdir(rgb_dir) and os.path.isdir(depth_dir):
                
                # 1. L·∫•y danh s√°ch file Depth tr∆∞·ªõc (ƒë·ªÉ l√†m map tra c·ª©u)
                # Map: {'face_0': 'ƒë∆∞·ªùng_d·∫´n_full/face_0.png'}
                depth_files_map = {}
                for f in os.listdir(depth_dir):
                    if f.lower().endswith(('.png', '.jpg', '.jpeg')):
                        name_no_ext = os.path.splitext(f)[0] # L·∫•y t√™n "face_0" b·ªè ƒëu√¥i
                        depth_files_map[name_no_ext] = os.path.join(depth_dir, f)
                
                paired_rgb = []
                
                # 2. Duy·ªát file RGB v√† t√¨m c·∫∑p trong Depth
                rgb_files = [f for f in os.listdir(rgb_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                
                for f in rgb_files:
                    name_no_ext = os.path.splitext(f)[0] # "face_0"
                    
                    # N·∫øu t√™n n√†y c√≥ trong depth map -> C√≥ c·∫∑p!
                    if name_no_ext in depth_files_map:
                        rgb_full_path = os.path.join(rgb_dir, f)
                        depth_full_path = depth_files_map[name_no_ext]
                        paired_rgb.append((rgb_full_path, depth_full_path))
                
                # N·∫øu ng∆∞·ªùi n√†y c√≥ √≠t nh·∫•t 2 c·∫∑p ·∫£nh (ƒë·ªÉ ch·ªçn pos/neg)
                if len(paired_rgb) > 1:
                    self.people_data[person_name] = paired_rgb
                    self.people_names.append(person_name)
        
        print(f"‚úÖ ƒê√£ load: {len(self.people_names)} ng∆∞·ªùi.")
        if len(self.people_names) == 0:
            print("‚ùå V·∫´n ch∆∞a t√¨m th·∫•y ·∫£nh! H√£y ki·ªÉm tra k·ªπ t√™n folder con b√™n trong.")
            print(f"   V√≠ d·ª• mong ƒë·ª£i: {root_dir}/Longvu/{RGB_FOLDER_NAME}/face_0.jpg")

    def __len__(self):
        return sum([len(imgs) for imgs in self.people_data.values()])

    def load_tuple(self, pair_paths):
        rgb_path, depth_path = pair_paths
        
        # 1. Load ·∫£nh (Ch·ªâ m·ªü file, CH∆ØA transform th√†nh tensor)
        try:
            rgb_pil = Image.open(rgb_path).convert("RGB")
            depth_pil = Image.open(depth_path).convert("L") # Grayscale
        except Exception as e:
            print(f"‚ùå L·ªói: {rgb_path}")
            # Tr·∫£ v·ªÅ ·∫£nh ƒëen n·∫øu l·ªói
            rgb_pil = Image.new('RGB', (224, 224))
            depth_pil = Image.new('L', (224, 224))

        # Tr·∫£ v·ªÅ ·∫£nh PIL nguy√™n b·∫£n ƒë·ªÉ ƒë∆∞a v√†o sync
        return rgb_pil, depth_pil

    def transform_sync(self, rgb_img, depth_img):
        """H√†m bi·∫øn ƒë·ªïi ƒë·ªìng b·ªô: RGB l·∫≠t th√¨ Depth c≈©ng l·∫≠t"""
        
        # 1. Resize (B·∫Øt bu·ªôc gi·ªëng nhau)
        rgb_img = TF.resize(rgb_img, (224, 224))
        depth_img = TF.resize(depth_img, (224, 224))
        
        # 2. Random Horizontal Flip (ƒê·ªíNG B·ªò)
        if random.random() > 0.5:
            rgb_img = TF.hflip(rgb_img)
            depth_img = TF.hflip(depth_img)
            
        # 3. Random Rotation (ƒê·ªíNG B·ªò - Quan tr·ªçng v√¨ d·ªØ li·ªáu √≠t)
        if random.random() > 0.5:
            angle = random.uniform(-10, 10)
            rgb_img = TF.rotate(rgb_img, angle)
            depth_img = TF.rotate(depth_img, angle)
            
        # 4. Color Jitter (CH·ªà RGB - Kh√¥ng l√†m bi·∫øn ƒë·ªïi h√¨nh h·ªçc)
        # Depth kh√¥ng ƒë∆∞·ª£c ch·ªânh m√†u v√¨ s·∫Ω sai gi√° tr·ªã ƒë·ªô s√¢u
        rgb_img = self.rgb_color_aug(rgb_img)
        
        # 5. ToTensor & Normalize (Ri√™ng bi·ªát)
        rgb_t = TF.to_tensor(rgb_img)
        rgb_t = self.norm_rgb(rgb_t)
        
        depth_t = TF.to_tensor(depth_img)
        depth_t = self.norm_depth(depth_t)
        
        return rgb_t, depth_t
        
    def __getitem__(self, idx):
        # Ch·ªçn Anchor
        anchor_name = random.choice(self.people_names)
        anchor_imgs = self.people_data[anchor_name]
        anchor_pair = random.choice(anchor_imgs)
        
        # Ch·ªçn Positive
        pos_pair = random.choice(anchor_imgs)
        while pos_pair == anchor_pair and len(anchor_imgs) > 1:
            pos_pair = random.choice(anchor_imgs)
            
        # Ch·ªçn Negative
        neg_name = random.choice(self.people_names)
        while neg_name == anchor_name:
            neg_name = random.choice(self.people_names)
        neg_pair = random.choice(self.people_data[neg_name])
        
        a_rgb_pil, a_dep_pil = self.load_tuple(anchor_pair)
        p_rgb_pil, p_dep_pil = self.load_tuple(pos_pair)
        n_rgb_pil, n_dep_pil = self.load_tuple(neg_pair)

        # Transform ƒë·ªìng b·ªô
        a_rgb, a_d = self.transform_sync(a_rgb_pil, a_dep_pil)
        p_rgb, p_d = self.transform_sync(p_rgb_pil, p_dep_pil)
        n_rgb, n_d = self.transform_sync(n_rgb_pil, n_dep_pil)

        return {
            "anchor": (a_rgb, a_d),
            "positive": (p_rgb, p_d),
            "negative": (n_rgb, n_d)
        }

# ==========================================
# 2. MODEL: X·ª¨ L√ù WEIGHTS TH√îNG MINH
# ==========================================

def smart_load_weights(model, weight_path, device):
    """H√†m h·ªó tr·ª£ load weights b·ªè qua ti·ªÅn t·ªë 'backbone.' ho·∫∑c 'module.' n·∫øu c√≥"""
    state_dict = torch.load(weight_path, map_location=device)
    new_state_dict = {}
    
    for key, value in state_dict.items():
        # X√≥a ti·ªÅn t·ªë th∆∞·ªùng g·∫∑p khi train SSL
        new_key = key.replace("backbone.", "").replace("module.", "").replace("encoder.", "")
        new_state_dict[new_key] = value
        
    # Load v·ªõi strict=False ƒë·ªÉ b·ªè qua c√°c l·ªõp FC (classifier) b·ªã l·ªách
    missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
    print(f"üîπ Load {weight_path}: Missing keys (FC layers): {len(missing)}")

class FaceModelTrain(nn.Module):
    def __init__(self, rgb_pth, depth_pth, device, freeze_backbone=False):
        super().__init__()
        self.device = device
        
        # ==========================================
        # 1. RGB ENCODER: S·ª¨A TH√ÄNH RESNET18
        # ==========================================
        print(f"üîπ ƒêang load RGB Encoder: ResNet18...")
        base_rgb = models.resnet18(weights=None) # <--- S·ª¨A T·ª™ 50 V·ªÄ 18
        
        # Load weights
        smart_load_weights(base_rgb, rgb_pth, device)
        
        self.rgb_backbone = nn.Sequential(*list(base_rgb.children())[:-1]) 
        # ResNet18 c√≥ output feature l√† 512 (thay v√¨ 2048 nh∆∞ ResNet50)
        self.rgb_projector = nn.Linear(512, 512) 

        # ==========================================
        # 2. DEPTH ENCODER: EFFICIENTNET-B0
        # ==========================================
        print(f"üîπ ƒêang load Depth Encoder: EfficientNet-B0...")
        base_depth = models.efficientnet_b0(weights=None)
        
        # S·ª≠a input conv th√†nh 1 k√™nh (Grayscale)
        base_depth.features[0][0] = nn.Conv2d(
            1, 32, kernel_size=3, stride=2, padding=1, bias=False
        )
        
        # Load weights
        smart_load_weights(base_depth, depth_pth, device)
        
        self.depth_features = base_depth.features
        self.depth_pool = nn.AdaptiveAvgPool2d(1)
        self.depth_projector = nn.Linear(1280, 512)

        # ==========================================
        # 3. FUSION HEAD
        # ==========================================
        # Input: 512 (RGB) + 512 (Depth) = 1024
        self.fusion_head = nn.Sequential(
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 512) 
        )

        # Freeze Backbone logic
        if freeze_backbone:
            for param in self.rgb_backbone.parameters():
                param.requires_grad = False
            for param in self.depth_features.parameters():
                param.requires_grad = False
            print("‚ùÑÔ∏è ƒê√£ ƒë√≥ng bƒÉng Backbone (ResNet18 & EfficientNet).")

        self.to(device)

    def forward_one(self, rgb, depth):
        # RGB Stream
        x_rgb = self.rgb_backbone(rgb).view(rgb.size(0), -1) # (B, 512)
        x_rgb = self.rgb_projector(x_rgb)
        x_rgb = F.normalize(x_rgb)

        # Depth Stream
        x_d = self.depth_features(depth)
        x_d = self.depth_pool(x_d).flatten(1) # (B, 1280)
        x_d = self.depth_projector(x_d)       # (B, 512)
        x_d = F.normalize(x_d)

        # Fusion
        concat = torch.cat([x_rgb, x_d], dim=1) # (B, 1024)
        embedding = self.fusion_head(concat)    # (B, 512)
        
        return F.normalize(embedding)

    def forward(self, a_r, a_d, p_r, p_d, n_r, n_d):
        emb_a = self.forward_one(a_r, a_d)
        emb_p = self.forward_one(p_r, p_d)
        emb_n = self.forward_one(n_r, n_d)
        return emb_a, emb_p, emb_n

# --- Triplet Loss (Gi·ªØ nguy√™n) ---
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
    def forward(self, anchor, positive, negative):
        dist_pos = (anchor - positive).pow(2).sum(1)
        dist_neg = (anchor - negative).pow(2).sum(1)
        losses = F.relu(dist_pos - dist_neg + self.margin)
        return losses.mean()


In [2]:
def train():
    # 1. Config
    BATCH_SIZE = 32      
    LR = 0.0001
    EPOCHS = 20
    DATA_PATH = "/kaggle/input/people/lfw_processed" 
    RGB_PTH = "/kaggle/input/deeplearn/rgb_encoder_epoch20.pth" 
    DEPTH_PTH = "/kaggle/input/deeplearn-eff/depth_encoder_epoch20.pth"
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {DEVICE}")
    train_loss = []
    train_acc = []

    # 3. Data Loader
    dataset = RGBDTripletDataset(DATA_PATH)
    # L∆∞u √Ω: drop_last=True ƒë·ªÉ tr√°nh l·ªói batch l·∫ª 1 m·∫´u g√¢y l·ªói BatchNorm
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)

    # 4. Init Model (ResNet18 + EfficientNet)
    model = FaceModelTrain(
        rgb_pth=RGB_PTH, 
        depth_pth=DEPTH_PTH, 
        device=DEVICE,
        freeze_backbone=True 
    )
    model.train()

    criterion = TripletLoss(margin=1.0)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)

    # 5. Training Loop
    for epoch in range(EPOCHS):
        total_loss = 0
        total_correct = 0 # ƒê·∫øm s·ªë m·∫´u ƒë√∫ng
        total_samples = 0
        
        # Unfreeze sau 5 epoch
        if epoch == 5:
            print("üîì Unfreezing Backbones...")
            for param in model.rgb_backbone.parameters(): param.requires_grad = True
            for param in model.depth_features.parameters(): param.requires_grad = True
            optimizer = optim.Adam(model.parameters(), lr=LR * 0.1)

        for batch_idx, batch in enumerate(dataloader):
            a_r, a_d = batch["anchor"][0].to(DEVICE), batch["anchor"][1].to(DEVICE)
            p_r, p_d = batch["positive"][0].to(DEVICE), batch["positive"][1].to(DEVICE)
            n_r, n_d = batch["negative"][0].to(DEVICE), batch["negative"][1].to(DEVICE)

            optimizer.zero_grad()
            
            # Forward
            emb_a, emb_p, emb_n = model(a_r, a_d, p_r, p_d, n_r, n_d)

            # T√≠nh Loss
            loss = criterion(emb_a, emb_p, emb_n)
            
            # Backward
            loss.backward()
            optimizer.step()

            # --- T√çNH ACCURACY ---
            with torch.no_grad():
                # Kho·∫£ng c√°ch A-P v√† A-N
                dist_pos = (emb_a - emb_p).pow(2).sum(1)
                dist_neg = (emb_a - emb_n).pow(2).sum(1)
                
                # ƒê√∫ng n·∫øu dist_pos < dist_neg
                pred_correct = (dist_pos < dist_neg).sum().item()
                
                total_correct += pred_correct
                total_samples += a_r.size(0)

            total_loss += loss.item()
            
            # In log chi ti·∫øt
            if batch_idx % 5 == 0:
                acc_batch = pred_correct / a_r.size(0)
                print(f"Ep {epoch+1} | Batch {batch_idx} | Loss: {loss.item():.4f} | Acc: {acc_batch:.2%}")

        # T·ªïng k·∫øt epoch
        avg_loss = total_loss / len(dataloader)
        avg_acc = total_correct / total_samples if total_samples > 0 else 0
        print(f"===> End Ep {epoch+1} | Avg Loss: {avg_loss:.4f} | Avg Acc: {avg_acc:.2%}")
        train_acc.append(avg_acc)
        train_loss.append(avg_loss)
        
        if (epoch+1) % 5 == 0:
            torch.save(model.state_dict(), f"fusion_face_ep{epoch+1}.pth")

    np.save("train_loss.npy", np.array(train_loss))
    np.save("train_accuracy.npy", np.array(train_acc))
    torch.save(model.state_dict(), "fusion_face_final.pth")
    print("Training Complete!")

In [3]:
train()

Device: cuda
üîÑ ƒêang qu√©t d·ªØ li·ªáu... (T√¨m folder 'dataset_face' v√† 'dataset_depth')
‚úÖ ƒê√£ load: 423 ng∆∞·ªùi.
üîπ ƒêang load RGB Encoder: ResNet18...
üîπ Load /kaggle/input/deeplearn/rgb_encoder_epoch20.pth: Missing keys (FC layers): 2
üîπ ƒêang load Depth Encoder: EfficientNet-B0...
üîπ Load /kaggle/input/deeplearn-eff/depth_encoder_epoch20.pth: Missing keys (FC layers): 2
‚ùÑÔ∏è ƒê√£ ƒë√≥ng bƒÉng Backbone (ResNet18 & EfficientNet).
Ep 1 | Batch 0 | Loss: 0.9039 | Acc: 71.88%
Ep 1 | Batch 5 | Loss: 0.8927 | Acc: 68.75%
Ep 1 | Batch 10 | Loss: 0.7658 | Acc: 68.75%
Ep 1 | Batch 15 | Loss: 0.7631 | Acc: 68.75%
Ep 1 | Batch 20 | Loss: 0.7177 | Acc: 71.88%
Ep 1 | Batch 25 | Loss: 0.6431 | Acc: 68.75%
Ep 1 | Batch 30 | Loss: 0.8421 | Acc: 65.62%
Ep 1 | Batch 35 | Loss: 0.8222 | Acc: 65.62%
Ep 1 | Batch 40 | Loss: 0.8370 | Acc: 50.00%
Ep 1 | Batch 45 | Loss: 0.5648 | Acc: 81.25%
Ep 1 | Batch 50 | Loss: 0.8343 | Acc: 68.75%
Ep 1 | Batch 55 | Loss: 0.7442 | Acc: 75.00%
Ep 1 | 