Skip to content

Commit

Permalink
Merge pull request #585 from lostsnow/feature/method-pool-to-log-service
Browse files Browse the repository at this point in the history
Save method pool data using log service
  • Loading branch information
Bidaya0 committed May 9, 2022
2 parents 8824a06 + 2a360f6 commit 0138e0b
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 36 deletions.
88 changes: 73 additions & 15 deletions apiserver/report/handler/saas_method_pool_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import json
import logging
import random
import socket
import time
import uuid
from hashlib import sha256,sha1

import requests
Expand Down Expand Up @@ -35,6 +37,18 @@

@ReportHandler.register(const.REPORT_VULN_SAAS_POOL)
class SaasMethodPoolHandler(IReportHandler):
def __init__(self):
super(SaasMethodPoolHandler, self).__init__()
self.async_send = settings.config.getboolean('task', 'async_send', fallback=False)
self.async_send_delay = settings.config.getint('task', 'async_send_delay', fallback=2)
self.retryable = settings.config.getboolean('task', 'retryable', fallback=False)

if self.async_send and (ReportHandler.log_service_disabled or ReportHandler.log_service is None):
logger.error('log service disabled or failed to connect, disable async send method pool')
self.async_send = False
else:
self.log_service = ReportHandler.log_service

@staticmethod
def parse_headers(headers_raw):
headers = dict()
Expand Down Expand Up @@ -138,18 +152,55 @@ def save(self):
self.send_to_engine(method_pool_id=method_pool_id,
model='replay')
else:
pool_sign = self.calc_hash()
current_version_agents = self.get_project_agents(self.agent)
with transaction.atomic():
pool_sign = uuid.uuid4().hex
if self.async_send:
try:
update_record, method_pool = self.save_method_call(
pool_sign, current_version_agents)
self.send_to_engine(method_pool_id=method_pool.id,
update_record=update_record)
method_pool = self.to_json(pool_sign)
ok = self.log_service.send(method_pool)
if ok:
self.send_to_engine(self.agent_id, method_pool_sign=pool_sign)
except Exception as e:
logger.info(e, exc_info=True)

logger.error(e, exc_info=True)
else:
current_version_agents = self.get_project_agents(self.agent)
with transaction.atomic():
try:
update_record, method_pool = self.save_method_call(
pool_sign, current_version_agents)
self.send_to_engine(self.agent_id, method_pool_sign=pool_sign,
update_record=update_record)
except Exception as e:
logger.error(e, exc_info=True)

def to_json(self, pool_sign: str):
timestamp = int(time.time())
pool = {
'agent_id': self.agent_id,
'url': self.http_url,
'uri': self.http_uri,
'http_method': self.http_method,
'http_scheme': self.http_scheme,
'http_protocol': self.http_protocol,
'req_header': self.http_req_header,
'req_params': self.http_query_string,
'req_data': self.http_req_data,
'req_header_for_search': utils.build_request_header(req_method=self.http_method,
raw_req_header=self.http_req_header,
uri=self.http_uri,
query_params=self.http_query_string,
http_protocol=self.http_protocol),
'res_header': utils.base64_decode(self.http_res_header),
'res_body': decode_content(get_res_body(self.http_res_body, self.version),
get_content_encoding(self.http_res_header), self.version),
'context_path': self.context_path,
'method_pool': json.dumps(self.method_pool),
'pool_sign': pool_sign,
'clent_ip': self.client_ip,
'create_time': timestamp,
'update_time': timestamp,
'uri_sha1': self.sha1(self.http_uri),
}
return json.dumps(pool)

def save_method_call(self, pool_sign: str,
current_version_agents) -> Tuple[bool, MethodPool]:
Expand Down Expand Up @@ -232,22 +283,29 @@ def save_method_call(self, pool_sign: str,
)
return update_record, method_pool

