# 1. Import Libraries

In [58]:
import numpy as np
import os 
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import random
import glob
from tqdm import tqdm
import json
import datetime
from PIL import Image
from skimage import measure
from pycocotools.coco import COCO
import albumentations as A
from albumentations.pytorch import ToTensorV2


import cv2 
import torch 
from torch.utils.data import Dataset, DataLoader 
from torchvision import transforms 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    
set_seed(2002)


In [3]:
print(f"Using device: {device}")

Using device: cpu


In [4]:
PATH = "../../public/data"
PATH_GOOD = os.path.join(PATH, "good")
PATH_OIL = os.path.join(PATH, "oil")
SCRATCH_PATH = os.path.join(PATH, "scratch")
STAIN_PATH = os.path.join(PATH, "stain")

GROUND_TRUTH_PATH1 = os.path.join(PATH, "ground_truth_1")
GROUND_TRUTH_PATH2 = os.path.join(PATH, "ground_truth_2")

## 1.1 Utils

In [5]:
def read_image(image_path):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

def show_images(images, titles=None, cols=5, figsize=(15, 10)):
    rows = (len(images) + cols - 1) // cols
    plt.figure(figsize=figsize)
    for i, image in enumerate(images):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(image)
        if titles:
            plt.title(titles[i])
        plt.axis('off')
    plt.tight_layout()
    plt.show()
    
# show_images([read_image(os.path.join(PATH_GOOD, img)) for img in os.listdir(PATH_GOOD)[:10]], titles=os.listdir(PATH_GOOD)[:10])

# 2. Data pipeline

In [6]:
img_paths = glob.glob(os.path.join(PATH, "*", "*.png"))
    

In [8]:
def create_coco_annotations(root_dir, output_file):
    """
    Quét qua thư mục dữ liệu và tạo file chú thích định dạng COCO.
    """

    # --- 1. Khởi tạo cấu trúc file COCO ---
    info = {
        "description": "Phone Defect Dataset",
        "url": "",
        "version": "1.0",
        "year": datetime.date.today().year,
        "contributor": "Your Name",
        "date_created": datetime.date.today().isoformat()
    }

    licenses = [{"url": "", "id": 0, "name": "License"}]

    # Định nghĩa các lớp lỗi. ID 0 thường dành cho background.
    categories = [
        {"id": 1, "name": "scratch", "supercategory": "defect"},
        {"id": 2, "name": "stain", "supercategory": "defect"},
        {"id": 3, "name": "oil", "supercategory": "defect"}
    ]
    
    # Tạo category mapping để dễ tra cứu
    category_map = {cat['name']: cat['id'] for cat in categories}

    coco_output = {
        "info": info,
        "licenses": licenses,
        "categories": categories,
        "images": [],
        "annotations": []
    }

    image_id_counter = 1
    annotation_id_counter = 1

    # # --- 2. Xử lý các thư mục chứa lỗi ---
    defect_folders = ["scratch", "stain", "oil"]
    ground_truth_folders = ["ground_truth_1", "ground_truth_2"]

    for category_name in defect_folders:
        category_id = category_map[category_name]
        image_folder = os.path.join(root_dir, category_name)
        
        if not os.path.isdir(image_folder):
            continue

        print(f"Processing folder: {category_name}")
        for image_filename in tqdm(os.listdir(image_folder)):
            image_path = os.path.join(image_folder, image_filename)
            
            # Đọc ảnh để lấy kích thước
            try:
                with Image.open(image_path) as img:
                    width, height = img.size
            except IOError:
                print(f"Warning: Could not read image {image_path}. Skipping.")
                continue

            # Thêm thông tin ảnh vào danh sách
            image_info = {
                "id": image_id_counter,
                "file_name": os.path.join(category_name, image_filename),
                "width": width,
                "height": height
            }
            coco_output["images"].append(image_info)

            # Tìm mask tương ứng
            # Giả định tên mask có hậu tố '_mask.png'
            mask_filename = os.path.splitext(image_filename)[0] + '.png' 
            mask_path = None
            for gt_folder in ground_truth_folders:
                potential_path = os.path.join(root_dir, gt_folder, mask_filename)
                if os.path.exists(potential_path):
                    mask_path = potential_path
                    break
            
            if mask_path:
                # Chuyển mask thành polygon cho COCO
                mask_image = Image.open(mask_path).convert('L')
                mask_np = np.array(mask_image)
                
                # Tìm các đường viền trong mask (mỗi đường là một vùng lỗi)
                # 0.5 là ngưỡng để coi pixel là một phần của đối tượng
                contours = measure.find_contours(mask_np, 0.5)

                for contour in contours:
                    # Chuyển contour thành list [x1, y1, x2, y2, ...]
                    contour = np.flip(contour, axis=1) # Đảo (row, col) thành (x, y)
                    segmentation = contour.ravel().tolist()

                    # Chỉ thêm vào nếu polygon có ít nhất 3 điểm (6 tọa độ)
                    if len(segmentation) < 6:
                        continue

                    # Tính bounding box [x, y, width, height]
                    x_coords = contour[:, 0]
                    y_coords = contour[:, 1]
                    x_min, x_max = np.min(x_coords), np.max(x_coords)
                    y_min, y_max = np.min(y_coords), np.max(y_coords)
                    bbox = [int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)]

                    # Tính diện tích
                    area = (x_max - x_min) * (y_max - y_min)
                    
                    annotation_info = {
                        "id": annotation_id_counter,
                        "image_id": image_id_counter,
                        "category_id": category_id,
                        "segmentation": [segmentation], # COCO format yêu cầu list của các list
                        "area": float(area),
                        "bbox": bbox,
                        "iscrowd": 0
                    }
                    coco_output["annotations"].append(annotation_info)
                    annotation_id_counter += 1

            image_id_counter += 1

    # --- 3. Xử lý thư mục "good" (chỉ thêm thông tin ảnh, không có annotation) ---
    good_folder = os.path.join(root_dir, "good")
    if os.path.isdir(good_folder):
        print("Processing folder: good")
        for image_filename in tqdm(os.listdir(good_folder)):
            image_path = os.path.join(good_folder, image_filename)
            try:
                with Image.open(image_path) as img:
                    width, height = img.size
            except IOError:
                print(f"Warning: Could not read image {image_path}. Skipping.")
                continue
            
            image_info = {
                "id": image_id_counter,
                "file_name": os.path.join("good", image_filename),
                "width": width,
                "height": height
            }
            coco_output["images"].append(image_info)
            image_id_counter += 1

    # --- 4. Lưu file JSON ---
    with open(output_file, 'w') as f:
        json.dump(coco_output, f, indent=4)
    
    print(f"\nSuccessfully created COCO annotation file at: {output_file}")

