In [None]:
import os, xml.etree.ElementTree as ET
from collections import Counter

VOC_DIR = "E:\\Pycharm\\Advanced-Reading-on-Computer-Vision\\Datasets\\VOC"
ann_dir = os.path.join(VOC_DIR, "Annotations")

cnt = Counter()
bad = 0
for fn in os.listdir(ann_dir):
    if not fn.endswith(".xml"): continue
    root = ET.parse(os.path.join(ann_dir, fn)).getroot()
    nobj = 0
    for obj in root.findall("object"):
        name = (obj.find("name").text or "").strip()
        if name == "":
            bad += 1
            continue
        cnt[name] += 1
        nobj += 1
    if nobj == 0:
        pass

print("== Class frequency ==")
for k, v in cnt.most_common():
    print(f"{k:20s} : {v}")
print("\nEmpty/invalid objects:", bad)
print("Total XML files:", len([f for f in os.listdir(ann_dir) if f.endswith('.xml')]))


In [None]:
# E:\Pycharm\Advanced-Reading-on-Computer-Vision\config_label_map.py
FINAL_CLASSES = [
    "person", "vehicle", "animal", "furniture", "food_drink", "device", "sports", "signage_decor",
]

MERGE_TO_GROUP = {
    # person
    "person": "person",

    # vehicle
    "car": "vehicle", "truck": "vehicle", "bus": "vehicle", "train": "vehicle",
    "motorcycle": "vehicle", "bicycle": "vehicle", "boat": "vehicle", "airplane": "vehicle",

    # animal
    "dog": "animal", "cat": "animal", "bird": "animal", "horse": "animal", "sheep": "animal",
    "cow": "animal", "elephant": "animal", "bear": "animal", "zebra": "animal", "giraffe": "animal",

    # furniture
    "chair": "furniture", "couch": "furniture", "bed": "furniture",
    "dining table": "furniture", "toilet": "furniture", "sink": "furniture",

    # food & drink
    "banana": "food_drink", "apple": "food_drink", "sandwich": "food_drink", "orange": "food_drink",
    "broccoli": "food_drink", "carrot": "food_drink", "donut": "food_drink", "cake": "food_drink",
    "pizza": "food_drink", "hot dog": "food_drink", "bottle": "food_drink", "cup": "food_drink",
    "bowl": "food_drink", "wine glass": "food_drink",

    # device
    "tv": "device", "laptop": "device", "mouse": "device", "keyboard": "device", "cell phone": "device",
    "remote": "device", "refrigerator": "device", "microwave": "device", "oven": "device",
    "toaster": "device", "hair drier": "device",

    # sports
    "skis": "sports", "snowboard": "sports", "sports ball": "sports", "skateboard": "sports",
    "surfboard": "sports", "tennis racket": "sports", "frisbee": "sports",
    "baseball bat": "sports", "baseball glove": "sports", "kite": "sports",

    # signage / decor / ph·ª• ki·ªán
    "traffic light": "signage_decor", "stop sign": "signage_decor", "parking meter": "signage_decor",
    "handbag": "signage_decor", "backpack": "signage_decor", "umbrella": "signage_decor",
    "tie": "signage_decor", "suitcase": "signage_decor", "book": "signage_decor", "clock": "signage_decor",
    "vase": "signage_decor", "teddy bear": "signage_decor", "scissors": "signage_decor",
    "toothbrush": "signage_decor", "potted plant": "signage_decor", "fire hydrant": "signage_decor",
}


In [None]:
import os, xml.etree.ElementTree as ET

VOC_DIR = r"E:\Pycharm\Advanced-Reading-on-Computer-Vision\Datasets\VOC"
ann_dir = os.path.join(VOC_DIR, "Annotations")
imgsets = os.path.join(VOC_DIR, "ImageSets", "Main")
os.makedirs(imgsets, exist_ok=True)

TARGET = set(FINAL_CLASSES)
COCO2GROUP = dict(MERGE_TO_GROUP)
GROUP_PASS = set(TARGET)

kept_ids = []


