In [None]:
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
ds = load_dataset("huyhamhoc/top4000")

In [None]:
import numpy as np
np_char_name = np.array(ds['train']['char_name'])


In [None]:
unique_characters =np.unique(np_char_name)
character_to_label = {char: i for i, char in enumerate(unique_characters)}
character_to_label[ds['train']['char_name'][0]]

In [None]:
len(unique_characters)

In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import timm
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
import torch.optim.lr_scheduler as lr_scheduler


class AnimeDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
        self.character_to_label = character_to_label

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Lấy dữ liệu từ dataset
        item = self.dataset[idx]
        image = item["image"]  # Giả sử image là một ảnh (không phải đường dẫn)
        label = self.character_to_label[item["char_name"]]  # Nhãn là tên nhân vật

        if isinstance(image, str):  
            image = Image.open(image).convert("RGB")  # Chuyển sang RGB

        # Nếu image không phải RGB, chuyển đổi nó
        if image.mode != "RGB":
            image = image.convert("RGB")

        # Áp dụng transform nếu có
        if self.transform:
            image = self.transform(image)

        return image, label

    

# Define transforms (e.g., resize, normalize)
transform_train = transforms.Compose([
    # Biến đổi hình học
    transforms.RandomRotation(20),              # Xoay ngẫu nhiên ±30 độ (mô phỏng góc quay đầu)
    transforms.RandomHorizontalFlip(p=0.5),     # Lật ngang 50% (mô phỏng đối xứng khuôn mặt)
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  # Cắt ngẫu nhiên và resize (mô phỏng zoom)

    # Biến đổi màu sắc và ánh sáng
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),  # Thay đổi sáng, tương phản, bão hòa
    transforms.RandomGrayscale(p=0.2),          # Chuyển thành ảnh xám 20% (mô phỏng điều kiện ánh sáng kém)
    transforms.ToTensor(),

    # Che khuất (occlusion) - mô phỏng vật cản trên khuôn mặt
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.2), ratio=(0.3, 3.3), value='random'),  # Xóa ngẫu nhiên một vùng

    # Thêm nhiễu (noise)
    transforms.Lambda(lambda img: img + torch.randn_like(img) * 0.1),  # Thêm nhiễu Gaussian nhẹ

    # Chuẩn hóa
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Chuẩn hóa cho mô hình pre-trained
])
transform_val = transforms.Compose([
    transforms.Resize((224, 224)),                         # Resize về 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])    # Chuẩn hóa
])




In [None]:
dataset = AnimeDataset(ds['train'], transform=transform_train)

# Kiểm tra số lượng mẫu
print(f"Tổng số mẫu: {len(dataset)}")

# Lấy một mẫu kiểm tra
image, label = dataset[0]
print(f"Loại dữ liệu image: {type(image)}")
print(f"Kích thước image: {image.shape if isinstance(image, torch.Tensor) else 'Không phải tensor'}")
print(f"Nhãn: {label}")


In [None]:
import torch
from torch.utils.data import random_split

# Giả sử dataset có tổng số mẫu là N
dataset = AnimeDataset(ds['train'], transform=transform_train)
total_size = len(dataset)

# Xác định tỉ lệ train/val, ví dụ 80% train, 20% val
train_ratio = 0.8
train_size = int(train_ratio * total_size)
val_size = total_size - train_size  # Phần còn lại cho validation

# Chia dataset
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Kiểm tra kích thước
print(f"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}")


In [None]:
import hashlib
import os
import shutil
import sys
import tempfile

from urllib.request import urlopen, Request

try:
    from tqdm.auto import tqdm  # automatically select proper tqdm submodule if available
except ImportError:
    from tqdm import tqdm


