In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# Đường dẫn dataset
DATASET_DIR = 'datasets'
IMAGES_DIR = os.path.join(DATASET_DIR, 'ExDark')
ANNOTATIONS_DIR = os.path.join(DATASET_DIR, 'ExDark_Annno')

In [2]:
# Hàm đọc annotations từ cấu trúc thư mục phân cấp
def read_annotations(annotations_dir, images_dir):
    annotations = {}
    
    # Duyệt qua tất cả các thư mục class trong ExDark_Annno
    if not os.path.exists(annotations_dir):
        print(f"Annotations directory not found: {annotations_dir}")
        return {}
    
    for class_folder in os.listdir(annotations_dir):
        class_anno_path = os.path.join(annotations_dir, class_folder)
        if not os.path.isdir(class_anno_path):
            continue
            
        # Duyệt qua thư mục con (có thể có thêm 1 lớp thư mục class)
        for subfolder in os.listdir(class_anno_path):
            subfolder_path = os.path.join(class_anno_path, subfolder)
            if os.path.isdir(subfolder_path):
                # Nếu có thêm 1 lớp thư mục con
                annotation_files_path = subfolder_path
                # print(f"Found subfolder for class {class_folder}: {subfolder_path}")
            else:
                # Nếu file annotation nằm trực tiếp trong thư mục class
                annotation_files_path = class_anno_path
                # print(f"Using class folder for annotations: {class_anno_path}")
                break
        
        # Đọc các file annotation
        if os.path.exists(annotation_files_path):
            for filename in os.listdir(annotation_files_path):
                if filename.endswith('.txt'):
                    # Tìm file ảnh tương ứng
                    img_name_base = filename.replace('.txt', '')
                    
                    # Tìm file ảnh trong thư mục class tương ứng
                    img_class_path = os.path.join(images_dir, class_folder)
                    img_path = None
                    
                    if os.path.exists(img_class_path):
                        for img_file in os.listdir(img_class_path):
                            if img_file.startswith(img_name_base):
                                img_path = os.path.join(class_folder, img_file)
                                # print(f"Found image for annotation {filename}: {img_path}")
                                break
                    
                    if img_path is None:
                        continue
                    
                    # Đọc annotations từ file txt
                    anno_file_path = os.path.join(annotation_files_path, filename)
                    with open(anno_file_path, 'r') as f:
                        lines = f.readlines()
                        objs = []
                        
                        for line in lines:
                            line = line.strip()
                            if line.startswith('%') or not line:  # Bỏ qua comment và dòng trống
                                continue
                                
                            parts = line.split()
                            if len(parts) < 7:  # Ít nhất cần có label, width, height, xmin, ymin, xmax, ymax
                                continue
                            
                            try:
                                label = parts[0]
                                width = float(parts[1])
                                height = float(parts[2])
                                xmin = float(parts[3])
                                ymin = float(parts[4])
                                xmax = float(parts[5])
                                ymax = float(parts[6])
                                
                                # Chuyển đổi sang format [x, y, width, height] cho bbox
                                bbox_width = abs(xmax - xmin)
                                bbox_height = abs(ymax - ymin)
                                bbox = [xmin, ymin, bbox_width, bbox_height]
                                
                                objs.append({
                                    'label': label,
                                    'bbox': bbox,
                                    'img_width': width,
                                    'img_height': height
                                })
                                
                                # print(f"Parsed object: {label}, bbox: {bbox}, img_size: ({width}, {height})")
                            except ValueError:
                                print(f"Error parsing line in {filename}: {line}")
                                continue
                        
                        if objs:  # Chỉ thêm vào nếu có objects
                            annotations[img_path] = objs
    
    print(f"Loaded annotations for {len(annotations)} images")
    return annotations


In [3]:
# Hàm tạo label map từ annotations
def create_label_map(annotations):
    label2idx = {}
    idx = 0
    for objs in annotations.values():
        for obj in objs:
            label = obj['label']
            if label not in label2idx:
                label2idx[label] = idx
                idx += 1
    return label2idx

