In [1]:
!pip install ultralytics torchtoolbox bitsandbytes

Collecting ultralytics
  Downloading ultralytics-8.3.203-py3-none-any.whl.metadata (37 kB)
Collecting torchtoolbox
  Downloading torchtoolbox-0.1.8.2-py3-none-any.whl.metadata (15 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.47.0-py3-none-manylinux_2_24_x86_64.whl.metadata (11 kB)
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.17-py3-none-any.whl.metadata (14 kB)
Collecting lmdb (from torchtoolbox)
  Downloading lmdb-1.7.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.8.0->ultralytic

In [2]:
import torch
import torch.nn as nn
from ultralytics import YOLO
from torch.nn import MultiheadAttention
from torchvision import transforms
from PIL import Image
import random
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
from torch.utils.data import Dataset, DataLoader, Sampler, Subset, SubsetRandomSampler
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchtoolbox.tools import mixup_data, mixup_criterion
import bitsandbytes.optim as bnb_optim

# Đảm bảo tính lặp lại của kết quả
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

CHINESE_COUNT = 34  # 34 ký tự Chinese đứng đầu trong CHARACTERS
def is_chinese_idx(idx: int) -> bool:
    return 0 <= idx < CHINESE_COUNT

#--- Các hằng số ---
provinces = ["皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤", "桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新", "警", "学", "O"]
alphabets = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W',
             'X', 'Y', 'Z', 'O']
ads = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
       'Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'O']

CHARACTERS = [
    "皖","沪","津","渝","冀","晋","蒙","辽","吉","黑","苏","浙","京","闽","赣","鲁","豫","鄂","湘","粤","桂","琼","川","贵","云","藏","陕","甘","青","宁","新","警","学","O",
    "A","B","C","D","E","F","G","H","J","K","L","M","N","P","Q","R","S","T","U","V","W","X","Y","Z",
    "0","1","2","3","4","5","6","7","8","9"]

SOS_TOKEN = len(CHARACTERS)  # Mã thông báo bắt đầu chuỗi
EOS_TOKEN = len(CHARACTERS) + 1  # Mã thông báo kết thúc chuỗi
PAD_TOKEN = len(CHARACTERS) + 2  # Mã thông báo đệm
NUM_CLASSES = len(CHARACTERS) + 3  # Bao gồm SOS, EOS, PAD
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

#--- Các hàm tiện ích ---
def index_to_char(indices, include_special_tokens=False):
    result = []
    for i in indices:
        i = i.item() if torch.is_tensor(i) else i
        if i == SOS_TOKEN:
            if include_special_tokens: result.append('[SOS]')
        elif i == EOS_TOKEN:
            if include_special_tokens: result.append('[EOS]')
            break
        elif 0 <= i < len(CHARACTERS):
            result.append(CHARACTERS[i])
        else:
            if include_special_tokens or i not in [SOS_TOKEN, EOS_TOKEN, PAD_TOKEN]:
                result.append(f'[UNK_{i}]')
    return ''.join(result)

def char_to_indices(text):
    indices = [SOS_TOKEN]
    for c in text:
        if c in CHARACTERS:
            indices.append(CHARACTERS.index(c))
    indices.append(EOS_TOKEN)
    return torch.tensor(indices, dtype=torch.long)

#--- Lớp bao bọc YOLO với RViT ---
class YoloBackbone(nn.Module):
    def __init__(self, model_path, target_feature_layer_idx=9):
        super().__init__()
        _temp_yolo_instance = YOLO(model_path)
        self.yolo_detection_model = _temp_yolo_instance.model
        self.yolo_detection_model.to(DEVICE)
        self.target_feature_layer_index = target_feature_layer_idx

        for name, param in self.yolo_detection_model.named_parameters():
            param.requires_grad = True
        
        self._hook_handle = None
        self._fmap_out_hook = []
        
        self._register_hook()

    def _hook_fn_extractor(self, module, input_val, output_val):
        if isinstance(output_val, torch.Tensor):
            self._fmap_out_hook.append(output_val)
        elif isinstance(output_val, (list, tuple)):
            for item in output_val:
                if isinstance(item, torch.Tensor):
                    self._fmap_out_hook.append(item)
                    break

    def _register_hook(self):
        # Gắn hook vào lớp mục tiêu của mạng YOLO để trích xuất các đặc trưng.
        layer_to_hook = self.yolo_detection_model.model[self.target_feature_layer_index]
        self._hook_handle = layer_to_hook.register_forward_hook(self._hook_fn_extractor)

    def _remove_hook(self):
        if self._hook_handle:
            self._hook_handle.remove()
            self._hook_handle = None

    def forward(self, x):
        self._fmap_out_hook.clear()
        _ = self.yolo_detection_model(x)
        out_tensor = self._fmap_out_hook[0]
        return out_tensor if out_tensor.dim() == 4 else out_tensor.unsqueeze(0)