def rewrite_xml(xml_path):
    tree = ET.parse(xml_path);
    root = tree.getroot()
    objs = root.findall("object")
    new_objs = []
    for obj in objs:
        name_node = obj.find("name")
        if name_node is None: continue
        name = (name_node.text or "").strip()
        # map
        if name in COCO2GROUP:
            new_name = COCO2GROUP[name]
        elif name in GROUP_PASS:
            new_name = name
        else:
            new_name = None
        if new_name in TARGET:
            name_node.text = new_name
            new_objs.append(obj)

    for obj in objs: root.remove(obj)
    for obj in new_objs: root.append(obj)

    if len(new_objs) == 0:
        return False
    tree.write(xml_path, encoding="utf-8")
    return True


for fn in os.listdir(ann_dir):
    if not fn.endswith(".xml"): continue
    ok = rewrite_xml(os.path.join(ann_dir, fn))
    img_id = os.path.splitext(fn)[0]
    if ok:
        kept_ids.append(img_id)
    else:
        os.remove(os.path.join(ann_dir, fn))


def filter_ids(txt_path, keep_set):
    if not os.path.exists(txt_path): return
    with open(txt_path) as f:
        ids = [x.strip() for x in f if x.strip()]
    ids = [i for i in ids if i in keep_set]
    with open(txt_path, "w") as f:
        for i in ids: f.write(i + "\n")


keep = set(kept_ids)
filter_ids(os.path.join(imgsets, "train.txt"), keep)
filter_ids(os.path.join(imgsets, "val.txt"), keep)

print(f"Target classes:", sorted(TARGET))


## Hu·∫•n luy·ªán l·∫°i Faster R-CNN

In [None]:
import os, xml.etree.ElementTree as ET
from typing import List
import numpy as np
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision

VOC_DIR = r"E:\Pycharm\Advanced-Reading-on-Computer-Vision\Datasets\VOC"

CLS_TO_IDX = {c: i + 1 for i, c in enumerate(FINAL_CLASSES)}  # 0 = background
NUM_CLASSES = len(FINAL_CLASSES) + 1


class VOCDataset(Dataset):
    def __init__(self, root: str, image_set="train", size=800, augment=False, classes: List[str] = None):
        self.root = root
        self.img_dir = os.path.join(root, "JPEGImages")
        self.ann_dir = os.path.join(root, "Annotations")
        with open(os.path.join(root, "ImageSets", "Main", f"{image_set}.txt")) as f:
            self.ids = [x.strip() for x in f if x.strip()]
        self.size = size
        self.augment = augment
        self.classes = classes or FINAL_CLASSES
        self.cls_to_idx = {c: i + 1 for i, c in enumerate(self.classes)}
        self.normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                          std=[0.229, 0.224, 0.225])

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, i):
        img_id = self.ids[i]
        img = Image.open(os.path.join(self.img_dir, f"{img_id}.jpg")).convert("RGB")
        w0, h0 = img.size
        img = img.resize((self.size, self.size))
        sx, sy = self.size / w0, self.size / h0

        boxes, labels = [], []
        root = ET.parse(os.path.join(self.ann_dir, f"{img_id}.xml")).getroot()
        for obj in root.findall("object"):
            name = obj.find("name").text.strip()
            if name not in self.cls_to_idx:
                continue
            bb = obj.find("bndbox")
            x1 = float(bb.find("xmin").text) * sx
            y1 = float(bb.find("ymin").text) * sy
            x2 = float(bb.find("xmax").text) * sx
            y2 = float(bb.find("ymax").text) * sy
            if x2 > x1 and y2 > y1:
                boxes.append([x1, y1, x2, y2])
                labels.append(self.cls_to_idx[name])

        x = torchvision.transforms.functional.to_tensor(img)
        if self.augment:
            import random
            if random.random() < 0.5 and boxes:
                x = torchvision.transforms.functional.hflip(x)
                for b in boxes:
                    x1, y1, x2, y2 = b
                    b[0], b[2] = self.size - x2, self.size - x1
        x = self.normalize(x)

        target = {
            "boxes": torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0, 4), dtype=torch.float32),
            "labels": torch.tensor(labels, dtype=torch.int64) if labels else torch.zeros((0,), dtype=torch.int64),
            "image_id": torch.tensor([i]),
        }
        return x, target


