diff --git a/airflow/operators/branch.py b/airflow/operators/branch.py index a4653416c47ee..d1502b426425e 100644 --- a/airflow/operators/branch.py +++ b/airflow/operators/branch.py @@ -49,4 +49,6 @@ def choose_branch(self, context: Dict) -> Union[str, Iterable[str]]: raise NotImplementedError def execute(self, context: Dict): - self.skip_all_except(context['ti'], self.choose_branch(context)) + branches_to_execute = self.choose_branch(context) + self.skip_all_except(context['ti'], branches_to_execute) + return branches_to_execute diff --git a/tests/operators/test_branch_operator.py b/tests/operators/test_branch_operator.py index d3725340234de..f54dafe7e2435 100644 --- a/tests/operators/test_branch_operator.py +++ b/tests/operators/test_branch_operator.py @@ -170,3 +170,24 @@ def test_with_skip_in_branch_downstream_dependencies(self): assert ti.state == State.NONE else: raise Exception + + def test_xcom_push(self): + self.branch_op = ChooseBranchOne(task_id='make_choice', dag=self.dag) + + self.branch_1.set_upstream(self.branch_op) + self.branch_2.set_upstream(self.branch_op) + self.dag.clear() + + dr = self.dag.create_dagrun( + run_type=DagRunType.MANUAL, + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + + self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == 'make_choice': + assert ti.xcom_pull(task_ids='make_choice') == 'branch_1'