Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion app/api/v2/handlers/payload_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,32 @@
PayloadDeleteRequestSchema


ALLOWED_EXTENSIONS = frozenset([
'.ps1', '.sh', '.py', '.exe', '.elf', '.bat', '.vbs', '.js', '.go', '.c',
'.zip', '.tar', '.gz', '.dll', '.bin', '.yaml', '.yml', '.txt', '.json',
])

# b'<%@ Page' is redundant because b'<%@' already matches it via startswith().
DANGEROUS_MAGIC_BYTES = [
b'<?php', b'<%@', b'<%!',
]


def _validate_payload_file(filename, file_content_start):
"""Validate payload filename extension and magic bytes.
Returns (is_valid, error_message).
"""
if '\x00' in filename:
return False, 'Null byte detected in filename'
ext = os.path.splitext(filename)[1].lower()
if ext and ext not in ALLOWED_EXTENSIONS:
return False, f'File extension not allowed: {ext}'
for magic in DANGEROUS_MAGIC_BYTES:
if file_content_start.startswith(magic):
Comment on lines +36 to +37
return False, f'Dangerous file signature detected'
Comment on lines +36 to +38
return True, ''


class PayloadApi(BaseApi):
def __init__(self, services):
super().__init__(auth_svc=services['auth_svc'])
Expand Down Expand Up @@ -70,9 +96,18 @@ async def post_payloads(self, request: web.Request):
# accessing the file using the prefilled request["form"] dictionary.
file_field: web.FileField = request["form"]["file"]

# Sanitize the file name to prevent directory traversal
# Sanitize the filename first so validation uses the same name that
# will be used for storage, preventing discrepancies if sanitization
# changes the extension or structure.
sanitized_filename = self.sanitize_filename(file_field.filename)

# Validate sanitized filename and magic bytes.
first_bytes = file_field.file.read(16)
file_field.file.seek(0)
is_valid, error_msg = _validate_payload_file(sanitized_filename, first_bytes)
if not is_valid:
raise web.HTTPBadRequest(text=error_msg)

# Generate the file name and path
file_name, file_path = await self.__generate_file_name_and_path(sanitized_filename)

Expand Down
37 changes: 37 additions & 0 deletions tests/security/test_payload_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import unittest

from app.api.v2.handlers.payload_api import _validate_payload_file


class TestPayloadValidation(unittest.TestCase):
def test_valid_extension(self):
ok, _ = _validate_payload_file('test.ps1', b'\x00\x00\x00\x00')
self.assertTrue(ok)

def test_invalid_extension(self):
ok, msg = _validate_payload_file('test.php', b'normal content')
self.assertFalse(ok)
self.assertIn('extension', msg.lower())

def test_dangerous_magic_bytes_php(self):
ok, msg = _validate_payload_file('test.txt', b'<?php echo "hi";')
self.assertFalse(ok)
self.assertIn('dangerous', msg.lower())

def test_dangerous_magic_bytes_jsp(self):
ok, msg = _validate_payload_file('test.txt', b'<%@ page import')
self.assertFalse(ok)

def test_null_byte_in_filename(self):
ok, msg = _validate_payload_file('test\x00.txt', b'safe')
self.assertFalse(ok)
self.assertIn('Null byte', msg)

def test_no_extension_is_allowed(self):
ok, _ = _validate_payload_file('myagent', b'\x7fELF')
self.assertTrue(ok)

def test_redundant_asp_page_signature_still_blocked(self):
"""b'<%@ Page' must still be rejected because b'<%@' prefix matches it."""
ok, _ = _validate_payload_file('test.txt', b'<%@ Page Language')
self.assertFalse(ok)
Loading