In [35]:
import os
import json
import torch
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.path import Path
from torchvision.transforms import v2 as T
import numpy as np

In [36]:

class CustomDataset:
    def __init__(self, folder_path, transforms=None):
        self.folder_path = folder_path
        self.transforms = transforms
        self.data_pairs = self._load_data_pairs()

    def _load_data_pairs(self):
        image_files = []
        json_files = {}
        data_pairs = []

        # 폴더를 재귀적으로 검색하여 이미지 파일과 JSON 파일 목록을 생성
        for root, _, files in os.walk(self.folder_path):
            for file in files:
                if file.endswith((".jpg", ".jpeg", ".png")):
                    image_files.append(os.path.join(root, file))
                elif file.endswith(".json"):
                    json_files[os.path.splitext(file)[0]] = os.path.join(root, file)

        # 진행률 표시줄 추가
        for image_file_path in tqdm(image_files, desc="Matching image and JSON files"):
            # 이미지 파일 이름에서 확장자를 제외한 부분 가져오기
            image_name = os.path.splitext(os.path.basename(image_file_path))[0]
        
            # 해당 이미지 파일과 매칭되는 JSON 파일 찾기
            if image_name in json_files:
                json_file_path = json_files[image_name]
                data_pairs.append((image_file_path, json_file_path))
            else:
                print(f"JSON file not found for image: {image_file_path}")
        
        return data_pairs

    def create_detr_target(self, json_data):
        annotations = json_data['annotations']
        categories = {cat['id']: cat['name'] for cat in json_data['categories']}
        
        # 한글에서 영어로 라벨을 매핑하는 딕셔너리
        label_mapping = {'화방': "buds", '줄기': "stem", '잎': "leaf", '열매': "fruit"}

        height = 1960  # 로즈 테스트시 고정값 사용
        width = float(json_data['images'][0]['width'])

        target = {
            'boxes': [],
            'labels': [],
            'area': [],
            'iscrowd': [],
        }

        for ann in annotations:
            bbox = ann['bbox']
            obj_name = categories[ann['category_id']]
            
            # 한글 카테고리 이름을 영어로 변환
            label = label_mapping.get(obj_name, "unknown")

            # 바운딩 박스 좌표 변환 및 정규화
            x_min = bbox[1] / width
            y_min = (height-(bbox[0]+bbox[2])) / height
            box_width = bbox[3] / width
            box_height = bbox[2] / height
            
            target['boxes'].append([x_min, y_min, box_width, box_height])
            target['area'].append(ann['area'])
            target['iscrowd'].append(ann['isCrowd'])
            target['labels'].append(label)

        # 리스트를 텐서로 변환
        target['boxes'] = torch.tensor(target['boxes'], dtype=torch.float32)
        target['area'] = torch.tensor(target['area'], dtype=torch.float32)
        target['iscrowd'] = torch.tensor(target['iscrowd'], dtype=torch.int64)

        return target


    def __getitem__(self, index):
        img_path, json_path = self.data_pairs[index]

        with open(json_path, 'r') as f:
            json_data = json.load(f)

        img = Image.open(img_path).convert("RGB")

        if self.transforms is not None:
            img = self.transforms(img)

        target = self.create_detr_target(json_data)

        return img, target

    def __len__(self):
        return len(self.data_pairs)

In [37]:
# 데이터셋 및 변환 정의

normalize = T.Compose([
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        T.ConvertImageDtype(torch.float32)
    ])


folder_path = "./Dataset"
transform_train = T.Compose([
    normalize
])
transform_test =  T.Compose([
    normalize
])
dataset = CustomDataset(folder_path, transforms=transform_train)
dataset_test = CustomDataset(folder_path, transforms=transform_test)

# 학습 및 검증 데이터로더 정의
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-100])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-100:])

def collate_fn(batch):
    return tuple(zip(*batch))

data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=collate_fn
  
)
data_loader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=1,
    shuffle=False,
    collate_fn=collate_fn
)

Matching image and JSON files: 100%|██████████| 411/411 [00:00<00:00, 206178.56it/s]

Dataset size: 411
target: {'boxes': tensor([1287.7600,  648.5200, 2196.7100, 1641.3600]), 'labels': tensor(3)}
Type of target: <class 'dict'>
img: tensor([[[0.8784, 0.9176, 0.9569,  ..., 0.1647, 0.1647, 0.1725],
         [0.8941, 0.9176, 0.9412,  ..., 0.1608, 0.1647, 0.1686],
         [0.9137, 0.9412, 0.9373,  ..., 0.1647, 0.1569, 0.1608],
         ...,
         [0.7176, 0.4078, 0.3490,  ..., 0.4863, 0.3882, 0.2863],
         [0.7176, 0.4118, 0.3490,  ..., 0.4863, 0.3843, 0.2824],
         [0.7176, 0.4196, 0.3490,  ..., 0.4863, 0.3804, 0.2824]],

        [[0.8902, 0.9294, 0.9686,  ..., 0.2588, 0.2667, 0.2706],
         [0.9098, 0.9294, 0.9529,  ..., 0.2588, 0.2667, 0.2706],
         [0.9294, 0.9529, 0.9451,  ..., 0.2627, 0.2588, 0.2627],
         ...,
         [0.7412, 0.4588, 0.3882,  ..., 0.4941, 0.4078, 0.3216],
         [0.7412, 0.4549, 0.3882,  ..., 0.4941, 0.4000, 0.3098],
         [0.7412, 0.4627, 0.3882,  ..., 0.4980, 0.3961, 0.3020]],

        [[0.9255, 0.9608, 0.9922,  ..., 0




In [38]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

def plot_detection(img, target):
    fig, ax = plt.subplots(1)
    # 이미지 형태 변경
    img = np.transpose(img, (1, 2, 0))
    # 변경된 이미지로 시각화
    ax.imshow(img)

    boxes = target['boxes'].numpy()
    labels = target['labels']

    height, width = img.shape[0], img.shape[1]  # 이미지의 높이와 너비 가져오기

    for box, label in zip(boxes, labels):
        x_min, y_min, box_width, box_height = box
        x_min *= width
        y_min *= height
        box_width *= width
        box_height *= height

        rect = patches.Rectangle((x_min, y_min), box_width, box_height, linewidth=1, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        
        plt.text(x_min, y_min, label, color='white', fontsize=12, bbox=dict(facecolor='red', alpha=0.5))

    plt.show()


Matching image and JSON files: 100%|██████████| 411/411 [00:00<00:00, 206129.25it/s]


In [39]:
# 데이터셋 내용 확인 및 시각화
for i in range(len(dataset)):
    img, target = dataset[i]
    print(f"Image {i}:", img)
    print(f"Target {i}:", target)
    plot_detection(img, target)
    if i == 2:  # 예시로 처음 3개 항목만 확인
        break