diff --git a/lemur/certificates/service.py b/lemur/certificates/service.py index 7b40861fd3..9337e356f7 100644 --- a/lemur/certificates/service.py +++ b/lemur/certificates/service.py @@ -721,13 +721,15 @@ def query_name(certificate_name, args): def query_common_name(common_name, args): """ - Helper function that queries for not expired certificates by common name (and owner) + Helper function that queries for not expired certificates by common name, + owner and san. Pagination is supported. :param common_name: :param args: :return: """ owner = args.pop("owner") + san = args.pop("san") page = args.pop("page") count = args.pop("count") @@ -747,6 +749,10 @@ def query_common_name(common_name, args): # if common_name is a wildcard ('%'), no need to include it in the query query = query.filter(Certificate.cn.ilike(common_name)) + if san and san != "%": + # if san is a wildcard ('%'), no need to include it in the query + query = query.filter(Certificate.id.in_(like_domain_query(san))) + if paginate: args = {"page": page, "count": count, "sort_by": "id", "sort_dir": "desc"} return database.sort_and_page(query, Certificate, args) diff --git a/lemur/certificates/views.py b/lemur/certificates/views.py index d3e8f39e03..982c84f0fc 100644 --- a/lemur/certificates/views.py +++ b/lemur/certificates/views.py @@ -137,6 +137,7 @@ def get(self): # using non-paginated parser to ensure backward compatibility self.reqparse.add_argument("filter", type=str, location="args") self.reqparse.add_argument("owner", type=str, location="args") + self.reqparse.add_argument("san", type=str, location="args") self.reqparse.add_argument("count", type=int, location="args") self.reqparse.add_argument("page", type=int, location="args") diff --git a/lemur/tests/test_certificates.py b/lemur/tests/test_certificates.py index 69e9330c71..f950ab6a7e 100644 --- a/lemur/tests/test_certificates.py +++ b/lemur/tests/test_certificates.py @@ -1776,10 +1776,12 @@ def test_query_common_name(session): cert_cn1_replaced.cn = cn1 cert_cn1_valid = CertificateFactory() cert_cn1_valid.cn = cn1 + cert_cn1_valid.domains = [Domain(name=cn1)] cert_cn1_valid.owner = "owner1@example.org" cert_cn1_valid.replaces.append(cert_cn1_replaced) cert_cn1_valid2 = CertificateFactory() cert_cn1_valid2.cn = cn1 + cert_cn1_valid2.domains = [Domain(name=cn1)] cert_cn1_valid2.owner = "owner2@example.org" yesterday = arrow.utcnow() + timedelta(days=-1) cert_cn1_expired = CertificateFactory() @@ -1793,34 +1795,63 @@ def test_query_common_name(session): cert_cn2 = CertificateFactory() cert_cn2.cn = cn2 - cn1_valid_certs = query_common_name(cn1, {"owner": "", "page": "", "count": ""}) + cn1_valid_certs = query_common_name(cn1, {"owner": "", "san": "", "page": "", "count": ""}) assert len(cn1_valid_certs) == 2 - cn1_valid_certs_paged = query_common_name(cn1, {"owner": "", "page": 1, "count": 100}) + # since CN is also stored as SAN, count should be the same if filtered using cn1 as SAN + cn1_san_valid_certs = query_common_name('%', {"owner": "", "san": cn1, "page": "", "count": ""}) + assert len(cn1_san_valid_certs) == 2 + + cn1_valid_certs_paged = query_common_name(cn1, {"owner": "", "san": "", "page": 1, "count": 100}) assert cn1_valid_certs_paged["total"] == 2 assert len(cn1_valid_certs_paged["items"]) == 2 - cn1_valid_certs_paged_single = query_common_name(cn1, {"owner": "", "page": 1, "count": 1}) + cn1_valid_certs_paged_single = query_common_name(cn1, {"owner": "", "san": "", "page": 1, "count": 1}) assert cn1_valid_certs_paged_single["total"] == 2 assert len(cn1_valid_certs_paged_single["items"]) == 1 - cn1_owner1_valid_certs = query_common_name(cn1, {"owner": "owner1@example.org", "page": "", "count": ""}) + cn1_owner1_valid_certs = query_common_name(cn1, {"owner": "owner1@example.org", "san": "", "page": "", "count": ""}) assert len(cn1_owner1_valid_certs) == 1 - cn1_owner1_valid_certs_paged = query_common_name(cn1, {"owner": "owner1@example.org", "page": 1, "count": 100}) + cn1_owner1_valid_certs_paged = query_common_name(cn1, {"owner": "owner1@example.org", "san": "", "page": 1, "count": 100}) assert cn1_owner1_valid_certs_paged["total"] == 1 assert len(cn1_owner1_valid_certs_paged["items"]) == 1 - cn1_owner2_valid_certs = query_common_name(cn1, {"owner": "owner2@example.org", "page": "", "count": ""}) + cn1_owner2_valid_certs = query_common_name(cn1, {"owner": "owner2@example.org", "san": "", "page": "", "count": ""}) assert len(cn1_owner2_valid_certs) == 1 - cn1_owner3_valid_certs = query_common_name(cn1, {"owner": "owner3@example.org", "page": "", "count": ""}) + cn1_owner3_valid_certs = query_common_name(cn1, {"owner": "owner3@example.org", "san": "", "page": "", "count": ""}) assert len(cn1_owner3_valid_certs) == 0 - cn2_valid_certs = query_common_name(cn2, {"owner": "", "page": "", "count": ""}) + cn2_valid_certs = query_common_name(cn2, {"owner": "", "san": "", "page": "", "count": ""}) assert len(cn2_valid_certs) == 1 +def test_query_san(session): + from lemur.tests.factories import CertificateFactory + from lemur.certificates.service import query_common_name + + san1 = "testsan1.example.org" + san2 = "testsan2.example.org" + + cert_one_san_valid = CertificateFactory() + cert_one_san_valid.domains = [Domain(name=san1)] + cert_one_san_valid.owner = "owner1@example.org" + + cert_two_san_valid = CertificateFactory() + cert_two_san_valid.domains = [Domain(name=san1), Domain(name=san2)] + cert_two_san_valid.owner = "owner2@example.org" + + san1_valid_certs = query_common_name('%', {"owner": "", "san": san1, "page": "", "count": ""}) + assert len(san1_valid_certs) == 2 + + san1_owner1_valid_certs = query_common_name('%', {"owner": "owner1@example.org", "san": san1, "page": "", "count": ""}) + assert len(san1_owner1_valid_certs) == 1 + + san1_valid_certs = query_common_name('%', {"owner": "", "san": san2, "page": "", "count": ""}) + assert len(san1_valid_certs) == 1 + + def test_reissue_certificate_with_duplicate_destinations_not_allowed(session, logged_in_user, crypto_authority,