In [None]:
import os
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as ET
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split
import json
from pathlib import Path
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm
import time
from collections import defaultdict

processed_dir = r"D:\xla v1\code\tienxuli\archive"
images_dir = os.path.join(processed_dir, "images")
annotations_dir = os.path.join(processed_dir, "xmls")

class ObjectDetectionDataset(Dataset):
    def __init__(self, images_dir, annotations_dir, image_files, class_to_idx, transforms=None):
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.image_files = image_files
        self.class_to_idx = class_to_idx
        self.transforms = transforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        try:
            img_name = self.image_files[idx]
            img_path = os.path.join(self.images_dir, img_name)

            if not os.path.exists(img_path):
                print(f"Lỗi: File ảnh không tồn tại: {img_path}")
                return None

            image = Image.open(img_path).convert("RGB")

            xml_name = img_name.replace('.jpg', '.xml')
            xml_path = os.path.join(self.annotations_dir, xml_name)

            if not os.path.exists(xml_path):
                print(f"Lỗi: File XML không tồn tại: {xml_path}")
                boxes = torch.zeros((0, 4), dtype=torch.float32)
                labels = torch.zeros((0,), dtype=torch.int64)
            else:
                boxes = []
                labels = []
                tree = ET.parse(xml_path)
                root = tree.getroot()

                for obj in root.findall('object'):
                    name = obj.find('name').text
                    bbox = obj.find('bndbox')

                    xmin = float(bbox.find('xmin').text)
                    ymin = float(bbox.find('ymin').text)
                    xmax = float(bbox.find('xmax').text)
                    ymax = float(bbox.find('ymax').text)

                    if xmax > xmin and ymax > ymin:
                        boxes.append([xmin, ymin, xmax, ymax])
                        labels.append(self.class_to_idx[name])

                if len(boxes) == 0:
                    boxes = torch.zeros((0, 4), dtype=torch.float32)
                    labels = torch.zeros((0,), dtype=torch.int64)
                else:
                    boxes = torch.as_tensor(boxes, dtype=torch.float32)
                    labels = torch.as_tensor(labels, dtype=torch.int64)

            target = {}
            target["boxes"] = boxes
            target["labels"] = labels
            target["image_id"] = torch.tensor([idx])
            target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
            target["iscrowd"] = torch.zeros((len(boxes),), dtype=torch.int64)

            if self.transforms:
                image = self.transforms(image)

            return image, target

        except Exception as e:
            print(f"Lỗi trong __getitem__ với index {idx}: {e}")
            print(f"Image path: {img_path}")
            print(f"XML path: {xml_path}")
            return None

def get_class_mapping():
    class_names = set()

    for xml_file in os.listdir(annotations_dir):
        if xml_file.endswith('.xml'):
            xml_path = os.path.join(annotations_dir, xml_file)
            try:
                tree = ET.parse(xml_path)
                root = tree.getroot()

                for obj in root.findall('object'):
                    class_names.add(obj.find('name').text)
            except Exception as e:
                print(f"Lỗi khi parse XML file {xml_path}: {e}")

    class_names = sorted(list(class_names))
    class_to_idx = {name: idx + 1 for idx, name in enumerate(class_names)}
    idx_to_class = {idx: name for name, idx in class_to_idx.items()}

    print(f"Số lượng classes: {len(class_names)}")
    print(f"Classes: {class_names}")

    return class_to_idx, idx_to_class, class_names

def split_data(test_size=0.2, val_size=0.1):
    image_files = [f for f in os.listdir(images_dir) if f.endswith(('.jpg', '.png'))]

    train_files, temp_files = train_test_split(
        image_files, test_size=test_size + val_size, random_state=42
    )

    val_files, test_files = train_test_split(
        temp_files, test_size=test_size/(test_size + val_size), random_state=42
    )

    print(f"Train: {len(train_files)} files")
    print(f"Validation: {len(val_files)} files")
    print(f"Test: {len(test_files)} files")

    return train_files, val_files, test_files

def get_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
        in_features, num_classes + 1
    )

    return model

def collate_fn(batch):
    batch = [data for data in batch if data is not None]
    if not batch:
        return None, None
    return tuple(zip(*batch))

def train_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    total_loss = 0
    num_batches = 0

    data_loader_tqdm = tqdm(data_loader, desc=f"Epoch {epoch} (Training)", leave=False)

    for i, (images, targets) in enumerate(data_loader_tqdm):
        if images is None or targets is None:
            print("Bỏ qua batch vì có lỗi trong collate_fn")
            continue

        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        total_loss += losses.item()
        num_batches += 1

        data_loader_tqdm.set_postfix({"loss": losses.item()})

    return total_loss / num_batches

