In [None]:
import copy
import csv
import os
import warnings
from argparse import ArgumentParser

import torch
import tqdm
import yaml
from torch.utils import data
# 개별 json 라벨 파일을 이용해 학습 데이터 리스트 생성
import glob
import json
import os
from nets import nn
from utils import util
from utils.dataset import Dataset
from torch.utils import data
import numpy as np
import cv2
import random
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import matplotlib.patches as patches
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:",device)

In [None]:

# 파라미터 및 데이터 경로 설정
with open('utils/args.yaml', errors='ignore') as f:
    params = yaml.safe_load(f)


label_dir = '../../data/IGNITE/annotations/pdl1/individual/'
image_dir = '../../data/IGNITE/images/pdl1/pdl1/'

label_files = sorted(glob.glob(os.path.join(label_dir, '*.json')))
filenames = []
labels = []
for label_file in label_files:
    with open(label_file) as f:
        data1 = json.load(f)
    img_path = os.path.join(image_dir, data1['image']['file_name'])
    if os.path.exists(img_path):
        filenames.append(img_path)
        temp_labels = []
        for i in range(len(data1['annotations'])):
            
            temp_labels.append([data1['annotations'][i]['category_id'],int(data1['annotations'][i]['bbox'][0]),
                         int(data1['annotations'][i]['bbox'][1]),int(data1['annotations'][i]['bbox'][2]),int(data1['annotations'][i]['bbox'][3])])
        labels.append(temp_labels)