@staticmethod
def send_to_engine(method_pool_id, update_record=False, model=None):
def send_to_engine(self, agent_id, method_pool_id="", method_pool_sign="", update_record=False, model=None):
try:
if model is None:
logger.info(
f'[+] send method_pool [{method_pool_id}] to engine for {"update" if update_record else "new record"}')
search_vul_from_method_pool.delay(method_pool_id)
search_sink_from_method_pool.delay(method_pool_id)
f'[+] send method_pool [{method_pool_sign}] to engine for {"update" if update_record else "new record"}')
delay = 0
if self.async_send:
delay = self.async_send_delay
kwargs = {
'method_pool_sign': method_pool_sign,
'agent_id': agent_id,
'retryable': self.retryable,
}
search_vul_from_method_pool.apply_async(kwargs=kwargs, countdown=delay)
search_sink_from_method_pool.apply_async(kwargs=kwargs, countdown=delay)
else:
logger.info(
f'[+] send method_pool [{method_pool_id}] to engine for {model if model else ""}'
)
search_vul_from_replay_method_pool.delay(method_pool_id)
#requests.get(url=settings.REPLAY_ENGINE_URL.format(id=method_pool_id))
except Exception as e:
logger.info(f'[-] Failure: send method_pool [{method_pool_id}], Error: {e}')
logger.error(f'[-] Failure: send method_pool [{method_pool_id}{method_pool_sign}], Error: {e}')

def calc_hash(self):
sign_raw = '-'.join(
Expand Down
48 changes: 48 additions & 0 deletions apiserver/report/log_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging
import socket

logger = logging.getLogger('dongtai.openapi')


class LogService:
def __init__(self, host, port):
super(LogService, self).__init__()
self.host = host
self.port = port
self.socket = None

def create_socket(self):
if self.socket:
return

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5)
try:
sock.connect((self.host, self.port))
sock.setblocking(False)
self.socket = sock
return True
except OSError:
logger.error(f'failed to connect log service {self.host}:{self.port}')
self.socket = None
sock.close()
return False

def __del__(self):
if self.socket:
self.socket.close()
self.socket = None

def send(self, message):
try:
if not self.socket:
self.create_socket()
if self.socket:
self.socket.sendall(bytes(message + "\n", encoding='utf-8'), socket.MSG_DONTWAIT)
return True
except Exception as e:
logger.error('failed to send message to log service', exc_info=e)
if self.socket:
self.socket.close()
self.socket = None
return False
16 changes: 16 additions & 0 deletions apiserver/report/report_handler_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
import logging, requests, json, time
from django.utils.translation import gettext_lazy as _
from AgentServer import settings
from apiserver.report.log_service import LogService
from dongtai.models.agent import IastAgent

logger = logging.getLogger('dongtai.openapi')


class ReportHandler:
HANDLERS = {}
log_service = None
log_service_disabled = False

# 注册handler到当前命名空间,后续进行异步处理数据
@staticmethod
Expand Down Expand Up @@ -60,6 +63,19 @@ def handler(reports, user):
@classmethod
def register(cls, handler_name):
def wrapper(handler):
async_send = settings.config.getboolean('task', 'async_send', fallback=False)
if not async_send:
cls.log_service_disabled = True
if cls.log_service is None and not cls.log_service_disabled:
host = settings.config.get('log_service', 'host')
port = settings.config.getint('log_service', 'port')
if not host or not port:
logger.error('log service must config host and post')
cls.log_service_disabled = True
srv = LogService(host, port)
if srv.create_socket():
cls.log_service = srv

logger.info(
_('Registration report type {} handler {}').format(handler_name, handler.__name__))
if handler_name not in cls.HANDLERS:
Expand Down
10 changes: 10 additions & 0 deletions conf/config.ini.example
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ access_key_secret = ZoEOSi7KfayQ7JalvJVHa37fdZ4XFY
[sca]
base_url = http://52.80.75.225:8000

[task]
retryable = false
max_retries = 3
async_send = false
async_send_delay = 2

[log_service]
host = localhost
port = 8082

[security]
csrf_trust_origins = localhost,.huoxian.cn,.secnium.xyz
secret_key = vbjlvbxfvazjfprywuxgyclmvhtmselddsefxxlcixovmqfpgy
Expand Down
10 changes: 10 additions & 0 deletions conf/config.ini.test
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ secret_key = vbjlvbxfvazjfprywuxgyclmvhtmselddsefxxlcixovmqfpgy
[sca]
base_url = http://52.80.75.225:8000

[task]
retryable = false
max_retries = 3
async_send = false
async_send_delay = 2

[log_service]
host = localhost
port = 8082

[other]
domain = http://localhost.domain/
demo_session_cookie_domain = .huoxian.cn
59 changes: 47 additions & 12 deletions core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@
}


RETRY_INTERVALS = [10, 30, 90]


class RetryableException(Exception):
pass


