Skip to content

Commit

Permalink
Remove provide_session decorator from TaskInstancePydantic methods (a…
Browse files Browse the repository at this point in the history
…pache#37853)

If we decorate these methods then the worker will try to create a session.  But there's no reason to do this.  Sessions should be created on the static methods invoked by the API server right?
  • Loading branch information
dstandish authored and abhishekbhakat committed Mar 5, 2024
1 parent 6150ae6 commit 95d084e
Showing 1 changed file with 8 additions and 18 deletions.
26 changes: 8 additions & 18 deletions airflow/serialization/pydantic/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
PlainValidator,
is_pydantic_2_installed,
)
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.xcom import XCOM_RETURN_KEY

if TYPE_CHECKING:
Expand Down Expand Up @@ -148,13 +147,12 @@ def xcom_pull(
"""
return None

@provide_session
def xcom_push(
self,
key: str,
value: Any,
execution_date: datetime | None = None,
session: Session = NEW_SESSION,
session: Session | None = None,
) -> None:
"""
Push an XCom value for this task instance.
Expand All @@ -166,8 +164,7 @@ def xcom_push(
"""
pass

@provide_session
def get_dagrun(self, session: Session = NEW_SESSION) -> DagRunPydantic:
def get_dagrun(self, session: Session | None = None) -> DagRunPydantic:
"""
Return the DagRun for this TaskInstance.
Expand All @@ -190,8 +187,7 @@ def _execute_task(self, context, task_orig):

return _execute_task(task_instance=self, context=context, task_orig=task_orig)

@provide_session
def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None:
def refresh_from_db(self, session: Session | None = None, lock_for_update: bool = False) -> None:
"""
Refresh the task instance from the database based on the primary key.
Expand Down Expand Up @@ -248,14 +244,13 @@ def is_eligible_to_retry(self):

return _is_eligible_to_retry(task_instance=self)

@provide_session
def handle_failure(
self,
error: None | str | Exception | KeyboardInterrupt,
test_mode: bool | None = None,
context: Context | None = None,
force_fail: bool = False,
session: Session = NEW_SESSION,
session: Session | None = None,
) -> None:
"""
Handle Failure for a task instance.
Expand Down Expand Up @@ -288,7 +283,6 @@ def refresh_from_task(self, task: Operator, pool_override: str | None = None) ->

_refresh_from_task(task_instance=self, task=task, pool_override=pool_override)

@provide_session
def get_previous_dagrun(
self,
state: DagRunState | None = None,
Expand All @@ -304,11 +298,10 @@ def get_previous_dagrun(

return _get_previous_dagrun(task_instance=self, state=state, session=session)

@provide_session
def get_previous_execution_date(
self,
state: DagRunState | None = None,
session: Session = NEW_SESSION,
session: Session | None = None,
) -> pendulum.DateTime | None:
"""
Return the execution date from property previous_ti_success.
Expand Down Expand Up @@ -344,11 +337,10 @@ def get_email_subject_content(

return _get_email_subject_content(task_instance=self, exception=exception, task=task)

@provide_session
def get_previous_ti(
self,
state: DagRunState | None = None,
session: Session = NEW_SESSION,
session: Session | None = None,
) -> TaskInstance | TaskInstancePydantic | None:
"""
Return the task instance for the task that ran before this task instance.
Expand All @@ -360,7 +352,6 @@ def get_previous_ti(

return _get_previous_ti(task_instance=self, state=state, session=session)

@provide_session
def check_and_change_state_before_execution(
self,
verbose: bool = True,
Expand All @@ -374,7 +365,7 @@ def check_and_change_state_before_execution(
job_id: str | None = None,
pool: str | None = None,
external_executor_id: str | None = None,
session: Session = NEW_SESSION,
session: Session | None = None,
) -> bool:
return TaskInstance._check_and_change_state_before_execution(
task_instance=self,
Expand All @@ -393,8 +384,7 @@ def check_and_change_state_before_execution(
session=session,
)

@provide_session
def schedule_downstream_tasks(self, session: Session = NEW_SESSION, max_tis_per_query: int | None = None):
def schedule_downstream_tasks(self, session: Session | None = None, max_tis_per_query: int | None = None):
"""
Schedule downstream tasks of this task instance.
Expand Down

0 comments on commit 95d084e

Please sign in to comment.