Skip to content
This repository has been archived by the owner on Oct 10, 2019. It is now read-only.

Commit

Permalink
Merge pull request #47 from automationator/master
Browse files Browse the repository at this point in the history
Adds bulk mode when getting indicators
  • Loading branch information
automationator committed Mar 18, 2019
2 parents 4577d23 + a33dd85 commit d721cff
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 51 deletions.
16 changes: 15 additions & 1 deletion services/web/project/api/routes/indicator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import datetime
import gzip
import json

from dateutil.parser import parse
from flask import current_app, jsonify, request, url_for
from flask import current_app, jsonify, request, Response, url_for
from sqlalchemy import and_, exc

from project import db
Expand Down Expand Up @@ -559,6 +561,18 @@ def read_indicators():
if 'value' in request.args:
filters.add(Indicator.value.like('%{}%'.format(request.args.get('value'))))

# If bulk is enabled, get all of the results and compress them.
if 'bulk' in request.args:
if parse_boolean(request.args.get('bulk')):
data = [indicator.to_dict(bulk=True) for indicator in Indicator.query.filter(*filters)]
data = json.dumps(data).encode('utf-8')

response = Response(status=200, mimetype='application/json')
response.data = gzip.compress(data)
response.headers['Content-Encoding'] = 'gzip'
response.headers['Content-Length'] = len(response.data)
return response

data = Indicator.to_collection_dict(Indicator.query.filter(*filters), 'api.read_indicators', **request.args)
return jsonify(data)

Expand Down
93 changes: 43 additions & 50 deletions services/web/project/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def to_collection_dict(query, endpoint, **kwargs):
else:
page = 1
if 'per_page' in args:
per_page = min(int(args['per_page'][0]), 100)
per_page = min(int(args['per_page'][0]), 1000)
else:
per_page = 10
per_page = 100

# Now that we have the page and per_page values, remove them
# from the arguments so that the url_for function does not
Expand Down Expand Up @@ -74,38 +74,28 @@ def to_collection_dict(query, endpoint, **kwargs):


indicator_campaign_association = db.Table('indicator_campaign_mapping',
db.Column('indicator_id', db.Integer, db.ForeignKey('indicator.id')),
db.Column('campaign_id', db.Integer, db.ForeignKey('campaign.id'))
)
db.Column('indicator_id', db.Integer, db.ForeignKey('indicator.id'), primary_key=True),
db.Column('campaign_id', db.Integer, db.ForeignKey('campaign.id'), primary_key=True))

indicator_equal_association = db.Table('indicator_equal_mapping',
db.Column('left_id', db.Integer, db.ForeignKey('indicator.id'),
primary_key=True),
db.Column('right_id', db.Integer, db.ForeignKey('indicator.id'),
primary_key=True)
)
db.Column('left_id', db.Integer, db.ForeignKey('indicator.id'), primary_key=True),
db.Column('right_id', db.Integer, db.ForeignKey('indicator.id'), primary_key=True))

indicator_reference_association = db.Table('indicator_reference_mapping',
db.Column('indicator_id', db.Integer, db.ForeignKey('indicator.id')),
db.Column('intel_reference_id', db.Integer,
db.ForeignKey('intel_reference.id'))
)
db.Column('indicator_id', db.Integer, db.ForeignKey('indicator.id'), primary_key=True),
db.Column('intel_reference_id', db.Integer, db.ForeignKey('intel_reference.id'), primary_key=True))

indicator_relationship_association = db.Table('indicator_relationship_mapping',
db.Column('parent_id', db.Integer, db.ForeignKey('indicator.id'),
primary_key=True),
db.Column('child_id', db.Integer, db.ForeignKey('indicator.id'),
primary_key=True)
)
db.Column('parent_id', db.Integer, db.ForeignKey('indicator.id'), primary_key=True),
db.Column('child_id', db.Integer, db.ForeignKey('indicator.id'),primary_key=True))

indicator_tag_association = db.Table('indicator_tag_mapping',
db.Column('indicator_id', db.Integer, db.ForeignKey('indicator.id')),
db.Column('tag_id', db.Integer, db.ForeignKey('tag.id'))
)
db.Column('indicator_id', db.Integer, db.ForeignKey('indicator.id'), primary_key=True),
db.Column('tag_id', db.Integer, db.ForeignKey('tag.id'), primary_key=True))

