In [6]:
import os
import xml.etree.ElementTree as ET
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from ultralytics import YOLO

# Định nghĩa đường dẫn
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))
DATA_DIR = os.path.join(PROJECT_ROOT, 'data')
RAW_DIR = os.path.join(DATA_DIR, 'raw')
TRAIN_DIR = os.path.join(RAW_DIR, 'train')
VALID_DIR = os.path.join(RAW_DIR, 'valid')
MODELS_DIR = os.path.join(PROJECT_ROOT, 'models')

# Tạo thư mục models nếu chưa tồn tại
os.makedirs(MODELS_DIR, exist_ok=True)

# Hàm đọc annotation từ file XML
def read_xml_annotation(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    
    boxes = []
    labels = []
    
    for obj in root.findall('object'):
        label = obj.find('name').text
        bbox = obj.find('bndbox')
        xmin = int(bbox.find('xmin').text)
        ymin = int(bbox.find('ymin').text)
        xmax = int(bbox.find('xmax').text)
        ymax = int(bbox.find('ymax').text)
        
        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(int(label))
    
    return boxes, labels

# Dataset class
class NumberDetectionDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_files = [f for f in os.listdir(img_dir) if f.endswith('.jpg')]
        
    def __len__(self):
        return len(self.img_files)
    
    def __getitem__(self, idx):
        img_name = self.img_files[idx]
        img_path = os.path.join(self.img_dir, img_name)
        xml_path = os.path.join(self.img_dir, img_name.replace('.jpg', '.xml'))
        
        image = Image.open(img_path).convert('RGB')
        boxes, labels = read_xml_annotation(xml_path)
        
        if self.transform:
            image = self.transform(image)
        
        return image, {'boxes': torch.tensor(boxes), 'labels': torch.tensor(labels)}

# Lightning Module
class NumberDetectionModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = YOLO('yolov8n.yaml')
        
    def training_step(self, batch, batch_idx):
        images, targets = batch
        loss = self.model(images, targets)
        self.log('train_loss', loss)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=1e-3)

# Chuẩn bị dữ liệu
transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
])

train_dataset = NumberDetectionDataset(TRAIN_DIR, transform=transform)
valid_dataset = NumberDetectionDataset(VALID_DIR, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16)

# Huấn luyện mô hình
model = NumberDetectionModule()

checkpoint_callback = ModelCheckpoint(
    dirpath=MODELS_DIR,
    filename='number_detection_model',
    save_top_k=1,
    monitor='val_loss'
)

trainer = pl.Trainer(
    max_epochs=50,
    callbacks=[checkpoint_callback],
    gpus=1 if torch.cuda.is_available() else 0
)

trainer.fit(model, train_loader, valid_loader)

print(f"Mô hình đã được lưu tại: {checkpoint_callback.best_model_path}")

TypeError: Trainer.__init__() got an unexpected keyword argument 'gpus'