In [61]:
import os
import sys
import glob
import math
import numpy as np
from PIL import Image
from utils.data_augumentation import Compose, Scale, RandomRotation, RandomMirror, Resize, Normalize_Tensor
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models

In [38]:
# inputファイルとlabelsに存在しているファイルを取得
input_dir, label_dir = 'input','labels'
in_files = glob.glob(input_dir+'/*.jpg')
label_files = glob.glob(label_dir+'/*.png')
in_f = [f.split('/')[-1].split('.')[-2] for f in in_files ]
lb_f = [f.split('/')[-1].split('.')[-2] for f in label_files ]
f_name = set(in_f) & set(lb_f)
print('deta count : ',len(list(f_name)))

input_files, annotation_files = [], []
for i in list(f_name):
    input_files.append(input_dir+'/'+i+'.jpg')
    annotation_files.append(label_dir+'/'+i+'.png')
datafiles = list(zip(input_files, annotation_files))
print(datafiles)

deta count :  284
[('input/DJI_0009_4_1.jpg', 'labels/DJI_0009_4_1.png'), ('input/DJI_0009_4_3.jpg', 'labels/DJI_0009_4_3.png'), ('input/DJI_0009_7_2.jpg', 'labels/DJI_0009_7_2.png'), ('input/DJI_0014_8_4.jpg', 'labels/DJI_0014_8_4.png'), ('input/DJI_0021_5_1.jpg', 'labels/DJI_0021_5_1.png'), ('input/DJI_0010_4_1.jpg', 'labels/DJI_0010_4_1.png'), ('input/DJI_0025_3_2.jpg', 'labels/DJI_0025_3_2.png'), ('input/DJI_0004_2_2.jpg', 'labels/DJI_0004_2_2.png'), ('input/DJI_0011_1_1.jpg', 'labels/DJI_0011_1_1.png'), ('input/DJI_0018_3_7.jpg', 'labels/DJI_0018_3_7.png'), ('input/DJI_0004_3_2.jpg', 'labels/DJI_0004_3_2.png'), ('input/DJI_0008_6_1.jpg', 'labels/DJI_0008_6_1.png'), ('input/DJI_0014_7_7.jpg', 'labels/DJI_0014_7_7.png'), ('input/DJI_0016_5_2.jpg', 'labels/DJI_0016_5_2.png'), ('input/DJI_0022_4_2.jpg', 'labels/DJI_0022_4_2.png'), ('input/DJI_0007_7_8.jpg', 'labels/DJI_0007_7_8.png'), ('input/DJI_0016_5_1.jpg', 'labels/DJI_0016_5_1.png'), ('input/DJI_0010_1_1.jpg', 'labels/DJI_0010_1_

In [32]:
class DataTransform():
    def __init__(self, input_size, color_mean, color_std):
        self.data_transform = {
            'train': Compose([
                Scale(scale=[0.5, 1.5]),  # 画像の拡大
                RandomRotation(angle=[-10, 10]),  # 回転
                RandomMirror(),  # ランダムミラー
                Resize(input_size),  # リサイズ(input_size)
                Normalize_Tensor(color_mean, color_std)  # 色情報の標準化とテンソル化
            ]),
            'val': Compose([
                Resize(input_size),  # リサイズ(input_size)
                Normalize_Tensor(color_mean, color_std)  # 色情報の標準化とテンソル化
            ])
        }

    def __call__(self, phase, img, anno_class_img):
        return self.data_transform[phase](img, anno_class_img)

In [36]:
class VOCDataset(Dataset):
    def __init__(self, img_list, anno_list, phase, transform):
        self.img_list = img_list
        self.anno_list = anno_list
        self.phase = phase
        self.transform = transform

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, index):
        img, anno_class_img = self.pull_item(index)
        return img, anno_class_img

    def pull_item(self, index):
        image_file_path = self.img_list[index]
        img = Image.open(image_file_path)   # [高さ][幅][色RGB]
        anno_file_path = self.anno_list[index]
        anno_class_img = Image.open(anno_file_path)   # [高さ][幅]
        img, anno_class_img = self.transform(self.phase, img, anno_class_img)
        return img, anno_class_img

In [62]:
# 動作確認
num_train = math.floor(len(datafiles)*0.8)
num_test = len(datafiles)-num_train
num_all = num_train + num_test

def split_train_test(data):
    id_all   = np.random.choice(num_all, num_all, replace=False)
    id_test  = id_all[0:num_test]
    id_train = id_all[num_test:num_all]
    test_data  = data[id_test]
    train_data = data[id_train]
    return train_data, test_data
    
train_list, test_list = split_train_test(np.asarray(datafiles))
input_train, annotation_train = train_list[:,0], train_list[:,1]
input_val, annotation_val = test_list[:,0], test_list[:,1]

print('input   :: train: %d , test: %d'%(len(input_train), len(input_val)))
print('annotation  :: train: %d , test: %d'%(len(annotation_train), len(annotation_val)))

# (RGB)の色の平均値と標準偏差
color_mean = (0.485, 0.456, 0.406)
color_std = (0.229, 0.224, 0.225)

# データセット作成
train_dataset = VOCDataset(input_train, annotation_train, phase="train", transform=DataTransform(
    input_size=475, color_mean=color_mean, color_std=color_std))

val_dataset = VOCDataset(input_val, annotation_val, phase="val", transform=DataTransform(
    input_size=475, color_mean=color_mean, color_std=color_std))

# データの取り出し例
print(val_dataset.__getitem__(0)[0].shape)
print(val_dataset.__getitem__(0)[1].shape)
print(val_dataset.__getitem__(0))

input   :: train: 227 , test: 57
annotation  :: train: 227 , test: 57
torch.Size([3, 475, 475])
torch.Size([475, 475])
(tensor([[[-0.2513, -0.2513, -0.2513,  ..., -0.5767, -0.5767, -0.5767],
         [-0.2856, -0.2856, -0.2856,  ..., -0.5767, -0.5767, -0.5767],
         [-0.3198, -0.3198, -0.3198,  ..., -0.5596, -0.5596, -0.5596],
         ...,
         [-1.1418, -1.1418, -1.1418,  ..., -1.0904, -1.0904, -1.0904],
         [-1.1418, -1.1418, -1.1418,  ..., -1.0904, -1.0904, -1.0904],
         [-1.1418, -1.1418, -1.1418,  ..., -1.0904, -1.0904, -1.0904]],

        [[ 0.4503,  0.4503,  0.4503,  ...,  0.1352,  0.1352,  0.1352],
         [ 0.4153,  0.4153,  0.4153,  ...,  0.1352,  0.1352,  0.1352],
         [ 0.3803,  0.3803,  0.3803,  ...,  0.1527,  0.1527,  0.1527],
         ...,
         [-0.9853, -0.9853, -0.9853,  ..., -0.9853, -0.9853, -0.9853],
         [-0.9853, -0.9853, -0.9853,  ..., -0.9853, -0.9853, -0.9853],
         [-0.9853, -0.9853, -0.9853,  ..., -0.9853, -0.9853, -0.9853]

In [65]:
# データローダーの作成

batch_size = 8

train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

val_dataloader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False)

# 辞書オブジェクトにまとめる
dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

# 動作の確認
batch_iterator = iter(dataloaders_dict["val"])  # イタレータに変換
imges, anno_class_imges = next(batch_iterator)  # 1番目の要素を取り出す
print(imges.size())  # torch.Size([8, 3, 475, 475])
print(anno_class_imges.size())  # torch.Size([8, 3, 475, 475])

torch.Size([8, 3, 475, 475])
torch.Size([8, 475, 475])
