Skip to content

Commit

Permalink
Fix circular reference between TdmsFile and TdmsChannel (#201)
Browse files Browse the repository at this point in the history
Fixes #199. This circular reference meant that when a TdmsFile went out of scope the memory it used wasn't immediately freed but was only freed after a GC collection.
  • Loading branch information
adamreeve committed May 19, 2020
1 parent c4ced51 commit 4a5c9df
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 83 deletions.
5 changes: 5 additions & 0 deletions nptdms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def is_group(self):
def is_channel(self):
return self.channel is not None

def group_path(self):
""" For channel paths, returns the path of the channel's group as a string
"""
return _components_to_path(self.group, None)

@staticmethod
def from_string(path_string):
components = list(_path_components(path_string))
Expand Down
9 changes: 9 additions & 0 deletions nptdms/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def close(self):
def read_metadata(self):
""" Read all metadata and structure information from a TdmsFile
"""
self._ensure_open()

if self._index_file_path is not None:
reading_index_file = True
Expand Down Expand Up @@ -103,6 +104,7 @@ def read_raw_data(self):
:returns: A generator that yields RawDataChunk objects
"""
self._ensure_open()
if self._segments is None:
raise RuntimeError(
"Cannot read data unless metadata has first been read")
Expand All @@ -121,6 +123,7 @@ def read_raw_data_for_channel(self, channel_path, offset=0, length=None):
Fewer values will be returned if attempting to read beyond the end of the available data.
:returns: A generator that yields RawChannelDataChunk objects
"""
self._ensure_open()
if self._segments is None:
raise RuntimeError("Cannot read data unless metadata has first been read")

Expand Down Expand Up @@ -186,6 +189,7 @@ def read_channel_chunk_for_index(self, channel_path, index):
:returns: Tuple of raw channel data chunk and the integer offset to the beginning of the chunk
:rtype: (RawChannelDataChunk, int)
"""
self._ensure_open()
if self._segments is None:
raise RuntimeError("Cannot read data unless metadata has first been read")

Expand Down Expand Up @@ -359,6 +363,11 @@ def _build_index(self):
self._segment_channel_offsets = {
path: np.cumsum(segment_count) for (path, segment_count) in segment_num_values.items()}

def _ensure_open(self):
if self._file is None:
raise RuntimeError(
"Cannot read data after the underlying TDMS reader is closed")


def _number_of_segment_values(segment_object, segment):
""" Compute the number of values an object has in a segment
Expand Down
169 changes: 86 additions & 83 deletions nptdms/tdms.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,16 @@ def __init__(self, file, raw_timestamps=False, memmap_dir=None, read_metadata_on
self._memmap_dir = memmap_dir
self._raw_timestamps = raw_timestamps
self._groups = OrderedDict()
self._properties = {}
self._properties = OrderedDict()
self._channel_data = {}
self._reader = None
self.data_read = False

reader = TdmsReader(file)
self._reader = TdmsReader(file)
try:
self._read_file(reader, read_metadata_only)
self._read_file(self._reader, read_metadata_only)
finally:
if keep_open:
self._reader = reader
else:
reader.close()
if not keep_open:
self._reader.close()

def groups(self):
"""Returns a list of the groups in this file
Expand Down Expand Up @@ -190,10 +187,9 @@ def data_chunks(self):
:rtype: Generator that yields :class:`DataChunk` objects
"""
reader = self._get_reader()
channel_offsets = defaultdict(int)
for chunk in reader.read_raw_data():
self._convert_data_chunk(chunk)
for chunk in self._reader.read_raw_data():
_convert_data_chunk(chunk, self._raw_timestamps)
yield DataChunk(self, chunk, channel_offsets)
for path, data in chunk.channel_data.items():
channel_offsets[path] += len(data)
Expand Down Expand Up @@ -232,31 +228,37 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
self.close()

def _get_reader(self):
if self._reader is None:
raise RuntimeError(
"Cannot read data after the underlying TDMS reader is closed")
return self._reader

def _read_file(self, tdms_reader, read_metadata_only):
tdms_reader.read_metadata()

# Use object metadata to build group and channel objects
group_properties = OrderedDict()
group_channels = OrderedDict()
object_properties = {
path_string: self._convert_properties(obj.properties)
for path_string, obj in tdms_reader.object_metadata.items()}
try:
self._properties = object_properties['/']
except KeyError:
pass

for (path_string, obj) in tdms_reader.object_metadata.items():
properties = object_properties[path_string]
path = ObjectPath.from_string(path_string)
obj_properties = self._convert_properties(obj.properties)
if path.is_root:
# Root object provides properties for the whole file
self._properties = obj_properties
pass
elif path.is_group:
group_properties[path.group] = obj_properties
group_properties[path.group] = properties
else:
# Object is a channel
try:
channel_group_properties = object_properties[path.group_path()]
except KeyError:
channel_group_properties = OrderedDict()
channel = TdmsChannel(
self, path, obj_properties, obj.data_type,
obj.scaler_data_types, obj.num_values)
path, obj.data_type, obj.scaler_data_types, obj.num_values,
properties, channel_group_properties, self._properties,
tdms_reader, self._raw_timestamps, self._memmap_dir)
if path.group in group_channels:
group_channels[path.group].append(channel)
else:
Expand Down Expand Up @@ -306,59 +308,13 @@ def _read_data(self, tdms_reader):

self.data_read = True

def _read_channel_data_chunks(self, channel):
reader = self._get_reader()
for chunk in reader.read_raw_data_for_channel(channel.path):
self._convert_channel_data_chunk(chunk)
yield chunk

def _read_channel_data_chunk_for_index(self, channel, index):
(chunk, offset) = self._get_reader().read_channel_chunk_for_index(channel.path, index)
self._convert_channel_data_chunk(chunk)
return chunk, offset

def _read_channel_data(self, channel, offset=0, length=None):
if offset < 0:
raise ValueError("offset must be non-negative")
if length is not None and length < 0:
raise ValueError("length must be non-negative")
reader = self._get_reader()

with Timer(log, "Allocate space for channel"):
# Allocate space for data
if length is None:
num_values = len(channel) - offset
else:
num_values = min(length, len(channel) - offset)
num_values = max(0, num_values)
channel_data = get_data_receiver(channel, num_values, self._raw_timestamps, self._memmap_dir)

with Timer(log, "Read data for channel"):
# Now actually read all the data
for chunk in reader.read_raw_data_for_channel(channel.path, offset, length):
if chunk.data is not None:
channel_data.append_data(chunk.data)
if chunk.scaler_data is not None:
for scaler_id, scaler_data in chunk.scaler_data.items():
channel_data.append_scaler_data(scaler_id, scaler_data)

return channel_data

def _convert_properties(self, properties):
def convert_prop(val):
if isinstance(val, TdmsTimestamp) and not self._raw_timestamps:
# Convert timestamps to numpy datetime64 if raw timestamps are not requested
return val.as_datetime64()
return val
return {k: convert_prop(v) for (k, v) in properties.items()}

def _convert_data_chunk(self, chunk):
for channel_chunk in chunk.channel_data.values():
self._convert_channel_data_chunk(channel_chunk)

def _convert_channel_data_chunk(self, channel_chunk):
if not self._raw_timestamps and isinstance(channel_chunk.data, TimestampArray):
channel_chunk.data = channel_chunk.data.as_datetime64()
return OrderedDict((k, convert_prop(v)) for (k, v) in properties.items())

def object(self, *path):
"""(Deprecated) Get a TDMS object from the file
Expand Down Expand Up @@ -598,14 +554,19 @@ class TdmsChannel(object):
"""

def __init__(
self, tdms_file, path, properties, data_type=None,
scaler_data_types=None, number_values=0):
self._tdms_file = tdms_file
self, path, data_type, scaler_data_types, number_values,
properties, group_properties, file_properties,
tdms_reader, raw_timestamps, memmap_dir):
self._path = path
self.properties = properties
self._length = number_values
self.data_type = data_type
self.scaler_data_types = scaler_data_types
self._group_properties = group_properties
self._file_properties = file_properties
self._reader = tdms_reader
self._raw_timestamps = raw_timestamps
self._memmap_dir = memmap_dir

