Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions wsds/ws_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ def __init__(

self._register_wsds_links()

def close(self):
"""Close all cached shard file handles and linked datasets."""
for shard in self._open_shards.values():
shard.close()
self._open_shards.clear()
for ds in self._linked_datasets.values():
ds.close()
self._linked_datasets.clear()

def enable_filter(self, filter_name: str, filter_df: pl.DataFrame):
"""
Enabling a filter adds extra columns to the dataset, each column representing a filter.
Expand Down
16 changes: 14 additions & 2 deletions wsds/ws_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ def __init__(self, dataset, fname, shard_ref=None):

try:
if dataset.disable_memory_map:
self.reader = pa.RecordBatchFileReader(pa.OSFile(str(fname)))
self._source_file = pa.OSFile(str(fname))
else:
self.reader = pa.RecordBatchFileReader(pa.memory_map(str(fname)))
self._source_file = pa.memory_map(str(fname))
self.reader = pa.RecordBatchFileReader(self._source_file)
except FileNotFoundError:
raise WSShardMissingError(fname) from None

Expand Down Expand Up @@ -92,6 +93,17 @@ def get_sample(self, column: str, offset: int) -> typing.Any:
except Exception as e:
raise ValueError(f"Failed to decode column {column} in shard {self.fname} (offset {offset}): {e}")

def close(self):
"""Close the underlying pyarrow file handle."""
self.reader = None
self._data = None
if hasattr(self, "_source_file") and self._source_file is not None:
try:
self._source_file.close()
except Exception:
pass
self._source_file = None

def __repr__(self):
r = f"WSShard({repr(self.fname)})"
if self._data:
Expand Down
27 changes: 23 additions & 4 deletions wsds/ws_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def close(self):
)
assert self._sink is not None, "closing a WSSink that was never written to"
self._sink.close()
# pyarrow RecordBatchFileWriter.close() does NOT release the fd — only GC does.
# Drop the reference and force collection so volume.reload() won't be blocked.
self._sink = None
import gc; gc.collect()

def __enter__(self):
assert self._sink is None, "WSSink is not re-entrant"
Expand All @@ -136,6 +140,13 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
self.close()
elif self._sink is not None:
try:
self._sink.close()
except Exception:
pass
self._sink = None
import gc; gc.collect()


class AtomicFile:
Expand Down Expand Up @@ -172,12 +183,20 @@ def _build_key_iter(fname):
p = Path(fname)
parent_dir = p.parent.parent

ds = None
try:
ds = WSDataset(parent_dir, ignore_index=True)
it = (s["__key__"] for s in ds.iter_shard(("", p.stem)))
first = next(it)
return itertools.chain([first], it)
except (OSError, WSShardMissingError, StopIteration):
keys = [s["__key__"] for s in ds.iter_shard(("", p.stem))]
ds.close()
if not keys:
return None
return iter(keys)
except (OSError, WSShardMissingError, StopIteration, KeyError):
if ds is not None:
try:
ds.close()
except Exception:
pass
return None


Expand Down