In [None]:
class custom_dataset(data.Dataset):
    def __init__(self, filenames, input_size, params, augment, labels=None, image_infos=None):
        self.params = params
        self.mosaic = augment
        self.augment = augment
        self.input_size = input_size
        if labels is not None:
            self.labels = labels
            self.filenames = filenames
            self.n = len(self.filenames)
            self.image_infos = image_infos if image_infos is not None else [None]*len(filenames)
        else:
            loaded = self.load_label(filenames)
            self.labels = list(loaded.values())
            self.filenames = list(loaded.keys())
            self.n = len(self.filenames)
            self.image_infos = [None]*self.n
        self.indices = range(self.n)
        self.albumentations = Albumentations()
    def __len__(self):
        return self.n
    def __getitem__(self, index):
        index = self.indices[index]
        temp_label = copy.deepcopy(self.labels[index])
        
        image,crop_index=self.load_image(index)
        
        crop_x, crop_y = crop_index
        label=[]
        #y,x,h,w, to x_center,y_center,w,h
        for i in range(len(temp_label)):
            x = temp_label[i][2]
            y = temp_label[i][1]
            w = temp_label[i][4]
            h = temp_label[i][3]
            if x >= crop_x and y >= crop_y and x <= crop_x + self.input_size-5 and y <= crop_y + self.input_size-5:
                temp_label[i][1] = (y+h/2 - crop_y)/ self.input_size
                temp_label[i][2] = (x+w/2 - crop_x)/ self.input_size
                temp_label[i][3] = (h) / self.input_size
                temp_label[i][4] = (w) / self.input_size
                label.append(temp_label[i])

        cls=[]
        box=[]
        for i in range(len(label)):
            cls.append(label[i][0])
            box.append(label[i][1:5])
        cls=np.array(cls)
        box=np.array(box)
        nl = len(box)
        if self.augment:
            nl = len(box)  # update after albumentations

            # Flip up-down
            if random.random() < self.params['flip_ud']:
                image = np.flipud(image).copy()
                if nl:
                    box[:, 1] = 1 - box[:, 1]
            # Flip left-right
            if random.random() < self.params['flip_lr']:
                image = np.fliplr(image).copy()
                if nl:
                    box[:, 0] = 1 - box[:, 0]

        image = image.transpose((2, 0, 1))
        return torch.from_numpy(image),torch.from_numpy(cls), torch.from_numpy(box), torch.zeros(nl)

    def load_image(self, i):
        image = cv2.imread(self.filenames[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # BGR -> RGB 변환
        h, w = image.shape[:2]
        r = self.input_size / max(h, w)
        if r < 1:
            h1 = random.randint(0, h - self.input_size)
            w1 = random.randint(0, w - self.input_size)
            image = image[h1:h1 + self.input_size, w1:w1 + self.input_size]
        else:
            h1 = 0
            w1 = 0
            pad_image = np.zeros((self.input_size, self.input_size, 3), dtype=np.uint8)
            pad_image[:h, :w, :] = image
            image = pad_image
        return image, (h1, w1)


    
    
class Albumentations:
    def __init__(self):
        self.transform = None
        try:
            import albumentations

            transforms = [albumentations.Blur(p=0.01),
                          albumentations.CLAHE(p=0.01),
                          albumentations.ToGray(p=0.01),
                          albumentations.MedianBlur(p=0.01)]
            self.transform = albumentations.Compose(transforms,
                                                    albumentations.BboxParams('yolo', ['class_labels']))

        except ImportError:  # package not installed, skip
            pass

    def __call__(self, image, box, cls):
        if self.transform:
            x = self.transform(image=image,
                               bboxes=box,
                               class_labels=cls)
            image = x['image']
            box = np.array(x['bboxes'])
            cls = np.array(x['class_labels'])
        return image, box, cls

split=[0.8, 0.2]
train_dataset=custom_dataset(filenames[:int(len(filenames)*split[0])], 512, params, augment=True, labels=labels[:int(len(filenames)*split[0])])
val_dataset = custom_dataset(filenames[int(len(filenames)*split[0]):], 512, params, augment=False, labels=labels[int(len(filenames)*split[0]):])

In [None]:
def collate_fn1(batch):
    samples, cls, box, indices = zip(*batch)

    cls = torch.cat(cls, dim=0)
    box = torch.cat(box, dim=0)

    new_indices = list(indices)
    for i in range(len(indices)):
        new_indices[i] += i
    indices = torch.cat(new_indices, dim=0)

    targets = {'cls': cls,
                'box': box,
                'idx': indices}
    return torch.stack(samples, dim=0), targets


# 모델 및 파라미터 준비
model = nn.yolo_v11_m(len(params['names'])).to(device)
optimizer = torch.optim.SGD(util.set_params(model, params['weight_decay']),
                            params['min_lr'], params['momentum'], nesterov=True)
criterion = util.ComputeLoss(model, params)

# 데이터셋 및 데이터로


In [None]:

loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=8, shuffle=True, num_workers=4, collate_fn=collate_fn1
)

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=8, shuffle=False, num_workers=4, collate_fn=collate_fn1
)
train_losses = []
val_losses = []
epochs=200
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    for images, targets in loader:
        images = images.to(device).float() / 255
        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):
            outputs = model(images)
            loss_box, loss_cls, loss_dfl = criterion(outputs, targets)
        loss = loss_box + loss_cls + loss_dfl
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    train_losses.append(epoch_loss / len(loader))
    print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_losses[-1]:.4f}")

    # 검증
    model.eval()
    val_epoch_loss = 0
    with torch.no_grad():
        for val_images, val_targets in val_loader:
            val_images = val_images.to(device).float() / 255
            with torch.amp.autocast('cuda'):
                val_outputs = model(val_images)
                val_loss_box, val_loss_cls, val_loss_dfl = criterion(val_outputs, val_targets)
            val_loss = val_loss_box + val_loss_cls + val_loss_dfl
            val_epoch_loss += val_loss.item()
    val_losses.append(val_epoch_loss / len(val_loader))
    print(f"Epoch {epoch+1}/{epochs} - Val Loss: {val_losses[-1]:.4f}")

    # 검증 이미지 오버랩 시각화 (첫 번째 배치)
    val_img, val_label, _ = val_dataset[0]
    plt.figure(figsize=(8,8))
    plt.imshow(val_img.permute(1,2,0).cpu().numpy())
    for l in val_label:
        class_id, x_center, y_center, w, h = l
        x = (x_center - w/2) * val_img.shape[2]
        y = (y_center - h/2) * val_img.shape[1]
        w_box = w * val_img.shape[2]
        h_box = h * val_img.shape[1]
        rect = patches.Rectangle((x, y), w_box, h_box, linewidth=2, edgecolor='r', facecolor='none')
        plt.gca().add_patch(rect)
    plt.title(f'Epoch {epoch+1} 검증 이미지 오버랩')
    plt.show()
