Skip to content

Commit

Permalink
[AIRFLOW-1852] Allow hostname to be overridable.
Browse files Browse the repository at this point in the history
This allows hostnames to be overridable to
facilitate service discovery
requirements in common production deployments.

Closes #3036 from thekashifmalik/hostnames
  • Loading branch information
akatrevorjay authored and Joy Gao committed Feb 21, 2018
1 parent fc26cad commit 6c93460
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 15 deletions.
3 changes: 2 additions & 1 deletion airflow/bin/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS)
from airflow.utils import db as db_utils
from airflow.utils.net import get_hostname
from airflow.utils.log.logging_mixin import (LoggingMixin, redirect_stderr,
redirect_stdout, set_context)
from airflow.www.app import cached_app
Expand Down Expand Up @@ -437,7 +438,7 @@ def run(args, dag=None):

ti.init_run_context(raw=args.raw)

hostname = socket.getfqdn()
hostname = get_hostname()
log.info("Running %s on host %s", ti, hostname)

if args.interactive:
Expand Down
4 changes: 4 additions & 0 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ logging_config_class =
log_format = [%%(asctime)s] {{%%(filename)s:%%(lineno)d}} %%(levelname)s - %%(message)s
simple_log_format = %%(asctime)s %%(levelname)s - %%(message)s

# Hostname override by providing a path to a callable.
# hostname_callable = socket:getfqdn


# Default timezone in case supplied date times are naive
# can be utc (default), system, or any IANA timezone string (e.g. Europe/Amsterdam)
default_timezone = utc
Expand Down
5 changes: 3 additions & 2 deletions airflow/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from airflow.utils.log.logging_mixin import LoggingMixin, set_context, StreamLogWriter
from airflow.utils.state import State
from airflow.utils.configuration import tmp_configuration_copy
from airflow.utils.net import get_hostname

Base = models.Base
ID_LEN = models.ID_LEN
Expand Down Expand Up @@ -99,7 +100,7 @@ def __init__(
executor=executors.GetDefaultExecutor(),
heartrate=conf.getfloat('scheduler', 'JOB_HEARTBEAT_SEC'),
*args, **kwargs):
self.hostname = socket.getfqdn()
self.hostname = get_hostname()
self.executor = executor
self.executor_class = executor.__class__.__name__
self.start_date = timezone.utcnow()
Expand Down Expand Up @@ -2569,7 +2570,7 @@ def heartbeat_callback(self, session=None):
self.task_instance.refresh_from_db()
ti = self.task_instance

fqdn = socket.getfqdn()
fqdn = get_hostname()
same_hostname = fqdn == ti.hostname
same_process = ti.pid == os.getpid()

Expand Down
6 changes: 3 additions & 3 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import pickle
import re
import signal
import socket
import sys
import textwrap
import traceback
Expand Down Expand Up @@ -84,6 +83,7 @@
from airflow.utils.timeout import timeout
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule
from airflow.utils.net import get_hostname
from airflow.utils.log.logging_mixin import LoggingMixin