In [4]:
# Dataset cho object detection
class ExDarkDataset(Dataset):
    def __init__(self, img_dir, annotations, label2idx, transform=None):
        self.img_dir = img_dir
        self.annotations = annotations
        self.label2idx = label2idx
        self.transform = transform
        self.img_files = list(annotations.keys())
        
        print(f"Dataset initialized with {len(self.img_files)} images")
        if len(self.img_files) > 0:
            print(f"Sample image paths: {self.img_files[:10]}")

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_file = self.img_files[idx]
        img_full_path = os.path.join(self.img_dir, img_file)
        
        # Kiểm tra file có tồn tại không
        if not os.path.exists(img_full_path):
            print(f"Image not found: {img_full_path}")
            # Thử tìm với các extension khác
            base_path = os.path.splitext(img_full_path)[0]
            for ext in ['.jpg', '.png', '.jpeg']:
                if os.path.exists(base_path + ext):
                    img_full_path = base_path + ext
                    break
        
        try:
            image = Image.open(img_full_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {img_full_path}: {e}")
            # Tạo ảnh trắng thay thế
            image = Image.new('RGB', (224, 224), color='white')

        objs = self.annotations[img_file]
        boxes = []
        labels = []
        
        for obj in objs:
            bbox = obj['bbox']
            # Đảm bảo bbox có định dạng [xmin, ymin, xmax, ymax] cho Faster R-CNN
            xmin, ymin, width, height = bbox
            xmax = xmin + width
            ymax = ymin + height
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.label2idx[obj['label']])

        # Chuyển đổi sang tensor
        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)
        
        # Tạo target dict
        target = {}
        target['boxes'] = boxes
        target['labels'] = labels
        target['image_id'] = torch.tensor([idx])

        if self.transform:
            image = self.transform(image)

        return image, target

# Transform ảnh
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Đọc annotations và tạo label map
print("Loading annotations...")
annotations = read_annotations(ANNOTATIONS_DIR, IMAGES_DIR)
print(annotations)

print("Creating label map...")
label2idx = create_label_map(annotations)
print(f"Label mapping: {label2idx}")

