Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions webloader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
import collections

from past.utils import old_div
import atexit

from . import filters, gopen, paths, utils

_big = 1 << 60

try:
from torch import Tensor as TorchTensor
from torch import Tensor as TorchTensor, from_numpy, device as torch_device
except:
class TorchTensor(object): pass

Expand Down Expand Up @@ -170,6 +171,7 @@ def __init__(self,
epochs=1,
pipeline=None,
verbose=False,
use_shared_mem=True,
use_tracker=True,
use_torch_mp=False,
processes=1,
Expand All @@ -189,6 +191,7 @@ def __init__(self,
:param epochs: number of epochs to iterate for (Default value = 1)
:param pipeline: pipeline to apply to samples before field extraction (Default value = None)
:param verbose: output extra information (Default value = False)
:param use_shared_mem: whether to use shared memory or not (Default value = True)
:param use_tracker: ignored (for interface compatiblity with MultiWebLoader)
:param use_torch_mp: ignored (for interface compatiblity with MultiWebLoader)
:param processes: ignored (for interface compatiblity with MultiWebLoader)
Expand Down Expand Up @@ -219,6 +222,8 @@ def __init__(self,
if use_tracker:
self.tracker = Tracker()
self.verbose = verbose
self.use_shared_mem = use_shared_mem
self.converted = None

def __iter__(self):
"""Iterate over samples."""
Expand Down Expand Up @@ -262,6 +267,19 @@ def __iter__(self):
if isinstance(sample, dict):
raise ValueError("expect list for batch_transforms; did you specify fields= for WebLoader?")
sample = transform_with(sample, self.converters)

if not self.use_shared_mem:
if isinstance(sample[0], TorchTensor):
device = sample[0].device
sample[0] = sample[0].cpu().numpy()
sample[1] = sample[1].cpu().numpy()
if self.converted is None:
self.converted = True
yield f"Converted;{device}"
if self.converted is None:
self.converted = False
yield "Not Converted;None"

self.last_sample = sample
total += max(1, self.batch_size)
yield sample
Expand All @@ -271,6 +289,7 @@ def __len__(self):
"""Return the length of the dataset (the size argument passed on initialization)."""
return self.size


def make_loader(args, kw, queue, index):
kw["use_tracker"] = False
data = WebLoader(*args, **kw)
Expand Down Expand Up @@ -331,17 +350,41 @@ def __init__(self, urls, size, processes=4, use_torch_mp=False, queue_size=10, m
:param **kw: other keyword arguments are passed to WebLoader

"""

stats = os.popen('df').read().splitlines()
stats = [line.split() for line in stats]
shared_mem_available = None
assert stats[0][1] == "1K-blocks" # First line, second column
for line in stats[1:]:
if line[-1] == "/dev/shm":
shared_mem_available = int(line[1])
break
assert shared_mem_available is not None
shared_mem_available = shared_mem_available // 1024 # Convert to MB
print(f"Shared memory available: {shared_mem_available} MB")

# Disable shared memory usage if the amount of available memory is less than 1 GB
use_shared_mem = shared_mem_available >= 1024
if not use_shared_mem:
print("Disabling shared memory usage due to low shared memory!")

assert "epochs" not in kw, kw
kw["epochs"] = 999999999
self.size = size
self.args = (urls, size)
self.kw = kw
self.kw['use_shared_mem'] = use_shared_mem
self.use_torch_mp = use_torch_mp
self.processes = processes
self.queue_size = queue_size
self.multi_pipe = multi_pipes.get(multi_pipe, multi_pipe)
assert self.multi_pipe is None or callable(self.multi_pipe)
self.jobs = None
self.converted = None
self.device = None

# Register a terminate call on object removal
atexit.register(self.terminate)

def raw_iter(self):
"""Iterate over samples.
Expand All @@ -356,7 +399,6 @@ def raw_iter(self):
import multiprocessing as mp
while True:
if self.jobs is None:
#print("starting jobs")
self.queue = mp.Queue(self.queue_size)
self.jobs = [mp.Process(target=make_loader, args=(self.args, self.kw, self.queue, i))
for i in range(self.processes)]
Expand All @@ -365,13 +407,29 @@ def raw_iter(self):
try:
while True:
sample = self.queue.get()

if not self.kw['use_shared_mem']:
if isinstance(sample, str):
assert "Converted" in sample
if self.converted is None:
converted, device = sample.split(";")
self.converted = True if converted == "Converted" else False
if self.converted:
self.device = torch_device(device)
print(f"Converting back to tensors (Device: {self.device})!")
continue
else:
# Converting back from numpy array to torch tensor
assert self.converted is not None
if self.converted:
assert self.device is not None
sample[0] = from_numpy(sample[0]).to(self.device)
sample[1] = from_numpy(sample[1]).to(self.device)

yield sample
except FileNotFoundError as exn:
print("restarting MultiWebLoader jobs")
#print("got exception in mp:", exn)
#print("terminating jobs")
print("Restarting MultiWebLoader jobs")
self.terminate()
#print("done terminating")

def __iter__(self):
result = self.raw_iter()
Expand Down
4 changes: 2 additions & 2 deletions webloader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,14 @@ def pildumps(image, format="PNG"):


def autodecode1(data, tname, imagetype="rgb"):
# Unicode change. If it is alread an unicode string, no decoding (Byte->Unicode req)
# Unicode change. If it is already an unicode string, no decoding (Byte->Unicode req)
if isinstance(data, (int, float, unicode)):
return data
if sys.version_info[0] == 2:
# Then, it has to be byte string, which is also of type str
assert isinstance(data, (str, buffer)), type(data)
else:
# In Python 3, it has to be a bytes string at this point. You've checked if it is normal string above (unicode check)
# In Python 3, it has to be a bytes string at this point. Already verified if it is a normal string above (unicode check)
assert isinstance(data, bytes), type(data)
assert isinstance(tname, str), tname
extension = re.sub(r".*\.", "", tname).lower()
Expand Down