#--- RViT đơn giản hóa ---
class RViT(nn.Module):
    def __init__(self, yolo_channels=256, d_model=512, num_patches=1600, n_heads=8, num_encoder_layers=3, dim_feedforward=2048, dropout_rate=0.5):
        super().__init__()
        self.d_model = d_model
        self.proj = nn.Sequential(
            nn.Conv2d(yolo_channels, d_model, kernel_size=3, padding=1),
            nn.BatchNorm2d(d_model),
            nn.ReLU(),
            nn.Dropout2d(dropout_rate if dropout_rate > 0 else 0)
        )
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=dim_feedforward,
            dropout=dropout_rate, batch_first=True, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
        self.region_q = nn.Parameter(torch.zeros(1, 1, d_model))
        self.embed = nn.Embedding(NUM_CLASSES, d_model)
        self.gru_num_layers = 1
        self.gru = nn.GRU(d_model, d_model, num_layers=self.gru_num_layers, batch_first=True,
                          dropout=dropout_rate if self.gru_num_layers > 1 else 0)
        self.attn = MultiheadAttention(d_model, num_heads=n_heads, batch_first=True, dropout=dropout_rate)
        self.fc = nn.Sequential(
            nn.Dropout(dropout_rate if dropout_rate > 0 else 0),
            nn.Linear(2 * d_model, NUM_CLASSES)
        )

    def forward(self, fmap, target=None, teach_ratio=0.5, forced_output_length=None):
        b = fmap.size(0)
        x = self.proj(fmap)
        x = x.flatten(2).permute(0, 2, 1)

        current_num_patches = x.size(1)
        expected_pos_embed_len = current_num_patches + 1
        
        if self.pos_embed.size(1) != expected_pos_embed_len:
            if self.pos_embed.size(1) > expected_pos_embed_len:
                pos_embed_to_add = self.pos_embed[:, :expected_pos_embed_len, :]
            else:
                raise ValueError(f"RViT pos_embed second dim {self.pos_embed.size(1)} is smaller than required {expected_pos_embed_len}")
        else:
            pos_embed_to_add = self.pos_embed

        q = self.region_q.expand(b, -1, -1)
        x = torch.cat([q, x], dim=1)
        x = x + pos_embed_to_add

        enc = self.encoder(x)
        region_feat, spatial_feats = enc[:, 0], enc[:, 1:]
        
        if forced_output_length is not None:
            max_gen_len = forced_output_length
        elif target is not None:
            max_gen_len = target.size(1) - 1
        else:
            max_gen_len = MAX_SEQ_LENGTH - 1

        h = region_feat.unsqueeze(0).contiguous()
        current_input_tokens = torch.full((b,), SOS_TOKEN, device=DEVICE, dtype=torch.long)
        outputs_logits = []

        finished_sequences_tracker = None
        if target is None:
            finished_sequences_tracker = torch.zeros(b, dtype=torch.bool, device=DEVICE)

        for t in range(max_gen_len):
            emb = self.embed(current_input_tokens).unsqueeze(1)
            g, h = self.gru(emb, h)
            a, _ = self.attn(g, spatial_feats, spatial_feats)
            comb = torch.cat([g.squeeze(1), a.squeeze(1)], dim=-1)
            logits_step = self.fc(comb)
            outputs_logits.append(logits_step)
            
            # Logic mới đã sửa lỗi
            if target is not None and random.random() < teach_ratio:
                next_input_tokens = target[:, t + 1]
            else:
                next_input_tokens = logits_step.argmax(-1)

            if finished_sequences_tracker is not None:
                # Đảm bảo chúng ta không thay thế các chuỗi đã hoàn thành
                eos_predicted_this_step = (next_input_tokens == EOS_TOKEN)
                finished_sequences_tracker = torch.logical_or(finished_sequences_tracker, eos_predicted_this_step)
                
                mask = finished_sequences_tracker.long()
                eos_token_tensor = torch.tensor(EOS_TOKEN, device=DEVICE, dtype=torch.long).expand_as(next_input_tokens)
                current_input_tokens = mask * eos_token_tensor + (1 - mask) * next_input_tokens
                if finished_sequences_tracker.all():
                    break
            else:
                current_input_tokens = next_input_tokens
        
        return torch.stack(outputs_logits, dim=1)

