In [48]:
import json

with open('segment guns.v2i.coco-segmentation/train/_annotations.coco.json', 'r') as file:
    data = json.load(file)

In [49]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os
import torchvision
import torch
import torchvision.transforms as transforms

from PIL import Image


def read_dataset(path, n=-1):
    features, labels = [], []
    transform = transforms.ToTensor()
    for i, image in enumerate(data['images']):
        if i == n: break
        features.append(transform(Image.open(os.path.join(path, str(image['file_name'])))))
        annotations = list(filter(lambda annotation: annotation.get('image_id') == image['id'], data['annotations']))
        label = np.zeros((image['height'], image['width']), dtype=np.uint8)
        for annotation in annotations:
            pts = np.array(annotation.get('segmentation'), dtype=np.int32)
            pts = pts.reshape((1, int(pts.shape[-1] / 2), 2))
            cv2.fillPoly(label, pts, 1)

        labels.append(torch.from_numpy(label))

    return features, labels
path = 'C:\\Proxectos\\Custom_U-NET\\segment guns.v2i.coco-segmentation\\train'

# A, B = read_dataset(path, 10)

In [50]:
def voc_rand_crop(feature, label, height, width):
    """Randomly crop both feature and label images."""
    rect = torchvision.transforms.RandomCrop.get_params(
        feature, (height, width))
    feature = torchvision.transforms.functional.crop(feature, *rect)
    label = torchvision.transforms.functional.crop(label, *rect)
    return feature, label

In [51]:
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    """Plot a list of images.

    Defined in :numref:`sec_utils`"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        try:
            img = img.detach().numpy()
        except:
            pass
        ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes
     
# imgs = []
# n=5
# for _ in range(n):
#     imgs += voc_rand_crop(A[3], B[3].unsqueeze(0) , 420, 600)
# 
# imgs = [img.permute(1, 2, 0) for img in imgs]
# show_images(imgs[::2] + imgs[1::2], 2, n, scale = 10);

In [52]:
class GunsDataset(torch.utils.data.Dataset):

    def __init__(self, crop_size, path):
        self.transform = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.crop_size = crop_size
        features, labels = read_dataset(path)
        self.features = [self.normalize_image(feature)
                         for feature in self.filter(features)]
        self.labels = self.filter(labels)
        print('read ' + str(len(self.features)) + ' examples')

    def normalize_image(self, img):
        return self.transform(img.float() / 255.0)

    def filter(self, imgs):
        return [img for img in imgs if (
            img.shape[1] >= self.crop_size[0] and
            img.shape[2] >= self.crop_size[1])]

    def __getitem__(self, idx):
        feature, label = voc_rand_crop(self.features[idx], self.labels[idx], *self.crop_size)
        return feature, label

    def __len__(self):
        return len(self.features)
     

In [None]:
crop_size = (420, 600)
voc_train = GunsDataset(crop_size, path)

In [None]:
batch_size = 32
train_iter = torch.utils.data.DataLoader(voc_train, batch_size, shuffle=True,
                                    drop_last=True,
                                    num_workers=4)

for X, Y in train_iter:
    print(X.shape)
    print(Y.shape)
    break