In [None]:
import sys

!{sys.executable} -m pip install ratelimit requests tqdm pydantic

In [None]:
from __future__ import annotations

import hashlib
import logging
import ssl
import time
import urllib3
from collections import defaultdict
from datetime import datetime, timedelta, UTC
from functools import wraps
from typing import Any, List, Optional
from urllib3.util.retry import Retry

import ratelimit
import requests
from requests import Response
from requests.adapters import HTTPAdapter
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
from pydantic import BaseModel, Field, TypeAdapter

logging.basicConfig(level=logging.ERROR)
log = logging.getLogger("mdatp")
log.setLevel(logging.INFO)

THREATAPI_KEY = "<your key>"

TENANT_ID = "<tenant_id>"
CLIENT_ID = "<client_id>"
CLIENT_SECRET = "<client_secret>"

In [None]:
# Models

class ThreatApiIndicator(BaseModel):
    value:       str = Field(alias="indicator")
    port:        int
    last_seen:   datetime
    threat:      str
    threat_type: str
    confidence:  str
    tlp:         str
    type:        str

class TiIndicator(BaseModel):
    domain_name:    Optional[str]       = Field(default=None, alias="domainName") 
    network_ipv4:   Optional[str]       = Field(default=None, alias="networkIPv4")
    network_ipv6:   Optional[str]       = Field(default=None, alias="networkIPv6")
    id:             Optional[str]       = Field(default=None)
    external_id:    Optional[str]       = Field(default=None, alias="externalId")
    action:         Optional[str]       = Field(default=None)
    description:    Optional[str]       = Field(default=None)
    target_product: Optional[str]       = Field(default=None, alias="targetProduct")
    threat_type:    Optional[str]       = Field(default=None, alias="threatType")
    tlp_level:      Optional[str]       = Field(default=None, alias="tlpLevel")
    confidence:     Optional[int]       = Field(default=None)
    is_active:      Optional[bool]      = Field(default=None, alias="isActive")
    expiration:     Optional[datetime]  = Field(default=None, alias="expirationDateTime")
    last_reported:  Optional[datetime]  = Field(default=None, alias="lastReportedDateTime")
    tags:           Optional[List[str]] = Field(default=None)

    @property
    def value(self) -> str:
        """Get the primary IOC value"""

        return self.domain_name or self.network_ipv4 or self.network_ipv6

class ValueResponse(BaseModel):
    next_link: Optional[str] = Field(default=None, alias="@odata.nextLink")
    value: List[TiIndicator]

class ResultInfo(BaseModel):
    class Error(BaseModel):
        code: int
        subcode: int
        message: str
    
    value: List[Error]

class MSGraphError(BaseModel):
    class ErrorDetail(BaseModel):
        # class InnerError(BaseModel):
        #     date:              datetime
        #     request_id:        str = Field(alias="request-id")
        #     client_request_id: str = Field(alias="client-request-id")

        code:        str
        message:     str
        # inner_error: InnerError = Field(alias="innerError")
    
    status_code: int = 0
    error:       ErrorDetail

class MSGraphException(Exception):
    def __init__(self, error: MSGraphError) -> None:
        """Initializer"""

        self.error = error
    
    @classmethod
    def from_response(cls, resp: Response) -> MSGraphException:
        """Create an exception from a response"""

        error = MSGraphError.model_validate_json(resp.content)
        error.status_code = resp.status_code

        return MSGraphException(error)
        

In [None]:
# Get a copy of the ThreatAPI data

class ThreatApi:
    URL = "https://api.threatanalysis.io/all/1d/json"
    PARAMS = {"confidence": "high", "threat_type": "c2"}

    def __init__(self, api_key: str) -> None:
        """Initializer"""

        self._api_key = api_key
    
    def get_indicators(self) -> List[ThreatApiIndicator]:
        "Get the ThreatAPI indicators"

        headers = headers={"x-api-key": self._api_key}
        resp = requests.get(self.URL, params=self.PARAMS, headers=headers)
        data = resp.json()
    
        iocs = [
            ioc 
            for ports in data.values()
            for iocs in ports.values()
            for ioc in iocs
        ]
        
        ta = TypeAdapter(List[ThreatApiIndicator])
        iocs = ta.validate_python(iocs)
        return iocs

