diff --git a/webloader/loader.py b/webloader/loader.py index 260dcef..18805d1 100644 --- a/webloader/loader.py +++ b/webloader/loader.py @@ -15,13 +15,14 @@ import collections from past.utils import old_div +import atexit from . import filters, gopen, paths, utils _big = 1 << 60 try: - from torch import Tensor as TorchTensor + from torch import Tensor as TorchTensor, from_numpy, device as torch_device except: class TorchTensor(object): pass @@ -170,6 +171,7 @@ def __init__(self, epochs=1, pipeline=None, verbose=False, + use_shared_mem=True, use_tracker=True, use_torch_mp=False, processes=1, @@ -189,6 +191,7 @@ def __init__(self, :param epochs: number of epochs to iterate for (Default value = 1) :param pipeline: pipeline to apply to samples before field extraction (Default value = None) :param verbose: output extra information (Default value = False) + :param use_shared_mem: whether to use shared memory or not (Default value = True) :param use_tracker: ignored (for interface compatiblity with MultiWebLoader) :param use_torch_mp: ignored (for interface compatiblity with MultiWebLoader) :param processes: ignored (for interface compatiblity with MultiWebLoader) @@ -219,6 +222,8 @@ def __init__(self, if use_tracker: self.tracker = Tracker() self.verbose = verbose + self.use_shared_mem = use_shared_mem + self.converted = None def __iter__(self): """Iterate over samples.""" @@ -262,6 +267,19 @@ def __iter__(self): if isinstance(sample, dict): raise ValueError("expect list for batch_transforms; did you specify fields= for WebLoader?") sample = transform_with(sample, self.converters) + + if not self.use_shared_mem: + if isinstance(sample[0], TorchTensor): + device = sample[0].device + sample[0] = sample[0].cpu().numpy() + sample[1] = sample[1].cpu().numpy() + if self.converted is None: + self.converted = True + yield f"Converted;{device}" + if self.converted is None: + self.converted = False + yield "Not Converted;None" + self.last_sample = sample total += max(1, self.batch_size) yield sample @@ -271,6 +289,7 @@ def __len__(self): """Return the length of the dataset (the size argument passed on initialization).""" return self.size + def make_loader(args, kw, queue, index): kw["use_tracker"] = False data = WebLoader(*args, **kw) @@ -331,17 +350,41 @@ def __init__(self, urls, size, processes=4, use_torch_mp=False, queue_size=10, m :param **kw: other keyword arguments are passed to WebLoader """ + + stats = os.popen('df').read().splitlines() + stats = [line.split() for line in stats] + shared_mem_available = None + assert stats[0][1] == "1K-blocks" # First line, second column + for line in stats[1:]: + if line[-1] == "/dev/shm": + shared_mem_available = int(line[1]) + break + assert shared_mem_available is not None + shared_mem_available = shared_mem_available // 1024 # Convert to MB + print(f"Shared memory available: {shared_mem_available} MB") + + # Disable shared memory usage if the amount of available memory is less than 1 GB + use_shared_mem = shared_mem_available >= 1024 + if not use_shared_mem: + print("Disabling shared memory usage due to low shared memory!") + assert "epochs" not in kw, kw kw["epochs"] = 999999999 self.size = size self.args = (urls, size) self.kw = kw + self.kw['use_shared_mem'] = use_shared_mem self.use_torch_mp = use_torch_mp self.processes = processes self.queue_size = queue_size self.multi_pipe = multi_pipes.get(multi_pipe, multi_pipe) assert self.multi_pipe is None or callable(self.multi_pipe) self.jobs = None + self.converted = None + self.device = None + + # Register a terminate call on object removal + atexit.register(self.terminate) def raw_iter(self): """Iterate over samples. @@ -356,7 +399,6 @@ def raw_iter(self): import multiprocessing as mp while True: if self.jobs is None: - #print("starting jobs") self.queue = mp.Queue(self.queue_size) self.jobs = [mp.Process(target=make_loader, args=(self.args, self.kw, self.queue, i)) for i in range(self.processes)] @@ -365,13 +407,29 @@ def raw_iter(self): try: while True: sample = self.queue.get() + + if not self.kw['use_shared_mem']: + if isinstance(sample, str): + assert "Converted" in sample + if self.converted is None: + converted, device = sample.split(";") + self.converted = True if converted == "Converted" else False + if self.converted: + self.device = torch_device(device) + print(f"Converting back to tensors (Device: {self.device})!") + continue + else: + # Converting back from numpy array to torch tensor + assert self.converted is not None + if self.converted: + assert self.device is not None + sample[0] = from_numpy(sample[0]).to(self.device) + sample[1] = from_numpy(sample[1]).to(self.device) + yield sample except FileNotFoundError as exn: - print("restarting MultiWebLoader jobs") - #print("got exception in mp:", exn) - #print("terminating jobs") + print("Restarting MultiWebLoader jobs") self.terminate() - #print("done terminating") def __iter__(self): result = self.raw_iter() diff --git a/webloader/utils.py b/webloader/utils.py index 09f570d..83820bf 100644 --- a/webloader/utils.py +++ b/webloader/utils.py @@ -248,14 +248,14 @@ def pildumps(image, format="PNG"): def autodecode1(data, tname, imagetype="rgb"): - # Unicode change. If it is alread an unicode string, no decoding (Byte->Unicode req) + # Unicode change. If it is already an unicode string, no decoding (Byte->Unicode req) if isinstance(data, (int, float, unicode)): return data if sys.version_info[0] == 2: # Then, it has to be byte string, which is also of type str assert isinstance(data, (str, buffer)), type(data) else: - # In Python 3, it has to be a bytes string at this point. You've checked if it is normal string above (unicode check) + # In Python 3, it has to be a bytes string at this point. Already verified if it is a normal string above (unicode check) assert isinstance(data, bytes), type(data) assert isinstance(tname, str), tname extension = re.sub(r".*\.", "", tname).lower()