Skip to content

Commit

Permalink
Fix xcom arg.py .zip bug (#26636)
Browse files Browse the repository at this point in the history
(cherry picked from commit f219bfb)
  • Loading branch information
rjmcginness authored and jedcunningham committed Sep 27, 2022
1 parent 12bfb57 commit 45a461b
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions airflow/models/xcom_arg.py
Expand Up @@ -31,7 +31,7 @@
from airflow.utils.context import Context
from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.types import NOTSET
from airflow.utils.types import NOTSET, ArgNotSet

if TYPE_CHECKING:
from airflow.models.dag import DAG
Expand Down Expand Up @@ -322,7 +322,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
task_id = self.operator.task_id
result = context["ti"].xcom_pull(task_ids=task_id, key=str(self.key), default=NOTSET, session=session)
if result is not NOTSET:
if not isinstance(result, ArgNotSet):
return result
if self.key == XCOM_RETURN_KEY:
return None
Expand Down Expand Up @@ -437,7 +437,7 @@ def __getitem__(self, index: Any) -> Any:

def __len__(self) -> int:
lengths = (len(v) for v in self.values)
if self.fillvalue is NOTSET:
if isinstance(self.fillvalue, ArgNotSet):
return min(lengths)
return max(lengths)

Expand All @@ -460,13 +460,13 @@ def __repr__(self) -> str:
args_iter = iter(self.args)
first = repr(next(args_iter))
rest = ", ".join(repr(arg) for arg in args_iter)
if self.fillvalue is NOTSET:
if isinstance(self.fillvalue, ArgNotSet):
return f"{first}.zip({rest})"
return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"

def _serialize(self) -> dict[str, Any]:
args = [serialize_xcom_arg(arg) for arg in self.args]
if self.fillvalue is NOTSET:
if isinstance(self.fillvalue, ArgNotSet):
return {"args": args}
return {"args": args, "fillvalue": self.fillvalue}

Expand All @@ -486,7 +486,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
ready_lengths = [length for length in all_lengths if length is not None]
if len(ready_lengths) != len(self.args):
return None # If any of the referenced XComs is not ready, we are not ready either.
if self.fillvalue is NOTSET:
if isinstance(self.fillvalue, ArgNotSet):
return min(ready_lengths)
return max(ready_lengths)

Expand Down

0 comments on commit 45a461b

Please sign in to comment.