In [39]:
import os
import numpy as np
import pandas as pd

from PIL import Image, ImageFile

import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

# utils.ipynb 파일에서 해당 함수를 불러오기 위한 라이브러리
# jupyter notebook에서 가져올 수 있는 방법 알아보기.

'''
from utils import (
    cells_to_bboxes,
    iou_width_height as iou,
    non_max_suppression as nms,
    plot_image
)
'''

'\nfrom utils import (\n    cells_to_bboxes,\n    iou_width_height as iou,\n    non_max_suppression as nms,\n    plot_image\n)\n'

## 기본 변수 정의

In [24]:
ori_dir = './train/'
img_dir = ori_dir + 'image/'
label_dir = ori_dir + 'label/'

train_csv = ori_dir + 'train.csv'

img_size = 416
S = [13, 26, 52]
C = 4

# anchor list,ipynb 파일에서 값을 조정해주어야 함.
anchors = [
    [(0.1505, 0.1137), (0.2265, 0.4055), (0.2675, 0.4709)],
    [(0.3691, 0.6242), (0.1773, 0.3336), (0.2075, 0.0813)],
    [(0.3107, 0.5441), (0.1714, 0.1769), (0.1512, 0.0604)
]]

classes = [ "AC", "FL", "HC", "HUM" ]

In [3]:
# train.csv 파일 읽어오기
df = pd.read_csv(train_csv)

In [4]:
df.head()

Unnamed: 0,img_png,img_txt
0,20160307_E0000776_I0068889.png,20160307_E0000776_I0068889.txt
1,20160307_E0000776_I0068893.png,20160307_E0000776_I0068893.txt
2,20160307_E0000777_I0069006.png,20160307_E0000777_I0069006.txt
3,20160307_E0000777_I0069020.png,20160307_E0000777_I0069020.txt
4,20160307_E0000778_I0069096.png,20160307_E0000778_I0069096.txt


# cells_to_bboxes

In [42]:
"""
Scales the predictions coming from the model to
be relative to the entire image such that they for example later
can be plotted or.
INPUT:
predictions: tensor of size (N, 3, S, S, num_classes+5)
anchors: the anchors used for the predictions
S: the number of cells the image is divided in on the width (and height)
is_preds: whether the input is predictions or the true bounding boxes
OUTPUT:
converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
                  object score, bounding box coordinates
"""

def cells_to_bboxes(predictions, anchors, S, is_preds=True):

    BATCH_SIZE = predictions.shape[0]
    num_anchors = len(anchors)
    box_predictions = predictions[..., 1:5]
    
    if is_preds:
        anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
        box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
        box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
        scores = torch.sigmoid(predictions[..., 0:1])
        best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
    else:
        scores = predictions[..., 0:1]
        best_class = predictions[..., 5:6]

    cell_indices = (
        torch.arange(S)
        .repeat(predictions.shape[0], 3, S, 1)
        .unsqueeze(-1)
        .to(predictions.device)
    )
    x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
    y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
    w_h = 1 / S * box_predictions[..., 2:4]
    converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(BATCH_SIZE, num_anchors * S * S, 6)
    return converted_bboxes.tolist()

# IoU 계산

In [43]:
def iou_width_height(boxes1, boxes2):
    """
    Parameters:
        boxes1 (tensor): width and height of the first bounding boxes
        boxes2 (tensor): width and height of the second bounding boxes
    Returns:
        tensor: Intersection over union of the corresponding boxes
    """
    intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
        boxes1[..., 1], boxes2[..., 1]
    )
    union = (
        boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
    )
    return intersection / union


# non_max_suppression

