Skip to content

Commit

Permalink
Extend SparklyCatalog to work with database properties (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
drudim committed Sep 3, 2019
1 parent f30222f commit ae193e9
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 1 deletion.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## 2.8.0
* Extend `SparklyCatalog` to work with database properties:
- `spark.catalog_ext.set_database_property`
- `spark.catalog_ext.get_database_property`
- `spark.catalog_ext.get_database_properties`

## 2.7.1
* Allow newer versions of `six` package (avoid depednecy hell)

Expand Down
2 changes: 1 addition & 1 deletion sparkly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
assert SparklySession


__version__ = '2.7.1'
__version__ = '2.8.0'
116 changes: 116 additions & 0 deletions sparkly/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#
import uuid

from pyspark.sql import functions as F


class SparklyCatalog(object):
"""A set of tools to interact with HiveMetastore."""
Expand Down Expand Up @@ -203,6 +205,71 @@ def set_table_property(self, table_name, property_name, value):
table_name, property_name, value
))

def get_database_property(self, db_name, property_name, to_type=None):
"""Read value for database property.
Args:
db_name (str): A database name.
property_name (str): A property name to read value for.
to_type (function): Cast value to the given type. E.g. `int` or `float`.
Returns:
Any
"""
if not to_type:
to_type = str

value = self.get_database_properties(db_name).get(property_name)
if value is not None:
return to_type(value)

def get_database_properties(self, db_name):
"""Get database properties from the metastore.
Args:
db_name (str): A database name.
Returns:
dict[str,str]: Key/value for properties.
"""
properties = (
self._spark.sql('DESCRIBE DATABASE EXTENDED {}'.format(db_name))
.where(F.col('database_description_item') == 'Properties')
.select('database_description_value')
.first()
)

parsed_properties = {}

if properties:
for name, value in read_db_properties_format(properties.database_description_value):
parsed_properties[name] = value

return parsed_properties

def set_database_property(self, db_name, property_name, value):
"""Set value for database property.
Args:
db_name (str): A database name.
property_name (str): A property name to set value for.
value (Any): Will be automatically casted to string.
"""
property_name_blacklist = {',', '(', ')'}
property_value_blacklist = {'(', ')'}

if set(property_name) & property_name_blacklist:
raise ValueError(
'Property name must not contain symbols: {}'.format(property_name_blacklist))

if set(str(value)) & property_value_blacklist:
raise ValueError(
'Property value must not contain symbols: {}'.format(property_value_blacklist))

self._spark.sql("ALTER DATABASE {} SET DBPROPERTIES ('{}'='{}')".format(
db_name, property_name, value,
))


def get_db_name(table_name):
"""Get database name from full table name."""
Expand All @@ -217,3 +284,52 @@ def get_table_name(table_name):
"""Get table name from full table name."""
parts = table_name.split('.', 1)
return parts[-1]


def read_db_properties_format(raw_db_properties):
"""Helper to read non-standard db properties format.
Note:
Spark/Hive doesn't provide a way to read separate key/values for database properties.
They provide a custom format like: ((key_a,value_a), (key_b,value_b))
Neither keys nor values are escaped.
Here we try our best to parse this format by tracking balanced parentheses.
We assume property names don't contain comma.
Return:
list[list[str]] - the list of key-value pairs.
"""
def _unpack_parentheses(string):
bits = []
last_bit = ''
checksum = 0

for c in string:
if c == '(':
if checksum == 0:
last_bit = ''
else:
last_bit += c
checksum += 1
elif c == ')':
checksum -= 1
if checksum == 0:
bits.append(last_bit)
else:
last_bit += c
else:
last_bit += c

if checksum < 0:
raise ValueError('Parentheses are not balanced')

if checksum != 0:
raise ValueError('Parentheses are not balanced')

return bits

properties = _unpack_parentheses(raw_db_properties)
if properties:
return [x.split(',', 1) for x in _unpack_parentheses(properties[0])]
else:
return []
51 changes: 51 additions & 0 deletions tests/integration/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from sparkly.testing import SparklyGlobalSessionTest
from tests.integration.base import SparklyTestSession
from sparkly.catalog import read_db_properties_format


class TestSparklyCatalog(SparklyGlobalSessionTest):
Expand Down Expand Up @@ -144,3 +145,53 @@ def test_get_table_property_unknown(self):
self.assertIsNone(
self.spark.catalog_ext.get_table_property('test_db.test_table', 'unknown')
)

def test_set_database_property_with_prohibited_symbols(self):
with self.assertRaises(ValueError):
self.spark.catalog_ext.set_database_property('test_db', 'broken,key', 'normal_value')

with self.assertRaises(ValueError):
self.spark.catalog_ext.set_database_property('test_db', 'normal_key', 'broken(value)')

def test_get_database_property(self):
self.spark.catalog_ext.set_database_property('test_db', 'property_a', 'just,a,string')
self.spark.catalog_ext.set_database_property('test_db', 'property_b', '123')

self.assertEqual(
self.spark.catalog_ext.get_database_property('test_db', 'property_a'),
'just,a,string',
)
self.assertEqual(
self.spark.catalog_ext.get_database_property('test_db', 'property_b', to_type=int),
123,
)
self.assertIsNone(
self.spark.catalog_ext.get_database_property('test_db', 'unknown_prop', to_type=int),
)

def test_get_database_properties(self):
self.spark.catalog_ext.set_database_property('test_db', 'property_a', 'just,a,string')
self.spark.catalog_ext.set_database_property('test_db', 'property_b', '123')

self.assertEqual(self.spark.catalog_ext.get_database_properties('test_db'), {
'property_a': 'just,a,string',
'property_b': '123',
})

def test_read_db_properties_format_for_typical_input(self):
self.assertEqual(read_db_properties_format('((a,b), (c,d))'), [['a', 'b'], ['c', 'd']])
self.assertEqual(read_db_properties_format('((a,b))'), [['a', 'b']])
self.assertEqual(read_db_properties_format('()'), [])

def test_read_db_properties_format_for_broken_input(self):
with self.assertRaises(ValueError):
read_db_properties_format('((a, b), (c, d)')

with self.assertRaises(ValueError):
read_db_properties_format(')(a, b), (c, d)(')

with self.assertRaises(ValueError):
read_db_properties_format(')(')

with self.assertRaises(ValueError):
read_db_properties_format(')')

0 comments on commit ae193e9

Please sign in to comment.