def convert_threatapi(
    iocs: List[ThreatApiIndicator],
    expiration_delta: timedelta,
) -> List[TiIndicator]:
    """Convert a series of ThreatAPI indicators to a tiIndicators"""
    
    values = []

    # First dedup the origin data
    dedup = defaultdict(list)
    for ioc in iocs:
        key = ioc.value
        dedup[key].append(ioc)

    for key, iocs in dedup.items():
        external_id = hashlib.md5(key.encode("utf-8")).hexdigest()

        # tlp: max value and TLP1.0
        tlps = set(ioc.tlp for ioc in iocs)
        if "RED" in tlps:
            tlp = "red"
        elif "AMBER+STRICT" in tlps:
            tlp = "amber"
        elif "AMBER" in tlps:
            tlp = "amber"
        elif "GREEN" in tlps:
            tlp = "green"
        elif "CLEAR" in tlps:
            tlp = "white"
        else:
            tlp = "white"

        # confidence: max value
        confidences = set(ioc.confidence for ioc in iocs)
        if "High" in confidences:
            confidence = 90
        elif "Medium" in confidences:
            confidence = 60
        elif "Low" in confidences:
            confidence = 30
        else:
            confidence = 0

        # last_seen: get newest
        last_seen = set(ioc.last_seen for ioc in iocs)
        last_seen = max(last_seen)

        description = "IronRadar indicator of compromise.\n"
        description += "Threats: "
        for ioc in iocs:
            description += f"\n- {ioc.threat} ({ioc.threat_type}) on port {ioc.port}"

        ti = TiIndicator(
            domainName=None,
            networkIPv4=None,
            networkIPv6=None,
            action="alert",
            externalId=external_id,
            description=description,
            targetProduct="Microsoft Defender ATP",
            threatType="WatchList",
            tlpLevel=tlp,
            confidence=confidence,
            isActive=True,
            lastReportedDateTime=last_seen,
            expirationDateTime=last_seen + expiration_delta,
            tags=["ironradar"],
        )
        
        if ioc.type == "ipv4-addr":
            ti.network_ipv4 = ioc.value
        elif ioc.type == "ipv6-addr":
            ti.network_ipv6 = ioc.value
        elif ioc.type == "domain-name":
            ti.domain_name = ioc.value
        else:
            raise ValueError(f"Unhandled IOC type: {ioc.value}")

        values.append(ti)

    return values


threatapi = ThreatApi(THREATAPI_KEY)
THREAT_IOCS = threatapi.get_indicators()
log.info(f"{len(THREAT_IOCS)=}")

IOCS = convert_threatapi(THREAT_IOCS, timedelta(hours=1))
log.info(f"{len(IOCS)=}")

In [None]:
# Graph API

def sleep_and_retry(func):
    """Return a wrapped function that rescues rate limit exceptions and sleeps instead"""

    @wraps(func)
    def wrapper(*args, **kargs):
        while True:
            try:
                return func(*args, **kargs)
            except MSGraphException as ex:
                if "statusCode=429" in ex.error.error.message:
                    log.info("Rate limit reached, sleeping for 5")
                    time.sleep(5)
            except ratelimit.RateLimitException as ex:
                period = ex.period_remaining
                log.info(f"Rate limit reached, sleeping for {period}")
                time.sleep(period)
    
    return wrapper

