Skip to content

Commit

Permalink
Merge pull request #1757 from davidmarin/google-ssh
Browse files Browse the repository at this point in the history
SSH tunnel on Dataproc. Fixes #1670
  • Loading branch information
David Marin committed Apr 27, 2018
2 parents d153ca8 + 3e0999d commit d5039d1
Show file tree
Hide file tree
Showing 6 changed files with 432 additions and 162 deletions.
215 changes: 214 additions & 1 deletion mrjob/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,27 @@
import logging
import os
import pipes
import socket
import random
import time
from os.path import basename
from signal import SIGKILL
from subprocess import Popen
from subprocess import PIPE

from mrjob.bin import MRJobBinRunner
from mrjob.conf import combine_dicts
from mrjob.py2 import xrange
from mrjob.setup import WorkingDirManager
from mrjob.setup import parse_setup_cmd
from mrjob.util import cmd_line
from mrjob.util import file_ext

log = logging.getLogger(__name__)

# don't try to bind SSH tunnel to more than this many local ports
_MAX_SSH_RETRIES = 20


# map archive file extensions to the command used to unarchive them
_EXT_TO_UNARCHIVE_CMD = {
Expand Down Expand Up @@ -62,6 +72,9 @@ class HadoopInTheCloudJobRunner(MRJobBinRunner):
'num_core_instances',
'num_task_instances',
'region',
'ssh_bind_ports',
'ssh_tunnel',
'ssh_tunnel_is_open',
'task_instance_type',
'zone',
}
Expand Down Expand Up @@ -95,12 +108,34 @@ def __init__(self, **kwargs):
# we'll create this script later, as needed
self._master_bootstrap_script_path = None

# ssh state

# the process for the SSH tunnel
self._ssh_proc = None

# if this is true, stop trying to launch the SSH tunnel
self._give_up_on_ssh_tunnel = False

# store the (tunneled) URL of the job tracker/resource manager
self._ssh_tunnel_url = None



### Options ###

def _default_opts(self):
return combine_dicts(
super(HadoopInTheCloudJobRunner, self)._default_opts(),
dict(max_mins_idle=_DEFAULT_MAX_MINS_IDLE)
dict(
max_mins_idle=_DEFAULT_MAX_MINS_IDLE,
# don't use a list because it makes it hard to read option
# values when running in verbose mode. See #1284
ssh_bind_ports=xrange(40001, 40841),
ssh_tunnel=False,
ssh_tunnel_is_open=False,
# ssh_bin isn't included here. For example, the Dataproc
# runner launches ssh through the gcloud util
),
)

def _fix_opts(self, opts, source=None):
Expand Down Expand Up @@ -361,3 +396,181 @@ def _add_extra_cluster_params(self, params):
params = {k: v for k, v in params.items() if v is not None}

return params

### SSH Tunnel ###

def _ssh_tunnel_args(self, bind_port):
"""Redefine this in your subclass. You will probably want to call
:py:meth:`_ssh_tunnel_opts` somewhere in here.
Should return the list of args used to run the command
to open the SSH tunnel, bound to *bind_port* on your computer,
or ``None`` if it isn't possible to set up an SSH tunnel.
"""
return None

def _ssh_tunnel_config(self):
"""Redefine this in your subclass. Should return a dict with the
following keys:
*localhost*: once we SSH in, is the web interface?
reachable at ``localhost``
*name*: either ``'job tracker'`` or ``'resource manager'``
*path*: path of main page on web interface (e.g. "/cluster")
*port*: port number of the web interface
"""
raise NotImplementedError

def _launch_ssh_proc(self, args):
"""The command used to create a :py:class:`subprocess.Popen` to
run the SSH tunnel. You usually don't need to redefine this."""
log.debug('> %s' % cmd_line(args))
return Popen(args, stdin=PIPE, stdout=PIPE, stderr=PIPE)

def _ssh_launch_wait_secs(self):
"""Wait this long after launching the SSH process before checking
for failure (default 1 second). You may redefine this."""
return 1.0

def _set_up_ssh_tunnel(self):
"""Call this whenever you think it is possible to SSH to your cluster.
This sets :py:attr:`_ssh_proc`. Does nothing if :mrjob-opt:`ssh_tunnel`
is not set, or there is already a tunnel process running.
"""
# did the user request an SSH tunnel?
if not self._opts['ssh_tunnel']:
return

