Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save method pool data using log service #585

Merged
merged 6 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 73 additions & 15 deletions apiserver/report/handler/saas_method_pool_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import json
import logging
import random
import socket
import time
import uuid
from hashlib import sha256,sha1

import requests
Expand Down Expand Up @@ -35,6 +37,18 @@

@ReportHandler.register(const.REPORT_VULN_SAAS_POOL)
class SaasMethodPoolHandler(IReportHandler):
def __init__(self):
super(SaasMethodPoolHandler, self).__init__()
self.async_send = settings.config.getboolean('task', 'async_send', fallback=False)
self.async_send_delay = settings.config.getint('task', 'async_send_delay', fallback=2)
self.retryable = settings.config.getboolean('task', 'retryable', fallback=False)

if self.async_send and (ReportHandler.log_service_disabled or ReportHandler.log_service is None):
logger.error('log service disabled or failed to connect, disable async send method pool')
self.async_send = False
else:
self.log_service = ReportHandler.log_service

@staticmethod
def parse_headers(headers_raw):
headers = dict()
Expand Down Expand Up @@ -138,18 +152,55 @@ def save(self):
self.send_to_engine(method_pool_id=method_pool_id,
model='replay')
else:
pool_sign = self.calc_hash()
current_version_agents = self.get_project_agents(self.agent)
with transaction.atomic():
pool_sign = uuid.uuid4().hex
if self.async_send:
try:
update_record, method_pool = self.save_method_call(
pool_sign, current_version_agents)
self.send_to_engine(method_pool_id=method_pool.id,
update_record=update_record)
method_pool = self.to_json(pool_sign)
ok = self.log_service.send(method_pool)
if ok:
self.send_to_engine(self.agent_id, method_pool_sign=pool_sign)
except Exception as e:
logger.info(e, exc_info=True)

logger.error(e, exc_info=True)
else:
current_version_agents = self.get_project_agents(self.agent)
with transaction.atomic():
try:
update_record, method_pool = self.save_method_call(
pool_sign, current_version_agents)
self.send_to_engine(self.agent_id, method_pool_sign=pool_sign,
update_record=update_record)
except Exception as e:
logger.error(e, exc_info=True)

def to_json(self, pool_sign: str):
timestamp = int(time.time())
pool = {
'agent_id': self.agent_id,
'url': self.http_url,
'uri': self.http_uri,
'http_method': self.http_method,
'http_scheme': self.http_scheme,
'http_protocol': self.http_protocol,
'req_header': self.http_req_header,
'req_params': self.http_query_string,
'req_data': self.http_req_data,
'req_header_for_search': utils.build_request_header(req_method=self.http_method,
raw_req_header=self.http_req_header,
uri=self.http_uri,
query_params=self.http_query_string,
http_protocol=self.http_protocol),
'res_header': utils.base64_decode(self.http_res_header),
'res_body': decode_content(get_res_body(self.http_res_body, self.version),
get_content_encoding(self.http_res_header), self.version),
'context_path': self.context_path,
'method_pool': json.dumps(self.method_pool),
'pool_sign': pool_sign,
'clent_ip': self.client_ip,
'create_time': timestamp,
'update_time': timestamp,
'uri_sha1': self.sha1(self.http_uri),
}
return json.dumps(pool)

def save_method_call(self, pool_sign: str,
current_version_agents) -> Tuple[bool, MethodPool]:
Expand Down Expand Up @@ -232,22 +283,29 @@ def save_method_call(self, pool_sign: str,
)
return update_record, method_pool

@staticmethod
def send_to_engine(method_pool_id, update_record=False, model=None):
def send_to_engine(self, agent_id, method_pool_id="", method_pool_sign="", update_record=False, model=None):
try:
if model is None:
logger.info(
f'[+] send method_pool [{method_pool_id}] to engine for {"update" if update_record else "new record"}')
search_vul_from_method_pool.delay(method_pool_id)
search_sink_from_method_pool.delay(method_pool_id)
f'[+] send method_pool [{method_pool_sign}] to engine for {"update" if update_record else "new record"}')
delay = 0
if self.async_send:
delay = self.async_send_delay
kwargs = {
'method_pool_sign': method_pool_sign,
'agent_id': agent_id,
'retryable': self.retryable,
}
search_vul_from_method_pool.apply_async(kwargs=kwargs, countdown=delay)
search_sink_from_method_pool.apply_async(kwargs=kwargs, countdown=delay)
else:
logger.info(
f'[+] send method_pool [{method_pool_id}] to engine for {model if model else ""}'
)
search_vul_from_replay_method_pool.delay(method_pool_id)
#requests.get(url=settings.REPLAY_ENGINE_URL.format(id=method_pool_id))
except Exception as e:
logger.info(f'[-] Failure: send method_pool [{method_pool_id}], Error: {e}')
logger.error(f'[-] Failure: send method_pool [{method_pool_id}{method_pool_sign}], Error: {e}')