def collate_fn(b):
    imgs, tgts = list(zip(*b))
    return list(imgs), list(tgts)


# DataLoader
train_ds = VOCDataset(VOC_DIR, "train", size=800, augment=True, classes=FINAL_CLASSES)
val_ds = VOCDataset(VOC_DIR, "val", size=800, augment=False, classes=FINAL_CLASSES)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=2, collate_fn=collate_fn)

print("Classes:", FINAL_CLASSES, "| #train:", len(train_ds), "| #val:", len(val_ds))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


In [None]:
import numpy as np
import torch


def box_iou(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    N, M = a.shape[0], b.shape[0]
    ious = np.zeros((N, M), dtype=np.float32)
    for i in range(N):
        ax1, ay1, ax2, ay2 = a[i]
        aarea = max(0, ax2 - ax1) * max(0, ay2 - ay1)
        if aarea <= 0: continue
        xx1 = np.maximum(ax1, b[:, 0]);
        yy1 = np.maximum(ay1, b[:, 1])
        xx2 = np.minimum(ax2, b[:, 2]);
        yy2 = np.minimum(ay2, b[:, 3])
        inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1)
        barea = np.maximum(0, b[:, 2] - b[:, 0]) * np.maximum(0, b[:, 3] - b[:, 1])
        union = aarea + barea - inter + 1e-8
        ious[i] = inter / union
    return ious


@torch.no_grad()
def evaluate_ap50(model, loader, iou_th=0.5, score_th=0.05, max_det=100):
    model.eval()
    device = next(model.parameters()).device

    all_scores, all_tp, all_fp, npos = [], [], [], 0
    for imgs, targets in loader:
        imgs = [im.to(device) for im in imgs]
        outs = model(imgs)
        for out, tgt in zip(outs, targets):
            gt_boxes = tgt["boxes"].numpy()
            npos += len(gt_boxes)

            boxes = out["boxes"].cpu().numpy()
            scores = out["scores"].cpu().numpy()
            keep = scores >= score_th
            boxes, scores = boxes[keep], scores[keep]
            if len(boxes) > max_det:
                idx = np.argsort(-scores)[:max_det]
                boxes, scores = boxes[idx], scores[idx]

            ious = box_iou(boxes, gt_boxes) if (len(boxes) and len(gt_boxes)) else np.zeros((len(boxes), len(gt_boxes)))
            order = np.argsort(-scores)
            boxes, scores = boxes[order], scores[order]
            ious = ious[order]

            used = np.zeros((len(gt_boxes),), dtype=bool)
            tp = np.zeros((len(boxes),), dtype=np.float32)
            fp = np.zeros((len(boxes),), dtype=np.float32)
            for i in range(len(boxes)):
                if len(gt_boxes) == 0:
                    fp[i] = 1;
                    continue
                j = np.argmax(ious[i])
                if ious[i, j] >= iou_th and not used[j]:
                    tp[i] = 1;
                    used[j] = True
                else:
                    fp[i] = 1

            all_scores.extend(scores.tolist())
            all_tp.extend(tp.tolist())
            all_fp.extend(fp.tolist())

    if not all_scores:
        return {"AP50": 0.0, "Precision": 0.0, "Recall": 0.0}

    order = np.argsort(-np.array(all_scores))
    tp = np.array(all_tp)[order]
    fp = np.array(all_fp)[order]
    tp_cum = np.cumsum(tp)
    fp_cum = np.cumsum(fp)

    rec = tp_cum / max(npos, 1)
    prec = tp_cum / np.maximum(tp_cum + fp_cum, 1e-8)

    # VOC2007 11-pt
    ap = 0.0
    for t in np.linspace(0, 1, 11):
        p = prec[rec >= t].max() if np.any(rec >= t) else 0
        ap += p / 11.0

    return {"AP50": float(ap), "Precision": float(prec[-1] if len(prec) else 0),
            "Recall": float(rec[-1] if len(rec) else 0)}


