Skip to content

Commit

Permalink
Fixes 144 (#146)
Browse files Browse the repository at this point in the history
* Fixes 144

* Supports read/write, warn on pointLog

* Adds tests and fixes bugs that the tests identified

* Updates for comments

* updates tests
  • Loading branch information
jlaura committed Apr 20, 2020
1 parent 4187900 commit 0127ed8
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 96 deletions.
83 changes: 81 additions & 2 deletions plio/io/io_controlnetwork.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from enum import IntEnum
from time import gmtime, strftime
import warnings

import pandas as pd
import numpy as np
Expand All @@ -13,6 +15,7 @@
HEADERSTARTBYTE = 65536
DEFAULTUSERNAME = 'None'


def write_filelist(lst, path="fromlist.lis"):
"""
Writes a filelist to a file so it can be used in ISIS3.
Expand All @@ -29,6 +32,73 @@ def write_filelist(lst, path="fromlist.lis"):
handle.write('\n')
return


class MeasureMessageType(IntEnum):
"""
An enum to mirror the ISIS3 MeasureLogData enum.
"""
GoodnessOfFit = 2
MinimumPixelZScore = 3
MaximumPixelZScore = 4
PixelShift = 5
WholePixelCorrelation = 6
SubPixelCorrelation = 7

class MeasureLog():

def __init__(self, messagetype, value):
"""
A protobuf compliant measure log object.
Parameters
----------
messagetype : int or str
Either the integer or string representation from the MeasureMessageType enum
value : int or float
The value to be stored in the message log
"""
if isinstance(messagetype, int):
# by value
self.messagetype = MeasureMessageType(messagetype)
else:
# by name
self.messagetype = MeasureMessageType[messagetype]

if not isinstance(value, (float, int)):
raise TypeError(f'{value} is not a numeric type')
self.value = value

def __repr__(self):
return f'{self.messagetype.name}: {self.value}'

def to_protobuf(self, version=2):
"""
Return protobuf compliant measure log object representation
of this class.
Returns
-------
log_message : obj
MeasureLogData object suitable to append to a MeasureLog
repeated field.
"""
# I do not see a better way to get to the inner MeasureLogData obj than this
# imports were not working because it looks like these need to instantiate off
# an object
if version == 2:
log_message = cnf.ControlPointFileEntryV0002().Measure().MeasureLogData()
elif version == 5:
log_message = cnp5.ControlPointFileEntryV0005().Measure().MeasureLogData()
log_message.doubleDataValue = self.value
log_message.doubleDataType = self.messagetype
return log_message

@classmethod
def from_protobuf(cls, protobuf):
return cls(protobuf.doubleDataType, protobuf.doubleDataValue)


class IsisControlNetwork(pd.DataFrame):

# normal properties
Expand Down Expand Up @@ -171,7 +241,6 @@ def read(self):
for s in pbuf_header.pointMessageSizes:
cp.ParseFromString(self._handle.read(s))
pt = [getattr(cp, i) for i in self.point_attrs if i != 'measures']

for measure in cp.measures:
meas = pt + [getattr(measure, j) for j in self.measure_attrs]
pts.append(meas)
Expand Down Expand Up @@ -211,6 +280,10 @@ def read(self):
if 'aprioriline' in df.columns:
df['aprioriline'] -= 0.5
df['apriorisample'] -= 0.5

# Munge the MeasureLogData into Python objs
df['measureLog'] = df['measureLog'].apply(lambda x: [MeasureLog.from_protobuf(i) for i in x])

df.header = pvl_header
return df

Expand Down Expand Up @@ -266,6 +339,10 @@ def _set_pid(pointid):
# Un-mangle common attribute names between points and measures
df_attr = self.point_field_map.get(attr, attr)
if df_attr in g.columns:
if df_attr == 'pointLog':
# Currently pointLog is not supported.
warnings.warn('The pointLog field is currently unsupported. Any pointLog data will not be saved.')
continue
# As per protobuf docs for assigning to a repeated field.
if df_attr == 'aprioriCovar' or df_attr == 'adjustedCovar':
arr = g.iloc[0][df_attr]
Expand All @@ -290,8 +367,10 @@ def _set_pid(pointid):
# Un-mangle common attribute names between points and measures
df_attr = self.measure_field_map.get(attr, attr)
if df_attr in g.columns:
if df_attr == 'measureLog':
[getattr(measure_spec, attr).extend([i.to_protobuf()]) for i in m[df_attr]]
# If field is repeated you must extend instead of assign
if cnf._CONTROLPOINTFILEENTRYV0002_MEASURE.fields_by_name[attr].label == 3:
elif cnf._CONTROLPOINTFILEENTRYV0002_MEASURE.fields_by_name[attr].label == 3:
getattr(measure_spec, attr).extend(m[df_attr])
else:
setattr(measure_spec, attr, attrtype(m[df_attr]))
Expand Down
200 changes: 106 additions & 94 deletions plio/io/tests/test_io_controlnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,97 +31,109 @@ def test_cnet_read(cnet_file):
assert proto_field not in df.columns
assert mangled_field in df.columns

class TestWriteIsisControlNetwork(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.npts = 5
serial_times = {295: '1971-07-31T01:24:11.754',
296: '1971-07-31T01:24:36.970'}
cls.serials = {i:'APOLLO15/METRIC/{}'.format(j) for i, j in enumerate(serial_times.values())}
columns = ['id', 'pointType', 'serialnumber', 'measureType', 'sample', 'line', 'image_index', 'pointLog', 'measureLog']

data = []
for i in range(cls.npts):
data.append((i, 2, cls.serials[0], 2, 0, 0, 0, [], []))
data.append((i, 2, cls.serials[1], 2, 0, 0, 1, [], []))

df = pd.DataFrame(data, columns=columns)

cls.creation_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
cls.modified_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
io_controlnetwork.to_isis(df, 'test.net', mode='wb', targetname='Moon')

cls.header_message_size = 78
cls.point_start_byte = 65614 # 66949

def test_create_buffer_header(self):
npts = 5
serial_times = {295: '1971-07-31T01:24:11.754',
296: '1971-07-31T01:24:36.970'}
serials = {i:'APOLLO15/METRIC/{}'.format(j) for i, j in enumerate(serial_times.values())}
columns = ['id', 'pointType', 'serialnumber', 'measureType', 'sample', 'line', 'image_index']

data = []
for i in range(self.npts):
data.append((i, 2, serials[0], 2, 0, 0, 0))
data.append((i, 2, serials[1], 2, 0, 0, 1))

df = pd.DataFrame(data, columns=columns)

self.creation_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
self.modified_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
io_controlnetwork.to_isis(df, 'test.net', mode='wb', targetname='Moon')

self.header_message_size = 78
self.point_start_byte = 65614 # 66949

with open('test.net', 'rb') as f:
f.seek(io_controlnetwork.HEADERSTARTBYTE)
raw_header_message = f.read(self.header_message_size)
header_protocol = cnf.ControlNetFileHeaderV0002()
header_protocol.ParseFromString(raw_header_message)
#Non-repeating
#self.assertEqual('None', header_protocol.networkId)
self.assertEqual('Moon', header_protocol.targetName)
self.assertEqual(io_controlnetwork.DEFAULTUSERNAME,
header_protocol.userName)
self.assertEqual(self.creation_date,
header_protocol.created)
self.assertEqual('None', header_protocol.description)
self.assertEqual(self.modified_date, header_protocol.lastModified)
#Repeating
self.assertEqual([135] * self.npts, header_protocol.pointMessageSizes)

def test_create_point(self):

with open('test.net', 'rb') as f:
f.seek(self.point_start_byte)
for i, length in enumerate([135] * self.npts):
point_protocol = cnf.ControlPointFileEntryV0002()
raw_point = f.read(length)
point_protocol.ParseFromString(raw_point)
self.assertEqual(str(i), point_protocol.id)
self.assertEqual(2, point_protocol.type)
for m in point_protocol.measures:
self.assertTrue(m.serialnumber in self.serials.values())
self.assertEqual(2, m.type)

def test_create_pvl_header(self):
pvl_header = pvl.load('test.net')

npoints = find_in_dict(pvl_header, 'NumberOfPoints')
self.assertEqual(5, npoints)

mpoints = find_in_dict(pvl_header, 'NumberOfMeasures')
self.assertEqual(10, mpoints)

points_bytes = find_in_dict(pvl_header, 'PointsBytes')
self.assertEqual(675, points_bytes)

points_start_byte = find_in_dict(pvl_header, 'PointsStartByte')
self.assertEqual(self.point_start_byte, points_start_byte)

@classmethod
def tearDownClass(cls):
os.remove('test.net')
@pytest.mark.parametrize('messagetype, value', [
(2, 0.5),
(3, 0.5),
(4, -0.25),
(5, 1e6),
(6, 1),
(7, -1e10),
('GoodnessOfFit', 0.5),
('MinimumPixelZScore', 0.25)
])
def test_MeasureLog(messagetype, value):
l = io_controlnetwork.MeasureLog(messagetype, value)
if isinstance(messagetype, int):
assert l.messagetype == io_controlnetwork.MeasureMessageType(messagetype)
elif isinstance(messagetype, str):
assert l.messagetype == io_controlnetwork.MeasureMessageType[messagetype]

assert l.value == value
assert isinstance(l.to_protobuf, object)

def test_log_error():
with pytest.raises(TypeError) as err:
io_controlnetwork.MeasureLog(2, 'foo')

def test_to_protobuf():
value = 1.25
int_dtype = 2
log = io_controlnetwork.MeasureLog(int_dtype, value)
proto = log.to_protobuf()
assert proto.doubleDataType == int_dtype
assert proto.doubleDataValue == value

@pytest.fixture
def cnet_dataframe(tmpdir):
npts = 5
serial_times = {295: '1971-07-31T01:24:11.754',
296: '1971-07-31T01:24:36.970'}
serials = {i:'APOLLO15/METRIC/{}'.format(j) for i, j in enumerate(serial_times.values())}
columns = ['id', 'pointType', 'serialnumber', 'measureType', 'sample', 'line', 'image_index', 'pointLog', 'measureLog']

data = []
for i in range(npts):
data.append((i, 2, serials[0], 2, 0, 0, 0, [], []))
data.append((i, 2, serials[1], 2, 0, 0, 1, [], [io_controlnetwork.MeasureLog(2, 0.5)]))

df = pd.DataFrame(data, columns=columns)

df.creation_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
df.modified_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
io_controlnetwork.to_isis(df, tmpdir.join('test.net'), mode='wb', targetname='Moon')

df.header_message_size = 78
df.point_start_byte = 65614 # 66949
df.npts = npts
df.measure_size = 149 # Size of each measure in bytes
df.serials = serials
return df

def test_create_buffer_header(cnet_dataframe, tmpdir):
with open(tmpdir.join('test.net'), 'rb') as f:

f.seek(io_controlnetwork.HEADERSTARTBYTE)
raw_header_message = f.read(cnet_dataframe.header_message_size)
header_protocol = cnf.ControlNetFileHeaderV0002()
header_protocol.ParseFromString(raw_header_message)
#Non-repeating
#self.assertEqual('None', header_protocol.networkId)
assert 'Moon' == header_protocol.targetName
assert io_controlnetwork.DEFAULTUSERNAME == header_protocol.userName
assert cnet_dataframe.creation_date == header_protocol.created
assert 'None' == header_protocol.description
assert cnet_dataframe.modified_date == header_protocol.lastModified
#Repeating
assert [cnet_dataframe.measure_size] * cnet_dataframe.npts == header_protocol.pointMessageSizes

def test_create_point(cnet_dataframe, tmpdir):
with open(tmpdir.join('test.net'), 'rb') as f:
f.seek(cnet_dataframe.point_start_byte)
for i, length in enumerate([cnet_dataframe.measure_size] * cnet_dataframe.npts):
point_protocol = cnf.ControlPointFileEntryV0002()
raw_point = f.read(length)
point_protocol.ParseFromString(raw_point)
assert str(i) == point_protocol.id
assert 2 == point_protocol.type
print(len(point_protocol.measures))
for i, m in enumerate(point_protocol.measures):
assert m.serialnumber in cnet_dataframe.serials.values()
assert 2 == m.type
assert len(m.log) == i

def test_create_pvl_header(cnet_dataframe, tmpdir):
with open(tmpdir.join('test.net'), 'rb') as f:
pvl_header = pvl.load(f)

npoints = find_in_dict(pvl_header, 'NumberOfPoints')
assert 5 == npoints

mpoints = find_in_dict(pvl_header, 'NumberOfMeasures')
assert 10 == mpoints

points_bytes = find_in_dict(pvl_header, 'PointsBytes')
assert 745 == points_bytes

points_start_byte = find_in_dict(pvl_header, 'PointsStartByte')
assert cnet_dataframe.point_start_byte == points_start_byte

0 comments on commit 0127ed8

Please sign in to comment.