#--- Mô hình hoàn chỉnh (YOLO_RViT) ---
class YOLO_RViT(nn.Module):
    def __init__(self, yolo_path, yolo_target_feature_layer_idx=9):
        super().__init__()
        self.backbone = YoloBackbone(yolo_path, target_feature_layer_idx=yolo_target_feature_layer_idx)
        dummy_input = torch.randn(1, 3, 640, 640).to(DEVICE)
        
        with torch.no_grad():
            dummy_feats = self.backbone(dummy_input)
        
        yolo_channels = dummy_feats.shape[1]
        h_feat, w_feat = dummy_feats.shape[2], dummy_feats.shape[3]
        num_patches = h_feat * w_feat
        
        self.rvit = RViT(yolo_channels=yolo_channels, num_patches=num_patches).to(DEVICE)

    def forward(self, x, target=None, teach_ratio=0.5, forced_output_length=None):
        x = x.to(DEVICE)
        feats = self.backbone(x)
        return self.rvit(feats, target, teach_ratio, forced_output_length)

    def train(self, mode: bool = True):
        super().train(mode)
        self.rvit.train(mode)
        self.backbone.train(mode)
        return self

    def eval(self):
        super().eval()
        self.rvit.eval()
        self.backbone.eval()
        return self

#--- Tập dữ liệu ---
import albumentations as A
import cv2
import random
from albumentations.pytorch import ToTensorV2

IMG_SIZE = 640
# Nếu dùng YOLO backbone -> False; nếu không thì True
NORMALIZE_FOR_BACKBONE = False

# Letterbox đặt CUỐI CÙNG để tránh lộ vùng đệm kỳ dị khi biến đổi hình học
_LETTERBOX_LAST = [
    A.LongestMaxSize(max_size=IMG_SIZE, interpolation=cv2.INTER_LINEAR, p=1.0),
    A.PadIfNeeded(min_height=IMG_SIZE, min_width=IMG_SIZE,
                  border_mode=cv2.BORDER_CONSTANT, fill=(114,114,114), p=1.0),
    A.ToFloat(max_value=255.0),  # <-- thêm dòng này
]

def _maybe_norm(normalize: bool):
    return [A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))] if normalize else []

def make_ccpd_aug(subset: str, img_size=IMG_SIZE, normalize=False):
    # ----- từng chủ đề/subnet -----
    if subset == "base":
        t = [
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15),
                A.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.12, hue=0.03),
            ], p=0.6),
            A.GaussNoise(std_range=(0.005, 0.015), mean_range=(0.0,0.0), p=0.3),
        ]
    elif subset == "db":  # Dark & Bright
        t = [
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.45, contrast_limit=0.35),
                A.RandomGamma(gamma_limit=(60,160)),
            ], p=0.9),
            A.OneOf([
                A.RandomShadow(shadow_roi=(0.05,0.05,0.95,0.95),
                               num_shadows_limit=(1,1),          # hoặc (1,2) nếu muốn ngẫu nhiên 1–2 bóng
                               shadow_dimension=3,
                               shadow_intensity_range=(0.25,0.5)),
                A.RandomSunFlare(flare_roi=(0,0,1,0.8),
                                 num_flare_circles_range=(2,5),
                                 angle_range=(0.0,1.0),
                                 src_radius=60),
            ], p=0.35),
        ]
    elif subset == "blur":
        t = [
            A.OneOf([
                A.MotionBlur(blur_limit=(5,11)),
                A.GaussianBlur(blur_limit=(3,7)),
                A.GlassBlur(sigma=0.7, max_delta=3, iterations=1),
            ], p=0.9),
        ]
    elif subset == "fn":  # Far & Near
        t = [
            A.OneOf([
                A.Affine(scale=(0.55,0.75), translate_percent=(0.0,0.02),
                         rotate=(0,0), shear=(0,0),
                         interpolation=cv2.INTER_LINEAR, mask_interpolation=cv2.INTER_NEAREST,
                         border_mode=cv2.BORDER_CONSTANT, fill=(114,114,114)),
                A.RandomResizedCrop(size=(img_size, img_size),
                                    scale=(0.80,1.00), ratio=(0.95,1.05),
                                    interpolation=cv2.INTER_LINEAR),
            ], p=0.9),
            A.OneOf([
                A.Downscale(scale_range=(0.60,0.85)),
                A.ImageCompression(quality_range=(40,70)),
            ], p=0.5),
        ]
    elif subset == "rotate":
        t = [
            A.Affine(rotate=(-15,15), scale=(0.95,1.05), translate_percent=(0.0,0.04),
                     shear=(0,0),
                     interpolation=cv2.INTER_LINEAR, mask_interpolation=cv2.INTER_NEAREST,
                     border_mode=cv2.BORDER_CONSTANT, fill=(114,114,114), p=1.0),
        ]
    elif subset == "tilt":
        t = [
            A.Affine(scale=(0.95,1.05), translate_percent=(0.0,0.05),
                     rotate=(-3,3), shear=(-18,18),
                     interpolation=cv2.INTER_LINEAR, mask_interpolation=cv2.INTER_NEAREST,
                     border_mode=cv2.BORDER_CONSTANT, fill=(114,114,114), p=0.8),
            A.Perspective(scale=(0.06,0.14), keep_size=True,
                          border_mode=cv2.BORDER_CONSTANT, fill=(114,114,114),
                          interpolation=cv2.INTER_LINEAR, mask_interpolation=cv2.INTER_NEAREST, p=0.5),
        ]
    elif subset == "weather":
        t = [
            A.OneOf([
                A.RandomFog(fog_coef_range=(0.12,0.28), alpha_coef=0.06),
                A.RandomRain(slant_range=(-10,10), drop_length=14, blur_value=3),
                A.RandomSnow(snow_point_range=(0.10,0.30), brightness_coeff=1.5),
            ], p=0.9),
            A.ISONoise(color_shift=(0.01,0.05), intensity=(0.1,0.35), p=0.4),
        ]
    elif subset == "challenge":
        t = [
            A.OneOf([
                A.Affine(scale=(0.90,1.08), translate_percent=(0.0,0.06),
                         rotate=(-7,7), shear=(-15,15),
                         interpolation=cv2.INTER_LINEAR, mask_interpolation=cv2.INTER_NEAREST,
                         border_mode=cv2.BORDER_CONSTANT, fill=(114,114,114)),
                A.Perspective(scale=(0.08,0.16), keep_size=True,
                              border_mode=cv2.BORDER_CONSTANT, fill=(114,114,114),
                              interpolation=cv2.INTER_LINEAR, mask_interpolation=cv2.INTER_NEAREST),
            ], p=0.7),
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.35, contrast_limit=0.3),
                A.RandomGamma(gamma_limit=(70,140)),
            ], p=0.8),
            A.OneOf([
                A.MotionBlur(blur_limit=(3,9)),
                A.GaussianBlur(blur_limit=(3,7)),
            ], p=0.5),
            A.CoarseDropout(num_holes_range=(1,2),           # hoặc (1,3)
                            hole_height_range=(0.06,0.12),
                            hole_width_range=(0.06,0.12),
                            fill=0, p=0.12),
        ]
    else:
        raise ValueError(f"Unknown subset: {subset}")

    return A.Compose(
        [*t, *_LETTERBOX_LAST, *_maybe_norm(normalize), ToTensorV2()],
        strict=True
    )


