Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOP-6758] Fix Hive.check() behavior when Hive Metastore is not available #164

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog/next_release/164.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix ``Hive.check()`` behavior when Hive Metastore is not available.
1 change: 1 addition & 0 deletions docs/changelog/next_release/164.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add check to all DB and FileDF connections that Spark session is alive.
15 changes: 14 additions & 1 deletion onetl/connection/db_connection/db_connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from logging import getLogger
from typing import TYPE_CHECKING

from pydantic import Field
from pydantic import Field, validator

from onetl._util.spark import try_import_pyspark
from onetl.base import BaseDBConnection
Expand Down Expand Up @@ -48,6 +48,19 @@
refs["SparkSession"] = SparkSession
return refs

@validator("spark")
def _check_spark_session_alive(cls, spark):
# https://stackoverflow.com/a/36044685
msg = "Spark session is stopped. Please recreate Spark session."
try:
if not spark._jsc.sc().isStopped():
return spark
except Exception as e:
# None has no attribute "something"
raise ValueError(msg) from e

raise ValueError(msg)

Check warning on line 62 in onetl/connection/db_connection/db_connection/connection.py

View check run for this annotation

Codecov / codecov/patch

onetl/connection/db_connection/db_connection/connection.py#L62

Added line #L62 was not covered by tests

def _log_parameters(self):
log.info("|%s| Using connection parameters:", self.__class__.__name__)
parameters = self.dict(exclude_none=True, exclude={"spark"})
Expand Down
4 changes: 2 additions & 2 deletions onetl/connection/db_connection/hive/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class Hive(DBConnection):
# TODO: remove in v1.0.0
slots = HiveSlots

_CHECK_QUERY: ClassVar[str] = "SELECT 1"
_CHECK_QUERY: ClassVar[str] = "SHOW DATABASES"

@slot
@classmethod
Expand Down Expand Up @@ -207,7 +207,7 @@ def check(self):
log_lines(log, self._CHECK_QUERY, level=logging.DEBUG)

try:
self._execute_sql(self._CHECK_QUERY)
self._execute_sql(self._CHECK_QUERY).limit(1).collect()
log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
Expand Down
1 change: 1 addition & 0 deletions onetl/connection/db_connection/mongodb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def write_df_to_target(
)

if self._collection_exists(target):
# MongoDB connector does not support mode=ignore and mode=error
if write_options.if_exists == MongoDBCollectionExistBehavior.ERROR:
raise ValueError("Operation stopped due to MongoDB.WriteOptions(if_exists='error')")
elif write_options.if_exists == MongoDBCollectionExistBehavior.IGNORE:
Expand Down
15 changes: 14 additions & 1 deletion onetl/connection/file_df_connection/spark_file_df_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from logging import getLogger
from typing import TYPE_CHECKING

from pydantic import Field
from pydantic import Field, validator

from onetl._util.hadoop import get_hadoop_config
from onetl._util.spark import try_import_pyspark
Expand Down Expand Up @@ -182,6 +182,19 @@
refs["SparkSession"] = SparkSession
return refs

@validator("spark")
def _check_spark_session_alive(cls, spark):
# https://stackoverflow.com/a/36044685
msg = "Spark session is stopped. Please recreate Spark session."
try:
if not spark._jsc.sc().isStopped():
return spark
except Exception as e:
# None has no attribute "something"
raise ValueError(msg) from e

raise ValueError(msg)

Check warning on line 196 in onetl/connection/file_df_connection/spark_file_df_connection.py

View check run for this annotation

Codecov / codecov/patch

onetl/connection/file_df_connection/spark_file_df_connection.py#L196

Added line #L196 was not covered by tests

def _log_parameters(self):
log.info("|%s| Using connection parameters:", self.__class__.__name__)
parameters = self.dict(exclude_none=True, exclude={"spark"})
Expand Down
25 changes: 24 additions & 1 deletion tests/fixtures/spark_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,23 @@
import pytest


@pytest.fixture(
scope="function",
params=[pytest.param("mock-spark-stopped", marks=[pytest.mark.db_connection, pytest.mark.connection])],
)
def spark_stopped():
import pyspark
from pyspark.sql import SparkSession

spark = Mock(spec=SparkSession)
spark.sparkContext = Mock()
spark.sparkContext.appName = "abc"
spark.version = pyspark.__version__
spark._sc = Mock()
spark._sc._gateway = Mock()
return spark