install_aliases()
Expand Down Expand Up @@ -1363,7 +1363,7 @@ def _check_and_change_state_before_execution(
self.test_mode = test_mode
self.refresh_from_db(session=session, lock_for_update=True)
self.job_id = job_id
self.hostname = socket.getfqdn()
self.hostname = get_hostname()
self.operator = task.__class__.__name__

if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS:
Expand Down Expand Up @@ -1480,7 +1480,7 @@ def _run_raw_task(
self.test_mode = test_mode
self.refresh_from_db(session=session)
self.job_id = job_id
self.hostname = socket.getfqdn()
self.hostname = get_hostname()
self.operator = task.__class__.__name__

context = {}
Expand Down
4 changes: 3 additions & 1 deletion airflow/security/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import socket
import airflow.configuration as conf

from airflow.utils.net import get_hostname

# Pattern to replace with hostname
HOSTNAME_PATTERN = '_HOST'

Expand Down Expand Up @@ -53,7 +55,7 @@ def replace_hostname_pattern(components, host=None):


def get_localhost_name():
return socket.getfqdn()
return get_hostname()


def get_fqdn(hostname_or_ip=None):
Expand Down
40 changes: 40 additions & 0 deletions airflow/utils/net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib
import socket
from airflow.configuration import (conf, AirflowConfigException)


def get_hostname():
"""
Fetch the hostname using the callable from the config or using
`socket.getfqdn` as a fallback.
"""
# First we attempt to fetch the callable path from the config.
try:
callable_path = conf.get('core', 'hostname_callable')
except AirflowConfigException:
callable_path = None

# Then we handle the case when the config is missing or empty. This is the
# default behavior.
if not callable_path:
return socket.getfqdn()

# Since we have a callable path, we try to import and run it next.
module_path, attr_name = callable_path.split(':')
module = importlib.import_module(module_path)
callable = getattr(module, attr_name)
return callable()
4 changes: 2 additions & 2 deletions airflow/www/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import socket
import six

from flask import Flask
Expand All @@ -32,6 +31,7 @@
from airflow import jobs
from airflow import settings
from airflow import configuration
from airflow.utils.net import get_hostname

csrf = CSRFProtect()

Expand Down Expand Up @@ -149,7 +149,7 @@ def integrate_plugins():
@app.context_processor
def jinja_globals():
return {
'hostname': socket.getfqdn(),
'hostname': get_hostname(),
}

@app.teardown_appcontext
Expand Down
5 changes: 3 additions & 2 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from airflow.utils.helpers import alchemy_to_dict
from airflow.utils.dates import infer_time_unit, scale_time_units, parse_execution_date
from airflow.utils.timezone import datetime
from airflow.utils.net import get_hostname
from airflow.www import utils as wwwutils
from airflow.www.forms import DateTimeForm, DateTimeWithNumRunsForm
from airflow.www.validators import GreaterEqualThan
Expand Down Expand Up @@ -647,14 +648,14 @@ def dag_details(self, session=None):
@current_app.errorhandler(404)
def circles(self):
return render_template(
'airflow/circles.html', hostname=socket.getfqdn()), 404
'airflow/circles.html', hostname=get_hostname()), 404

@current_app.errorhandler(500)
def show_traceback(self):
from airflow.utils import asciiart as ascii_
return render_template(
'airflow/traceback.html',
hostname=socket.getfqdn(),
hostname=get_hostname(),
nukular=ascii_.nukular,
info=traceback.format_exc()), 500

Expand Down
5 changes: 3 additions & 2 deletions tests/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from airflow.utils.dag_processing import SimpleDag, SimpleDagBag, list_py_file_paths
from airflow.utils.net import get_hostname

from mock import Mock, patch
from sqlalchemy.orm.session import make_transient
Expand Down Expand Up @@ -841,7 +842,7 @@ def test_localtaskjob_heartbeat(self, mock_pid):

mock_pid.return_value = 1
ti.state = State.RUNNING
ti.hostname = socket.getfqdn()
ti.hostname = get_hostname()
ti.pid = 1
session.merge(ti)
session.commit()
Expand Down Expand Up @@ -911,7 +912,7 @@ def test_localtaskjob_double_trigger(self):
session=session)
ti = dr.get_task_instance(task_id=task.task_id, session=session)
ti.state = State.RUNNING
ti.hostname = socket.getfqdn()
ti.hostname = get_hostname()
ti.pid = 1
session.commit()

Expand Down
55 changes: 55 additions & 0 deletions tests/utils/test_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import mock

from airflow.utils import net


def get_hostname():
return 'awesomehostname'


class GetHostname(unittest.TestCase):

@mock.patch('airflow.utils.net.socket')
@mock.patch('airflow.utils.net.conf')
def test_get_hostname_unset(self, patched_conf, patched_socket):
patched_conf.get = mock.Mock(return_value=None)
patched_socket.getfqdn = mock.Mock(return_value='first')
self.assertTrue(net.get_hostname() == 'first')

@mock.patch('airflow.utils.net.conf')
def test_get_hostname_set(self, patched_conf):
patched_conf.get = mock.Mock(
return_value='tests.utils.test_net:get_hostname'
)
self.assertTrue(net.get_hostname() == 'awesomehostname')

@mock.patch('airflow.utils.net.conf')
def test_get_hostname_set_incorrect(self, patched_conf):
patched_conf.get = mock.Mock(
return_value='tests.utils.test_net'
)
with self.assertRaises(ValueError):
net.get_hostname()

@mock.patch('airflow.utils.net.conf')
def test_get_hostname_set_missing(self, patched_conf):
patched_conf.get = mock.Mock(
return_value='tests.utils.test_net:missing_func'
)
with self.assertRaises(AttributeError):
net.get_hostname()
5 changes: 3 additions & 2 deletions tests/www/api/experimental/test_kerberos_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from airflow import configuration
from airflow.api.auth.backend.kerberos_auth import client_auth
from airflow.utils.net import get_hostname
from airflow.www import app as application


Expand Down Expand Up @@ -57,7 +58,7 @@ def test_trigger_dag(self):
)
self.assertEqual(401, response.status_code)

response.url = 'http://{}'.format(socket.getfqdn())
response.url = 'http://{}'.format(get_hostname())

class Request():
headers = {}
Expand All @@ -72,7 +73,7 @@ class Request():
client_auth.mutual_authentication = 3

# case can influence the results
client_auth.hostname_override = socket.getfqdn()
client_auth.hostname_override = get_hostname()

client_auth.handle_response(response)
self.assertIn('Authorization', response.request.headers)
Expand Down

0 comments on commit 6c93460

Please sign in to comment.