In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%pip install -q torchvision pandas

Note: you may need to restart the kernel to use updated packages.


In [21]:
from charts.common.dataset import LabeledImage

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.io import read_image

import pandas as pd

import os
from pathlib import Path

In [10]:
img_dir = Path('../../generated/drawings')
json_files = sorted(img_dir.glob("img-?????-???.json"))
labeled_images = list(map(LabeledImage, json_files))
labeled_images[0:2]

[img-00000-000.json, img-00000-001.json]

In [34]:
class ColorRegressionImageDataset(Dataset):
    def __init__(self, img_dir, transform=None, target_transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

        json_files = sorted(img_dir.glob("img-?????-???.json"))
        self.labeled_images = list(map(LabeledImage, json_files))

    def __len__(self):
        return len(self.labeled_images)

    def __getitem__(self, idx):
        labeled_img = self.labeled_images[idx]
        labeled_img.ensure_images_loaded()
        image = labeled_img.rendered_image
        labels_image = labeled_img.labels_image
        labeled_img.release_images()
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            labels_image = self.target_transform(labels_image)
        return image, labels_image, repr(self.labeled_images[idx])

    def __repr__(self):
        return f"{len(self)} images, first is {self.labeled_images[0]}, last is {self.labeled_images[-1]}"

In [35]:
dataset = ColorRegressionImageDataset(Path('../../generated/drawings'))
n_train = len(dataset) // 2
n_test = len(dataset) - n_train
train_dataset, test_dataset = random_split(dataset, [n_train, n_test], generator=torch.Generator().manual_seed(42))

In [44]:
display([v[2] for v in list(train_dataset)[0:10]])
display([v[2] for v in list(test_dataset)[0:10]])

['img-00005-042.json',
 'img-00006-018.json',
 'img-00008-016.json',
 'img-00000-068.json',
 'img-00000-094.json',
 'img-00002-015.json',
 'img-00000-060.json',
 'img-00005-085.json',
 'img-00009-042.json',
 'img-00001-065.json']

['img-00008-082.json',
 'img-00003-019.json',
 'img-00001-001.json',
 'img-00005-078.json',
 'img-00008-062.json',
 'img-00004-063.json',
 'img-00002-078.json',
 'img-00006-063.json',
 'img-00006-071.json',
 'img-00007-038.json']