# Prerequisits

In [None]:
from modules.elastic import ArticleSearchQuery
from modules.objects import FullArticle
from modules.config import BaseConfig

import logging
from dotenv import load_dotenv

logger = logging.getLogger("osinter")
load_dotenv()

config_options = BaseConfig()

In [None]:
import json
import gzip
from openai import OpenAIError, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
import tiktoken

from tenacity import (
    RetryError,
    before_sleep_log,
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_random_exponential,
)

from concurrent.futures import ThreadPoolExecutor
import multiprocessing

openai_client = OpenAI(api_key=config_options.OPENAI_KEY)

In [None]:
instruction_prompt = """You are given a summary of a news article, surrounded by triple qoutes, and your only job is to state if the article describes a specific cyber security incident.
Clarifying the Definition of "Incident": Emphasize that the incident must be explicitly described with clear evidence of a specific event occurring, rather than potential or evaded threats.
THERE MUST BE IMPACT! NO POTENTIAL IMPACT! IT MUST HAVE HAPPENED! Use the CIA triad to decide on impact
You shall only return a number, and nothing else. If the article descripes an incident, you should return a 2, and if not return a 1"""


In [None]:
def query_openai(prompts: list[ChatCompletionMessageParam]) -> str | None:
    @retry(
        wait=wait_random_exponential(min=1, max=3600),
        stop=stop_after_attempt(20),
        retry=retry_if_exception_type(OpenAIError),
        before_sleep=before_sleep_log(logger, logging.DEBUG),
    )
    def query(q: list[ChatCompletionMessageParam]) -> ChatCompletion:
        return openai_client.chat.completions.create(
            model="ft:gpt-3.5-turbo-1106:osinter-bertie:incident:8wsHrdRo",
            messages=q,
            n=1,
            temperature=1,
            frequency_penalty=0,
            presence_penalty=0,
        )

    try:
        return query(prompts).choices[0].message.content
    except RetryError:
        return None

def is_incident(content: str) -> bool | str:
    messages = [
        {"role": "user", "content": instruction_prompt},
        {"role": "user", "content": f'"""{content}"""'}
    ]

    response = query_openai(messages)

    try:
        number = int(response)
        if number < 1 or number > 2:
            return response
        else:
            return number == 2
    except ValueError:
        return response
    

In [None]:
class Counter:
    def __init__(self) -> None:
        self.lock = multiprocessing.Manager().Lock()
        self.count = 0

    def get_count(self) -> int:
        with self.lock:
            self.count += 1
            return self.count

In [None]:
print(config_options.ELASTICSEARCH_ARTICLE_INDEX)
articles = config_options.es_article_client.query_all_documents()
print(len(articles))

# For validation and finetuning

In [None]:
pre_classified_articles = [article for article in articles if article.ml.incident and article.summary]
incident = [article for article in pre_classified_articles if article.ml.incident == 2]
not_incident = [article for article in pre_classified_articles if article.ml.incident == 1]

shortest = min(len(incident), len(not_incident))
dataset = []
dataset.extend(incident[:shortest])
dataset.extend(not_incident[:shortest])


print(
    len(pre_classified_articles),
    len(incident),
    len(not_incident),
    len(dataset),
    {article.ml.incident for article in articles}
)

## For finetuning

In [None]:
prompts = []

for article in dataset:
    prompts.append({
        "messages": [
            {"role": "user", "content": instruction_prompt},
            {"role": "user", "content": f'"""{i.summary}"""'},
            {"role": "assistant", "content": f"{i.ml.incident}"}
        ]
    })

with open("finetuning.jsonl", "w") as f:
    for prompt in prompts:
        f.write(json.dumps(prompt) + "\n")

## For validation

In [None]:
successes = []
failures = []
counter = Counter()

def validate_article(article):
    count = counter.get_count()
    print(f"Starting {count}")
    response = is_incident(article.summary)

    if isinstance(response, str):
        failures.append((article, response))
    else:
        if response == (article.ml.incident == 2):
            successes.append(article)
        else:
            failures.append((article, response))
    
    print(f"Stopped {count}")

with ThreadPoolExecutor(max_workers=12) as executor:
    executor.map(validate_article, dataset)

In [None]:
formatting_fail = [fail for fail in failures if isinstance(fail[1], str)]
incident_fail = [fail for fail in failures if fail[0].ml.incident == 2]
not_incident_fail = [fail for fail in failures if fail[0].ml.incident == 1]

print(len(successes), len(failures), len(formatting_fail), len(incident_fail), len(not_incident_fail))

for fail in failures:
    article = fail[0]

    is_incident = "Incident" if article.ml.incident == 2 else "Not-incident"
    print(f"{is_incident}: {article.summary}\n\n")

# For prediction

In [None]:
def process_article(article):
    count = counter.get_count()
    print(f"Starting {count}")
    response = query_for_incident(article.summary)

    if isinstance(response, bool):
        article.ml.classification.incident = response
    
    print(f"Stopped {count}")

with ThreadPoolExecutor(max_workers=12) as executor:
    executor.map(process_article, articles)

In [None]:
with gzip.open("./classified.gz", "wt", encoding="utf-8") as f:
    json.dump([article.model_dump(mode="json") for article in articles], f)

In [None]:
config_options.es_article_client.update_documents(articles, ["ml"])