class CCPDAugmenter:
    def __init__(self, img_size=IMG_SIZE, normalize=False, weights=None, fixed_subset=None):
        self.img_size = img_size
        self.normalize = normalize
        # phân phối gần giống “default test” của CCPD; bạn chỉnh tùy ý
        self.weights = weights or {
            "base":0.20, "db":0.35, "blur":0.10, "fn":0.35,
            "rotate":0.20, "tilt":0.20, "weather":0.20, "challenge":0.35
        }
        self.fixed_subset = fixed_subset  # nếu muốn khóa 1 subnet

    def _sample_subset(self):
        if self.fixed_subset is not None:
            return self.fixed_subset
        names, probs = zip(*self.weights.items())
        return random.choices(names, weights=probs, k=1)[0]

    def __call__(self, image=None, **kwargs):
        if image is None:
            image = kwargs.get("image")
        subset = self._sample_subset()
        tf = make_ccpd_aug(subset, img_size=self.img_size, normalize=self.normalize)
        out = tf(image=image)
        # có thể gắn kèm tên subnet để debug:
        # out["subset"] = subset
        return out

# === Drop-in thay thế cho train_tf / val_tf trong code của bạn ===
train_tf = CCPDAugmenter(img_size=IMG_SIZE, normalize=NORMALIZE_FOR_BACKBONE)
val_tf = A.Compose([*_LETTERBOX_LAST, *(_maybe_norm(NORMALIZE_FOR_BACKBONE)), ToTensorV2()],
                   strict=True)

class LicensePlateDataset(Dataset):
    def __init__(self, img_dir, license_dir, max_seq_length=15, is_train=True):
        self.img_dir = img_dir
        self.license_dir = license_dir
        self.max_seq_length = max_seq_length
        self.img_names = [f for f in os.listdir(self.img_dir) if f.endswith(('.jpg', '.png'))]
        self.is_train = is_train
        self.transform = train_tf if is_train else val_tf
        
    def __len__(self):
        return len(self.img_names)
        
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_names[idx])
        img = Image.open(img_path).convert("RGB")
        # Albumentations nhận ndarray
        img_np = np.array(img)
        img_tensor = self.transform(image=img_np)["image"]
        
        license_filename = os.path.splitext(self.img_names[idx])[0] + ".txt"
        license_path = os.path.join(self.license_dir, license_filename)
        try:
            with open(license_path, 'r', encoding='utf-8') as f:
                license_text = f.read().upper().strip()
        except FileNotFoundError:
            return img_tensor, torch.full((self.max_seq_length,), EOS_TOKEN, dtype=torch.long)
        
        license_indices = char_to_indices(license_text)
        target = torch.full((self.max_seq_length,), PAD_TOKEN, dtype=torch.long)
        
        actual_len = min(len(license_indices), self.max_seq_length)
        target[:actual_len] = license_indices[:actual_len]

        return img_tensor, target

    @staticmethod
    def collate_fn(batch):
        images = torch.stack([item[0] for item in batch])
        targets = torch.stack([item[1] for item in batch])
        return images, targets

