This repository has been archived by the owner on May 12, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 235
/
helper.py
387 lines (333 loc) · 13.6 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
from contextlib import closing
import errno
import os
import signal
import time
from twitter.common import log
from twitter.common.dirutil import lock_file, safe_mkdir
from twitter.common.quantity import Amount, Time
from twitter.common.recordio import ThriftRecordWriter
from twitter.thermos.common.ckpt import CheckpointDispatcher
from twitter.thermos.common.path import TaskPath
from gen.twitter.thermos.ttypes import (
ProcessState,
ProcessStatus,
RunnerCkpt,
TaskState,
TaskStatus)
import psutil
class TaskKiller(object):
"""
Task killing interface.
"""
def __init__(self, task_id, checkpoint_root):
self._task_id = task_id
self._checkpoint_root = checkpoint_root
def kill(self, force=True):
TaskRunnerHelper.kill(self._task_id, self._checkpoint_root, force=force,
terminal_status=TaskState.KILLED)
def lose(self, force=True):
TaskRunnerHelper.kill(self._task_id, self._checkpoint_root, force=force,
terminal_status=TaskState.LOST)
class TaskRunnerHelper(object):
"""
TaskRunner helper methods that can be operated directly upon checkpoint
state. These operations do not require knowledge of the underlying
task.
TaskRunnerHelper is sort of a mishmash of "checkpoint-only" operations and
the "Process Platform" stuff that started to get pulled into process.py
This really needs some hard design thought to see if it can be extracted out
even further.
"""
class Error(Exception): pass
class PermissionError(Error): pass
# Maximum drift between when the system says a task was forked and when we checkpointed
# its fork_time (used as a heuristic to determine a forked task is really ours instead of
# a task with coincidentally the same PID but just wrapped around.)
MAX_START_TIME_DRIFT = Amount(10, Time.SECONDS)
@staticmethod
def get_actual_user():
import getpass, pwd
try:
pwd_entry = pwd.getpwuid(os.getuid())
except KeyError:
return getpass.getuser()
return pwd_entry[0]
@staticmethod
def process_from_name(task, process_name):
if task.has_processes():
for process in task.processes():
if process.name().get() == process_name:
return process
return None
@classmethod
def this_is_really_our_pid(cls, process, current_user, start_time):
"""
A heuristic to make sure that this is likely the pid that we own/forked. Necessary
because of pid-space wrapping. We don't want to go and kill processes we don't own,
especially if the killer is running as root.
process: psutil.Process representing the process to check
current_user: user expected to own the process
start_time: time at which it's expected the process has started
Raises:
psutil.NoSuchProcess - if the Process supplied no longer exists
"""
if process.username != current_user:
log.info("Expected pid %s to be ours but the pid user is %s and we're %s" % (
process.pid, process.username, current_user))
return False
if abs(start_time - process.create_time) >= cls.MAX_START_TIME_DRIFT.as_(Time.SECONDS):
log.info("Expected pid %s start time to be %s but it's %s" % (
process.pid, start_time, process.create_time))
return False
return True
@classmethod
def scan_process(cls, state, process_name):
"""
Given a RunnerState and a process_name, return the following:
(coordinator pid, process pid, process tree)
(int or None, int or None, set)
"""
process_run = state.processes[process_name][-1]
process_owner = state.header.user
coordinator_pid, pid, tree = None, None, set()
if process_run.coordinator_pid:
try:
coordinator_process = psutil.Process(process_run.coordinator_pid)
if cls.this_is_really_our_pid(coordinator_process, process_owner, process_run.fork_time):
coordinator_pid = process_run.coordinator_pid
except psutil.NoSuchProcess:
log.info(' Coordinator %s [pid: %s] completed.' % (process_run.process,
process_run.coordinator_pid))
except psutil.Error as err:
log.warning(' Error gathering information on pid %s: %s' % (process_run.coordinator_pid,
err))
if process_run.pid:
try:
process = psutil.Process(process_run.pid)
if cls.this_is_really_our_pid(process, process_owner, process_run.start_time):
pid = process.pid
except psutil.NoSuchProcess:
log.info(' Process %s [pid: %s] completed.' % (process_run.process, process_run.pid))
except psutil.Error as err:
log.warning(' Error gathering information on pid %s: %s' % (process_run.pid, err))
else:
if pid:
try:
tree = set(proc.pid for proc in process.get_children(recursive=True))
except psutil.Error:
log.warning(' Error gathering information on children of pid %s' % pid)
return (coordinator_pid, pid, tree)
@classmethod
def scantree(cls, state):
"""
Scan the process tree associated with the provided task state.
Returns a dictionary of process name => (coordinator pid, pid, pid children)
If the coordinator is no longer active, coordinator pid will be None. If the
forked process is no longer active, pid will be None and its children will be
an empty set.
"""
return dict((process_name, cls.scan_process(state, process_name))
for process_name in state.processes)
@classmethod
def safe_signal(cls, pid, sig=signal.SIGTERM):
try:
os.kill(pid, sig)
except OSError as e:
if e.errno not in (errno.ESRCH, errno.EPERM):
log.error('Unexpected error in os.kill: %s' % e)
except Exception as e:
log.error('Unexpected error in os.kill: %s' % e)
@classmethod
def terminate_pid(cls, pid):
cls.safe_signal(pid, signal.SIGTERM)
@classmethod
def kill_pid(cls, pid):
cls.safe_signal(pid, signal.SIGKILL)
@classmethod
def kill_group(cls, pgrp):
cls.safe_signal(-pgrp, signal.SIGKILL)
@classmethod
def _get_process_tuple(cls, state, process_name):
assert process_name in state.processes and len(state.processes[process_name]) > 0
return cls.scan_process(state, process_name)
@classmethod
def _get_coordinator_group(cls, state, process_name):
assert process_name in state.processes and len(state.processes[process_name]) > 0
return state.processes[process_name][-1].coordinator_pid
@classmethod
def terminate_process(cls, state, process_name):
log.debug('TaskRunnerHelper.terminate_process(%s)' % process_name)
_, pid, _ = cls._get_process_tuple(state, process_name)
if pid:
log.debug(' => SIGTERM pid %s' % pid)
cls.terminate_pid(pid)
return bool(pid)
@classmethod
def kill_process(cls, state, process_name):
log.debug('TaskRunnerHelper.kill_process(%s)' % process_name)
coordinator_pgid = cls._get_coordinator_group(state, process_name)
coordinator_pid, pid, tree = cls._get_process_tuple(state, process_name)
# This is super dangerous. TODO(wickman) Add a heuristic that determines
# that 1) there are processes that currently belong to this process group
# and 2) those processes have inherited the coordinator checkpoint filehandle
# This way we validate that it is in fact the process group we expect.
if coordinator_pgid:
log.debug(' => SIGKILL coordinator group %s' % coordinator_pgid)
cls.kill_group(coordinator_pgid)
if coordinator_pid:
log.debug(' => SIGKILL coordinator %s' % coordinator_pid)
cls.kill_pid(coordinator_pid)
if pid:
log.debug(' => SIGKILL pid %s' % pid)
cls.kill_pid(pid)
for child in tree:
log.debug(' => SIGKILL child %s' % child)
cls.kill_pid(child)
return bool(coordinator_pid or pid or tree)
@classmethod
def kill_runner(cls, state):
log.debug('TaskRunnerHelper.kill_runner()')
if not state or not state.statuses:
raise cls.Error('Could not read state!')
pid = state.statuses[-1].runner_pid
if pid == os.getpid():
raise cls.Error('Unwilling to commit seppuku.')
try:
os.kill(pid, signal.SIGKILL)
return True
except OSError as e:
if e.errno == errno.EPERM:
# Permission denied
return False
elif e.errno == errno.ESRCH:
# pid no longer exists
return True
raise
@classmethod
def open_checkpoint(cls, filename, force=False, state=None):
"""
Acquire a locked checkpoint stream.
"""
safe_mkdir(os.path.dirname(filename))
fp = lock_file(filename, "a+")
if fp in (None, False):
if force:
log.info('Found existing runner, forcing leadership forfeit.')
state = state or CheckpointDispatcher.from_file(filename)
if cls.kill_runner(state):
log.info('Successfully killed leader.')
# TODO(wickman) Blocking may not be the best idea here. Perhaps block up to
# a maximum timeout. But blocking is necessary because os.kill does not immediately
# release the lock if we're in force mode.
fp = lock_file(filename, "a+", blocking=True)
else:
log.error('Found existing runner, cannot take control.')
if fp in (None, False):
raise cls.PermissionError('Could not open locked checkpoint: %s, lock_file = %s' %
(filename, fp))
ckpt = ThriftRecordWriter(fp)
ckpt.set_sync(True)
return ckpt
@classmethod
def kill(cls, task_id, checkpoint_root, force=False,
terminal_status=TaskState.KILLED, clock=time):
"""
An implementation of Task killing that doesn't require a fully hydrated TaskRunner object.
Terminal status must be either KILLED or LOST state.
"""
if terminal_status not in (TaskState.KILLED, TaskState.LOST):
raise cls.Error('terminal_status must be KILLED or LOST (got %s)' %
TaskState._VALUES_TO_NAMES.get(terminal_status) or terminal_status)
pathspec = TaskPath(root=checkpoint_root, task_id=task_id)
checkpoint = pathspec.getpath('runner_checkpoint')
state = CheckpointDispatcher.from_file(checkpoint)
if state is None or state.header is None or state.statuses is None:
if force:
log.error('Task has uninitialized TaskState - forcibly finalizing')
cls.finalize_task(pathspec)
return
else:
log.error('Cannot update states in uninitialized TaskState!')
return
ckpt = cls.open_checkpoint(checkpoint, force=force, state=state)
def write_task_state(state):
update = TaskStatus(state=state, timestamp_ms=int(clock.time() * 1000),
runner_pid=os.getpid(), runner_uid=os.getuid())
ckpt.write(RunnerCkpt(task_status=update))
def write_process_status(status):
ckpt.write(RunnerCkpt(process_status=status))
if cls.is_task_terminal(state.statuses[-1].state):
log.info('Task is already in terminal state! Finalizing.')
cls.finalize_task(pathspec)
return
with closing(ckpt):
write_task_state(TaskState.ACTIVE)
for process, history in state.processes.items():
process_status = history[-1]
if not cls.is_process_terminal(process_status.state):
if cls.kill_process(state, process):
write_process_status(ProcessStatus(process=process,
state=ProcessState.KILLED, seq=process_status.seq + 1, return_code=-9,
stop_time=clock.time()))
else:
if process_status.state is not ProcessState.WAITING:
write_process_status(ProcessStatus(process=process,
state=ProcessState.LOST, seq=process_status.seq + 1))
write_task_state(terminal_status)
cls.finalize_task(pathspec)
@classmethod
def reap_children(cls):
pids = set()
while True:
try:
pid, status, rusage = os.wait3(os.WNOHANG)
if pid == 0:
break
pids.add(pid)
log.debug('Detected terminated process: pid=%s, status=%s, rusage=%s' % (
pid, status, rusage))
except OSError as e:
if e.errno != errno.ECHILD:
log.warning('Unexpected error when calling waitpid: %s' % e)
break
return pids
TERMINAL_PROCESS_STATES = frozenset([
ProcessState.SUCCESS,
ProcessState.KILLED,
ProcessState.FAILED,
ProcessState.LOST])
TERMINAL_TASK_STATES = frozenset([
TaskState.SUCCESS,
TaskState.FAILED,
TaskState.KILLED,
TaskState.LOST])
@classmethod
def is_process_terminal(cls, process_status):
return process_status in cls.TERMINAL_PROCESS_STATES
@classmethod
def is_task_terminal(cls, task_status):
return task_status in cls.TERMINAL_TASK_STATES
@classmethod
def initialize_task(cls, spec, task):
active_task = spec.given(state='active').getpath('task_path')
finished_task = spec.given(state='finished').getpath('task_path')
is_active, is_finished = os.path.exists(active_task), os.path.exists(finished_task)
if is_finished:
raise cls.Error('Cannot initialize task with "finished" record!')
if not is_active:
safe_mkdir(os.path.dirname(active_task))
with open(active_task, 'w') as fp:
fp.write(task)
@classmethod
def finalize_task(cls, spec):
active_task = spec.given(state='active').getpath('task_path')
finished_task = spec.given(state='finished').getpath('task_path')
is_active, is_finished = os.path.exists(active_task), os.path.exists(finished_task)
if not is_active:
raise cls.Error('Cannot finalize task with no "active" record!')
elif is_finished:
raise cls.Error('Cannot finalize task with "finished" record!')
safe_mkdir(os.path.dirname(finished_task))
os.rename(active_task, finished_task)
os.utime(finished_task, None)