def download_url_to_file(url, dst, hash_prefix=None, progress=True):
    r"""Download object at the given URL to a local path.
    Args:
        url (string): URL of the object to download
        dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
        hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`.
            Default: None
        progress (bool, optional): whether or not to display a progress bar to stderr
            Default: True
    Example:
        >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
    """
    file_size = None
    # We use a different API for python2 since urllib(2) doesn't recognize the CA
    # certificates in older Python
    req = Request(url, headers={"User-Agent": "torch.hub"})
    u = urlopen(req)
    meta = u.info()
    if hasattr(meta, 'getheaders'):
        content_length = meta.getheaders("Content-Length")
    else:
        content_length = meta.get_all("Content-Length")
    if content_length is not None and len(content_length) > 0:
        file_size = int(content_length[0])

    # We deliberately save it in a temp file and move it after
    # download is complete. This prevents a local working checkpoint
    # being overridden by a broken download.
    dst = os.path.expanduser(dst)
    dst_dir = os.path.dirname(dst)
    f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)

    try:
        if hash_prefix is not None:
            sha256 = hashlib.sha256()
        with tqdm(total=file_size, disable=not progress,
                  unit='B', unit_scale=True, unit_divisor=1024) as pbar:
            while True:
                buffer = u.read(8192)
                if len(buffer) == 0:
                    break
                f.write(buffer)
                if hash_prefix is not None:
                    sha256.update(buffer)
                pbar.update(len(buffer))

        f.close()
        if hash_prefix is not None:
            digest = sha256.hexdigest()
            if digest[:len(hash_prefix)] != hash_prefix:
                raise RuntimeError('invalid hash value (expected "{}", got "{}")'
                                   .format(hash_prefix, digest))
        shutil.move(f.name, dst)
    finally:
        f.close()
        if os.path.exists(f.name):
            os.remove(f.name)


import os
import requests
from requests.adapters import HTTPAdapter

import torch
from torch import nn
from torch.nn import functional as F



class BasicConv2d(nn.Module):

    def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
        super().__init__()
        self.conv = nn.Conv2d(
            in_planes, out_planes,
            kernel_size=kernel_size, stride=stride,
            padding=padding, bias=False
        ) # verify bias false
        self.bn = nn.BatchNorm2d(
            out_planes,
            eps=0.001, # value found in tensorflow
            momentum=0.1, # default pytorch value
            affine=True
        )
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class Block35(nn.Module):

    def __init__(self, scale=1.0):
        super().__init__()

        self.scale = scale

        self.branch0 = BasicConv2d(256, 32, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(256, 32, kernel_size=1, stride=1),
            BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
        )

        self.branch2 = nn.Sequential(
            BasicConv2d(256, 32, kernel_size=1, stride=1),
            BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1),
            BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
        )

        self.conv2d = nn.Conv2d(96, 256, kernel_size=1, stride=1)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        out = torch.cat((x0, x1, x2), 1)
        out = self.conv2d(out)
        out = out * self.scale + x
        out = self.relu(out)
        return out


class Block17(nn.Module):

    def __init__(self, scale=1.0):
        super().__init__()

        self.scale = scale

        self.branch0 = BasicConv2d(896, 128, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(896, 128, kernel_size=1, stride=1),
            BasicConv2d(128, 128, kernel_size=(1,7), stride=1, padding=(0,3)),
            BasicConv2d(128, 128, kernel_size=(7,1), stride=1, padding=(3,0))
        )

        self.conv2d = nn.Conv2d(256, 896, kernel_size=1, stride=1)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        out = torch.cat((x0, x1), 1)
        out = self.conv2d(out)
        out = out * self.scale + x
        out = self.relu(out)
        return out


class Block8(nn.Module):

    def __init__(self, scale=1.0, noReLU=False):
        super().__init__()

        self.scale = scale
        self.noReLU = noReLU

        self.branch0 = BasicConv2d(1792, 192, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(1792, 192, kernel_size=1, stride=1),
            BasicConv2d(192, 192, kernel_size=(1,3), stride=1, padding=(0,1)),
            BasicConv2d(192, 192, kernel_size=(3,1), stride=1, padding=(1,0))
        )

        self.conv2d = nn.Conv2d(384, 1792, kernel_size=1, stride=1)
        if not self.noReLU:
            self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        out = torch.cat((x0, x1), 1)
        out = self.conv2d(out)
        out = out * self.scale + x
        if not self.noReLU:
            out = self.relu(out)
        return out


class Mixed_6a(nn.Module):

    def __init__(self):
        super().__init__()

        self.branch0 = BasicConv2d(256, 384, kernel_size=3, stride=2)

        self.branch1 = nn.Sequential(
            BasicConv2d(256, 192, kernel_size=1, stride=1),
            BasicConv2d(192, 192, kernel_size=3, stride=1, padding=1),
            BasicConv2d(192, 256, kernel_size=3, stride=2)
        )

        self.branch2 = nn.MaxPool2d(3, stride=2)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        out = torch.cat((x0, x1, x2), 1)
        return out