class TiIndicators:
    """MSGraph TiIndicators API"""

    MS_LOGIN_URL = "https://login.microsoftonline.com/{tenant}/oauth2/token"
    MS_GRAPH_URL = "https://graph.microsoft.com"
    MS_GRAPH_VERSION = "beta"

    def __init__(
        self,
        tenant_id: str,
        client_id: str,
        client_secret: str,
        proxy: bool = False
    ) -> None:
        """Initializer"""
        
        self._tenant_id = tenant_id
        self._client_id = client_id
        self._client_secret = client_secret
        self._session = requests.Session()
        self._adapter = TypeAdapter(Any)
        
        self._setup_http_adapter(proxy)
        self._setup_reauthentication()

    def _setup_http_adapter(self, proxy: bool) -> None:
        """Create a custom HTTP adapter for SSL issues"""

        # SSLError: [SSL: UNSAFE_LEGACY_RENEGOTIATION_DISABLED] unsafe legacy renegotiation disabled
        class CustomHTTPAdapter (HTTPAdapter):
            def init_poolmanager(self, connections, maxsize, block=False):
                ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
                ctx.options |= 0x4  # OP_LEGACY_SERVER_CONNECT
                self.poolmanager = urllib3.poolmanager.PoolManager(num_pools=connections, maxsize=maxsize, block=block, ssl_context=ctx)

        retry = Retry(
            status_forcelist=[429], 
            backoff_factor=5,
            total=5)
        
        if proxy:
            adapter = CustomHTTPAdapter(max_retries=retry)
        else:
            adapter = HTTPAdapter(max_retries=retry)
        
        self._session.mount('https://', adapter)

    def _setup_reauthentication(self) -> None:
        """Hook 401 errors to automatically re-authenticate"""

        def refresh_token(resp: Response, *args: Any, **kwargs: Any) -> Response:
            # 401, {'error': {'code': 'InvalidAuthenticationToken', ...}}
            if resp.status_code == 401:
                log.debug("Re-authenticating")
                self.authenticate()
                resp.request.headers["Authorization"] = self._session.headers["Authorization"]
                return self._session.send(resp.request)

        self._session.hooks["response"].append(refresh_token)


    def authenticate(self) -> None:
        """Authenticate the internal session"""
        
        resp = self._session.post(
            self.MS_LOGIN_URL.format(tenant=self._tenant_id), 
            data={
                "client_id": self._client_id,
                "client_secret": self._client_secret,
                "resource": self.MS_GRAPH_URL,
                "grant_type": "client_credentials",
            }
        )
        data = resp.json()
        token = data["access_token"]
        self._session.headers.update({"Authorization": f"Bearer {token}"})

    def list_indicators(self, top: int = 1000, filter: str = None) -> List[TiIndicator]:
        """List all indicators"""

        endpoint = "security/tiIndicators"
        params = {}

        if top:
            params["$top"] = str(top)
        if filter:
            params["$filter"] = filter

        log.debug("tiIndicators.list")
        resp = self._request("GET", endpoint, params=params)
        data = ValueResponse.model_validate_json(resp.content)
        return data.value

    def get_indicator(self, id: str) -> TiIndicator:
        """Get an indicator"""

        endpoint = f"security/tiIndicators/{id}"

        log.debug(f"tiIndicators.get: {id}")
        resp = self._request("GET", endpoint)
        return TiIndicator.model_validate_json(resp.content)
    
    def create_indicator(self, ioc: TiIndicator) -> TiIndicator:
        """Create an indicator"""

        endpoint = "security/tiIndicators"
        json_data = ioc.model_dump(mode="json", exclude_none=True, by_alias=True)

        log.debug(f"tiIndicators.create: {ioc.value}")
        resp = self._request("POST", endpoint, json=json_data)
        return TiIndicator.model_validate_json(resp.content)

    def update_indicator(self, ioc: TiIndicator) -> TiIndicator:
        """Update an indicator"""

        endpoint = f"security/tiIndicators/{ioc.id}"
        json_data = ioc.model_dump(mode="json", exclude_none=True, by_alias=True)
        headers = {"Prefer": "return=representation"}

        log.debug(f"tiIndicators.update: {ioc.value}")
        resp = self._request("PATCH", endpoint, json=json_data, headers=headers)
        return TiIndicator.model_validate_json(resp.content)

    def delete_indicator(self, id: str) -> None:
        """Delete an indicator"""

        endpoint = f"security/tiIndicators/{id}"

        log.debug(f"tiIndicators.delete: {id}")
        self._request("DELETE", endpoint)
        return None

    def bulk_create_indicators(self, iocs: List[TiIndicator]) -> List[TiIndicator]:
        """Bulk import indicators"""

        endpoint = "security/tiIndicators/submitTiIndicators"
        json_data = self._adapter.dump_python(
            {"value": iocs}, 
            mode="json", 
            exclude_none=True, 
            by_alias=True,
        )

        log.debug("tiIndicators.bulkCreate")
        resp = self._request("POST", endpoint, json=json_data)
        data = ValueResponse.model_validate_json(resp.content)
        return data.value


    def bulk_update_indicators(self, iocs: List[TiIndicator]) -> List[TiIndicator]:
        """Bulk update indicators"""

        endpoint = "security/tiIndicators/updateTiIndicators"
        json_data = self._adapter.dump_python(
            {"value": iocs}, 
            mode="json", 
            exclude_none=True, 
            by_alias=True,
        )

        log.debug("tiIndicators.bulkUpdate")
        resp = self._request("POST", endpoint, json=json_data)
        data = ValueResponse.model_validate_json(resp.content)
        return data.value

    def bulk_delete_indicators(self, iocs: List[TiIndicator]) -> List[ResultInfo.Error]:
        """Bulk delete indicators"""

        endpoint = "security/tiIndicators/deleteTiIndicators"
        json_data = self._adapter.dump_python(
            {"value": [ioc.id for ioc in iocs]}, 
            mode="json", 
            exclude_none=True, 
            by_alias=True,
        )

        log.debug("tiIndicators.bulkDelete")
        resp = self._request("POST", endpoint, json=json_data)
        data = ResultInfo.model_validate_json(resp.content)
        return data.value
    
    def bulk_delete_indicators_external(self, iocs: List[TiIndicator]) -> List[ResultInfo.Error]:
        """Bulk delete indicators"""

        endpoint = "security/tiIndicators/deleteTiIndicatorsByExternalId"
        json_data = self._adapter.dump_python(
            {"value": [ioc.external_id for ioc in iocs]}, 
            mode="json", 
            exclude_none=True, 
            by_alias=True,
        )

        log.debug("tiIndicators.bulkDeleteExternal")
        resp = self._request("POST", endpoint, json=json_data)
        data = ResultInfo.model_validate_json(resp.content)
        return data.value
    
    @sleep_and_retry
    @ratelimit.limits(calls=50, period=61)
    def _request(self, op: str, endpoint: str, **kwargs: Any) -> Response:
        """Send a request"""

        url = f"{self.MS_GRAPH_URL}/{self.MS_GRAPH_VERSION}/{endpoint}"
        resp = self._session.request(op, url, **kwargs)
        if not resp.ok:
            raise MSGraphException.from_response(resp)
        return resp


ti = TiIndicators(TENANT_ID, CLIENT_ID, CLIENT_SECRET, proxy=True)
ti.authenticate()


In [None]:
def submit_indicators(iocs: List[TiIndicator]) -> None:
    """Submit indicators"""
    
    for ioc in tqdm(iocs, desc="Creating"):
        ti.create_indicator(ioc)

def delete_expired_indicators() -> None:
    """Delete expired indicators"""

    dt = datetime.utcnow().replace(tzinfo=UTC).isoformat().replace("+00:00", "Z")
    while True:
        iocs = ti.list_indicators(filter=f"ExpirationDateTime lt {dt}")
        for ioc in tqdm(iocs, desc="Deleting"):
            ti.delete_indicator(ioc.id)

        if len(iocs) == 0:
            break

iocs = IOCS[:55]
with logging_redirect_tqdm():
    submit_indicators(iocs)
    delete_expired_indicators()


