# Học nhúng ảnh đáy mắt với ArcFace và SupCon
Notebook này hướng dẫn xây dựng pipeline embedding cho ảnh đáy mắt bằng ArcFace Loss và Supervised Contrastive Loss (SupCon).

In [1]:
# Cài đặt thư viện nếu cần
!pip install torchmetrics



In [2]:
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import efficientnet_b3

In [3]:
# Dataset cho embedding supervised
class DRDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img_name = self.data.iloc[idx, 0] + '.png'
        label = int(self.data.iloc[idx, 1])
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

In [4]:
# Transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
train_csv = 'aptos2019/train_split.csv'
img_dir = 'aptos2019/train_images'
train_dataset = DRDataset(train_csv, img_dir, transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [5]:
# Mô hình embedding
class EmbeddingNet(nn.Module):
    def __init__(self, embedding_dim=128, num_classes=5):
        super().__init__()
        self.backbone = efficientnet_b3(weights=None)
        self.backbone.classifier[1] = nn.Linear(self.backbone.classifier[1].in_features, embedding_dim)
        self.arcface = ArcFace(embedding_dim, num_classes)
    def forward(self, x, labels=None):
        emb = self.backbone(x)
        if labels is not None:
            logits = self.arcface(emb, labels)
            return emb, logits
        return emb

## Sử dụng EfficientNet pretrained và fine-tune cho embedding
Bạn có thể sử dụng EfficientNet với trọng số pretrained (ImageNet) để cải thiện chất lượng embedding. Sau đó fine-tune trên dữ liệu của bạn.
- Sử dụng weights='IMAGENET1K_V1' khi khởi tạo backbone.
- Có thể freeze các layer đầu, chỉ train các layer cuối và lớp embedding.
- Áp dụng ArcFace hoặc SupCon như pipeline hiện tại.
Ví dụ dưới đây sẽ hướng dẫn cách làm.

In [6]:
# ArcFace Loss implementation
import math
class ArcFace(nn.Module):
    def __init__(self, embedding_dim, num_classes, s=30.0, m=0.50):
        super().__init__()
        self.W = nn.Parameter(torch.randn(embedding_dim, num_classes))
        self.s = s
        self.m = m
    def forward(self, emb, labels):
        emb_norm = nn.functional.normalize(emb, dim=1)
        W_norm = nn.functional.normalize(self.W, dim=0)
        logits = torch.matmul(emb_norm, W_norm)
        theta = torch.acos(torch.clamp(logits, -1.0, 1.0))
        target_logits = torch.cos(theta + self.m)
        one_hot = torch.zeros_like(logits)
        one_hot.scatter_(1, labels.view(-1,1), 1)
        output = logits * (1 - one_hot) + target_logits * one_hot
        output = output * self.s
        return output

In [7]:
# SupCon Loss implementation
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
    def forward(self, features, labels):
        features = nn.functional.normalize(features, dim=1)
        batch_size = features.shape[0]
        mask = torch.eq(labels.unsqueeze(1), labels.unsqueeze(0)).float()
        anchor_dot_contrast = torch.div(torch.matmul(features, features.T), self.temperature)
        logits_mask = torch.ones_like(mask) - torch.eye(batch_size, device=mask.device)
        mask = mask * logits_mask
        exp_logits = torch.exp(anchor_dot_contrast) * logits_mask
        log_prob = anchor_dot_contrast - torch.log(exp_logits.sum(1, keepdim=True) + 1e-8)
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-8)
        loss = -mean_log_prob_pos.mean()
        return loss

In [8]:
from PIL import Image