@pytest.fixture(
scope="function",
params=[pytest.param("mock-spark-no-packages", marks=[pytest.mark.db_connection, pytest.mark.connection])],
Expand All @@ -15,6 +32,9 @@ def spark_no_packages():
spark.sparkContext = Mock()
spark.sparkContext.appName = "abc"
spark.version = pyspark.__version__
spark._jsc = Mock()
spark._jsc.sc = Mock()
spark._jsc.sc().isStopped = Mock(return_value=False)
return spark


Expand All @@ -29,7 +49,10 @@ def spark_mock():
spark = Mock(spec=SparkSession)
spark.sparkContext = Mock()
spark.sparkContext.appName = "abc"
spark.version = pyspark.__version__
spark._sc = Mock()
spark._sc._gateway = Mock()
spark.version = pyspark.__version__
spark._jsc = Mock()
spark._jsc.sc = Mock()
spark._jsc.sc().isStopped = Mock(return_value=False)
return spark
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ def test_clickhouse_missing_package(spark_no_packages):
)


def test_clickhouse_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
Clickhouse(
host="some_host",
user="user",
database="database",
password="passwd",
spark=spark_stopped,
)


def test_clickhouse(spark_mock):
conn = Clickhouse(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)

Expand Down
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ def test_greenplum_missing_package(spark_no_packages):
)


def test_greenplum_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
Greenplum(
host="some_host",
user="user",
database="database",
password="passwd",
spark=spark_stopped,
)


def test_greenplum(spark_mock):
conn = Greenplum(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)

Expand Down
8 changes: 6 additions & 2 deletions tests/tests_unit/tests_db_connection_unit/test_hive_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def test_hive_instance_url(spark_mock):
assert hive.instance_url == "some-cluster"


def test_hive_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
Hive(cluster="some-cluster", spark=spark_stopped)


def test_hive_get_known_clusters_hook(request, spark_mock):
# no exception
Hive(cluster="unknown", spark=spark_mock)
Expand Down Expand Up @@ -60,8 +66,6 @@ def normalize_cluster_name(cluster: str) -> str:


def test_hive_known_get_current_cluster_hook(request, spark_mock, mocker):
mocker.patch.object(Hive, "_execute_sql", return_value=None)

# no exception
Hive(cluster="rnd-prod", spark=spark_mock).check()
Hive(cluster="rnd-dwh", spark=spark_mock).check()
Expand Down
10 changes: 10 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ def test_kafka_missing_package(spark_no_packages):
)


def test_kafka_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
Kafka(
cluster="some_cluster",
addresses=["192.168.1.1"],
spark=spark_stopped,
)


@pytest.mark.parametrize(
"option, value",
[
Expand Down
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ def test_mongodb_missing_package(spark_no_packages):
)


def test_mongodb_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
MongoDB(
host="host",
user="user",
password="password",
database="database",
spark=spark_stopped,
)


def test_mongodb(spark_mock):
conn = MongoDB(
host="host",
Expand Down
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def test_mssql_missing_package(spark_no_packages):
)


def test_mssql_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
MSSQL(
host="some_host",
user="user",
database="database",
password="passwd",
spark=spark_stopped,
)


def test_mssql(spark_mock):
conn = MSSQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)

Expand Down
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ def test_mysql_missing_package(spark_no_packages):
)


def test_mysql_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
MySQL(
host="some_host",
user="user",
database="database",
password="passwd",
spark=spark_stopped,
)


def test_mysql(spark_mock):
conn = MySQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)

Expand Down
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def test_oracle_missing_package(spark_no_packages):
)


def test_oracle_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
Oracle(
host="some_host",
user="user",
sid="sid",
password="passwd",
spark=spark_stopped,
)


def test_oracle(spark_mock):
conn = Oracle(host="some_host", user="user", sid="sid", password="passwd", spark=spark_mock)

Expand Down
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ def test_oracle_missing_package(spark_no_packages):
)


def test_postgres_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
Postgres(
host="some_host",
user="user",
database="database",
password="passwd",
spark=spark_stopped,
)


def test_postgres(spark_mock):
conn = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)

Expand Down
12 changes: 12 additions & 0 deletions tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ def test_teradata_missing_package(spark_no_packages):
)


def test_teradata_spark_stopped(spark_stopped):
msg = "Spark session is stopped. Please recreate Spark session."
with pytest.raises(ValueError, match=msg):
Teradata(
host="some_host",
user="user",
database="database",
password="passwd",
spark=spark_stopped,
)


def test_teradata(spark_mock):
conn = Teradata(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)

Expand Down
Loading
Loading