class Mixed_7a(nn.Module):

    def __init__(self):
        super().__init__()

        self.branch0 = nn.Sequential(
            BasicConv2d(896, 256, kernel_size=1, stride=1),
            BasicConv2d(256, 384, kernel_size=3, stride=2)
        )

        self.branch1 = nn.Sequential(
            BasicConv2d(896, 256, kernel_size=1, stride=1),
            BasicConv2d(256, 256, kernel_size=3, stride=2)
        )

        self.branch2 = nn.Sequential(
            BasicConv2d(896, 256, kernel_size=1, stride=1),
            BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
            BasicConv2d(256, 256, kernel_size=3, stride=2)
        )

        self.branch3 = nn.MaxPool2d(3, stride=2)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out


class InceptionResnetV1(nn.Module):
    """Inception Resnet V1 model with optional loading of pretrained weights.

    Model parameters can be loaded based on pretraining on the VGGFace2 or CASIA-Webface
    datasets. Pretrained state_dicts are automatically downloaded on model instantiation if
    requested and cached in the torch cache. Subsequent instantiations use the cache rather than
    redownloading.

    Keyword Arguments:
        pretrained {str} -- Optional pretraining dataset. Either 'vggface2' or 'casia-webface'.
            (default: {None})
        classify {bool} -- Whether the model should output classification probabilities or feature
            embeddings. (default: {False})
        num_classes {int} -- Number of output classes. If 'pretrained' is set and num_classes not
            equal to that used for the pretrained model, the final linear layer will be randomly
            initialized. (default: {None})
        dropout_prob {float} -- Dropout probability. (default: {0.6})
    """
    def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None):
        super().__init__()

        # Set simple attributes
        self.pretrained = pretrained
        self.classify = classify
        self.num_classes = num_classes

        if pretrained == 'vggface2':
            tmp_classes = 8631
        elif pretrained == 'casia-webface':
            tmp_classes = 10575
        elif pretrained is None and self.classify and self.num_classes is None:
            raise Exception('If "pretrained" is not specified and "classify" is True, "num_classes" must be specified')


        # Define layers
        self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
        self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
        self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.maxpool_3a = nn.MaxPool2d(3, stride=2)
        self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
        self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
        self.conv2d_4b = BasicConv2d(192, 256, kernel_size=3, stride=2)
        self.repeat_1 = nn.Sequential(
            Block35(scale=0.17),
            Block35(scale=0.17),
            Block35(scale=0.17),
            Block35(scale=0.17),
            Block35(scale=0.17),
        )
        self.mixed_6a = Mixed_6a()
        self.repeat_2 = nn.Sequential(
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
        )
        self.mixed_7a = Mixed_7a()
        self.repeat_3 = nn.Sequential(
            Block8(scale=0.20),
            Block8(scale=0.20),
            Block8(scale=0.20),
            Block8(scale=0.20),
            Block8(scale=0.20),
        )
        self.block8 = Block8(noReLU=True)
        self.avgpool_1a = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(dropout_prob)
        self.last_linear = nn.Linear(1792, 512, bias=False)
        self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True)

        if pretrained is not None:
            self.logits = nn.Linear(512, tmp_classes)
            load_weights(self, pretrained)

        if self.classify and self.num_classes is not None:
            self.logits = nn.Linear(512, self.num_classes)

        self.device = torch.device('cpu')
        if device is not None:
            self.device = device
            self.to(device)

    def forward(self, x):
        """Calculate embeddings or logits given a batch of input image tensors.

        Arguments:
            x {torch.tensor} -- Batch of image tensors representing faces.

        Returns:
            torch.tensor -- Batch of embedding vectors or multinomial logits.
        """
        x = self.conv2d_1a(x)
        x = self.conv2d_2a(x)
        x = self.conv2d_2b(x)
        x = self.maxpool_3a(x)
        x = self.conv2d_3b(x)
        x = self.conv2d_4a(x)
        x = self.conv2d_4b(x)
        x = self.repeat_1(x)
        x = self.mixed_6a(x)
        x = self.repeat_2(x)
        x = self.mixed_7a(x)
        x = self.repeat_3(x)
        x = self.block8(x)
        x = self.avgpool_1a(x)
        x = self.dropout(x)
        x = self.last_linear(x.view(x.shape[0], -1))
        x = self.last_bn(x)
        if self.classify:
            x = self.logits(x)
        else:
            x = F.normalize(x, p=2, dim=1)
        return x