#--- Dừng sớm (Early Stopping) ---
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.0, monitor_metric='val_acc', mode='max', verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.monitor_metric = monitor_metric
        self.mode = mode
        self.verbose = verbose

        if self.mode == 'min':
            self.best_metric_val = np.Inf
        else:
            self.best_metric_val = -np.Inf
        self.counter = 0
        self.early_stop = False

    def __call__(self, current_metric_val):
        improved = False
        if self.mode == 'min':
            if current_metric_val < self.best_metric_val - self.min_delta:
                self.best_metric_val = current_metric_val
                improved = True
        else:
            if current_metric_val > self.best_metric_val + self.min_delta:
                self.best_metric_val = current_metric_val
                improved = True
        
        if improved:
            self.counter = 0
            if self.verbose:
                print(f"EarlyStopping: New best {self.monitor_metric}: {self.best_metric_val:.4f}")
        else:
            self.counter += 1

        if self.verbose and self.counter > 0 and not improved:
            print(f"EarlyStopping counter: {self.counter}/{self.patience} (Best {self.monitor_metric}: {self.best_metric_val:.4f})")
        if self.counter >= self.patience:
            self.early_stop = True
            if self.verbose:
                print(f"Early stopping triggered for {self.monitor_metric}.")
        return self.counter

#--- Huấn luyện ---
YOLO_MODEL_PATH = '/kaggle/input/yolov11s/pytorch/default/1/best.pt'
YOLO_TARGET_FEATURE_LAYER_INDEX = 13

# IMG_DIR_TRAIN = "/kaggle/input/aolp-ac-1/AOLP_AC/Images/Train"
# LICENSE_DIR_TRAIN = "/kaggle/input/aolp-ac-1/AOLP_AC/Text/Train"

IMG_DIR_VAL = "/kaggle/input/clpd-dataset/CLPD/image"
LICENSE_DIR_VAL = "/kaggle/input/clpd-dataset/CLPD/text"

# IMG_DIR_TEST = "/kaggle/input/pku-g1/G1/G1_lp_images"
# LICENSE_DIR_TEST = "/kaggle/input/pku-g1/G1/G1_lp_images"

MAX_SEQ_LENGTH = 15
BATCH_SIZE = 8
NUM_WORKERS = 4
LEARNING_RATE = 5e-5
MAX_LR_SCHEDULER = 5e-4
WEIGHT_DECAY = 5e-5
NUM_EPOCHS = 1
ACCUM_STEPS = 2
PATIENCE_EARLY_STOP = 200
TEACH_RATIO_START = 0.7
TEACH_RATIO_END = 0.05
LABEL_SMOOTHING = 0.01
scaler = torch.amp.GradScaler('cuda')
autocast_context = lambda: torch.amp.autocast('cuda')

# train_dataset_full = LicensePlateDataset(img_dir=IMG_DIR_TRAIN, license_dir=LICENSE_DIR_TRAIN, max_seq_length=MAX_SEQ_LENGTH, is_train=True)
val_dataset = LicensePlateDataset(img_dir=IMG_DIR_VAL, license_dir=LICENSE_DIR_VAL, max_seq_length=MAX_SEQ_LENGTH, is_train=False)
# test_dataset = LicensePlateDataset(img_dir=IMG_DIR_TEST, license_dir=LICENSE_DIR_TEST, max_seq_length=MAX_SEQ_LENGTH, is_train=False)

# train_dataloader = DataLoader(train_dataset_full, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, collate_fn=LicensePlateDataset.collate_fn, pin_memory=(DEVICE == 'cuda'), drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=LicensePlateDataset.collate_fn, pin_memory=(DEVICE == 'cuda'))
# test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=LicensePlateDataset.collate_fn, pin_memory=(DEVICE == 'cuda'))

model = YOLO_RViT(YOLO_MODEL_PATH, yolo_target_feature_layer_idx=YOLO_TARGET_FEATURE_LAYER_INDEX).to(DEVICE)

optimizer = bnb_optim.AdamW8bit(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY)

# scheduler = torch.optim.lr_scheduler.OneCycleLR(
#     optimizer, max_lr=MAX_LR_SCHEDULER,
#     # epochs=NUM_EPOCHS,
#     # steps_per_epoch=(len(train_dataloader) + ACCUM_STEPS - 1) // ACCUM_STEPS,
#     pct_start=0.2,
#     div_factor=(MAX_LR_SCHEDULER / LEARNING_RATE) if MAX_LR_SCHEDULER > LEARNING_RATE else 10.0)

scheduler_type = "OneCycleLR"
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN, label_smoothing=LABEL_SMOOTHING)
early_stopper = EarlyStopping(patience=PATIENCE_EARLY_STOP, min_delta=0.0005, monitor_metric='val_acc', mode='max', verbose=True)

