Permalink
Browse files

support column names as pattern keys in constructor of FilteringCSVRe…

…ader
  • Loading branch information...
1 parent ab853b9 commit bd017d34f37014c2c3826bedce8db3da1416593f @JoeGermuska committed Apr 30, 2012
Showing with 60 additions and 11 deletions.
  1. +31 −11 csvkit/grep.py
  2. +29 −0 tests/test_grep.py
View
42 csvkit/grep.py
@@ -1,4 +1,5 @@
#!/usr/bin/env python
+from csvkit.exceptions import ColumnIdentifierError
class FilteringCSVReader(object):
"""
@@ -16,29 +17,35 @@ class FilteringCSVReader(object):
or not the filtering reader yields a prospective row. To test for explicitly blank, use a regular
expression such as "^$" or "^\s*$"
- If patterns is a dictionary, the keys should be integers identifying indices in the input rows. (It might
- be that this would all work with a dictionary iterator and a looser set of keys, but that's not
- officially supported.) If patterns is a sequence, then it is assumed that they will be applied to the
+ If patterns is a dictionary, the keys can be integers identifying indices in the input rows, or, if 'header'
+ is True (as it is by default), they can be strings matching column names in the first row of the reader.
+
+ If patterns is a sequence, then it is assumed that they will be applied to the
equivalently positioned values in the test rows.
- By specifying 'inverse=True', only rows which do not match the patterns will be passed by the filter.
+ By specifying 'inverse=True', only rows which do not match the patterns will be passed by the filter. The header,
+ if there is one, will always be returned regardless of the value for 'inverse'.
"""
+ returned_header = False
+ column_names = None
def __init__(self, reader, patterns, header=True, any_match=False, inverse=False):
super(FilteringCSVReader, self).__init__()
self.reader = reader
- self.patterns = standardize_patterns(patterns)
self.header = header
+ if self.header:
+ self.column_names = reader.next()
self.any_match = any_match
self.inverse = inverse
+ self.patterns = standardize_patterns(self.column_names,patterns)
def __iter__(self):
return self
def next(self):
- if self.header:
- self.header = False
- return self.reader.next()
+ if self.column_names and not self.returned_header:
+ self.returned_header = True
+ return self.column_names
while True:
row = self.reader.next()
@@ -57,14 +64,27 @@ def test_row(self, row):
return not self.inverse # True
-def standardize_patterns(patterns):
+def standardize_patterns(column_names, patterns):
"""
Given patterns in any of the permitted input forms, return a dict whose keys
- are row indices and whose values are functions which return a boolean value whether the value passes.
+ are column indices and whose values are functions which return a boolean value whether the value passes.
+ If patterns is a dictionary and any of its keys are values in column_names, the returned dictionary will
+ have those keys replaced with the integer position of that value in column_names
"""
try:
# Dictionary of patterns
- return dict((k, pattern_as_function(v)) for k, v in patterns.items() if v)
+ patterns = dict((k, pattern_as_function(v)) for k, v in patterns.items() if v)
+ if not column_names: return patterns
+ p2 = {}
+ for k in patterns:
+ if k in column_names:
+ idx = column_names.index(k)
+ if idx in patterns:
+ raise ColumnIdentifierError("Column %s has index %i which already has a pattern." % (k,idx))
+ p2[idx] = patterns[k]
+ else:
+ p2[k] = patterns[k]
+ return p2
except AttributeError:
# Sequence of patterns
return dict((i, pattern_as_function(x)) for i, x in enumerate(patterns))
View
29 tests/test_grep.py
@@ -4,6 +4,7 @@
import re
from csvkit.grep import FilteringCSVReader
+from csvkit.exceptions import ColumnIdentifierError
class TestGrep(unittest.TestCase):
def setUp(self):
@@ -67,3 +68,31 @@ def test_inverse(self):
except StopIteration:
pass
+ def test_column_names_in_patterns(self):
+ fcr = FilteringCSVReader(iter(self.tab2),patterns = {'age': 'only'})
+ self.assertEqual(self.tab2[0],fcr.next())
+ self.assertEqual(self.tab2[2],fcr.next())
+ self.assertEqual(self.tab2[4],fcr.next())
+ try:
+ fcr.next()
+ self.fail("Should be no more rows left.")
+ except StopIteration:
+ pass
+
+ def test_mixed_indices_and_column_names_in_patterns(self):
+ fcr = FilteringCSVReader(iter(self.tab2),patterns = {'age': 'only', 0: '2'})
+ self.assertEqual(self.tab2[0],fcr.next())
+ self.assertEqual(self.tab2[4],fcr.next())
+ try:
+ fcr.next()
+ self.fail("Should be no more rows left.")
+ except StopIteration:
+ pass
+
+ def test_duplicate_column_ids_in_patterns(self):
+ try:
+ fcr = FilteringCSVReader(iter(self.tab2),patterns = {'age': 'only', 1: 'second'})
+ self.fail("Should be an exception.")
+ except ColumnIdentifierError:
+ pass
+

0 comments on commit bd017d3

Please sign in to comment.