In [None]:
# B∆∞·ªõc 3: Hu·∫•n luy·ªán l·∫°i m√¥ h√¨nh Faster R-CNN
import torchvision
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR
import time
import psutil
import gc

print("=== B·∫ÆT ƒê·∫¶U HU·∫§N LUY·ªÜN M√î H√åNH ===")

# Kh·ªüi t·∫°o m√¥ h√¨nh Faster R-CNN
model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(
    weights_backbone=torchvision.models.ResNet50_Weights.IMAGENET1K_V2,
    num_classes=NUM_CLASSES
).to(device)

# Optimizer v√† scheduler
optimizer = SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

# Th√¥ng s·ªë hu·∫•n luy·ªán
num_epochs = 5  # Gi·∫£m s·ªë epoch ƒë·ªÉ ch·∫°y nhanh h∆°n tr√™n Kaggle
print_freq = 20  # Gi·∫£m t·ª´ 50 xu·ªëng 20 ƒë·ªÉ in th√¥ng tin th∆∞·ªùng xuy√™n h∆°n

print(f"Thi·∫øt b·ªã: {device}")
print(f"S·ªë l∆∞·ª£ng classes: {NUM_CLASSES}")
print(f"S·ªë epoch: {num_epochs}")
print(f"S·ªë m·∫´u training: {len(train_ds)}")
print(f"S·ªë m·∫´u validation: {len(val_ds)}")
print(f"Classes: {FINAL_CLASSES}")

best_ap50 = 0.0  # Bi·∫øn l∆∞u tr·ªØ gi√° tr·ªã AP50 t·ªët nh·∫•t


def get_memory_usage():
    """L·∫•y th√¥ng tin memory usage"""
    if torch.cuda.is_available():
        gpu_memory = torch.cuda.memory_allocated() / 1024**3  # GB
        gpu_memory_cached = torch.cuda.memory_reserved() / 1024**3  # GB
    else:
        gpu_memory = gpu_memory_cached = 0

    cpu_memory = psutil.Process().memory_info().rss / 1024**3  # GB
    return gpu_memory, gpu_memory_cached, cpu_memory


def print_detailed_step_info(step, total_steps, loss_dict, running_loss, step_count, step_time):
    """In th√¥ng tin chi ti·∫øt sau m·ªói step"""
    current_lr = optimizer.param_groups[0]['lr']
    avg_loss = running_loss / step_count
    gpu_mem, gpu_cached, cpu_mem = get_memory_usage()

    # T√°ch c√°c loss components
    loss_components = {}
    for key, value in loss_dict.items():
        loss_components[key] = value.item()

    print(f"  üìä Step [{step+1:4d}/{total_steps}] - {step_time:.2f}s")
    print(f"     üí∞ Total Loss: {avg_loss:.4f} | Current: {sum(loss_dict.values()).item():.4f}")

    # In chi ti·∫øt t·ª´ng loss component
    if 'loss_classifier' in loss_components:
        print(f"     üéØ Classifier: {loss_components['loss_classifier']:.4f}")
    if 'loss_box_reg' in loss_components:
        print(f"     üì¶ Box Reg: {loss_components['loss_box_reg']:.4f}")
    if 'loss_objectness' in loss_components:
        print(f"     üîç Objectness: {loss_components['loss_objectness']:.4f}")
    if 'loss_rpn_box_reg' in loss_components:
        print(f"     üé™ RPN Box: {loss_components['loss_rpn_box_reg']:.4f}")

    print(f"     üìà Learning Rate: {current_lr:.6f}")
    print(f"     üíæ Memory - GPU: {gpu_mem:.2f}GB | CPU: {cpu_mem:.2f}GB")
    print(f"     ‚è±Ô∏è  Steps/sec: {1/step_time:.2f}")
    print()