checkpoint = torch.load("/kaggle/input/ccpd_base_good/pytorch/default/1/CCPD_BASE_GOOD.pth", map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

train_loss_values, val_loss_values, test_loss_values = [], [], []
train_acc_values, val_acc_values, val_acc_constrained_values, test_acc_values, test_acc_constrained_values = [], [], [], [], []
epoch_count_list = []
best_val_acc = 0.0

for epoch in range(NUM_EPOCHS):
    epoch_count_list.append(epoch + 1)
    # model.train()
    # train_loss, train_correct, train_total_chars = 0, 0, 0

    teach_ratio = TEACH_RATIO_START - (TEACH_RATIO_START - TEACH_RATIO_END) * (epoch / max(1, NUM_EPOCHS - 1))

    # optimizer.zero_grad()

    # pbar_train = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [TRAIN] LR: {optimizer.param_groups[0]['lr']:.2e} Teach: {teach_ratio:.2f} Scheduler: {scheduler_type}")
    # for batch_idx, (imgs, targets) in enumerate(pbar_train):
    #     imgs, targets = imgs.to(DEVICE, non_blocking=True), targets.to(DEVICE, non_blocking=True)
        
    #     with autocast_context():
    #         outputs = model(imgs, target=targets, teach_ratio=teach_ratio)
    #         flat_outputs = outputs.reshape(-1, NUM_CLASSES)
    #         flat_targets = targets[:, 1:].reshape(-1)
    #         loss = loss_fn(flat_outputs, flat_targets)
    #         loss = loss / ACCUM_STEPS
            
    #     scaler.scale(loss).backward()
        
    #     if (batch_idx + 1) % ACCUM_STEPS == 0 or (batch_idx + 1) == len(train_dataloader):
    #         torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), max_norm=1.0)
    #         scaler.step(optimizer)
    #         scaler.update()
    #         optimizer.zero_grad()
    #         if scheduler_type == "OneCycleLR":
    #             scheduler.step()

    #     train_loss += loss.item() * ACCUM_STEPS
    #     with torch.no_grad():
    #         preds = outputs.argmax(-1)
    #         true_chars = targets[:, 1:]
    #         for i in range(imgs.size(0)):
    #             pred_seq_list_train = preds[i].tolist()
    #             if EOS_TOKEN in pred_seq_list_train:
    #                 pred_seq_list_train = pred_seq_list_train[:pred_seq_list_train.index(EOS_TOKEN)]
                
    #             true_seq_list_train = true_chars[i].tolist()
    #             true_content_train = [x for x in true_seq_list_train if x not in [EOS_TOKEN, PAD_TOKEN]]
    #             len_true_content_train = len(true_content_train)
                
    #             cmp_len = min(len(pred_seq_list_train), len_true_content_train)
    #             if cmp_len > 0:
    #                 train_correct += (torch.tensor(pred_seq_list_train[:cmp_len]) == torch.tensor(true_content_train[:cmp_len])).sum().item()
                
    #             train_total_chars += len_true_content_train

    #         if batch_idx == 0 and epoch % 1 == 0 and imgs.size(0) > 0:
    #             print("\n--- Ví dụ Batch 0 trong quá trình Huấn luyện ---")
    #             for i in range(min(5, imgs.size(0))):
    #                 pred_batch_i_list = preds[i].tolist()
    #                 pred_example = index_to_char(pred_batch_i_list)
    #                 true_batch_i_list = true_chars[i].tolist()
    #                 true_example = index_to_char(true_batch_i_list)
    #                 print(f"  Dự đoán: '{pred_example}'")
    #                 print(f"  Thực tế: '{true_example}'")
    #             print("-------------------------------")
        
    #     pbar_train.set_postfix(loss=loss.item() * ACCUM_STEPS)

    # avg_train_loss = train_loss / len(train_dataloader) if len(train_dataloader) > 0 else 0
    # avg_train_acc = train_correct / train_total_chars if train_total_chars > 0 else 0
    # train_loss_values.append(avg_train_loss)
    # train_acc_values.append(avg_train_acc)

    # --- Vòng lặp kiểm tra (Validation) ---
    model.eval()
    val_loss, val_correct, val_total_chars, val_exact_match_correct, val_total_sequences = 0, 0, 0, 0, 0
    pbar_val = tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [VAL]")
    with torch.no_grad():
        for imgs, targets in pbar_val:
            imgs, targets = imgs.to(DEVICE, non_blocking=True), targets.to(DEVICE, non_blocking=True)
            with autocast_context():
                outputs = model(imgs, target=None, teach_ratio=0.0)
                out_seq_len_val = outputs.size(1)
                tgt_content_len_val = targets.size(1) - 1
                
                # Cắt bớt hoặc đệm đầu ra để tính toán loss
                if out_seq_len_val > tgt_content_len_val:
                    outputs_for_loss_val = outputs[:, :tgt_content_len_val, :]
                elif out_seq_len_val < tgt_content_len_val:
                    padding_val_val = torch.zeros(outputs.size(0), tgt_content_len_val - out_seq_len_val, NUM_CLASSES, device=DEVICE)
                    padding_val_val[:,:,PAD_TOKEN] = 1
                    outputs_for_loss_val = torch.cat([outputs, padding_val_val], dim=1)
                else:
                    outputs_for_loss_val = outputs

                flat_outputs_val = outputs_for_loss_val.reshape(-1, NUM_CLASSES)
                flat_targets_val = targets[:, 1:].reshape(-1)
                loss = loss_fn(flat_outputs_val, flat_targets_val)
                
            val_loss += loss.item()
            preds_val = outputs.argmax(-1)           # [B, T]
            true_chars_val = targets[:, 1:]          # bỏ SOS ở đầu
            
            for i in range(imgs.size(0)):
                # --- Chuỗi dự đoán (cắt tại EOS, KHÔNG lọc theo giá trị token) ---
                pred_seq_full = preds_val[i].tolist()
                if EOS_TOKEN in pred_seq_full:
                    pred_seq_full = pred_seq_full[:pred_seq_full.index(EOS_TOKEN)]
            
                # --- Chuỗi GT (bỏ EOS/PAD) ---
                true_seq_full = true_chars_val[i].tolist()
                true_content  = [x for x in true_seq_full if x not in [EOS_TOKEN, PAD_TOKEN]]
            
                # Căn theo độ dài khả dụng để so sánh theo vị trí
                cmp_len = min(len(pred_seq_full), len(true_content))
                pred_cmp = pred_seq_full[:cmp_len]
                gt_cmp   = true_content[:cmp_len]
            
                # Chỉ lấy các vị trí mà GT KHÔNG phải Chinese
                non_cn_pos = [t for t, tk in enumerate(gt_cmp) if not is_chinese_idx(tk)]
            
                # --- Character-level accuracy (KHÔNG tính ký tự Chinese) ---
                if non_cn_pos:
                    val_correct += sum(int(pred_cmp[t] == gt_cmp[t]) for t in non_cn_pos)
                    val_total_chars += len(non_cn_pos)
            
                # --- Exact-match trên phần non-Chinese ---
                # Lấy chuỗi con theo đúng các vị trí non-Chinese của GT
                gt_non_cn   = [gt_cmp[t] for t in non_cn_pos]
                pred_non_cn = [pred_cmp[t] for t in non_cn_pos]
            
                # Chỉ tính E2E nếu có ít nhất 1 ký tự non-Chinese để đánh giá
                if gt_non_cn:
                    val_total_sequences += 1
                    if pred_non_cn == gt_non_cn:
                        val_exact_match_correct += 1


            
            pbar_val.set_postfix(loss=loss.item())

    avg_val_loss = val_loss / len(val_dataloader) if len(val_dataloader) > 0 else 0
    avg_val_acc = val_correct / val_total_chars if val_total_chars > 0 else 0
    val_exact_match_acc = val_exact_match_correct / val_total_sequences if val_total_sequences > 0 else 0
    val_loss_values.append(avg_val_loss)
    val_acc_values.append(avg_val_acc)
    val_acc_constrained_values.append(val_exact_match_acc)
    
    # # Chuyển sang ReduceLROnPlateau nếu val_acc chững lại
    # early_stopper_val_result = early_stopper(avg_val_acc)
    # if early_stopper_val_result >= 30 and scheduler_type == "OneCycleLR":
    #     print(f"Độ chính xác kiểm tra chững lại trong 50 epoch. Chuyển sang ReduceLROnPlateau.")
    #     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)
    #     scheduler_type = "ReduceLROnPlateau"
    
    # # Cập nhật scheduler
    # if scheduler_type == "ReduceLROnPlateau":
    #     scheduler.step(avg_val_acc)

    # if early_stopper.early_stop:
    #     print(f"--> Dừng sớm tại epoch {epoch+1}")
    #     break
    
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} | LR: {optimizer.param_groups[0]['lr']:.2e} | Teach: {teach_ratio:.2f} | Scheduler: {scheduler_type}")
    # print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc Ký tự: {avg_train_acc:.4f}")
    print(f"  Val Loss: {avg_val_loss:.4f}   | Val Acc Ký tự: {avg_val_acc:.4f}")
    print(f"  Test Acc Khớp chính xác (Toàn bộ chuỗi): {val_exact_match_acc:.4f}")
    print("-" * 70)

    # if avg_val_acc > best_val_acc:
    #     best_val_acc = avg_val_acc
    #     print(f"*** Độ chính xác kiểm tra tốt nhất mới: {best_val_acc:.4f}. Đang lưu best_model.pth ***")
    #     torch.save({
    #         'epoch': epoch,
    #         'model_state_dict': model.state_dict(),
    #         'optimizer_state_dict': optimizer.state_dict(),
    #         'scheduler_state_dict': scheduler.state_dict(),
    #         'val_loss': avg_val_loss,
    #         'val_acc': avg_val_acc,
    #         'val_exact_match_acc': val_exact_match_acc,
    #     }, "best_yolo_rvit_model.pth")
        