self._raw_data = None
self._cached_chunk = None
Expand Down Expand Up @@ -728,8 +689,8 @@ def data_chunks(self):
:rtype: Generator that yields :class:`ChannelDataChunk` objects
"""
channel_offset = 0
for raw_data_chunk in self._tdms_file._read_channel_data_chunks(self):
yield ChannelDataChunk(self._tdms_file, self, raw_data_chunk, channel_offset)
for raw_data_chunk in self._read_channel_data_chunks():
yield ChannelDataChunk(self, raw_data_chunk, channel_offset)
channel_offset += len(raw_data_chunk)

def read_data(self, offset=0, length=None, scaled=True):
Expand All @@ -745,7 +706,7 @@ def read_data(self, offset=0, length=None, scaled=True):
Set this parameter to False to return raw unscaled data.
For DAQmx data a dictionary of scaler id to raw scaler data will be returned.
"""
raw_data = self._tdms_file._read_channel_data(self, offset, length)
raw_data = self._read_channel_data(offset, length)
if raw_data is None:
dtype = self.dtype if scaled else self._raw_data_dtype()
return np.empty((0,), dtype=dtype)
Expand Down Expand Up @@ -891,7 +852,7 @@ def _read_at_index(self, index):
if bounds[0] <= index < bounds[1]:
return self._cached_chunk[index - bounds[0]]

