Skip to content

Commit

Permalink
Simplify checks for package versions
Browse files Browse the repository at this point in the history
Replaces more complex package version checks with one-liners.
  • Loading branch information
potiuk committed Feb 21, 2024
1 parent 70348de commit 6687cf3
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 44 deletions.
12 changes: 4 additions & 8 deletions airflow/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,14 @@

from __future__ import annotations

from importlib import metadata

def is_pydantic_2_installed() -> bool:
import sys
from packaging import version

from packaging.version import Version

if sys.version_info >= (3, 9):
from importlib.metadata import distribution
else:
from importlib_metadata import distribution
def is_pydantic_2_installed() -> bool:
try:
return Version(distribution("pydantic").version) >= Version("2.0.0")
return version.parse(metadata.version("pydantic")).major == 2
except ImportError:
return False

Expand Down
11 changes: 3 additions & 8 deletions airflow/utils/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import datetime
import json
import logging
from importlib.metadata import version
from importlib import metadata
from typing import TYPE_CHECKING, Any, Generator, Iterable, overload

from dateutil import relativedelta
from packaging.version import Version, parse as parse_version
from packaging import version
from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst, tuple_
from sqlalchemy.dialects import mysql
from sqlalchemy.types import JSON, Text, TypeDecorator
Expand Down Expand Up @@ -555,10 +555,5 @@ def get_orm_mapper():
return sqlalchemy.orm.mapper if is_sqlalchemy_v1() else sqlalchemy.orm.Mapper


def _get_lib_major_version(lib_name: str) -> int:
ver: Version = parse_version(version(lib_name))
return ver.major


def is_sqlalchemy_v1() -> bool:
return _get_lib_major_version("sqlalchemy") == 1
return version.parse(metadata.version("sqlalchemy")).major == 1
4 changes: 3 additions & 1 deletion airflow/utils/timezone.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@
from __future__ import annotations

import datetime as dt
from importlib import metadata
from typing import TYPE_CHECKING, overload

import pendulum
from dateutil.relativedelta import relativedelta
from packaging import version
from pendulum.datetime import DateTime

if TYPE_CHECKING:
from pendulum.tz.timezone import FixedTimezone, Timezone

from airflow.typing_compat import Literal

_PENDULUM3 = pendulum.__version__.startswith("3")
_PENDULUM3 = version.parse(metadata.version("pendulum")).major == 3
# UTC Timezone as a tzinfo instance. Actual value depends on pendulum version:
# - Timezone("UTC") in pendulum 3
# - FixedTimezone(0, "UTC") in pendulum 2
Expand Down
4 changes: 3 additions & 1 deletion tests/serialization/serializers/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import datetime
import decimal
from importlib import metadata
from unittest.mock import patch

import numpy as np
Expand All @@ -26,6 +27,7 @@
import pytest
from dateutil.tz import tzutc
from deltalake import DeltaTable
from packaging import version
from pendulum import DateTime
from pendulum.tz.timezone import FixedTimezone, Timezone

Expand All @@ -38,7 +40,7 @@
else:
from backports.zoneinfo import ZoneInfo

PENDULUM3 = pendulum.__version__.startswith("3")
PENDULUM3 = version.parse(metadata.version("pendulum")).major == 3


class TestSerializers:
Expand Down
32 changes: 6 additions & 26 deletions tests/utils/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from airflow.settings import Session
from airflow.utils.sqlalchemy import (
ExecutorConfigType,
_get_lib_major_version,
ensure_pod_is_valid_after_unpickling,
is_sqlalchemy_v1,
prohibit_commit,
Expand Down Expand Up @@ -317,32 +316,13 @@ def test_result_processor_bad_pickled_obj(self):


@pytest.mark.parametrize(
"version_string, expected_major_version",
"mock_version, expected_result",
[
("1.4.22", 1), # Test 1: "1.4.22" parsed as 1
("10.4.22", 10), # Test 2: "10.4.22" not parsed as 1
("invalid", None), # Test 3: Invalid version string
("3.x.x", None), # Test 4: Malformed version
("1.0.0", True), # Test 1: v1 identified as v1
("2.3.4", False), # Test 2: v2 not identified as v1
],
)
def test_get_lib_major_version(version_string, expected_major_version):
with mock.patch("airflow.utils.sqlalchemy.version") as mock_version:
mock_version.return_value = version_string
if expected_major_version is not None:
assert _get_lib_major_version("dummy_module") == expected_major_version
else:
with pytest.raises(ValueError):
_get_lib_major_version("dummy_module")


@pytest.mark.parametrize(
"major_version, expected_result",
[
(1, True), # Test 1: v1 identified as v1
(2, False), # Test 2: v2 not identified as v1
],
)
def test_is_sqlalchemy_v1(major_version, expected_result):
with mock.patch("airflow.utils.sqlalchemy._get_lib_major_version") as mock_get_major_version:
mock_get_major_version.return_value = major_version
def test_is_sqlalchemy_v1(mock_version, expected_result):
with mock.patch("airflow.utils.sqlalchemy.metadata") as mock_metadata:
mock_metadata.version.return_value = mock_version
assert is_sqlalchemy_v1() == expected_result

0 comments on commit 6687cf3

Please sign in to comment.