In [1]:
import os
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from torchvision.transforms import RandomResizedCrop
import torchvision.models as models
import torch.nn as nn

class RGBDepthDataset(Dataset):
    def __init__(self, rgb_dir, depth_dir, image_size=128):
        self.rgb_dir = rgb_dir
        self.depth_dir = depth_dir
        self.image_size = image_size

        # Lấy danh sách file (không lấy đuôi)
        rgb_files = sorted(os.listdir(rgb_dir))
        depth_files = sorted(os.listdir(depth_dir))

        rgb_ids = set([f.split('.')[0] for f in rgb_files])
        depth_ids = set([f.split('.')[0] for f in depth_files])

        # Tên file xuất hiện trong cả RGB và Depth
        self.ids = sorted(list(rgb_ids.intersection(depth_ids)))

        print("Total matched pairs:", len(self.ids))

    def __len__(self):
        return len(self.ids)

    def augment_pair(self, rgb, depth):
        i, j, h, w = RandomResizedCrop.get_params(
            img=rgb, scale=(0.8, 1.0), ratio=(0.9, 1.1)
        )

        rgb = TF.resized_crop(rgb, i, j, h, w, size=(self.image_size, self.image_size))
        depth = TF.resized_crop(depth, i, j, h, w, size=(self.image_size, self.image_size))

        if random.random() < 0.5:
            rgb = TF.hflip(rgb)
            depth = TF.hflip(depth)

        if random.random() < 0.5:
            rgb = TF.adjust_brightness(rgb, random.uniform(0.8, 1.2))
            rgb = TF.adjust_contrast(rgb, random.uniform(0.8, 1.2))

        return rgb, depth

    def __getitem__(self, idx):
        name = self.ids[idx]
        rgb_path = os.path.join(self.rgb_dir, name + ".jpg")
        depth_path = os.path.join(self.depth_dir, name + ".jpg")

        rgb = Image.open(rgb_path).convert("RGB")
        depth = Image.open(depth_path).convert("L")

        # Augment 2 views
        rgb1, depth1 = self.augment_pair(rgb, depth)
        rgb2, depth2 = self.augment_pair(rgb, depth)

        # To tensor
        rgb1 = TF.to_tensor(rgb1)
        rgb2 = TF.to_tensor(rgb2)
        depth1 = TF.to_tensor(depth1)
        depth2 = TF.to_tensor(depth2)

        return rgb1, rgb2, depth1, depth2


In [2]:
class RGBDepthEncoders(nn.Module):
    def __init__(self):
        super().__init__()
        rgb = models.resnet18(weights=None)
        depth = models.resnet18(weights=None)

        # sửa input depth thành 1 kênh
        depth.conv1 = nn.Conv2d(
            1, 64,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False
        )

        rgb.fc = nn.Identity()
        depth.fc = nn.Identity()

        self.rgb_encoder = rgb
        self.depth_encoder = depth

    def forward(self, rgb, depth):
        z_rgb = self.rgb_encoder(rgb)
        z_dep = self.depth_encoder(depth)
        return z_rgb, z_dep


In [3]:
import torch
import torch.nn.functional as F

def simclr_contrastive_loss(z1, z2, temperature=0.1):
    """
    z1, z2: [batch_size, feature_dim]
    """
    batch_size = z1.size(0)

    # normalize
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    # concatenate tất cả feature để tính similarity
    z = torch.cat([z1, z2], dim=0)  # [2*B, D]

    # similarity matrix
    sim = torch.matmul(z, z.T) / temperature  # [2B, 2B]

    # mask để loại bỏ similarity của cùng 1 sample
    mask = (~torch.eye(2*batch_size, 2*batch_size, dtype=bool)).to(z.device)

    sim_masked = sim[mask].view(2*batch_size, -1)  # loại bỏ self-similarity

    # positive pairs: i-th của z1 ↔ i-th của z2
    pos = torch.exp(torch.sum(z1 * z2, dim=-1) / temperature)
    pos = torch.cat([pos, pos], dim=0)  # [2B]

    # loss = -log(pos / sum(exp(similarities)))
    loss = -torch.log(pos / sim_masked.exp().sum(dim=1))
    return loss.mean()


