Skip to content

Commit

Permalink
[SPARK-5138][SQL] Ensure schema can be inferred from a namedtuple
Browse files Browse the repository at this point in the history
When attempting to infer the schema of an RDD that contains namedtuples, pyspark fails to identify the records as namedtuples, resulting in it raising an error.

Example:

```python
from pyspark import SparkContext
from pyspark.sql import SQLContext
from collections import namedtuple
import os

sc = SparkContext()
rdd = sc.textFile(os.path.join(os.getenv('SPARK_HOME'), 'README.md'))
TextLine = namedtuple('TextLine', 'line length')
tuple_rdd = rdd.map(lambda l: TextLine(line=l, length=len(l)))
tuple_rdd.take(5)  # This works

sqlc = SQLContext(sc)

# The following line raises an error
schema_rdd = sqlc.inferSchema(tuple_rdd)
```

The error raised is:
```
  File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/worker.py", line 107, in main
    process()
  File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/worker.py", line 98, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/serializers.py", line 227, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/rdd.py", line 1107, in takeUpToNumLeft
    yield next(iterator)
  File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/sql.py", line 816, in convert_struct
    raise ValueError("unexpected tuple: %s" % obj)
TypeError: not all arguments converted during string formatting
```

Author: Gabe Mulley <gabe@edx.org>

Closes #3978 from mulby/inferschema-namedtuple and squashes the following commits:

98c61cc [Gabe Mulley] Ensure exception message is populated correctly
375d96b [Gabe Mulley] Ensure schema can be inferred from a namedtuple
  • Loading branch information
mulby authored and marmbrus committed Jan 13, 2015
1 parent 5d9fa55 commit 1e42e96
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,14 +807,14 @@ def convert_struct(obj):
return

if isinstance(obj, tuple):
if hasattr(obj, "fields"):
d = dict(zip(obj.fields, obj))
if hasattr(obj, "__FIELDS__"):
if hasattr(obj, "_fields"):
d = dict(zip(obj._fields, obj))
elif hasattr(obj, "__FIELDS__"):
d = dict(zip(obj.__FIELDS__, obj))
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
d = dict(obj)
else:
raise ValueError("unexpected tuple: %s" % obj)
raise ValueError("unexpected tuple: %s" % str(obj))

elif isinstance(obj, dict):
d = obj
Expand Down Expand Up @@ -1327,6 +1327,16 @@ def inferSchema(self, rdd, samplingRatio=None):
>>> srdd = sqlCtx.inferSchema(nestedRdd2)
>>> srdd.collect()
[Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
>>> from collections import namedtuple
>>> CustomRow = namedtuple('CustomRow', 'field1 field2')
>>> rdd = sc.parallelize(
... [CustomRow(field1=1, field2="row1"),
... CustomRow(field1=2, field2="row2"),
... CustomRow(field1=3, field2="row3")])
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.collect()[0]
Row(field1=1, field2=u'row1')
"""

if isinstance(rdd, SchemaRDD):
Expand Down

0 comments on commit 1e42e96

Please sign in to comment.