def calc_hash(self):
sign_raw = '-'.join(
Expand Down
48 changes: 48 additions & 0 deletions apiserver/report/log_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging
import socket

logger = logging.getLogger('dongtai.openapi')


class LogService:
def __init__(self, host, port):
super(LogService, self).__init__()
self.host = host
self.port = port
self.socket = None

def create_socket(self):
if self.socket:
return

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5)
try:
sock.connect((self.host, self.port))
sock.setblocking(False)
self.socket = sock
return True
except OSError:
logger.error(f'failed to connect log service {self.host}:{self.port}')
self.socket = None
sock.close()
return False

def __del__(self):
if self.socket:
self.socket.close()
self.socket = None

def send(self, message):
try:
if not self.socket:
self.create_socket()
if self.socket:
self.socket.sendall(bytes(message + "\n", encoding='utf-8'), socket.MSG_DONTWAIT)
return True
except Exception as e:
logger.error('failed to send message to log service', exc_info=e)
if self.socket:
self.socket.close()
self.socket = None
return False
16 changes: 16 additions & 0 deletions apiserver/report/report_handler_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
import logging, requests, json, time
from django.utils.translation import gettext_lazy as _
from AgentServer import settings
from apiserver.report.log_service import LogService
from dongtai.models.agent import IastAgent

logger = logging.getLogger('dongtai.openapi')


class ReportHandler:
HANDLERS = {}
log_service = None
log_service_disabled = False

# 注册handler到当前命名空间,后续进行异步处理数据
@staticmethod
Expand Down Expand Up @@ -60,6 +63,19 @@ def handler(reports, user):
@classmethod
def register(cls, handler_name):
def wrapper(handler):
async_send = settings.config.getboolean('task', 'async_send', fallback=False)
if not async_send:
cls.log_service_disabled = True
if cls.log_service is None and not cls.log_service_disabled:
host = settings.config.get('log_service', 'host')
port = settings.config.getint('log_service', 'port')
if not host or not port:
logger.error('log service must config host and post')
cls.log_service_disabled = True
srv = LogService(host, port)
if srv.create_socket():
cls.log_service = srv

logger.info(
_('Registration report type {} handler {}').format(handler_name, handler.__name__))
if handler_name not in cls.HANDLERS:
Expand Down
10 changes: 10 additions & 0 deletions conf/config.ini.example
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ access_key_secret = ZoEOSi7KfayQ7JalvJVHa37fdZ4XFY
[sca]
base_url = http://52.80.75.225:8000

[task]
retryable = false
max_retries = 3
async_send = false
async_send_delay = 2

[log_service]
host = localhost
port = 8082

[security]
csrf_trust_origins = localhost,.huoxian.cn,.secnium.xyz
secret_key = vbjlvbxfvazjfprywuxgyclmvhtmselddsefxxlcixovmqfpgy
Expand Down
10 changes: 10 additions & 0 deletions conf/config.ini.test
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ secret_key = vbjlvbxfvazjfprywuxgyclmvhtmselddsefxxlcixovmqfpgy
[sca]
base_url = http://52.80.75.225:8000

[task]
retryable = false
max_retries = 3
async_send = false
async_send_delay = 2

[log_service]
host = localhost
port = 8082

[other]
domain = http://localhost.domain/
demo_session_cookie_domain = .huoxian.cn
59 changes: 47 additions & 12 deletions core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@
}


RETRY_INTERVALS = [10, 30, 90]


class RetryableException(Exception):
pass


