-
Notifications
You must be signed in to change notification settings - Fork 28k
/
sql_formatter.py
84 lines (73 loc) · 3.24 KB
/
sql_formatter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import string
import typing
from typing import Any, Optional, List, Tuple, Sequence, Mapping
import uuid
from py4j.java_gateway import is_instance_of
if typing.TYPE_CHECKING:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import lit
class SQLStringFormatter(string.Formatter):
"""
A standard ``string.Formatter`` in Python that can understand PySpark instances
with basic Python objects. This object has to be clear after the use for single SQL
query; cannot be reused across multiple SQL queries without cleaning.
"""
def __init__(self, session: "SparkSession") -> None:
self._session: "SparkSession" = session
self._temp_views: List[Tuple[DataFrame, str]] = []
def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any:
obj, first = super(SQLStringFormatter, self).get_field(field_name, args, kwargs)
return self._convert_value(obj, field_name), first
def _convert_value(self, val: Any, field_name: str) -> Optional[str]:
"""
Converts the given value into a SQL string.
"""
from pyspark import SparkContext
from pyspark.sql import Column, DataFrame
if isinstance(val, Column):
assert SparkContext._gateway is not None
gw = SparkContext._gateway
jexpr = val._jc.expr()
if is_instance_of(
gw, jexpr, "org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute"
) or is_instance_of(
gw, jexpr, "org.apache.spark.sql.catalyst.expressions.AttributeReference"
):
return jexpr.sql()
else:
raise ValueError(
"%s in %s should be a plain column reference such as `df.col` "
"or `col('column')`" % (val, field_name)
)
elif isinstance(val, DataFrame):
for df, n in self._temp_views:
if df is val:
return n
df_name = "_pyspark_%s" % str(uuid.uuid4()).replace("-", "")
self._temp_views.append((val, df_name))
val.createOrReplaceTempView(df_name)
return df_name
elif isinstance(val, str):
return lit(val)._jc.expr().sql() # for escaped characters.
else:
return val
def clear(self) -> None:
for _, n in self._temp_views:
self._session.catalog.dropTempView(n)
self._temp_views = []