|
9 | 9 | from datetime import datetime |
10 | 10 | from httpx import Client |
11 | 11 |
|
12 | | -from domaintools.constants import FEEDS_PRODUCTS_LIST, OutputFormat, HEADER_ACCEPT_KEY_CSV_FORMAT |
| 12 | +from domaintools.constants import RTTF_PRODUCTS_LIST, OutputFormat, HEADER_ACCEPT_KEY_CSV_FORMAT |
13 | 13 | from domaintools.exceptions import ( |
14 | 14 | BadRequestException, |
15 | 15 | InternalServerErrorException, |
@@ -75,45 +75,45 @@ def _wait_time(self): |
75 | 75 |
|
76 | 76 | return wait_for |
77 | 77 |
|
78 | | - def _get_session_params(self): |
79 | | - parameters = deepcopy(self.kwargs) |
80 | | - parameters.pop("output_format", None) |
81 | | - parameters.pop( |
82 | | - "format", None |
83 | | - ) # For some unknownn reasons, even if "format" is not included in the cli params for feeds endpoint, it is being populated thus we need to remove it. Happens only if using CLI. |
| 78 | + def _get_session_params_and_headers(self): |
84 | 79 | headers = {} |
85 | | - if self.kwargs.get("output_format", OutputFormat.JSONL.value) == OutputFormat.CSV.value: |
86 | | - parameters["headers"] = int(bool(self.kwargs.get("headers", False))) |
87 | | - headers["accept"] = HEADER_ACCEPT_KEY_CSV_FORMAT |
88 | | - |
89 | | - header_api_key = parameters.pop("X-Api-Key", None) |
90 | | - if header_api_key: |
91 | | - headers["X-Api-Key"] = header_api_key |
| 80 | + parameters = deepcopy(self.kwargs) |
| 81 | + is_rttf_product = self.product in RTTF_PRODUCTS_LIST |
| 82 | + if is_rttf_product: |
| 83 | + parameters.pop("output_format", None) |
| 84 | + parameters.pop( |
| 85 | + "format", None |
| 86 | + ) # For some unknownn reasons, even if "format" is not included in the cli params for feeds endpoint, it is being populated thus we need to remove it. Happens only if using CLI. |
| 87 | + if self.kwargs.get("output_format", OutputFormat.JSONL.value) == OutputFormat.CSV.value: |
| 88 | + parameters["headers"] = int(bool(self.kwargs.get("headers", False))) |
| 89 | + headers["accept"] = HEADER_ACCEPT_KEY_CSV_FORMAT |
| 90 | + |
| 91 | + if self.api.header_authentication: |
| 92 | + header_key_for_api_key = "X-Api-Key" if is_rttf_product else "X-API-Key" |
| 93 | + headers[header_key_for_api_key] = parameters.pop("api_key", None) |
92 | 94 |
|
93 | 95 | return {"parameters": parameters, "headers": headers} |
94 | 96 |
|
95 | 97 | def _make_request(self): |
96 | 98 |
|
97 | 99 | with Client(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session: |
| 100 | + session_params_and_headers = self._get_session_params_and_headers() |
| 101 | + headers = session_params_and_headers.get("headers") |
98 | 102 | if self.product in [ |
99 | 103 | "iris-investigate", |
100 | 104 | "iris-enrich", |
101 | 105 | "iris-detect-escalate-domains", |
102 | 106 | ]: |
103 | 107 | post_data = self.kwargs.copy() |
104 | 108 | post_data.update(self.api.extra_request_params) |
105 | | - return session.post(url=self.url, data=post_data) |
| 109 | + return session.post(url=self.url, data=post_data, headers=headers) |
106 | 110 | elif self.product in ["iris-detect-manage-watchlist-domains"]: |
107 | 111 | patch_data = self.kwargs.copy() |
108 | 112 | patch_data.update(self.api.extra_request_params) |
109 | | - return session.patch(url=self.url, json=patch_data) |
110 | | - elif self.product in FEEDS_PRODUCTS_LIST: |
111 | | - session_params = self._get_session_params() |
112 | | - parameters = session_params.get("parameters") |
113 | | - headers = session_params.get("headers") |
114 | | - return session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params) |
| 113 | + return session.patch(url=self.url, json=patch_data, headers=headers) |
115 | 114 | else: |
116 | | - return session.get(url=self.url, params=self.kwargs, **self.api.extra_request_params) |
| 115 | + parameters = session_params_and_headers.get("parameters") |
| 116 | + return session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params) |
117 | 117 |
|
118 | 118 | def _get_results(self): |
119 | 119 | wait_for = self._wait_time() |
@@ -170,7 +170,7 @@ def status(self): |
170 | 170 |
|
171 | 171 | def setStatus(self, code, response=None): |
172 | 172 | self._status = code |
173 | | - if code == 200 or (self.product in FEEDS_PRODUCTS_LIST and code == 206): |
| 173 | + if code == 200 or (self.product in RTTF_PRODUCTS_LIST and code == 206): |
174 | 174 | return |
175 | 175 |
|
176 | 176 | reason = None |
|
0 commit comments