roles_users_association = db.Table('role_user_mapping',
db.Column('user_id', db.Integer(), db.ForeignKey('user.id')),
db.Column('role_id', db.Integer(), db.ForeignKey('role.id')))
db.Column('user_id', db.Integer(), db.ForeignKey('user.id'), primary_key=True),
db.Column('role_id', db.Integer(), db.ForeignKey('role.id'), primary_key=True))

"""
TABLE CLASS DEFINITIONS
Expand Down Expand Up @@ -206,12 +196,12 @@ class Indicator(PaginatedAPIMixin, db.Model):
children = db.relationship('Indicator', secondary=indicator_relationship_association,
primaryjoin=(indicator_relationship_association.c.parent_id == id),
secondaryjoin=(indicator_relationship_association.c.child_id == id),
backref=db.backref('parent', lazy='select'), lazy='select')
backref=db.backref('parent', lazy='joined'), lazy='joined')

equal = db.relationship('Indicator', secondary=indicator_equal_association,
primaryjoin=(indicator_equal_association.c.left_id == id),
secondaryjoin=(indicator_equal_association.c.right_id == id),
lazy='select')
lazy='joined')

status = db.relationship('IndicatorStatus')
status_id = db.Column(db.Integer, db.ForeignKey('indicator_status.id'), nullable=False)
Expand All @@ -226,34 +216,37 @@ class Indicator(PaginatedAPIMixin, db.Model):
def __str__(self):
return str('{} : {}'.format(self.type, self.value))

def to_dict(self):
children = self.get_children(grandchildren=False)
all_children = self.get_children(grandchildren=True)

equal = self.get_equal(recursive=False)
all_equal = self.get_equal(recursive=True)

def to_dict(self, bulk=False):
data = {
'id': self.id,
'all_children': sorted([i.id for i in all_children]),
'all_equal': sorted([i.id for i in all_equal]),
'campaigns': [c.to_dict() for c in self.campaigns],
'case_sensitive': bool(self.case_sensitive),
'children': sorted([i.id for i in children]),
'confidence': self.confidence.value,
'created_time': self.created_time,
'equal': sorted([i.id for i in equal]),
'impact': self.impact.value,
'modified_time': self.modified_time,
'parent': self.get_parent().id if self.get_parent() else None,
'references': [r.to_dict() for r in self.references],
'status': self.status.value,
'substring': bool(self.substring),
'tags': sorted([t.value for t in self.tags]),
'type': self.type.value,
'user': self.user.username,
'value': self.value
}

if not bulk:
children = self.get_children(grandchildren=False)
all_children = self.get_children(grandchildren=True)

equal = self.get_equal(recursive=False)
all_equal = self.get_equal(recursive=True)

data['all_children'] = sorted([i.id for i in all_children])
data['all_equal'] = sorted([i.id for i in all_equal])
data['campaigns'] = [c.to_dict() for c in self.campaigns]
data['case_sensitive'] = bool(self.case_sensitive)
data['children'] = sorted([i.id for i in children])
data['confidence'] = self.confidence.value
data['created_time'] = self.created_time
data['equal'] = sorted([i.id for i in equal])
data['impact'] = self.impact.value
data['modified_time'] = self.modified_time
data['parent'] = self.get_parent().id if self.get_parent() else None
data['references'] = [r.to_dict() for r in self.references]
data['status'] = self.status.value
data['substring'] = bool(self.substring)
data['tags'] = sorted([t.value for t in self.tags])
data['user'] = self.user.username

return data

def add_child(self, other):
Expand Down Expand Up @@ -462,7 +455,7 @@ class Tag(db.Model):
__tablename__ = 'tag'

id = db.Column(db.Integer, primary_key=True, nullable=False)
value = db.Column(db.String(255), nullable=False)
value = db.Column(db.String(255), nullable=False, index=True)

def __str__(self):
return str(self.value)
Expand Down
8 changes: 8 additions & 0 deletions services/web/project/tests/api/test_indicator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import gzip
import time

from project.tests.conftest import TEST_ANALYST_APIKEY, TEST_INACTIVE_APIKEY, TEST_INVALID_APIKEY
Expand Down Expand Up @@ -787,6 +788,13 @@ def test_read_with_filters(client):

time.sleep(1)

# Filter with bulk mode enabled.
request = client.get('/api/indicators?bulk=true')
response = gzip.decompress(request.data)
response = json.loads(response.decode('utf-8'))
assert request.status_code == 200
assert len(response) == 2

# Filter by case_sensitive
request = client.get('/api/indicators?case_sensitive=true')
response = json.loads(request.data.decode())
Expand Down

0 comments on commit d721cff

Please sign in to comment.