chunk, chunk_offset = self._tdms_file._read_channel_data_chunk_for_index(self, index)
chunk, chunk_offset = self._read_channel_data_chunk_for_index(index)
scaled_chunk = self._scale_data(chunk)
self._cached_chunk = scaled_chunk
self._cached_chunk_bounds = (chunk_offset, chunk_offset + len(scaled_chunk))
Expand All @@ -909,10 +870,44 @@ def _scale_data(self, raw_data):

@cached_property
def _scaling(self):
group_properties = self._tdms_file[self._path.group].properties
file_properties = self._tdms_file.properties
return scaling.get_scaling(
self.properties, group_properties, file_properties)
self.properties, self._group_properties, self._file_properties)

def _read_channel_data_chunks(self):
for chunk in self._reader.read_raw_data_for_channel(self.path):
_convert_channel_data_chunk(chunk, self._raw_timestamps)
yield chunk

def _read_channel_data_chunk_for_index(self, index):
(chunk, offset) = self._reader.read_channel_chunk_for_index(self.path, index)
_convert_channel_data_chunk(chunk, self._raw_timestamps)
return chunk, offset

def _read_channel_data(self, offset=0, length=None):
if offset < 0:
raise ValueError("offset must be non-negative")
if length is not None and length < 0:
raise ValueError("length must be non-negative")

with Timer(log, "Allocate space for channel"):
# Allocate space for data
if length is None:
num_values = len(self) - offset
else:
num_values = min(length, len(self) - offset)
num_values = max(0, num_values)
channel_data = get_data_receiver(self, num_values, self._raw_timestamps, self._memmap_dir)

with Timer(log, "Read data for channel"):
# Now actually read all the data
for chunk in self._reader.read_raw_data_for_channel(self.path, offset, length):
if chunk.data is not None:
channel_data.append_data(chunk.data)
if chunk.scaler_data is not None:
for scaler_id, scaler_data in chunk.scaler_data.items():
channel_data.append_scaler_data(scaler_id, scaler_data)

return channel_data

def _set_raw_data(self, data):
self._raw_data = data
Expand Down Expand Up @@ -1003,7 +998,6 @@ def __init__(self, tdms_file, group, raw_data_chunk, channel_offsets):
self.name = group.name
self._channels = OrderedDict(
(channel.name, ChannelDataChunk(
tdms_file,
channel,
raw_data_chunk.channel_data.get(channel.path, RawChannelDataChunk.empty()),
channel_offsets[channel.path]))
Expand Down Expand Up @@ -1033,9 +1027,8 @@ class ChannelDataChunk(object):
:ivar ~.name: Name of the channel
:ivar ~.offset: Starting index of this chunk of data in the entire channel
"""
def __init__(self, tdms_file, channel, raw_data_chunk, offset):
def __init__(self, channel, raw_data_chunk, offset):
self._path = channel._path
self._tdms_file = tdms_file
self._channel = channel
self.name = channel.name
self.offset = offset
Expand Down Expand Up @@ -1102,3 +1095,13 @@ def _deprecated(name, detail=None):
if detail is not None:
message += " {0}".format(detail)
warnings.warn(message)


def _convert_data_chunk(chunk, raw_timestamps):
for channel_chunk in chunk.channel_data.values():
_convert_channel_data_chunk(channel_chunk, raw_timestamps)


def _convert_channel_data_chunk(channel_chunk, raw_timestamps):
if not raw_timestamps and isinstance(channel_chunk.data, TimestampArray):
channel_chunk.data = channel_chunk.data.as_datetime64()
23 changes: 23 additions & 0 deletions nptdms/test/test_tdms_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
from shutil import copyfile
import tempfile
import weakref
from hypothesis import (assume, given, example, settings, strategies)
import numpy as np
import pytest
Expand Down Expand Up @@ -840,3 +841,25 @@ def test_debug_logging(caplog):

assert "Reading metadata for object /'group'/'channel1' with index header 0x00000014" in caplog.text
assert "Object data type: Int32" in caplog.text


def test_memory_released_when_tdms_file_out_of_scope():
""" Tests that when a TDMS file object goes out of scope,
TDMS channels and their data are also freed.
This ensures there are no circular references between a TDMS file
and its channels, which would mean the GC is needed to free these objects.
"""

test_file, expected_data = scenarios.single_segment_with_one_channel().values
with test_file.get_tempfile() as temp_file:
tdms_data = TdmsFile.read(temp_file.file)
chan = tdms_data['group']['channel1']
chan_ref = weakref.ref(chan)
data_ref = weakref.ref(chan.data)
raw_data_ref = weakref.ref(chan.raw_data)
del tdms_data
del chan

assert raw_data_ref() is None
assert data_ref() is None
assert chan_ref() is None

0 comments on commit 4a5c9df

Please sign in to comment.