# B·∫Øt ƒë·∫ßu hu·∫•n luy·ªán
total_start_time = time.time()

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    epoch_start_time = time.time()
    step_count = 0

    print(f"\n{'='*60}")
    print(f"üöÄ EPOCH {epoch+1}/{num_epochs}")
    print(f"{'='*60}")

    for i, (images, targets) in enumerate(train_loader):
        step_start_time = time.time()

        # Chuy·ªÉn data l√™n device
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Forward pass
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        # Backward pass
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        running_loss += losses.item()
        step_count += 1
        step_time = time.time() - step_start_time

        # In th√¥ng tin chi ti·∫øt sau m·ªói step
        if i % print_freq == 0:
            print_detailed_step_info(i, len(train_loader), loss_dict,
                                   running_loss, step_count, step_time)

            # Clear cache ƒë·ªÉ tr√°nh memory leak
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()

    # C·∫≠p nh·∫≠t learning rate
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step()
    new_lr = optimizer.param_groups[0]['lr']

    # T√≠nh to√°n th√¥ng s·ªë epoch
    epoch_time = time.time() - epoch_start_time
    avg_epoch_loss = running_loss / len(train_loader)

    print(f"\n{'='*60}")
    print(f"üìã EPOCH {epoch+1} SUMMARY")
    print(f"{'='*60}")
    print(f"‚è∞ Th·ªùi gian: {epoch_time:.2f}s ({epoch_time/60:.1f} ph√∫t)")
    print(f"üìâ Loss trung b√¨nh: {avg_epoch_loss:.4f}")
    print(f"üìà Learning Rate: {old_lr:.6f} ‚Üí {new_lr:.6f}")
    print(f"üî¢ T·ªïng s·ªë steps: {len(train_loader)}")
    print(f"‚ö° T·ªëc ƒë·ªô: {len(train_loader)/epoch_time:.2f} steps/sec")

    # ƒê√°nh gi√° sau m·ªói epoch
    print(f"\nüîç ƒêang ƒë√°nh gi√° tr√™n t·∫≠p validation...")
    eval_start_time = time.time()
    metrics = evaluate_ap50(model, val_loader, iou_th=0.5, score_th=0.05, max_det=100)
    eval_time = time.time() - eval_start_time

    print(f"‚úÖ Evaluation ho√†n th√†nh trong {eval_time:.2f}s")
    print(f"üéØ Validation Metrics:")
    print(f"   AP50: {metrics['AP50']:.4f}")
    print(f"   Precision: {metrics['Precision']:.4f}")
    print(f"   Recall: {metrics['Recall']:.4f}")

    # L∆ØU CHECKPOINT SAU M·ªñI EPOCH
    checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'epoch': epoch + 1,
        'classes': FINAL_CLASSES,
        'num_classes': NUM_CLASSES,
        'loss': avg_epoch_loss,
        'metrics': metrics,
        'training_time': epoch_time,
        'learning_rate': new_lr
    }

    # L∆∞u checkpoint hi·ªán t·∫°i
    checkpoint_path = f"ckpt_voc_merged_epoch_{epoch+1}.pth"
    torch.save(checkpoint, checkpoint_path)
    print(f"‚úÖ ƒê√£ l∆∞u checkpoint: {checkpoint_path}")

    # L∆∞u th√™m checkpoint t·ªët nh·∫•t (theo AP50)
    if epoch == 0 or metrics['AP50'] > best_ap50:
        best_ap50 = metrics['AP50']
        best_checkpoint_path = "ckpt_voc_merged_best.pth"
        torch.save(checkpoint, best_checkpoint_path)
        print(f"üèÜ ƒê√£ l∆∞u best checkpoint: {best_checkpoint_path} (AP50: {best_ap50:.4f})")

    total_elapsed = time.time() - total_start_time
    print(f"‚è±Ô∏è  T·ªïng th·ªùi gian ƒë√£ train: {total_elapsed/60:.1f} ph√∫t")
    print(f"üìÅ B·∫°n c√≥ th·ªÉ d·ª´ng v√† ti·∫øp t·ª•c t·ª´ epoch {epoch+1}")