dataset_root_directory = PATH
    
# Tên file JSON đầu ra
output_json_file = os.path.join(dataset_root_directory, 'annotations.json')

# create_coco_annotations(dataset_root_directory, output_json_file)

In [59]:
def train_transform():
    train_transforms = A.Compose([
        A.Resize(height=512, width=512),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=20, p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])
    return train_transforms

def valid_transform():
    valid_transforms = A.Compose([
        A.Resize(height=512, width=512),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])
    return valid_transforms

def train_valid_split(X, test_size = 0.2):
    test_size = int(len(X) * test_size)
    X_train = X[:-test_size]
    X_valid = X[-test_size:]
    return X_train, X_valid 

def collate_fn(batch):
    return tuple(zip(*batch))
    

In [68]:
class Img_Segmentation_Dataset(Dataset):
    def __init__(self, root_dir, annotation_file, transforms=None, image_ids=None): # Thêm image_ids
        self.root_dir = root_dir
        self.transforms = transforms
        
        self.coco = COCO(annotation_file)
        
        # Lọc ID nếu được cung cấp, nếu không thì lấy tất cả
        if image_ids:
            self.ids = image_ids
        else:
            self.ids = list(sorted(self.coco.imgs.keys()))
    
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        img_id = self.ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.root_dir, img_info['file_name'])
        
        image = np.array(Image.open(img_path).convert("RGB"))
        
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        annotations = self.coco.loadAnns(ann_ids)
        
        mask = np.zeros((img_info["height"], img_info["width"]), dtype=np.uint8)
        
        for ann in annotations:
            category_id = ann["category_id"]
            single_ann_mask = self.coco.annToMask(ann)
            mask[single_ann_mask == 1] = category_id
            
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
            
        return image, mask

In [69]:
coco = COCO("../../public/data/annotations.json")
all_image_ids = list(sorted(coco.imgs.keys()))

# Chia danh sách ID thành train và valid
train_ids, valid_ids = train_valid_split(all_image_ids, test_size=0.2)

# Khởi tạo Dataset với các tập ID tương ứng
DATASET_ROOT = "../../public/data/" # Thư mục gốc chứa các folder ảnh
ANNOTATION_FILE = "../../public/data/annotations.json"

train_dataset = Img_Segmentation_Dataset(DATASET_ROOT, ANNOTATION_FILE, transforms=train_transform(), image_ids=train_ids)
valid_dataset = Img_Segmentation_Dataset(DATASET_ROOT, ANNOTATION_FILE, transforms=valid_transform(), image_ids=valid_ids)

# DataLoader vẫn giữ nguyên
train_data_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_data_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False)

loading annotations into memory...
Done (t=0.49s)
creating index...
index created!
loading annotations into memory...
Done (t=0.46s)
creating index...
index created!
loading annotations into memory...
Done (t=0.47s)
creating index...
index created!
