Skip to content

Commit

Permalink
Add error message on invalid column
Browse files Browse the repository at this point in the history
Add an error message that is printed when the user supplies an invalid
column to prevent unwanted calculations and user-visible exceptions.

Also add tests to ensure that the error message is printed to stderr.

Closes #1
  • Loading branch information
ExcaliburZero committed May 26, 2017
1 parent 1d9b976 commit 6f2b0d3
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 3 deletions.
8 changes: 6 additions & 2 deletions blendplot/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
A cli application for plotting 3D data in obj format for use in Blender.
"""
import cli.app
import sys
import time

from . import obj_graph
Expand Down Expand Up @@ -70,8 +71,11 @@ def main(input_filename, output_filename, num_rows, columns, spacing, point_size
end = time.time()
output_file.close()

print("Wrote plot file to %s" % output_filename)
print("Plotted %s points in %f seconds" % (points, end - start))
if points is None:
sys.exit(1)
else:
print("Wrote plot file to %s" % output_filename)
print("Plotted %s points in %f seconds" % (points, end - start))

def run():
blendplot.run()
Expand Down
43 changes: 42 additions & 1 deletion blendplot/obj_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from sklearn import preprocessing
import pandas as pd
import sys

def add_cube_verticies(cube_str, x, y, z, point_size):
"""
Expand Down Expand Up @@ -170,9 +171,20 @@ def plot_file(input_filename, output_file, num_rows, columns, spacing, point_siz
Returns
-------
points : int
the number of points that were plotted
the number of points that were plotted, or None if the plot is unable
to be made
"""
original_data = pd.read_csv(input_filename, nrows = num_rows)

missing = get_missing_columns(original_data, columns, category_column)
if len(missing) > 0:
missing_columns = ", ".join(missing)
valid_columns = ", ".join(list(original_data.columns))
error_msg = "Invalid column(s): %s\n" % missing_columns
error_msg += "Valid columns are: %s" % valid_columns
print(error_msg, file=sys.stderr)
return None

data = pd.DataFrame(original_data, columns = columns).dropna()
data = pd.DataFrame(preprocessing.scale(data), columns = data.columns)

Expand All @@ -196,3 +208,32 @@ def plot_file(input_filename, output_file, num_rows, columns, spacing, point_siz

points = num_rows if num_rows is not None else len(data.index)
return points

def get_missing_columns(data, columns, category_column):
"""
Returns all of the given columns that are not in the given dataframe.
Parameters
----------
columns : List[str]
the columns to look for
category_column : str
the category column to look for, or None if there is no category column
being used
Returns
-------
missing : List[str]
a list of the missing columns
"""
data_columns = set(data.columns)

if not category_column is None:
columns = columns + [category_column]

missing = []
for col in columns:
if not col in data_columns:
missing.append(col)

return missing
70 changes: 70 additions & 0 deletions test/test_obj_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from blendplot.obj_graph import *

import utilities

class TestObjGraph(unittest.TestCase):
def test_add_cube_verticies(self):
cube_str = ""
Expand Down Expand Up @@ -120,6 +122,52 @@ def test_plot_file(self):

self.assertEquals(actual_output, expected_output)

def test_plot_file_invalid_column(self):
input_filename = "test/resources/data_01.csv"
output_file = io.StringIO()
num_rows = None
invalid_column = "u"
columns = [invalid_column, "b", "c"]
spacing = 0.5
point_size = 0.1
category_column = None

func = lambda x: plot_file(input_filename, output_file, num_rows, columns, spacing, point_size, category_column)

(actual_return, actual_error) = utilities.capture_stderr(func)
expected_error = "Invalid column(s): %s\nValid columns are: a, b, c, d, category\n" % invalid_column

self.assertEquals(actual_error, expected_error)
self.assertEquals(actual_return, None)

actual_output = output_file.getvalue()
expected_output = ""

self.assertEquals(actual_output, expected_output)

def test_plot_file_invalid_column_multiple(self):
input_filename = "test/resources/data_01.csv"
output_file = io.StringIO()
num_rows = None
invalid_columns = ["u", "A"]
columns = ["c"] + invalid_columns
spacing = 0.5
point_size = 0.1
category_column = None

func = lambda x: plot_file(input_filename, output_file, num_rows, columns, spacing, point_size, category_column)

(actual_return, actual_error) = utilities.capture_stderr(func)
expected_error = "Invalid column(s): %s, %s\nValid columns are: a, b, c, d, category\n" % (invalid_columns[0], invalid_columns[1])

self.assertEquals(actual_error, expected_error)
self.assertEquals(actual_return, None)

actual_output = output_file.getvalue()
expected_output = ""

self.assertEquals(actual_output, expected_output)

def test_plot_file_category(self):
input_filename = "test/resources/data_01.csv"
output_file = io.StringIO()
Expand All @@ -142,6 +190,28 @@ def test_plot_file_category(self):

self.assertEquals(actual_output, expected_output)

def test_plot_file_category_invalid(self):
input_filename = "test/resources/data_01.csv"
output_file = io.StringIO()
num_rows = 4
columns = ["a", "b", "c"]
spacing = 0.5
point_size = 0.1
invalid_category_column = "cats"

func = lambda x: plot_file(input_filename, output_file, num_rows, columns, spacing, point_size, invalid_category_column)

(actual_return, actual_error) = utilities.capture_stderr(func)
expected_error = "Invalid column(s): %s\nValid columns are: a, b, c, d, category\n" % invalid_category_column

self.assertEquals(actual_error, expected_error)
self.assertEquals(actual_return, None)

actual_output = output_file.getvalue()
expected_output = ""

self.assertEquals(actual_output, expected_output)

@given(
st.text(),
st.floats(allow_nan=False, allow_infinity=False),
Expand Down
14 changes: 14 additions & 0 deletions test/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import io
import sys

def capture_stderr(func):
err, sys.stderr = sys.stderr, io.StringIO()
value = None
try:
ret = func(None)
sys.stderr.seek(0)
value = (ret, sys.stderr.read())
finally:
sys.stderr = err

return value

0 comments on commit 6f2b0d3

Please sign in to comment.