In [44]:
def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
    """
    Does Non Max Suppression given bboxes

    Parameters:
        bboxes (list): list of lists containing all bboxes with each bboxes
        specified as [class_pred, prob_score, x1, y1, x2, y2]
        iou_threshold (float): threshold where predicted bboxes is correct
        threshold (float): threshold to remove predicted bboxes (independent of IoU)
        box_format (str): "midpoint" or "corners" used to specify bboxes

    Returns:
        list: bboxes after performing NMS given a specific IoU threshold
    """

    assert type(bboxes) == list

    bboxes = [box for box in bboxes if box[1] > threshold]
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
    bboxes_after_nms = []

    while bboxes:
        chosen_box = bboxes.pop(0)

        bboxes = [
            box
            for box in bboxes
            if box[0] != chosen_box[0]
            or intersection_over_union(
                torch.tensor(chosen_box[2:]),
                torch.tensor(box[2:]),
                box_format=box_format,
            )
            < iou_threshold
        ]

        bboxes_after_nms.append(chosen_box)

    return bboxes_after_nms

# 이미지에 예측한 bbox를 표시한다.

In [45]:
def plot_image(image, boxes):
    
    cmap = plt.get_cmap("tab20b")
    class_labels = config.COCO_LABELS if config.DATASET=='COCO' else config.PASCAL_CLASSES
    colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
    im = np.array(image)
    height, width, _ = im.shape

    # Create figure and axes
    fig, ax = plt.subplots(1)
    # Display the image
    ax.imshow(im)

    # box[0] is x midpoint, box[2] is width
    # box[1] is y midpoint, box[3] is height

    # Create a Rectangle patch
    for box in boxes:
        assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
        class_pred = box[0]
        box = box[2:]
        upper_left_x = box[0] - box[2] / 2
        upper_left_y = box[1] - box[3] / 2
        rect = patches.Rectangle(
            (upper_left_x * width, upper_left_y * height),
            box[2] * width,
            box[3] * height,
            linewidth=2,
            edgecolor=colors[int(class_pred)],
            facecolor="none",
        )
        # Add the patch to the Axes
        ax.add_patch(rect)
        plt.text(
            upper_left_x * width,
            upper_left_y * height,
            s=class_labels[int(class_pred)],
            color="white",
            verticalalignment="top",
            bbox={"color": colors[int(class_pred)], "pad": 0},
        )

    plt.show()

# transform 정의

