Skip to content
Browse files

extend zero-based to all places where parse_column_identifiers is called

  • Loading branch information...
1 parent b6f2a2a commit fcfd486649b26880ae5c0dfb3d11ec8c9f137840 @JoeGermuska committed Mar 27, 2012
View
10 csvkit/cli.py
@@ -227,7 +227,7 @@ def print_column_names(self):
output.write('%3i: %s\n' % (i, c))
-def match_column_identifier(column_names, c):
+def match_column_identifier(column_names, c, zero_based=False):
"""
Determine what column a single column id (name or index) matches in a series of column names.
Note that integer values are *always* treated as positional identifiers. If you happen to have
@@ -237,7 +237,9 @@ def match_column_identifier(column_names, c):
return column_names.index(c)
else:
try:
- c = int(c) - 1
+ c = int(c)
+ if not zero_based:
+ c -= 1
# Fail out if neither a column name nor an integer
except:
raise ColumnIdentifierError('Column identifier "%s" is neither an integer, nor a existing column\'s name.' % c)
@@ -269,7 +271,7 @@ def parse_column_identifiers(ids, column_names,zero_based=False):
c = c.strip()
try:
- columns.append(match_column_identifier(column_names, c))
+ columns.append(match_column_identifier(column_names, c, zero_based))
except ColumnIdentifierError:
if ':' in c:
a,b = c.split(':',1)
@@ -292,7 +294,7 @@ def parse_column_identifiers(ids, column_names,zero_based=False):
raise ColumnIdentifierError("Invalid range %s. Ranges must be two integers separated by a - or : character.")
for x in range(a,b):
- columns.append(match_column_identifier(column_names, x))
+ columns.append(match_column_identifier(column_names, x, zero_based))
return columns
View
4 csvkit/table.py
@@ -174,7 +174,7 @@ def row(self, i):
return row_data
@classmethod
- def from_csv(cls, f, name='from_csv_table', snifflimit=None, column_ids=None, blanks_as_nulls=True, **kwargs):
+ def from_csv(cls, f, name='from_csv_table', snifflimit=None, column_ids=None, blanks_as_nulls=True, zero_based=False, **kwargs):
"""
Creates a new Table from a file-like object containing CSV data.
@@ -200,7 +200,7 @@ def from_csv(cls, f, name='from_csv_table', snifflimit=None, column_ids=None, bl
headers = reader.next()
if column_ids:
- column_ids = parse_column_identifiers(column_ids, headers)
+ column_ids = parse_column_identifiers(column_ids, headers, zero_based)
headers = [headers[c] for c in column_ids]
else:
column_ids = range(len(headers))
View
2 csvkit/utilities/csvcut.py
@@ -31,7 +31,7 @@ def main(self):
rows = CSVKitReader(self.args.file, **self.reader_kwargs)
column_names = rows.next()
- column_ids = parse_column_identifiers(self.args.columns, column_names)
+ column_ids = parse_column_identifiers(self.args.columns, column_names, self.args.zero_based)
output = CSVKitWriter(self.output_file, **self.writer_kwargs)
output.writerow([column_names[c] for c in column_ids])
View
2 csvkit/utilities/csvgrep.py
@@ -39,7 +39,7 @@ def main(self):
rows = CSVKitReader(self.args.file, **self.reader_kwargs)
column_names = rows.next()
- column_ids = parse_column_identifiers(self.args.columns, column_names)
+ column_ids = parse_column_identifiers(self.args.columns, column_names, self.args.zero_based)
if self.args.regex:
pattern = re.compile(self.args.regex)
View
2 csvkit/utilities/csvsort.py
@@ -31,7 +31,7 @@ def main(self):
table_name = 'csvsql_table'
tab = table.Table.from_csv(self.args.file, name=table_name, snifflimit=self.args.snifflimit, **self.reader_kwargs)
- column_ids = parse_column_identifiers(self.args.columns, tab.headers())
+ column_ids = parse_column_identifiers(self.args.columns, tab.headers(), self.args.zero_based)
rows = tab.to_rows(serialize_dates=True)
rows.sort(key=lambda r: [r[c] for c in column_ids], reverse=self.args.reverse)
View
2 csvkit/utilities/csvstat.py
@@ -44,7 +44,7 @@ def add_arguments(self):
help='Only output max value length.')
def main(self):
- tab = table.Table.from_csv(self.args.file, snifflimit=self.args.snifflimit, column_ids=self.args.columns, **self.reader_kwargs)
+ tab = table.Table.from_csv(self.args.file, snifflimit=self.args.snifflimit, column_ids=self.args.columns,zero_based=self.zero_based, **self.reader_kwargs)
operations = [op for op in OPERATIONS if getattr(self.args, op + '_only')]
View
11 tests/test_cli.py
@@ -10,24 +10,33 @@ def setUp(self):
def test_match_column_identifier_string(self):
self.assertEqual(2, match_column_identifier(self.headers, 'i_work_here'))
+ self.assertEqual(2, match_column_identifier(self.headers, 'i_work_here', zero_based=True))
def test_match_column_identifier_numeric(self):
self.assertEqual(2, match_column_identifier(self.headers, 3))
+ self.assertEqual(3, match_column_identifier(self.headers, 3, zero_based=True))
def test_match_column_which_could_be_integer_name_is_treated_as_positional_id(self):
self.assertEqual(0, match_column_identifier(self.headers, '1'))
+ self.assertEqual(1, match_column_identifier(self.headers, '1', zero_based=True))
def test_parse_column_identifiers(self):
self.assertEqual([2, 0, 1], parse_column_identifiers(' i_work_here, 1,name ', self.headers))
+ self.assertEqual([2, 1, 1], parse_column_identifiers(' i_work_here, 1,name ', self.headers, zero_based=True))
def test_range_notation(self):
self.assertEqual([0,1,2], parse_column_identifiers('1:3', self.headers))
+ self.assertEqual([1,2,3], parse_column_identifiers('1:3', self.headers, zero_based=True))
self.assertEqual([1,2,3], parse_column_identifiers('2-4', self.headers))
+ self.assertEqual([2,3,4], parse_column_identifiers('2-4', self.headers, zero_based=True))
self.assertEqual([0,1,2,3], parse_column_identifiers('1,2:4', self.headers))
+ self.assertEqual([1,2,3,4], parse_column_identifiers('1,2:4', self.headers, zero_based=True))
self.assertEqual([4,2,5], parse_column_identifiers('more-header-values,3,stuff', self.headers))
+ self.assertEqual([4,3,5], parse_column_identifiers('more-header-values,3,stuff', self.headers,zero_based=True))
def test_range_notation_open_ended(self):
self.assertEqual([0,1,2], parse_column_identifiers(':3', self.headers))
target = range(3,len(self.headers) - 1) # protect against devs adding to self.headers
target.insert(0,0)
- self.assertEqual(target, parse_column_identifiers('1,4:', self.headers))
+ self.assertEqual(target, parse_column_identifiers('1,4:', self.headers))
+

0 comments on commit fcfd486

Please sign in to comment.
Something went wrong with that request. Please try again.