def queryset_to_iterator(queryset):
"""
将queryset转换为迭代器,解决使用queryset遍历数据导致的一次性加载至内存带来的内存激增问题
Expand Down Expand Up @@ -185,14 +192,22 @@ def search_and_save_sink(engine, method_pool_model, strategy):
method_pool_model.sinks.add(strategy.get('strategy'))


@shared_task(queue='dongtai-method-pool-scan')
def search_vul_from_method_pool(method_pool_id):
@shared_task(bind=True, queue='dongtai-method-pool-scan',
max_retries=settings.config.getint('task', 'max_retries', fallback=3))
def search_vul_from_method_pool(self, method_pool_sign, agent_id, retryable=False):

logger.info(f'漏洞检测开始,方法池 {method_pool_id}')
logger.info(f'漏洞检测开始,方法池 {method_pool_sign}')
try:
method_pool_model = MethodPool.objects.filter(id=method_pool_id).first()
method_pool_model = MethodPool.objects.filter(pool_sign=method_pool_sign, agent_id=agent_id).first()
if method_pool_model is None:
logger.warn(f'漏洞检测终止,方法池 {method_pool_id} 不存在')
if retryable:
if self.request.retries < self.max_retries:
tries = self.request.retries + 1
raise RetryableException(f'漏洞检测方法池 {method_pool_sign} 不存在,重试第 {tries} 次')
else:
logger.error(f'漏洞检测超过最大重试次数 {self.max_retries},方法池 {method_pool_sign} 不存在')
else:
logger.warning(f'漏洞检测终止,方法池 {method_pool_sign} 不存在')
return
check_response_header(method_pool_model)
check_response_content(method_pool_model)
Expand All @@ -206,8 +221,14 @@ def search_vul_from_method_pool(method_pool_id):
if strategy.get('value') in engine.method_pool_signatures:
search_and_save_vul(engine, method_pool_model, method_pool, strategy)
logger.info(f'漏洞检测成功')
except RetryableException as e:
if self.request.retries < self.max_retries:
delay = 5 + pow(3, self.request.retries) * 10
self.retry(exc=e, countdown=delay)
else:
logger.error(f'漏洞检测超过最大重试次数,错误原因:{e}')
except Exception as e:
logger.error(f'漏洞检测出错,方法池 {method_pool_id}. 错误原因:{e}')
logger.error(f'漏洞检测出错,方法池 {method_pool_sign}. 错误原因:{e}')


@shared_task(queue='dongtai-replay-vul-scan')
Expand Down Expand Up @@ -258,25 +279,39 @@ def search_vul_from_strategy(strategy_id):
logger.error(f'漏洞检测出错,错误原因:{e}')


@shared_task(queue='dongtai-search-scan')
def search_sink_from_method_pool(method_pool_id):
@shared_task(bind=True, queue='dongtai-search-scan',
max_retries=settings.config.getint('task', 'max_retries', fallback=3))
def search_sink_from_method_pool(self, method_pool_sign, agent_id, retryable=False):
"""
根据方法池ID搜索方法池中是否匹配到策略库中的sink方法
:param method_pool_id: 方法池ID
:param self: celery task
:param method_pool_sign: 方法池 sign
:param agent_id: Agent ID
:param retryable: 可重试
:return: None
"""
logger.info(f'sink规则扫描开始,方法池ID[{method_pool_id}]')
logger.info(f'sink规则扫描开始,方法池[{method_pool_sign}]')
try:
method_pool_model = MethodPool.objects.filter(id=method_pool_id).first()
method_pool_model = MethodPool.objects.filter(pool_sign=method_pool_sign, agent_id=agent_id).first()
if method_pool_model is None:
logger.warn(f'sink规则扫描终止,方法池 [{method_pool_id}] 不存在')
if retryable:
if self.request.retries < self.max_retries:
tries = self.request.retries + 1
raise RetryableException(f'sink规则扫描方法池 {method_pool_sign} 不存在,重试第 {tries} 次')
else:
logger.error(f'sink规则扫描超过最大重试次数 {self.max_retries},方法池 {method_pool_sign} 不存在')
else:
logger.warn(f'sink规则扫描终止,方法池 [{method_pool_sign}] 不存在')
return
strategies = load_sink_strategy(method_pool_model.agent.user, method_pool_model.agent.language)
engine = VulEngine()

for strategy in strategies:
search_and_save_sink(engine, method_pool_model, strategy)
logger.info(f'sink规则扫描完成')
except RetryableException as e:
delay = 5 + pow(3, self.request.retries) * 10
self.retry(exc=e, countdown=delay)
except Exception as e:
logger.error(f'sink规则扫描出错,错误原因:{e}')

Expand Down
Loading