In [9]:
# Huấn luyện embedding với ArcFace
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = EmbeddingNet(embedding_dim=128, num_classes=5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
EPOCHS = 50
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        emb, logits = model(images, labels)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{EPOCHS} - ArcFace Loss: {epoch_loss:.4f}')

Epoch 1/50 - ArcFace Loss: 15.9791
Epoch 2/50 - ArcFace Loss: 12.8158
Epoch 3/50 - ArcFace Loss: 10.7569
Epoch 4/50 - ArcFace Loss: 10.1474
Epoch 5/50 - ArcFace Loss: 9.4278
Epoch 6/50 - ArcFace Loss: 9.1020
Epoch 7/50 - ArcFace Loss: 8.3536
Epoch 8/50 - ArcFace Loss: 8.2111
Epoch 9/50 - ArcFace Loss: 7.8227
Epoch 10/50 - ArcFace Loss: 7.0061
Epoch 11/50 - ArcFace Loss: 6.9451
Epoch 12/50 - ArcFace Loss: 6.6442
Epoch 13/50 - ArcFace Loss: 6.3533
Epoch 14/50 - ArcFace Loss: 6.0204
Epoch 15/50 - ArcFace Loss: 6.1441
Epoch 16/50 - ArcFace Loss: 5.6250
Epoch 17/50 - ArcFace Loss: 5.2926
Epoch 18/50 - ArcFace Loss: 5.2515
Epoch 19/50 - ArcFace Loss: 5.0078
Epoch 20/50 - ArcFace Loss: 5.1216
Epoch 21/50 - ArcFace Loss: 4.8258
Epoch 22/50 - ArcFace Loss: 4.7356
Epoch 23/50 - ArcFace Loss: 4.6181
Epoch 24/50 - ArcFace Loss: 4.9038
Epoch 25/50 - ArcFace Loss: 4.6366
Epoch 26/50 - ArcFace Loss: 4.9343
Epoch 27/50 - ArcFace Loss: 4.1154
Epoch 28/50 - ArcFace Loss: 4.7608
Epoch 29/50 - ArcFace Los

In [11]:
# Lưu lại model ArcFace sau khi train
torch.save(model.state_dict(), 'arcface_embedding.pth')

In [None]:
# Load lại model ArcFace khi cần sử dụng
model = EmbeddingNet(embedding_dim=128, num_classes=5).to(device)
model.load_state_dict(torch.load('arcface_embedding.pth'))

In [10]:
# Huấn luyện embedding với SupCon
model = EmbeddingNet(embedding_dim=128, num_classes=5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = SupConLoss(temperature=0.07)
EPOCHS = 50
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        emb = model(images)
        loss = criterion(emb, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{EPOCHS} - SupCon Loss: {epoch_loss:.4f}')

Epoch 1/50 - SupCon Loss: 4.1446
Epoch 2/50 - SupCon Loss: 3.4225
Epoch 3/50 - SupCon Loss: 3.4001
Epoch 4/50 - SupCon Loss: 3.3687
Epoch 5/50 - SupCon Loss: 3.3413
Epoch 6/50 - SupCon Loss: 3.3308
Epoch 7/50 - SupCon Loss: 3.2836
Epoch 8/50 - SupCon Loss: 3.2661
Epoch 9/50 - SupCon Loss: 3.3413
Epoch 10/50 - SupCon Loss: 3.2022
Epoch 11/50 - SupCon Loss: 3.2203
Epoch 12/50 - SupCon Loss: 3.1671
Epoch 13/50 - SupCon Loss: 3.1306
Epoch 14/50 - SupCon Loss: 3.0477
Epoch 15/50 - SupCon Loss: 3.0233
Epoch 16/50 - SupCon Loss: 2.9469
Epoch 17/50 - SupCon Loss: 3.0129
Epoch 18/50 - SupCon Loss: 2.9960
Epoch 19/50 - SupCon Loss: 2.9518
Epoch 20/50 - SupCon Loss: 2.9300
Epoch 21/50 - SupCon Loss: 2.9482
Epoch 22/50 - SupCon Loss: 2.8584
Epoch 23/50 - SupCon Loss: 2.8409
Epoch 24/50 - SupCon Loss: 2.7890
Epoch 25/50 - SupCon Loss: 2.7965
Epoch 26/50 - SupCon Loss: 2.8126
Epoch 27/50 - SupCon Loss: 2.8332
Epoch 28/50 - SupCon Loss: 2.7654
Epoch 29/50 - SupCon Loss: 2.8115
Epoch 30/50 - SupCon Lo

In [12]:
# Lưu lại model SupCon sau khi train
torch.save(model.state_dict(), 'supcon_embedding.pth')

In [None]:
# Load lại model SupCon khi cần sử dụng
model = EmbeddingNet(embedding_dim=128, num_classes=5).to(device)
model.load_state_dict(torch.load('supcon_embedding.pth'))