@torch.no_grad()
def evaluate(model, data_loader, device, class_names, idx_to_class):
    model.eval()
    iou_thresholds = [0.5]
    class_metrics = {class_name: {f'AP@{iou}': [] for iou in iou_thresholds}
                    for class_name in class_names}
    all_aps = []

    data_loader_tqdm = tqdm(data_loader, desc="Evaluating", leave=False)

    for images, targets in data_loader_tqdm:
        if images is None or targets is None:
            print("Bỏ qua batch vì có lỗi trong collate_fn")
            continue

        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(images)

        for i, image_output in enumerate(outputs):
            boxes = image_output['boxes'].cpu()
            scores = image_output['scores'].cpu()
            labels = image_output['labels'].cpu()
            true_labels = targets[i]['labels'].cpu()
            true_boxes = targets[i]['boxes'].cpu()

            for iou_thresh in iou_thresholds:
                tp = torch.zeros(len(boxes), dtype=torch.bool)
                fp = torch.zeros(len(boxes), dtype=torch.bool)

                if len(true_boxes) > 0:
                    ious = box_iou(boxes, true_boxes)
                    max_iou, argmax_iou = torch.max(ious, dim=1)
                    tp[(max_iou >= iou_thresh)] = True
                else:
                    fp[:] = True

                n_gt = len(true_labels)
                n_detections = len(labels)

                for class_idx in range(1, len(class_names) + 1):
                    class_name = idx_to_class[class_idx]
                    detections_of_class = labels == class_idx
                    gt_of_class = true_labels == class_idx

                    TP = torch.sum(tp[detections_of_class])
                    FP = torch.sum(fp[detections_of_class])
                    FN = torch.sum(gt_of_class) - TP

                    precision = TP / (TP + FP) if TP + FP > 0 else torch.tensor(0.0)
                    recall = TP / (TP + FN) if TP + FN > 0 else torch.tensor(0.0)

                    ap = (precision + recall) / 2
                    class_metrics[class_name][f'AP@{iou_thresh}'].append(ap.item())
                    all_aps.append(ap.item())

    mean_ap = sum(all_aps) / len(all_aps) if all_aps else 0.0

    return mean_ap

def box_iou(boxes1, boxes2):
    area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
    area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])

    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])
    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])

    wh = (rb - lt).clamp(min=0)
    inter = wh[:, :, 0] * wh[:, :, 1]

    union = area1[:, None] + area2 - inter
    iou = inter / union
    return iou

def save_checkpoint(epoch, model, optimizer, val_loss, filename="checkpoint.pth"):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
    }
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved to {filename}")

def load_checkpoint(model, optimizer, filename=r"D:\xla v1\model_output\checkpoint\checkpoint (1).pth"):
    if os.path.isfile(filename):
        print(f"Loading checkpoint {filename}")
        checkpoint = torch.load(filename)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint['val_loss']
        print(f"Loaded checkpoint from epoch {start_epoch-1} with val_loss {best_val_loss}")
        return start_epoch, best_val_loss
    else:
        print(f"No checkpoint found at {filename}")
        return 0, float('inf')

def main():
    global class_names, idx_to_class

    print("=== BƯỚC 1: Tạo class mapping ===")
    class_to_idx, idx_to_class, class_names = get_class_mapping()

    print("\n=== BƯỚC 2: Chia dữ liệu ===")
    train_files, val_files, test_files = split_data()

    print("\n=== BƯỚC 3: Tạo datasets và dataloaders ===")
    train_transforms = transforms.Compose([
        transforms.ToTensor(),
    ])

    val_transforms = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_dataset = ObjectDetectionDataset(
        images_dir, annotations_dir, train_files, class_to_idx, train_transforms
    )
    val_dataset = ObjectDetectionDataset(
        images_dir, annotations_dir, val_files, class_to_idx, val_transforms
    )

    train_loader = DataLoader(
        train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn, num_workers=2
    )
    val_loader = DataLoader(
        val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn, num_workers=2
    )

    print("\n=== BƯỚC 4: Tạo mô hình ===")
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print(f"Using device: {device}")

    model = get_model(len(class_names))
    model.to(device)

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

    print("\n=== BƯỚC 5: Bắt đầu training ===")
    num_epochs = 18
    start_epoch = 0
    best_val_loss = float('inf')

    start_epoch, best_val_loss = load_checkpoint(model, optimizer)

    train_losses = []
    val_losses = []
    training_start_time = time.time()

    try:
        for epoch in range(start_epoch, num_epochs):
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            print("-" * 30)

            epoch_start_time = time.time()
            train_loss = train_one_epoch(model, optimizer, train_loader, device, epoch+1)
            train_losses.append(train_loss)
            epoch_end_time = time.time()
            print(f"Epoch training time: {(epoch_end_time - epoch_start_time):.2f} seconds")

            epoch_start_time = time.time()
            mean_ap = evaluate(model, val_loader, device, class_names, idx_to_class)
            val_loss = mean_ap
            val_losses.append(val_loss)
            epoch_end_time = time.time()

            print(f"Val Loss (Mean AP): {val_loss:.4f}")
            print(f"Epoch validation time: {(epoch_end_time - epoch_start_time):.2f} seconds")

            lr_scheduler.step()

            if val_loss > best_val_loss:
                best_val_loss = val_loss
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_loss': val_loss,
                }, 'best_model.pth')
                print("Saved best model!")

            save_checkpoint(epoch, model, optimizer, val_loss)

    except KeyboardInterrupt:
        print("\nTraining interrupted. Saving checkpoint...")
        save_checkpoint(epoch, model, optimizer, val_loss, filename="interrupted_checkpoint.pth")
        print("Checkpoint saved. Exiting.")
        exit()

    training_end_time = time.time()
    total_training_time = training_end_time - training_start_time
    print(f"Total training time: {total_training_time:.2f} seconds")

    print("\n=== BƯỚC 6: Lưu kết quả ===")
    torch.save(model.state_dict(), 'final_model.pth')
    print("Lưu mô hình thành công")

if __name__ == "__main__":
    main()

=== BƯỚC 1: Tạo class mapping ===
Số lượng classes: 76

=== BƯỚC 2: Chia dữ liệu ===
Train: 22355 files
Validation: 3193 files
Test: 6388 files

=== BƯỚC 3: Tạo datasets và dataloaders ===

=== BƯỚC 4: Tạo mô hình ===
Using device: cuda





=== BƯỚC 5: Bắt đầu training ===
Loading checkpoint D:\xla v1\model_output\checkpoint\checkpoint (1).pth
Loaded checkpoint from epoch 19 with val_loss 0.02996216141588121
Total training time: 0.00 seconds

=== BƯỚC 6: Lưu kết quả ===
Lưu mô hình thành công