In [4]:
class RGBDepthTrainer:
    def __init__(self, device="cuda"):
        self.device = device
        self.model = RGBDepthEncoders().to(device)
        self.opt = torch.optim.Adam(self.model.parameters(), lr=1e-4)

    def step(self, rgb1, rgb2, dep1, dep2):
        rgb1, rgb2 = rgb1.to(self.device), rgb2.to(self.device)
        dep1, dep2 = dep1.to(self.device), dep2.to(self.device)

        z_r1, z_d1 = self.model(rgb1, dep1)
        z_r2, z_d2 = self.model(rgb2, dep2)

        # multi-view contrastive loss
        loss_rgb = simclr_contrastive_loss(z_r1, z_r2)
        loss_depth = simclr_contrastive_loss(z_d1, z_d2)
        loss_rgb_depth = simclr_contrastive_loss(z_r1, z_d1)

        loss = (loss_rgb + loss_depth + loss_rgb_depth) / 3

        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

        return loss.item()


In [5]:
import os
import torch
from torch.utils.data import DataLoader

# Dataset, Trainer đã được định nghĩa trước

dataset = RGBDepthDataset(
    rgb_dir="/kaggle/input/midasz-face/celeba_rgb_output",
    depth_dir="/kaggle/input/midasz-face/celeba_depth",
    image_size=128
)

loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2, drop_last=True)
trainer = RGBDepthTrainer(device="cuda")

save_dir = "/kaggle/working"
os.makedirs(save_dir, exist_ok=True)

for epoch in range(20):
    total_loss = 0
    for i, (rgb1, rgb2, dep1, dep2) in enumerate(loader):
        loss = trainer.step(rgb1, rgb2, dep1, dep2)
        total_loss += loss

        # In log mỗi 100 batch
        if (i+1) % 100 == 0:
            print(f"Epoch {epoch} | Batch {i+1} | Current batch loss: {loss:.4f}")

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch} completed — Average loss: {avg_loss:.4f}")

    
    if (epoch + 1) % 5 == 0:
        rgb_path = os.path.join(save_dir, f"rgb_encoder_epoch{epoch+1}.pth")
        depth_path = os.path.join(save_dir, f"depth_encoder_epoch{epoch+1}.pth")
        torch.save(trainer.model.rgb_encoder.state_dict(), rgb_path)
        torch.save(trainer.model.depth_encoder.state_dict(), depth_path)
        print(f"Saved RGB encoder to {rgb_path}")
        print(f"Saved Depth encoder to {depth_path}")


Total matched pairs: 50000
Epoch 0 | Batch 100 | Current batch loss: 2.1090
Epoch 0 | Batch 200 | Current batch loss: 1.6554
Epoch 0 | Batch 300 | Current batch loss: 1.5573
Epoch 0 | Batch 400 | Current batch loss: 1.5338
Epoch 0 | Batch 500 | Current batch loss: 1.4564
Epoch 0 | Batch 600 | Current batch loss: 1.2423
Epoch 0 | Batch 700 | Current batch loss: 1.4751
Epoch 0 | Batch 800 | Current batch loss: 1.1288
Epoch 0 | Batch 900 | Current batch loss: 1.2509
Epoch 0 | Batch 1000 | Current batch loss: 1.1142
Epoch 0 | Batch 1100 | Current batch loss: 1.1699
Epoch 0 | Batch 1200 | Current batch loss: 0.9708
Epoch 0 | Batch 1300 | Current batch loss: 0.9642
Epoch 0 | Batch 1400 | Current batch loss: 1.0215
Epoch 0 | Batch 1500 | Current batch loss: 1.0038
Epoch 0 completed — Average loss: 1.3726
Epoch 1 | Batch 100 | Current batch loss: 1.1369
Epoch 1 | Batch 200 | Current batch loss: 0.9516
Epoch 1 | Batch 300 | Current batch loss: 0.8728
Epoch 1 | Batch 400 | Current batch loss: 0.