Skip to content

Commit

Permalink
DataFrame.add_observation
Browse files Browse the repository at this point in the history
  • Loading branch information
SamStudio8 committed Aug 26, 2014
1 parent 6d83f94 commit da090fb
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 4 deletions.
25 changes: 23 additions & 2 deletions frontier/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,34 @@ def __array_finalize__(self, obj):
# type(obj) is DataFrame
#
# Note that it is here, rather than in the __new__ method,
# that we set the default value for 'info', because this
# that we set any default values, because this
# method sees all creation of default objects - with the
# DataFrame.__new__ constructor, but also with
# arr.view(DataFrame).
self.info = getattr(obj, 'info', None)
self.frontier_labels = getattr(obj, 'frontier_labels', [])
self.frontier_label_index = getattr(obj, 'frontier_labels', {})
# We do not need to return anything

def add_observation(self, observation_list):
if len(observation_list) == len(self.frontier_label_index):
# NOTE FUTURE(samstudio8)
# Somewhat inefficient to return a new DataFrame, perhaps create
# a wrapper around the DataFrame which can overwrite the frame
# without removing additional attributes like labels
return DataFrame(np.vstack( (self, observation_list) ), self.frontier_labels)
else:
raise Exception("Number of parameters in frame does not match number of parameters given.")

def exclude(self, labels):
index_list = []
for label in labels:
if label in obj.frontier_label_index:
index_list.append(obj.frontier_label_index[label])
else:
print("[WARN] Label %s not in DataFrame" % label)

return self[:, index_list]

def multiply_by_label(self, multiplier, label):
if label in self.frontier_label_index:
self[:,self.frontier_label_index[label]] *= multiplier
Expand Down
2 changes: 0 additions & 2 deletions frontier/frontier.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,7 @@ def get_data_by_target(self, names, targets):
else:
total += 1

from frame import DataFrame
data_np_array = np.empty([total,len(names)])
data_np_array = DataFrame(data_np_array, names)
targ_np_array = np.empty([total])

counter = 0
Expand Down
45 changes: 45 additions & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,48 @@ def test_frame_label_duplicate(self):
TEST_PARAMETERS_COPY.append(TEST_PARAMETERS_COPY[0])
data = np.empty([ARBITRARY_ROWS, len(TEST_PARAMETERS_COPY)])
self.assertRaises(Exception, DataFrame, data, TEST_PARAMETERS_COPY)

def test_add_bad_observation(self):
data = np.zeros([ARBITRARY_ROWS, len(TEST_PARAMETERS)])
frame = DataFrame(data, TEST_PARAMETERS)

test_observation = []
self.assertRaises(Exception, frame.add_observation, test_observation)

test_observation = []
for i, parameter in enumerate(TEST_PARAMETERS):
test_observation.append(i+1)
test_observation.pop()
self.assertRaises(Exception, frame.add_observation, test_observation)

test_observation = []
for i, parameter in enumerate(TEST_PARAMETERS):
test_observation.append(i+1)
test_observation.append(i+1)
self.assertRaises(Exception, frame.add_observation, test_observation)

def test_add_observation(self):
test_observation = []
for i, parameter in enumerate(TEST_PARAMETERS):
test_observation.append(i+1)

data = np.zeros([ARBITRARY_ROWS, len(TEST_PARAMETERS)])
frame = DataFrame(data, TEST_PARAMETERS).add_observation(test_observation)

# Ensure observation row was added and number of parameters was unchanged
self.assertEquals(ARBITRARY_ROWS+1, np.shape(frame)[0])
self.assertEquals(len(TEST_PARAMETERS), np.shape(frame)[1])

# Check observation was added successfully
for i, row in enumerate(frame):
for j, col in enumerate(frame[i]):
if i == ARBITRARY_ROWS:
# If this is the new row
self.assertEquals(j+1, frame[i, j])
else:
# Other rows should contain 0
self.assertEquals(0, frame[i, j])


if __name__ == '__main__':
unittest.main()

0 comments on commit da090fb

Please sign in to comment.