# T√≠nh t·ªïng th·ªùi gian training
total_training_time = time.time() - total_start_time

# L∆∞u m√¥ h√¨nh cu·ªëi c√πng
final_checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict(),
    'epoch': num_epochs,
    'classes': FINAL_CLASSES,
    'num_classes': NUM_CLASSES,
    'total_training_time': total_training_time
}

torch.save(final_checkpoint, "ckpt_voc_merged_finetuned.pth")
print(f"\n{'='*60}")
print("üéâ HO√ÄN TH√ÄNH TRAINING")
print(f"{'='*60}")
print(f"‚è∞ T·ªïng th·ªùi gian training: {total_training_time/60:.1f} ph√∫t")
print(f"üèÜ Best AP50 ƒë·∫°t ƒë∆∞·ª£c: {best_ap50:.4f}")
print(f"üíæ File checkpoint cu·ªëi c√πng: ckpt_voc_merged_finetuned.pth")


In [None]:
# ƒê√°nh gi√° m√¥ h√¨nh ƒë√£ hu·∫•n luy·ªán
print("=== ƒê√ÅNH GI√Å M√î H√åNH ƒê√É HU·∫§N LUY·ªÜN ===")

import torchvision

# Load checkpoint
ckpt = torch.load("ckpt_voc_merged_finetuned.pth", map_location=device)
model_eval = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(
    weights_backbone=torchvision.models.ResNet50_Weights.IMAGENET1K_V2,
    num_classes=NUM_CLASSES
).to(device)
model_eval.load_state_dict(ckpt["model"])

# ƒê√°nh gi√° cu·ªëi c√πng
final_metrics = evaluate_ap50(model_eval, val_loader, iou_th=0.5, score_th=0.05, max_det=100)

print("K·∫æT QU·∫¢ ƒê√ÅNH GI√Å CU·ªêI C√ôNG:")
print(f"AP50: {final_metrics['AP50']:.4f}")
print(f"Precision: {final_metrics['Precision']:.4f}")
print(f"Recall: {final_metrics['Recall']:.4f}")

final_metrics


In [None]:
# TI·∫æP T·ª§C TRAINING T·ª™ CHECKPOINT (n·∫øu b·ªã d·ª´ng gi·ªØa ch·ª´ng)
import os
import torch
import torchvision