# final_epoch_val = epoch if 'epoch' in locals() and epoch is not None else NUM_EPOCHS - 1
# torch.save({
#     'epoch': final_epoch_val,
#     'model_state_dict': model.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict(),
#     'scheduler_state_dict': scheduler.state_dict(),
#     'train_loss_history': train_loss_values,
#     'val_loss_history': val_loss_values,
#     'train_acc_history': train_acc_values,
#     'val_acc_history': val_acc_values,
#     'val_exact_match_acc_history': val_acc_constrained_values,}, "final_yolo_rvit_model.pth")

print("\nQuá trình huấn luyện đã hoàn thành!")
if val_acc_values:
    print(f"Độ chính xác kiểm tra cuối cùng (Cấp độ ký tự): {val_acc_values[-1]:.4f}")
if val_acc_constrained_values:
    print(f"Độ chính xác khớp chính xác cuối cùng: {val_acc_constrained_values[-1]:.4f}")

# # Vẽ đồ thị
# plt.figure(figsize=(18, 12))
# plt.subplot(2, 2, 1)
# plt.plot(epoch_count_list, train_loss_values, label='Train Loss', marker='o', linestyle='-')
# plt.plot(epoch_count_list, val_loss_values, label='Validation Loss', marker='s', linestyle='--')
# plt.title('Loss Curves')
# plt.xlabel('Epochs')
# plt.ylabel('Loss')
# plt.legend()
# plt.grid(True)
# plt.subplot(2, 2, 2)
# plt.plot(epoch_count_list, train_acc_values, label='Train Char Accuracy', marker='o', linestyle='-')
# plt.plot(epoch_count_list, val_acc_values, label='Validation Char Accuracy (Greedy)', marker='s', linestyle='--')
# plt.title('Character Accuracy')
# plt.xlabel('Epochs')
# plt.ylabel('Character Accuracy')
# plt.legend()
# plt.grid(True)
# plt.tight_layout()
# plt.savefig("performance_plots.png")
# plt.show()
# print("Đã lưu đồ thị vào performance_plots.png")

Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