Loading annotations...
Loaded annotations for 7361 images
{'Bicycle\\2015_00001.png': [{'label': 'Bicycle', 'bbox': [271.0, 193.0, 271.0, 193.0], 'img_width': 204.0, 'img_height': 28.0}], 'Bicycle\\2015_00002.png': [{'label': 'Bicycle', 'bbox': [79.0, 109.0, 79.0, 109.0], 'img_width': 136.0, 'img_height': 190.0}, {'label': 'Bicycle', 'bbox': [63.0, 131.0, 63.0, 131.0], 'img_width': 219.0, 'img_height': 172.0}, {'label': 'Bicycle', 'bbox': [76.0, 124.0, 76.0, 124.0], 'img_width': 277.0, 'img_height': 188.0}, {'label': 'Bicycle', 'bbox': [57.0, 81.0, 57.0, 81.0], 'img_width': 348.0, 'img_height': 183.0}, {'label': 'Car', 'bbox': [33.0, 26.0, 33.0, 26.0], 'img_width': 316.0, 'img_height': 171.0}, {'label': 'Car', 'bbox': [34.0, 24.0, 34.0, 24.0], 'img_width': 395.0, 'img_height': 175.0}], 'Bicycle\\2015_00003.png': [{'label': 'Bicycle', 'bbox': [211.0, 246.0, 211.0, 246.0], 'img_width': 287.0, 'img_height': 101.0}, {'label': 'Bus', 'bbox': [35.0, 21.0, 35.0, 21.0], 'img_width': 3.0, 'img_

In [5]:
# Tạo dataset
print("Creating dataset...")
dataset = ExDarkDataset(
    img_dir=IMAGES_DIR,
    annotations=annotations,
    label2idx=label2idx,
    transform=train_transform
)
print(f"Dataset size: {len(dataset)}")
try:
    sample_img, sample_target = dataset[0]
    print(f"Sample image shape: {sample_img.shape}")
    print(f"Sample target: {sample_target}")
except Exception as e:
    print(f"Error loading sample: {e}")

# Chia dataset train/val/test
def split_dataset(dataset, train_ratio=0.7, val_ratio=0.2):
    total = len(dataset)
    train_len = int(total * train_ratio)
    val_len = int(total * val_ratio)
    test_len = total - train_len - val_len
    return random_split(dataset, [train_len, val_len, test_len])

train_set, val_set, test_set = split_dataset(dataset)
print(f"Train: {len(train_set)}, Val: {len(val_set)}, Test: {len(test_set)}")

# Custom collate function cho object detection
def collate_fn(batch):
    return tuple(zip(*batch))

train_loader = DataLoader(train_set, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=4, shuffle=False, collate_fn=collate_fn)

Creating dataset...
Dataset initialized with 7361 images
Sample image paths: ['Bicycle\\2015_00001.png', 'Bicycle\\2015_00002.png', 'Bicycle\\2015_00003.png', 'Bicycle\\2015_00004.jpg', 'Bicycle\\2015_00005.jpg', 'Bicycle\\2015_00006.jpg', 'Bicycle\\2015_00007.jpg', 'Bicycle\\2015_00008.jpg', 'Bicycle\\2015_00009.jpg', 'Bicycle\\2015_00010.jpg']
Dataset size: 7361
Sample image shape: torch.Size([3, 224, 224])
Sample target: {'boxes': tensor([[271., 193., 542., 386.]]), 'labels': tensor([0]), 'image_id': tensor([0])}
Train: 5152, Val: 1472, Test: 737


In [6]:
# Sử dụng model Faster R-CNN cho object detection
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

def get_detection_model(num_classes):
    model = fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

num_classes = len(label2idx) + 1  # +1 cho background
print(f"Number of classes: {num_classes}")
        
model = get_detection_model(num_classes)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model.to(device)

# Hàm train model
def train_model(model, train_loader, val_loader, epochs=5, lr=1e-4):
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=lr)
            
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        num_batches = 0
                
        print(f"\nEpoch {epoch+1}/{epochs}")
        loop = tqdm(train_loader, desc=f"Training")
            
        for images, targets in loop:
            try:
                images = list(img.to(device) for img in images)
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())

                optimizer.zero_grad()
                losses.backward()
                optimizer.step()

                train_loss += losses.item()
                num_batches += 1
                loop.set_postfix(loss=train_loss / num_batches)
                        
            except Exception as e:
                print(f"Error in batch: {e}")
                continue
                
        if num_batches > 0:
            avg_loss = train_loss / num_batches
            print(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")

# Lưu model
def save_model(model, path='models/object_detection_model.pth'):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

print("Starting training...")
train_model(model, train_loader, val_loader, epochs=2)
save_model(model)

Number of classes: 13




Using device: cuda
Starting training...

Epoch 1/2


Training:   0%|          | 4/1288 [00:09<49:09,  2.30s/it, loss=33.7]


KeyboardInterrupt: 

In [None]:
# Hàm evaluate model
def evaluate(net: nn.Module, data: DataLoader) -> float:
    """
    Evaluates the neural network on the given data.
    @param net: the neural network to evaluate
    @param data: the data to evaluate on
    @return: the accuracy of the neural network on the given data
    """
    net.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for images, targets in data:
            images = list(img.to(device) for img in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            outputs = net(images)
            
            for i, output in enumerate(outputs):
                pred_boxes = output['boxes']
                pred_labels = output['labels']
                target_boxes = targets[i]['boxes']
                target_labels = targets[i]['labels']
                
                # So sánh số lượng bounding boxes
                if len(pred_boxes) == len(target_boxes):
                    total += 1
                    if (pred_labels == target_labels).all():
                        correct += 1
    accuracy = correct / total if total > 0 else 0
    print(f"Evaluation accuracy: {accuracy:.4f}")
    return accuracy

# Đánh giá model trên tập validation
train_loader = DataLoader(train_set, batch_size=4, shuffle=False, collate_fn=collate_fn)
print(train_loader)

In [None]:
print("Evaluating model on validation set...")
evaluate(model, train_loader)