def resume_training_from_checkpoint(checkpoint_path, num_additional_epochs=3):
    """
    H√†m ƒë·ªÉ ti·∫øp t·ª•c training t·ª´ checkpoint ƒë√£ l∆∞u
    """
    print(f"=== TI·∫æP T·ª§C TRAINING T·ª™ {checkpoint_path} ===")

    # Ki·ªÉm tra file checkpoint c√≥ t·ªìn t·∫°i kh√¥ng
    if not os.path.exists(checkpoint_path):
        print(f"‚ùå Kh√¥ng t√¨m th·∫•y checkpoint: {checkpoint_path}")
        return

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    print(f"üìÇ ƒê√£ load checkpoint t·ª´ epoch {checkpoint['epoch']}")

    # Kh·ªüi t·∫°o l·∫°i model
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(
        weights_backbone=torchvision.models.ResNet50_Weights.IMAGENET1K_V2,
        num_classes=checkpoint['num_classes']
    ).to(device)

    # Load state dict
    model.load_state_dict(checkpoint['model'])

    # Kh·ªüi t·∫°o l·∫°i optimizer v√† scheduler
    optimizer = SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
    scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

    # Load optimizer v√† scheduler state
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])

    start_epoch = checkpoint['epoch']
    best_ap50 = checkpoint.get('metrics', {}).get('AP50', 0.0)

    print(f"üîÑ Ti·∫øp t·ª•c t·ª´ epoch {start_epoch + 1}")
    print(f"üèÜ Best AP50 hi·ªán t·∫°i: {best_ap50:.4f}")

    # Ti·∫øp t·ª•c training
    for epoch in range(start_epoch, start_epoch + num_additional_epochs):
        model.train()
        running_loss = 0.0
        start_time = time.time()

        print(f"\n--- EPOCH {epoch+1}/{start_epoch + num_additional_epochs} ---")

        for i, (images, targets) in enumerate(train_loader):
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            running_loss += losses.item()

            if i % 50 == 0:
                avg_loss = running_loss / (i + 1)
                print(f"  Step [{i+1}/{len(train_loader)}], Loss: {avg_loss:.4f}")

        scheduler.step()

        epoch_time = time.time() - start_time
        avg_epoch_loss = running_loss / len(train_loader)

        print(f"  Epoch ho√†n th√†nh trong {epoch_time:.2f}s")
        print(f"  Loss trung b√¨nh: {avg_epoch_loss:.4f}")

        # ƒê√°nh gi√°
        print("  ƒêang ƒë√°nh gi√° tr√™n t·∫≠p validation...")
        metrics = evaluate_ap50(model, val_loader, iou_th=0.5, score_th=0.05, max_det=100)
        print(f"  Validation - AP50: {metrics['AP50']:.4f}, "
              f"Precision: {metrics['Precision']:.4f}, Recall: {metrics['Recall']:.4f}")

        # L∆∞u checkpoint m·ªõi
        new_checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch + 1,
            'classes': checkpoint['classes'],
            'num_classes': checkpoint['num_classes'],
            'loss': avg_epoch_loss,
            'metrics': metrics
        }

        checkpoint_path_new = f"ckpt_voc_merged_epoch_{epoch+1}.pth"
        torch.save(new_checkpoint, checkpoint_path_new)
        print(f"  ‚úÖ ƒê√£ l∆∞u checkpoint: {checkpoint_path_new}")

        # C·∫≠p nh·∫≠t best checkpoint
        if metrics['AP50'] > best_ap50:
            best_ap50 = metrics['AP50']
            best_checkpoint_path = "ckpt_voc_merged_best.pth"
            torch.save(new_checkpoint, best_checkpoint_path)
            print(f"  üèÜ ƒê√£ c·∫≠p nh·∫≠t best checkpoint: {best_checkpoint_path} (AP50: {best_ap50:.4f})")

    # L∆∞u checkpoint cu·ªëi c√πng
    final_checkpoint_path = "ckpt_voc_merged_finetuned.pth"
    torch.save(new_checkpoint, final_checkpoint_path)
    print(f"\n‚úÖ Ho√†n th√†nh! L∆∞u checkpoint cu·ªëi: {final_checkpoint_path}")

    return model

# V√ç D·ª§ S·ª¨ D·ª§NG:
# N·∫øu b·∫°n mu·ªën ti·∫øp t·ª•c t·ª´ epoch 1, s·ª≠ d·ª•ng:
# model_resumed = resume_training_from_checkpoint("ckpt_voc_merged_epoch_1.pth", num_additional_epochs=4)

# N·∫øu b·∫°n mu·ªën ti·∫øp t·ª•c t·ª´ epoch 2, s·ª≠ d·ª•ng:
# model_resumed = resume_training_from_checkpoint("ckpt_voc_merged_epoch_2.pth", num_additional_epochs=3)

print("\nüìã H∆Ø·ªöNG D·∫™N S·ª¨ D·ª§NG:")
print("1. N·∫øu training b·ªã d·ª´ng ·ªü epoch 2, b·∫°n c√≥ file: ckpt_voc_merged_epoch_1.pth")
print("2. ƒê·ªÉ ti·∫øp t·ª•c, ch·∫°y:")
print("   model_resumed = resume_training_from_checkpoint('ckpt_voc_merged_epoch_1.pth', num_additional_epochs=4)")
print("3. C√°c file checkpoint s·∫Ω ƒë∆∞·ª£c l∆∞u:")
print("   - ckpt_voc_merged_epoch_X.pth (sau m·ªói epoch)")
print("   - ckpt_voc_merged_best.pth (model t·ªët nh·∫•t)")
print("   - ckpt_voc_merged_finetuned.pth (model cu·ªëi c√πng)")

