forked from python-discord/snekbox
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nsjail.py
306 lines (262 loc) · 11.4 KB
/
nsjail.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import logging
import re
import subprocess
import sys
from collections.abc import Generator
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Iterable, TypeVar
from google.protobuf import text_format
from snekbox import DEBUG, utils
from snekbox.config_pb2 import NsJailConfig
from snekbox.filesystem import Size
from snekbox.memfs import MemFS
from snekbox.process import EvalResult
from snekbox.snekio import FileAttachment
from snekbox.utils.timed import time_limit
__all__ = ("NsJail",)
log = logging.getLogger(__name__)
_T = TypeVar("_T")
# [level][timestamp][PID]? function_signature:line_no? message
LOG_PATTERN = re.compile(
r"\[(?P<level>(I)|[DWEF])\]\[.+?\](?(2)|(?P<func>\[\d+\] .+?:\d+ )) ?(?P<msg>.+)"
)
def iter_lstrip(iterable: Iterable[_T]) -> Generator[_T, None, None]:
"""Remove leading falsy objects from an iterable."""
it = iter(iterable)
for item in it:
if item:
yield item
break
yield from it
class NsJail:
"""
Core Snekbox functionality, providing safe execution of Python code.
See config/snekbox.cfg for the default NsJail configuration.
"""
def __init__(
self,
nsjail_path: str = "/usr/sbin/nsjail",
config_path: str = "./config/snekbox.cfg",
max_output_size: int = 1_000_000,
read_chunk_size: int = 10_000,
memfs_instance_size: int = 48 * Size.MiB,
memfs_home: str = "home",
memfs_output: str = "home",
files_limit: int | None = 100,
files_timeout: int | None = 5,
files_pattern: str = "**/[!_]*",
):
"""
Initialize NsJail.
Args:
nsjail_path: Path to the NsJail binary.
config_path: Path to the NsJail configuration file.
max_output_size: Maximum size of the output in bytes.
read_chunk_size: Size of the read buffer in bytes.
memfs_instance_size: Size of the tmpfs instance in bytes.
memfs_home: Name of the mounted home directory.
memfs_output: Name of the output directory within home,
can be empty to use home as output.
files_limit: Maximum number of output files to parse.
files_timeout: Maximum time in seconds to wait for output files to be read.
files_pattern: Pattern to match files to attach within the output directory.
"""
self.nsjail_path = nsjail_path
self.config_path = config_path
self.max_output_size = max_output_size
self.read_chunk_size = read_chunk_size
self.memfs_instance_size = memfs_instance_size
self.memfs_home = memfs_home
self.memfs_output = memfs_output
self.files_limit = files_limit
self.files_timeout = files_timeout
self.files_pattern = files_pattern
self.config = self._read_config(config_path)
self.cgroup_version = utils.cgroup.init(self.config)
self.ignore_swap_limits = utils.swap.should_ignore_limit(self.config, self.cgroup_version)
log.info(f"Assuming cgroup version {self.cgroup_version}.")
@staticmethod
def _read_config(config_path: str) -> NsJailConfig:
"""Read the NsJail config at `config_path` and return a protobuf Message object."""
config = NsJailConfig()
try:
with open(config_path, encoding="utf-8") as f:
config_text = f.read()
except FileNotFoundError:
log.fatal(f"The NsJail config at {config_path!r} could not be found.")
sys.exit(1)
except OSError as e:
log.fatal(f"The NsJail config at {config_path!r} could not be read.", exc_info=e)
sys.exit(1)
try:
text_format.Parse(config_text, config)
except text_format.ParseError as e:
log.fatal(f"The NsJail config at {config_path!r} could not be parsed.", exc_info=e)
sys.exit(1)
return config
@staticmethod
def _parse_log(log_lines: Iterable[str]) -> None:
"""Parse and log NsJail's log messages."""
for line in log_lines:
match = LOG_PATTERN.fullmatch(line)
if match is None:
log.warning(f"Failed to parse log line '{line}'")
continue
msg = match["msg"]
if DEBUG and match["func"]:
# Prepend PID, function signature, and line number if debugging.
msg = f"{match['func']}{msg}"
if match["level"] == "D":
log.debug(msg)
elif match["level"] == "I":
if DEBUG or msg.startswith("pid="):
# Skip messages unrelated to process exit if not debugging.
log.info(msg)
elif match["level"] == "W":
log.warning(msg)
else:
# Treat fatal as error.
log.error(msg)
def _consume_stdout(self, nsjail: subprocess.Popen) -> str:
"""
Consume STDOUT, stopping when the output limit is reached or NsJail has exited.
The aim of this function is to limit the size of the output received from
NsJail to prevent container from claiming too much memory. If the output
received from STDOUT goes over the OUTPUT_MAX limit, the NsJail subprocess
is asked to terminate with a SIGKILL.
Once the subprocess has exited, either naturally or because it was terminated,
we return the output as a single string.
"""
output_size = 0
output = []
# Context manager will wait for process to terminate and close file descriptors.
with nsjail:
# We'll consume STDOUT as long as the NsJail subprocess is running.
while nsjail.poll() is None:
chars = nsjail.stdout.read(self.read_chunk_size)
output_size += sys.getsizeof(chars)
output.append(chars)
if output_size > self.max_output_size:
# Terminate the NsJail subprocess with SIGTERM.
# This in turn reaps and kills children with SIGKILL.
log.info("Output exceeded the output limit, sending SIGTERM to NsJail.")
nsjail.terminate()
break
return "".join(output)
def python3(
self,
py_args: Iterable[str],
files: Iterable[FileAttachment] = (),
nsjail_args: Iterable[str] = (),
) -> EvalResult:
"""
Execute Python 3 code in an isolated environment and return the completed process.
Args:
py_args: Arguments to pass to Python.
files: FileAttachments to write to the sandbox prior to running Python.
nsjail_args: Overrides for the NsJail configuration.
"""
if self.cgroup_version == 2:
nsjail_args = ("--use_cgroupv2", *nsjail_args)
if self.ignore_swap_limits:
nsjail_args = (
"--cgroup_mem_memsw_max",
"0",
"--cgroup_mem_swap_max",
"-1",
*nsjail_args,
)
with NamedTemporaryFile() as nsj_log, MemFS(
instance_size=self.memfs_instance_size,
home=self.memfs_home,
output=self.memfs_output,
) as fs:
nsjail_args = (
# Mount `home` with Read/Write access
"--bindmount",
f"{fs.home}:home",
*nsjail_args,
)
args = [
self.nsjail_path,
"--config",
self.config_path,
"--log",
nsj_log.name,
*nsjail_args,
"--",
self.config.exec_bin.path,
# Filter out empty strings at start of Python args
# (causes issues with python cli)
*iter_lstrip(self.config.exec_bin.arg),
*iter_lstrip(py_args),
]
# Write provided files if any
files_written: dict[Path, float] = {}
for file in files:
try:
f_path = file.save_to(fs.home)
# Allow file to be writable
f_path.chmod(0o777)
# Save the written at time to later check if it was modified
files_written[f_path] = f_path.stat().st_mtime
log.info(f"Created file at {(fs.home / file.path)!r}.")
except OSError as e:
log.info(f"Failed to create file at {(fs.home / file.path)!r}.", exc_info=e)
return EvalResult(
args, None, f"{e.__class__.__name__}: Failed to create file '{file.path}'."
)
msg = "Executing code..."
if DEBUG:
msg = f"{msg[:-3]} with the arguments {args}."
log.info(msg)
try:
nsjail = subprocess.Popen(
args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
except ValueError:
return EvalResult(args, None, "ValueError: embedded null byte")
try:
output = self._consume_stdout(nsjail)
except UnicodeDecodeError:
return EvalResult(args, None, "UnicodeDecodeError: invalid Unicode in output pipe")
# When you send signal `N` to a subprocess to terminate it using Popen, it
# will return `-N` as its exit code. As we normally get `N + 128` back, we
# convert negative exit codes to the `N + 128` form.
returncode = -nsjail.returncode + 128 if nsjail.returncode < 0 else nsjail.returncode
# Parse attachments with time limit
try:
with time_limit(self.files_timeout):
attachments = fs.files_list(
limit=self.files_limit,
pattern=self.files_pattern,
preload_dict=True,
exclude_files=files_written,
timeout=self.files_timeout,
)
log.info(f"Found {len(attachments)} files.")
except RecursionError:
log.info("Recursion error while parsing attachments")
return EvalResult(
args,
None,
"FileParsingError: Exceeded directory depth limit while parsing attachments",
)
except TimeoutError as e:
log.info(f"Exceeded time limit while parsing attachments: {e}")
return EvalResult(
args, None, "TimeoutError: Exceeded time limit while parsing attachments"
)
except Exception as e:
log.exception(f"Unexpected {type(e).__name__} while parse attachments", exc_info=e)
return EvalResult(
args, None, "FileParsingError: Unknown error while parsing attachments"
)
log_lines = nsj_log.read().decode("utf-8").splitlines()
if not log_lines and returncode == 255:
# NsJail probably failed to parse arguments so log output will still be in stdout
log_lines = output.splitlines()
self._parse_log(log_lines)
log.info(f"nsjail return code: {returncode}")
return EvalResult(args, returncode, output, files=attachments)