def load_weights(mdl, name):
    """Download pretrained state_dict and load into model.

    Arguments:
        mdl {torch.nn.Module} -- Pytorch model.
        name {str} -- Name of dataset that was used to generate pretrained state_dict.

    Raises:
        ValueError: If 'pretrained' not equal to 'vggface2' or 'casia-webface'.
    """
    if name == 'vggface2':
        path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180402-114759-vggface2.pt'
    elif name == 'casia-webface':
        path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180408-102900-casia-webface.pt'
    else:
        raise ValueError('Pretrained models only exist for "vggface2" and "casia-webface"')

    model_dir = os.path.join(get_torch_home(), 'checkpoints')
    os.makedirs(model_dir, exist_ok=True)

    cached_file = os.path.join(model_dir, os.path.basename(path))
    if not os.path.exists(cached_file):
        download_url_to_file(path, cached_file)

    state_dict = torch.load(cached_file)
    mdl.load_state_dict(state_dict)


def get_torch_home():
    torch_home = os.path.expanduser(
        os.getenv(
            'TORCH_HOME',
            os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')
        )
    )
    return torch_home

In [None]:
import torch
import torch.nn as nn
from torch.optim import lr_scheduler

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# Khởi tạo mô hình
model = InceptionResnetV1(pretrained='vggface2')
num_classes = len(unique_characters)  # Số lượng nhân vật duy nhất trong tập dữ liệu của bạn
model.classify = True  # Kích hoạt lớp phân loại

# Thay đổi lớp logits
model.logits = nn.Linear(model.logits.in_features, num_classes)

# Khởi tạo criterion, optimizer, và scheduler
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.0005)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)  # Điều chỉnh T_max cho phù hợp với số epoch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


num_epochs = 15
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    top5_correct = 0

    # Huấn luyện
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        # Đặt lại gradient
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass và tối ưu hóa
        loss.backward()
        optimizer.step()

        # Tích lũy loss và accuracy
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)  # Dự đoán lớp cho top-1 accuracy
        total += labels.size(0)
        correct += (predicted == labels).sum().item()  # Tính số lượng dự đoán chính xác cho top-1

        # Tính top-5 accuracy
        top5_pred = torch.topk(outputs, 5, dim=1).indices  # Lấy 5 nhãn dự đoán có xác suất cao nhất
        for i in range(labels.size(0)):
            if labels[i] in top5_pred[i]:
                top5_correct += 1  # Nếu nhãn thực nằm trong top 5 dự đoán, tăng biến đếm

    
    
    # Tính toán accuracy trong quá trình huấn luyện
    train_accuracy = correct / total * 100  # Tính accuracy theo phần trăm
    train_top5_accuracy = top5_correct / total * 100  # Tính top-5 accuracy theo phần trăm
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {running_loss/len(train_loader):.4f}, "
          f"Train Accuracy: {train_accuracy:.2f}%, Train Top-5 Accuracy: {train_top5_accuracy:.2f}%")

    # Validation
    model.eval()  # Chuyển mô hình sang chế độ đánh giá
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    val_top5_correct = 0

    with torch.no_grad():  # Không tính gradient trong quá trình validation
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Tích lũy validation loss
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)  # Dự đoán lớp cho top-1 accuracy
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()  # Tính số lượng dự đoán chính xác cho top-1

            # Tính top-5 accuracy
            top5_pred = torch.topk(outputs, 5, dim=1).indices  # Lấy 5 nhãn dự đoán có xác suất cao nhất
            for i in range(labels.size(0)):
                if labels[i] in top5_pred[i]:
                    val_top5_correct += 1  # Nếu nhãn thực nằm trong top 5 dự đoán, tăng biến đếm

    # Tính toán accuracy trong validation
    val_accuracy = val_correct / val_total * 100  # Tính accuracy theo phần trăm
    val_top5_accuracy = val_top5_correct / val_total * 100  # Tính top-5 accuracy theo phần trăm
    print(f"Validation Loss: {val_loss/len(val_loader):.4f}, Validation Accuracy: {val_accuracy:.2f}%, "
          f"Validation Top-5 Accuracy: {val_top5_accuracy:.2f}%")

    scheduler.step()
print("Fine-tuning complete.")
model.classify = False
torch.save(model,'fine_tuned_facenet.pth')