Epoch 1/1 [VAL]: 100%|██████████| 150/150 [00:22<00:00,  6.73it/s, loss=0.685]


Epoch 1/1 | LR: 1.05e-09 | Teach: 0.70 | Scheduler: OneCycleLR
  Val Loss: 1.0455   | Val Acc Ký tự: 0.9075
  Test Acc Khớp chính xác (Toàn bộ chuỗi): 0.8083
----------------------------------------------------------------------

Quá trình huấn luyện đã hoàn thành!
Độ chính xác kiểm tra cuối cùng (Cấp độ ký tự): 0.9075
Độ chính xác khớp chính xác cuối cùng: 0.8083





In [3]:
# from ultralytics import YOLO

# def print_yolo_modules(model_path):
#     """
#     Tải một mô hình YOLO và in ra các module (lớp) của nó.
#     :param model_path: Đường dẫn đến file .pt của mô hình YOLO.
#     """
#     try:
#         # Tải một mô hình YOLOv8 đã được huấn luyện sẵn
#         print(f"Đang tải mô hình YOLO từ: {model_path}")
#         yolo_model = YOLO(model_path)
        
#         # Truy cập vào kiến trúc mô hình.
#         # Các module được lưu trong một list Python tại yolo_model.model.model
#         modules = yolo_model.model.model
        
#         print("\n--- Cấu trúc mô hình YOLOv8 ---")
#         # Lặp qua từng module và in thông tin
#         for i, layer in enumerate(modules):
#             # In chỉ số, tên lớp và các tham số của module đó
#             print(f"Layer {i}: {layer.__class__.__name__}")
#             print(f"  - Parameters: {sum(p.numel() for p in layer.parameters()):,} trainable")
#             print(f"  - Details: {layer}")
        
#         print("---------------------------------")
#         print("Kết thúc việc in cấu trúc mô hình.")

#     except Exception as e:
#         print(f"Lỗi: Không thể tải hoặc xử lý mô hình. Vui lòng kiểm tra đường dẫn '{model_path}' và đảm bảo bạn đã cài đặt 'ultralytics' đúng cách.")
#         print(f"Chi tiết lỗi: {e}")

# if __name__ == "__main__":
#     # Thay đổi đường dẫn này nếu bạn muốn kiểm tra một mô hình khác
#     model_path_to_inspect = '/kaggle/input/yolo_v11s_ac/pytorch/default/1/best.pt'
#     print_yolo_modules(model_path_to_inspect)