In [103]:
transform = transforms.Compose([
    transforms.Resize((416, 416)),
    #transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Data 불러오는 class

In [104]:
# 객체 정보를 저장한 txt파일 불러오기
label_path = os.path.join(label_dir, df.iloc[0, 1])
# 이후 bboxes에 (center_x, center_y, w, h, class)와 같이 값이 저장됨.
bboxes = np.roll(np.loadtxt(fname=label_path, delimiter=" ", ndmin=2), 4, axis=1).tolist()

# 이미지 파일 불러오기
img_path = os.path.join(img_dir, df.iloc[0, 0])
image = Image.open(img_path)

augmentations = transform(image)

In [107]:
augmentations[0]

tensor([[-1., -1., -1.,  ..., -1., -1., -1.],
        [-1., -1., -1.,  ..., -1., -1., -1.],
        [-1., -1., -1.,  ..., -1., -1., -1.],
        ...,
        [-1., -1., -1.,  ..., -1., -1., -1.],
        [-1., -1., -1.,  ..., -1., -1., -1.],
        [-1., -1., -1.,  ..., -1., -1., -1.]])

In [90]:
class YOLODataset(Dataset):
    def __init__(
        self,
        csv_file,
        img_dir,
        label_dir,
        anchors,
        image_size=416,
        S=[13, 26, 52],
        C=4,
        transform=None,
    ):
        self.annotations = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.image_size = image_size
        self.transform = transform
        self.S = S
        
        self.anchors = torch.tensor(anchors[0] + anchors[1] + anchors[2])  # for all 3 scales
        self.num_anchors = self.anchors.shape[0]
        self.num_anchors_per_scale = self.num_anchors // 3
        self.C = C
        self.ignore_iou_thresh = 0.5

    def __len__(self):
        return len(self.annotations)

    # index : csv파일 총 개수. __len__에서 return으로 값을 받음.
    def __getitem__(self, index):
        # 객체 정보를 저장한 txt파일 불러오기
        label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1])
        # 이후 bboxes에 [center_x, center_y, w, h, class]와 같이 불러와 값이 저장됨. -> np.roll()
        bboxes = np.roll(np.loadtxt(fname=label_path, delimiter=" ", ndmin=2), 4, axis=1).tolist()
        
        # 이미지 파일 불러오기
        img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
        #image = np.array(Image.open(img_path)).convert("RGB"))
        image = Image.open(img_path)
        
        if self.transform:
            augmentations = self.transform(image=image, bboxes=bboxes)
            image = augmentations["image"]

        # Below assumes 3 scale predictions (as paper) and same num of anchors per scale
        # target에는 grid cell에서 객체의 위치, 크기, 클래스 정보가 저장이 된다.
        targets = [torch.zeros((self.num_anchors // 3, S, S, 6)) for S in self.S]
        
        for box in bboxes:
            # bounding box와 anchor 사이의 IoU계산한 값 저장. 가장 적절한 anchor 인덱스 찾아냄.
            iou_anchors = iou(torch.tensor(box[2:4]), self.anchors)  
            anchor_indices = iou_anchors.argsort(descending=True, dim=0)
            
            '''
            class_label = df['class']
            x = df['center_x']
            y = df['center_y']
            width = df['w']
            height = df['h']'''
            x, y, width, height, class_label = box
            has_anchor = [False] * 3  # each scale should have one anchor
            
            for anchor_idx in anchor_indices:
                # bounding box가 어떤 크기의 grid cell에 해당하는지 결정.
                scale_idx = anchor_idx // self.num_anchors_per_scale
                anchor_on_scale = anchor_idx % self.num_anchors_per_scale
                S = self.S[scale_idx]
                
                # i, j 계산해서 grid cell 위치 얻기.
                i, j = int(S * y), int(S * x)  # which cell
                anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0]
                
                # 아직 grid cell에 앵커가 할당되지 않거나, 해당 scale에서 anchor를 사용할 수 없는 경우
                # -> 해당 앵커의 bbox 정보 할당하기
                if not anchor_taken and not has_anchor[scale_idx]:
                    targets[scale_idx][anchor_on_scale, i, j, 0] = 1
                    x_cell, y_cell = S * x - j, S * y - i  # both between [0,1]
                    width_cell, height_cell = (
                        width * S,
                        height * S,
                    )  # can be greater than 1 since it's relative to cell
                    box_coordinates = torch.tensor(
                        [x_cell, y_cell, width_cell, height_cell]
                    )
                    targets[scale_idx][anchor_on_scale, i, j, 1:5] = box_coordinates
                    targets[scale_idx][anchor_on_scale, i, j, 5] = int(class_label)
                    has_anchor[scale_idx] = True
                    
                # 할당은 되었지만, IoU가 설정한 임계값보다 큰 경우
                # -> 해당 앵커에 대한 예측 무시.
                elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
                    targets[scale_idx][anchor_on_scale, i, j, 0] = -1  # ignore prediction
        return image, tuple(targets)

In [91]:
def test():

    dataset = YOLODataset(
        train_csv,
        img_dir,
        label_dir,
        anchors=anchors,
        transform=transform,
    )
    
    scaled_anchors = torch.tensor(anchors) / (
        1 / torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
    )
    
    loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)
    
    for x, y in loader:
        boxes = []

        for i in range(y[0].shape[1]):
            anchor = scaled_anchors[i]
            print(anchor.shape)
            print(y[i].shape)
            boxes += cells_to_bboxes(
                y[i], is_preds=False, S=y[i].shape[2], anchors=anchor
            )[0]
        boxes = nms(boxes, iou_threshold=1, threshold=0.7, box_format="midpoint")
        print(boxes)
        plot_image(x[0].permute(1, 2, 0).to("cpu"), boxes)

In [49]:
if __name__ == "__main__":
    test()

TypeError: __call__() got an unexpected keyword argument 'image'