def queryset_to_iterator(queryset):
"""
将queryset转换为迭代器,解决使用queryset遍历数据导致的一次性加载至内存带来的内存激增问题
Expand Down Expand Up @@ -185,14 +192,22 @@ def search_and_save_sink(engine, method_pool_model, strategy):
method_pool_model.sinks.add(strategy.get('strategy'))


@shared_task(queue='dongtai-method-pool-scan')
def search_vul_from_method_pool(method_pool_id):
@shared_task(bind=True, queue='dongtai-method-pool-scan',
max_retries=settings.config.getint('task', 'max_retries', fallback=3))
def search_vul_from_method_pool(self, method_pool_sign, agent_id, retryable=False):

logger.info(f'漏洞检测开始,方法池 {method_pool_id}')
logger.info(f'漏洞检测开始,方法池 {method_pool_sign}')
try:
method_pool_model = MethodPool.objects.filter(id=method_pool_id).first()
method_pool_model = MethodPool.objects.filter(pool_sign=method_pool_sign, agent_id=agent_id).first()
if method_pool_model is None:
logger.warn(f'漏洞检测终止,方法池 {method_pool_id} 不存在')
if retryable:
if self.request.retries < self.max_retries:
tries = self.request.retries + 1
raise RetryableException(f'漏洞检测方法池 {method_pool_sign} 不存在,重试第 {tries} 次')
else:
logger.error(f'漏洞检测超过最大重试次数 {self.max_retries},方法池 {method_pool_sign} 不存在')
else:
logger.warning(f'漏洞检测终止,方法池 {method_pool_sign} 不存在')
return
check_response_header(method_pool_model)
check_response_content(method_pool_model)
Expand All @@ -206,8 +221,14 @@ def search_vul_from_method_pool(method_pool_id):
if strategy.get('value') in engine.method_pool_signatures:
search_and_save_vul(engine, method_pool_model, method_pool, strategy)
logger.info(f'漏洞检测成功')
except RetryableException as e:
if self.request.retries < self.max_retries:
delay = 5 + pow(3, self.request.retries) * 10
self.retry(exc=e, countdown=delay)
else:
logger.error(f'漏洞检测超过最大重试次数,错误原因:{e}')
except Exception as e:
logger.error(f'漏洞检测出错,方法池 {method_pool_id}. 错误原因:{e}')
logger.error(f'漏洞检测出错,方法池 {method_pool_sign}. 错误原因:{e}')


@shared_task(queue='dongtai-replay-vul-scan')
Expand Down Expand Up @@ -258,25 +279,39 @@ def search_vul_from_strategy(strategy_id):
logger.error(f'漏洞检测出错,错误原因:{e}')


@shared_task(queue='dongtai-search-scan')
def search_sink_from_method_pool(method_pool_id):
@shared_task(bind=True, queue='dongtai-search-scan',
max_retries=settings.config.getint('task', 'max_retries', fallback=3))
def search_sink_from_method_pool(self, method_pool_sign, agent_id, retryable=False):
"""
根据方法池ID搜索方法池中是否匹配到策略库中的sink方法
:param method_pool_id: 方法池ID
:param self: celery task
:param method_pool_sign: 方法池 sign
:param agent_id: Agent ID
:param retryable: 可重试
:return: None
"""
logger.info(f'sink规则扫描开始,方法池ID[{method_pool_id}]')
logger.info(f'sink规则扫描开始,方法池[{method_pool_sign}]')
try:
method_pool_model = MethodPool.objects.filter(id=method_pool_id).first()
method_pool_model = MethodPool.objects.filter(pool_sign=method_pool_sign, agent_id=agent_id).first()
if method_pool_model is None:
logger.warn(f'sink规则扫描终止,方法池 [{method_pool_id}] 不存在')
if retryable:
if self.request.retries < self.max_retries:
tries = self.request.retries + 1
raise RetryableException(f'sink规则扫描方法池 {method_pool_sign} 不存在,重试第 {tries} 次')
else:
logger.error(f'sink规则扫描超过最大重试次数 {self.max_retries},方法池 {method_pool_sign} 不存在')
else:
logger.warn(f'sink规则扫描终止,方法池 [{method_pool_sign}] 不存在')
return
strategies = load_sink_strategy(method_pool_model.agent.user, method_pool_model.agent.language)
engine = VulEngine()

for strategy in strategies:
search_and_save_sink(engine, method_pool_model, strategy)
logger.info(f'sink规则扫描完成')
except RetryableException as e:
delay = 5 + pow(3, self.request.retries) * 10
self.retry(exc=e, countdown=delay)
except Exception as e:
logger.error(f'sink规则扫描出错,错误原因:{e}')

Expand Down
Loading

0 comments on commit 0138e0b

Please sign in to comment.