Skip to content

Commit

Permalink
feat: added load_percentage parameter to ImageList.from_files to load…
Browse files Browse the repository at this point in the history
… a subset of the given files (#739)

Closes #736 

### Summary of Changes

perf: massively improved RAM and VRAM usage in ImageList.from_files
perf: massively improved runtime of ImageList.from_files by using
threads
feat: added load_percentage parameter to ImageList.from_files to load a
subset of the given files

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
  • Loading branch information
Marsmaennchen221 and megalinter-bot committed May 8, 2024
1 parent 732aa48 commit 0564b52
Show file tree
Hide file tree
Showing 97 changed files with 378 additions and 27 deletions.
9 changes: 6 additions & 3 deletions src/safeds/data/image/containers/_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import io
import os.path
import sys
import warnings
from pathlib import Path
Expand Down Expand Up @@ -76,12 +77,14 @@ def from_file(path: str | Path) -> Image:
FileNotFoundError
If the file of the path cannot be found
"""
from PIL.Image import open as pil_image_open
from torchvision.transforms.functional import pil_to_tensor
from torchvision.io import read_image

_init_default_device()

return Image(image_tensor=pil_to_tensor(pil_image_open(path)))
if not os.path.isfile(path):
raise FileNotFoundError(f"No such file or directory: '{path}'")

return Image(image_tensor=read_image(str(path)).to(_get_device()))

@staticmethod
def from_bytes(data: bytes) -> Image:
Expand Down
187 changes: 170 additions & 17 deletions src/safeds/data/image/containers/_image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import io
import math
import os
import random
from abc import ABCMeta, abstractmethod
from pathlib import Path
from threading import Thread
from typing import TYPE_CHECKING, Literal, overload

from safeds._config import _init_default_device
from safeds.data.image.containers._image import Image
from safeds.exceptions import OutOfBoundsError, ClosedBound

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -87,26 +90,61 @@ def from_files(path: str | Path | Sequence[str | Path]) -> ImageList: ...

@staticmethod
@overload
def from_files(path: str | Path | Sequence[str | Path], return_filenames: Literal[False]) -> ImageList: ...
def from_files(path: str | Path | Sequence[str | Path], *, load_percentage: float) -> ImageList: ...

@staticmethod
@overload
def from_files(path: str | Path | Sequence[str | Path], *, return_filenames: Literal[False]) -> ImageList: ...

@staticmethod
@overload
def from_files(
path: str | Path | Sequence[str | Path],
*,
return_filenames: Literal[False],
load_percentage: float,
) -> ImageList: ...

@staticmethod
@overload
def from_files(
path: str | Path | Sequence[str | Path],
*,
return_filenames: Literal[True],
) -> tuple[ImageList, list[str]]: ...

@staticmethod
@overload
def from_files(
path: str | Path | Sequence[str | Path],
*,
return_filenames: Literal[True],
load_percentage: float,
) -> tuple[ImageList, list[str]]: ...

@staticmethod
@overload
def from_files(
path: str | Path | Sequence[str | Path],
*,
return_filenames: bool,
) -> ImageList | tuple[ImageList, list[str]]: ...

@staticmethod
@overload
def from_files(
path: str | Path | Sequence[str | Path],
*,
return_filenames: bool,
load_percentage: float,
) -> ImageList | tuple[ImageList, list[str]]: ...

@staticmethod
def from_files(
path: str | Path | Sequence[str | Path],
*,
return_filenames: bool = False,
load_percentage: float = 1.0,
) -> ImageList | tuple[ImageList, list[str]]:
"""
Create an ImageList from a directory or a list of files.
Expand All @@ -119,6 +157,8 @@ def from_files(
the path to the directory or a list of files
return_filenames:
if True the output will be a tuple which contains a list of the filenames in order of the images
load_percentage:
the percentage of the given data being loaded. If below 1 the files will be shuffled before loading
Returns
-------
Expand All @@ -129,22 +169,26 @@ def from_files(
------
FileNotFoundError
If the directory or one of the files of the path cannot be found
OutOfBoundsError
If load_percentage is not between 0 and 1
"""
from PIL.Image import open as pil_image_open
from torchvision.transforms.v2.functional import pil_to_tensor

_init_default_device()

from safeds.data.image.containers._empty_image_list import _EmptyImageList
from safeds.data.image.containers._multi_size_image_list import _MultiSizeImageList
from safeds.data.image.containers._single_size_image_list import _SingleSizeImageList

if load_percentage < 0 or load_percentage > 1:
raise OutOfBoundsError(
load_percentage, name="load_percentage", lower_bound=ClosedBound(0), upper_bound=ClosedBound(1)
)

if isinstance(path, list) and len(path) == 0:
return _EmptyImageList()

image_tensors = []
file_names = []
fixed_size = True

path_list: list[str | Path]
if isinstance(path, Path | str):
Expand All @@ -155,30 +199,139 @@ def from_files(
p = Path(path_list.pop(0))
if p.is_dir():
path_list += sorted([p / name for name in os.listdir(p)])
else:
image_tensors.append(pil_to_tensor(pil_image_open(p)))
elif p.is_file():
file_names.append(str(p))
if fixed_size and (
image_tensors[0].size(dim=2) != image_tensors[-1].size(dim=2)
or image_tensors[0].size(dim=1) != image_tensors[-1].size(dim=1)
):
fixed_size = False
else:
raise FileNotFoundError(f"No such file or directory: '{path}'")

if len(image_tensors) == 0:
return _EmptyImageList()
if load_percentage < 1:
random.shuffle(file_names)
file_names = file_names[: max(round(len(file_names) * load_percentage), 1) if load_percentage > 0 else 0]

indices = list(range(len(image_tensors)))
num_of_files = len(file_names)

if fixed_size:
image_list = _SingleSizeImageList._create_image_list(image_tensors, indices)
if num_of_files == 0:
return _EmptyImageList()

image_sizes: dict[tuple[int, int], dict[int, list[str]]] = {}
image_indices: dict[tuple[int, int], dict[int, list[int]]] = {}
image_count: dict[tuple[int, int], int] = {}
max_channel = -1

for i, filename in enumerate(file_names):
im = pil_image_open(filename)
im_channel = len(im.getbands())
im_size = (im.width, im.height)
if im_channel > max_channel:
max_channel = im_channel
if im_size not in image_sizes:
image_sizes[im_size] = {im_channel: [filename]}
image_indices[im_size] = {im_channel: [i]}
image_count[im_size] = 1
elif im_channel not in image_sizes[im_size]:
image_sizes[im_size][im_channel] = [filename]
image_indices[im_size][im_channel] = [i]
image_count[im_size] += 1
else:
image_sizes[im_size][im_channel].append(filename)
image_indices[im_size][im_channel].append(i)
image_count[im_size] += 1

num_of_threads = min(math.ceil(num_of_files / 1000), 100)
num_of_files_per_thread = math.ceil(num_of_files / num_of_threads)

single_sized_image_lists = []
thread_packages = []
for size, image_files in image_sizes.items():
im_list, packages = _SingleSizeImageList._create_image_list_from_files(
image_files,
image_count[size],
max_channel,
size[0],
size[1],
image_indices[size],
num_of_files_per_thread,
)
single_sized_image_lists.append(im_list._as_single_size_image_list())
thread_packages += packages
thread_packages.sort(key=lambda x: len(x), reverse=True)

threads: list[ImageList._FromImageThread] = []
for thread_index in range(num_of_threads):
current_thread_workload = 0
current_thread_packages = []
while current_thread_workload < num_of_files_per_thread and len(thread_packages) > 0:
next_package = thread_packages.pop()
current_thread_packages.append(next_package)
current_thread_workload += len(next_package)
if thread_index == num_of_threads - 1 and len(thread_packages) > 0:
current_thread_packages += thread_packages # pragma: no cover
thread = ImageList._FromImageThread(current_thread_packages)
threads.append(thread)
thread.start()
for thread in threads:
thread.join()

if len(single_sized_image_lists) == 1:
image_list: ImageList = single_sized_image_lists[0]
else:
image_list = _MultiSizeImageList._create_image_list(image_tensors, indices)
image_list = _MultiSizeImageList._create_from_single_sized_image_lists(single_sized_image_lists)

if return_filenames:
return image_list, file_names
else:
return image_list

class _FromFileThreadPackage:

def __init__(
self,
im_files: list[str],
im_channel: int,
to_channel: int,
im_width: int,
im_height: int,
tensor: Tensor,
start_index: int,
) -> None:
self._im_files = im_files
self._im_channel = im_channel
self._to_channel = to_channel
self._im_width = im_width
self._im_height = im_height
self._tensor = tensor
self._start_index = start_index

def load_files(self) -> None:
import torch
from torchvision.io import read_image

_init_default_device()

num_of_files = len(self._im_files)
tensor_channel = max(self._im_channel, min(self._to_channel, 3))
for index, im in enumerate(self._im_files):
self._tensor[index + self._start_index, 0:tensor_channel] = read_image(im)
if self._to_channel == 4 and self._im_channel < 4:
torch.full(
(num_of_files, 1, self._im_height, self._im_width),
255,
out=self._tensor[self._start_index : self._start_index + num_of_files, 3:4],
)

def __len__(self) -> int:
return len(self._im_files)

class _FromImageThread(Thread):

def __init__(self, packages: list[ImageList._FromFileThreadPackage]) -> None:
super().__init__()
self._packages = packages

def run(self) -> None:
for pck in self._packages:
pck.load_files()

@abstractmethod
def _clone(self) -> ImageList:
"""
Expand Down
39 changes: 39 additions & 0 deletions src/safeds/data/image/containers/_multi_size_image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,45 @@ def __init__(self) -> None:
self._image_list_dict: dict[tuple[int, int], ImageList] = {} # {image_size: image_list}
self._indices_to_image_size_dict: dict[int, tuple[int, int]] = {} # {index: image_size}

@staticmethod
def _create_from_single_sized_image_lists(single_size_image_lists: list[_SingleSizeImageList]) -> ImageList:
from safeds.data.image.containers._empty_image_list import _EmptyImageList

if len(single_size_image_lists) == 0:
return _EmptyImageList()
elif len(single_size_image_lists) == 1:
return single_size_image_lists[0]

different_channels: bool = False
max_channel: None | int = None

image_list = _MultiSizeImageList()
for single_size_image_list in single_size_image_lists:
image_size = (single_size_image_list.widths[0], single_size_image_list.heights[0])
image_list._image_list_dict[image_size] = single_size_image_list
image_list._indices_to_image_size_dict.update(
zip(
single_size_image_list._indices_to_tensor_positions.keys(),
[image_size] * len(single_size_image_list),
strict=False,
)
)
if max_channel is None:
max_channel = single_size_image_list.channel
elif max_channel < single_size_image_list.channel:
different_channels = True
max_channel = single_size_image_list.channel
elif max_channel > single_size_image_list.channel:
different_channels = True

if different_channels:
for size in image_list._image_list_dict:
if max_channel is not None and image_list._image_list_dict[size].channel != max_channel:
image_list._image_list_dict[size] = image_list._image_list_dict[size].change_channel(
int(max_channel)
)
return image_list

@staticmethod
def _create_image_list(images: list[Tensor], indices: list[int]) -> ImageList:
from safeds.data.image.containers._empty_image_list import _EmptyImageList
Expand Down

0 comments on commit 0564b52

Please sign in to comment.