# no point in trying to launch a nonexistent command twice
if self._give_up_on_ssh_tunnel:
return

# did we already launch the SSH tunnel process? is it still running?
if self._ssh_proc:
self._ssh_proc.poll()
if self._ssh_proc.returncode is None:
return
else:
log.warning(' Oops, ssh subprocess exited with return code'
' %d, restarting...' % self._ssh_proc.returncode)
self._ssh_proc = None

tunnel_config = self._ssh_tunnel_config()

bind_port = None
popen_exception = None
ssh_tunnel_args = []

for bind_port in self._pick_ssh_bind_ports():
ssh_proc = None
ssh_tunnel_args = self._ssh_tunnel_args(bind_port)

# can't launch SSH tunnel right now
if not ssh_tunnel_args:
return

try:
ssh_proc = self._launch_ssh_proc(ssh_tunnel_args)
except OSError as ex:
# e.g. OSError(2, 'File not found')
popen_exception = ex # warning handled below
break

if ssh_proc:
time.sleep(self._ssh_launch_wait_secs())

ssh_proc.poll()
# still running. We are golden
if ssh_proc.returncode is None:
self._ssh_proc = ssh_proc
break
else:
ssh_proc.stdin.close()
ssh_proc.stdout.close()
ssh_proc.stderr.close()

if self._ssh_proc:
if self._opts['ssh_tunnel_is_open']:
bind_host = socket.getfqdn()
else:
bind_host = 'localhost'
self._ssh_tunnel_url = 'http://%s:%d%s' % (
bind_host, bind_port, tunnel_config['path'])
log.info(' Connect to %s at: %s' % (
tunnel_config['name'], self._ssh_tunnel_url))

else:
if popen_exception:
# this only happens if the ssh binary is not present
# or not executable (so tunnel_config and the args to the
# ssh binary don't matter)
log.warning(
" Couldn't open SSH tunnel: %s" % popen_exception)
self._give_up_on_ssh_tunnel = True
return
else:
log.warning(
' Failed to open ssh tunnel to %s' %
tunnel_config['name'])

def _kill_ssh_tunnel(self):
"""Send SIGKILL to SSH tunnel, if it's running."""
if not self._ssh_proc:
return

self._ssh_proc.poll()
if self._ssh_proc.returncode is None:
log.info('Killing our SSH tunnel (pid %d)' %
self._ssh_proc.pid)

self._ssh_proc.stdin.close()
self._ssh_proc.stdout.close()
self._ssh_proc.stderr.close()

try:
os.kill(self._ssh_proc.pid, SIGKILL)
except Exception as e:
log.exception(e)

self._ssh_proc = None
self._ssh_tunnel_url = None

def _ssh_tunnel_opts(self, bind_port):
"""Options to SSH related to setting up a tunnel (rather than
SSHing in). Helper for :py:meth:`_ssh_tunnel_args`.
"""
args = self._ssh_local_tunnel_opt(bind_port) + [
'-N', '-n', '-q',
]
if self._opts['ssh_tunnel_is_open']:
args.extend(['-g', '-4']) # -4: listen on IPv4 only

return args

def _ssh_local_tunnel_opt(self, bind_port):
"""Helper for :py:meth:`_ssh_tunnel_opts`."""
tunnel_config = self._ssh_tunnel_config()

return [
'-L', '%d:%s:%d' % (
bind_port,
self._job_tracker_host(),
tunnel_config['port'],
),
]

def _pick_ssh_bind_ports(self):
"""Pick a list of ports to try binding our SSH tunnel to.
We will try to bind the same port for any given cluster (Issue #67)
"""
# don't perturb the random number generator
random_state = random.getstate()
try:
# seed random port selection on cluster ID
random.seed(self._cluster_id)
num_picks = min(_MAX_SSH_RETRIES,
len(self._opts['ssh_bind_ports']))
return random.sample(self._opts['ssh_bind_ports'], num_picks)
finally:
random.setstate(random_state)
64 changes: 62 additions & 2 deletions mrjob/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@

