In [52]:
import os
import numpy as np
from PIL import Image
from os.path import join, isdir, isfile
import cv2 

import torch
from torch.utils.data import Dataset, DataLoader


In [2]:
varietals = ['CDY', 'CFR', 'CSV', 'SVB', 'SYH']

In [41]:
DATA_FOLDER = "../../../wgisd/data"
print(isdir(DATA_FOLDER))
IM_NORM_MEAN = [0.485, 0.456, 0.406]
IM_NORM_STD = [0.229, 0.224, 0.225]

True


In [10]:
instances = {v: [] for v in varietals}
all_files = next(os.walk(DATA_FOLDER))[2]
label_files = [f for f in all_files if f.endswith(".txt")]
image_files = [f for f in all_files if f.endswith(".jpg")]
mask_files = [f for f in all_files if f.endswith(".npz")]
filenames = [f[:-4] for f in label_files]

for filename in filenames:
    for v in varietals:
        if filename.startswith(v):
            instances[v].append(filename)          
    

In [11]:
n_images = {k: len(v) for (k,v) in instances.items()}
print(n_images)

{'CDY': 65, 'CFR': 65, 'CSV': 57, 'SVB': 65, 'SYH': 48}


In [12]:
print(len(image_files), len(label_files), len(mask_files))

300 300 137


In [77]:
class WGISDMaskedDataset(Dataset):
    def __init__(self, root, transform=None, split="train"):
        self.root = root
        self.transform = transform
        assert split in ["train", "test"], "split must be train or test, get {}".format(split)
        
#         with open(join(self.root, split+"_masked.txt"), "r") as handle:
        with open(join(self.root, split+".txt"), "r") as handle:
            lines = handle.readlines()
            lines = [l.rstrip() for l in lines]
        
        self.img_paths = [join(self.root, "data", line+".jpg") for line in lines]
        self.label_paths = [join(self.root, "data", line+".txt") for line in lines]
        self.mask_paths = [join(self.root, "data", line+".npy") for line in lines]
        
    def __len__(self, ):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[i]
        mask_path = self.mask_paths[i]
        label_path = self.label_paths[i]
         
        img = Image.open(img_path).convert("RGB")
        img = np.array(img) 
        
        with open(label_path, "r") as handle:
            lines = handle.readlines()
            lines = [l.rstrip() for l in lines]
        labels = []
        boxes = []
        for l in lines:
            cls, cen_x, cen_y, w, h = l.split()
            cls, cen_x, cen_y = int(cls), float(cen_x), float(cen_y)
            w, h = float(w), float(h)
            labels.append(cls+1)
            boxes.append([cen_x, cen_y, w, h])
        labels = np.array(labels)
        boxes = np.array(boxes, dtype=np.float32)
        if self.transform is not None:
            img, boxes = self.transform(img, boxes)
        return img, torch.from_numpy(labels), boxes
        

In [78]:
import torchvision.transforms as transforms

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, bboxes):
        for t in self.transforms:
            img, bboxes = t(img), bboxes

        return img, bboxes
    
# base_transform = Compose([transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize()])
base_transform = Compose([transforms.ToTensor(), transforms.Normalize(mean=IM_NORM_MEAN, std=IM_NORM_STD)])
base_transform = Compose([transforms.ToTensor()])


In [79]:
# data = WGISDMaskedDataset(root="../../../wgisd", transform=base_transform)
data = WGISDMaskedDataset(root="../../../wgisd", )
for i in range(1):
    item = data.__getitem__(i)
    

In [80]:
img, labels, boxes = item
sizes = (2048, 1365)

In [87]:
draw_img = img.copy()
for idx, box in enumerate(boxes):
    cx, cy, w, h = box
    start_point = int((cx-w/2)*sizes[0]), int((cy-h/2)*sizes[1]) 
    end_point = int((cx+w/2)*sizes[0]), int((cy+h/2)*sizes[1]) 
    draw_img = cv2.rectangle(draw_img, start_point, end_point, (255, 0, 0), 2)
cv2.imwrite("sample.jpg", draw_img)

True

In [82]:
img[0][0]

array([17, 21,  6], dtype=uint8)

In [45]:
from matplotlib import pyplot as plt
from matplotlib import patches
from matplotlib.patches import Polygon

from skimage.color import label2rgb

import colorsys
import random