diff --git a/wsds/ws_dataset.py b/wsds/ws_dataset.py index 35dd5cd..580fb57 100644 --- a/wsds/ws_dataset.py +++ b/wsds/ws_dataset.py @@ -110,6 +110,15 @@ def __init__( self._register_wsds_links() + def close(self): + """Close all cached shard file handles and linked datasets.""" + for shard in self._open_shards.values(): + shard.close() + self._open_shards.clear() + for ds in self._linked_datasets.values(): + ds.close() + self._linked_datasets.clear() + def enable_filter(self, filter_name: str, filter_df: pl.DataFrame): """ Enabling a filter adds extra columns to the dataset, each column representing a filter. diff --git a/wsds/ws_shard.py b/wsds/ws_shard.py index b0cf502..39639cc 100644 --- a/wsds/ws_shard.py +++ b/wsds/ws_shard.py @@ -49,9 +49,10 @@ def __init__(self, dataset, fname, shard_ref=None): try: if dataset.disable_memory_map: - self.reader = pa.RecordBatchFileReader(pa.OSFile(str(fname))) + self._source_file = pa.OSFile(str(fname)) else: - self.reader = pa.RecordBatchFileReader(pa.memory_map(str(fname))) + self._source_file = pa.memory_map(str(fname)) + self.reader = pa.RecordBatchFileReader(self._source_file) except FileNotFoundError: raise WSShardMissingError(fname) from None @@ -92,6 +93,17 @@ def get_sample(self, column: str, offset: int) -> typing.Any: except Exception as e: raise ValueError(f"Failed to decode column {column} in shard {self.fname} (offset {offset}): {e}") + def close(self): + """Close the underlying pyarrow file handle.""" + self.reader = None + self._data = None + if hasattr(self, "_source_file") and self._source_file is not None: + try: + self._source_file.close() + except Exception: + pass + self._source_file = None + def __repr__(self): r = f"WSShard({repr(self.fname)})" if self._data: diff --git a/wsds/ws_sink.py b/wsds/ws_sink.py index e1697bd..ddebbf1 100644 --- a/wsds/ws_sink.py +++ b/wsds/ws_sink.py @@ -128,6 +128,10 @@ def close(self): ) assert self._sink is not None, "closing a WSSink that was never written to" self._sink.close() + # pyarrow RecordBatchFileWriter.close() does NOT release the fd — only GC does. + # Drop the reference and force collection so volume.reload() won't be blocked. + self._sink = None + import gc; gc.collect() def __enter__(self): assert self._sink is None, "WSSink is not re-entrant" @@ -136,6 +140,13 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): if exc_type is None: self.close() + elif self._sink is not None: + try: + self._sink.close() + except Exception: + pass + self._sink = None + import gc; gc.collect() class AtomicFile: @@ -172,12 +183,20 @@ def _build_key_iter(fname): p = Path(fname) parent_dir = p.parent.parent + ds = None try: ds = WSDataset(parent_dir, ignore_index=True) - it = (s["__key__"] for s in ds.iter_shard(("", p.stem))) - first = next(it) - return itertools.chain([first], it) - except (OSError, WSShardMissingError, StopIteration): + keys = [s["__key__"] for s in ds.iter_shard(("", p.stem))] + ds.close() + if not keys: + return None + return iter(keys) + except (OSError, WSShardMissingError, StopIteration, KeyError): + if ds is not None: + try: + ds.close() + except Exception: + pass return None