_GCP_CLUSTER_NAME_REGEX = '(?:[a-z](?:[-a-z0-9]{0,53}[a-z0-9])?).'

# on Dataproc, the resource manager is always at 8088. Tunnel to the master
# node's own hostname, not localhost.
_SSH_TUNNEL_CONFIG = dict(
localhost=False,
name='resource manager',
path='/cluster',
port=8088,
)


# convert enum values to strings (e.g. 'RUNNING')

Expand Down Expand Up @@ -197,6 +206,7 @@ class DataprocJobRunner(HadoopInTheCloudJobRunner, LogInterpretationMixin):
alias = 'dataproc'

OPT_NAMES = HadoopInTheCloudJobRunner.OPT_NAMES | {
'gcloud_bin',
'project_id',
}

Expand Down Expand Up @@ -309,6 +319,7 @@ def _default_opts(self):
check_cluster_every=_DEFAULT_CHECK_CLUSTER_EVERY,
cleanup=['CLUSTER', 'JOB', 'LOCAL_TMP'],
cloud_fs_sync_secs=_DEFAULT_CLOUD_FS_SYNC_SECS,
gcloud_bin=['gcloud'],
image_version=_DEFAULT_IMAGE_VERSION,
instance_type=_DEFAULT_INSTANCE_TYPE,
master_instance_type=_DEFAULT_INSTANCE_TYPE,
Expand Down Expand Up @@ -528,6 +539,9 @@ def _create_fs_tmp_bucket(self, bucket_name, location=None):
def cleanup(self, mode=None):
super(DataprocJobRunner, self).cleanup(mode=mode)

# close our SSH tunnel, if any
self._kill_ssh_tunnel()

# stop the cluster if it belongs to us (it may have stopped on its
# own already, but that's fine)
if self._cluster_id and not self._opts['cluster_id']:
Expand Down Expand Up @@ -675,6 +689,8 @@ def _launch_cluster(self):

self._wait_for_cluster_ready(self._cluster_id)

self._set_up_ssh_tunnel()

# keep track of when we launched our job
self._dataproc_job_start = time.time()
return self._cluster_id
Expand Down Expand Up @@ -819,8 +835,13 @@ def _get_new_driver_output_lines(self, driver_output_uri):
state['log_uri'] = log_uri

log_blob = self.fs._get_blob(log_uri)
# TODO: use start= kwarg once google-cloud-storage 1.9 is out
new_data = log_blob.download_as_string()[state['pos']:]

try:
# TODO: use start= kwarg once google-cloud-storage 1.9 is out
new_data = log_blob.download_as_string()[state['pos']:]
except google.api_core.exceptions.NotFound:
# handle race condition where blob was just created
break

state['buffer'] += new_data
state['pos'] += len(new_data)
Expand Down Expand Up @@ -1039,3 +1060,42 @@ def _manifest_download_commands(self):
#('gs://*', 'gsutil cp'),
('*://*', 'hadoop fs -copyToLocal'),
]

### SSH hooks ###

def _job_tracker_host(self):
return '%s-m' % self._cluster_id

def _ssh_tunnel_config(self):
return _SSH_TUNNEL_CONFIG

def _launch_ssh_proc(self, args):
ssh_proc = super(DataprocJobRunner, self)._launch_ssh_proc(args)

# enter an empty passphrase if creating a key for the first time
ssh_proc.stdin.write(b'\n\n')

return ssh_proc

def _ssh_launch_wait_secs(self):
"""Wait 20 seconds because gcloud has to update project metadata
(unless we were going to check the cluster sooner anyway)."""
return min(20.0, self._opts['check_cluster_every'])

def _ssh_tunnel_args(self, bind_port):
if not self._opts['gcloud_bin']:
self._give_up_on_ssh_tunnel = True
return None

if not self._cluster_id:
return

cluster = self._get_cluster(self._cluster_id)
zone = cluster.config.gce_cluster_config.zone_uri.split('/')[-1]

return self._opts['gcloud_bin'] + [
'compute', 'ssh',
'--zone', zone,
self._job_tracker_host(),
'--',
] + self._ssh_tunnel_opts(bind_port)

0 comments on commit d5039d1

Please sign in to comment.