In [1]:
from google.colab import drive

In [2]:
drive.mount('/content/drive')

Mounted at /content/drive


In [11]:
import pandas as pd
import requests
import json
import re
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from IPython.display import display, HTML, Markdown

In [12]:
def load_attack_data():
    """Load MITRE ATT&CK data from the MITRE STIX repository or local cache"""
    attack_data_path = "attack_data.csv"

    if os.path.exists(attack_data_path):
        return pd.read_csv(attack_data_path)

    print("Downloading MITRE ATT&CK data...")
    # If not cached, fetch from MITRE ATT&CK API
    url = "https://raw.githubusercontent.com/mitre/cti/master/enterprise-attack/enterprise-attack.json"
    response = requests.get(url)
    data = response.json()

    # Extract techniques and tactics
    techniques = []
    for obj in data["objects"]:
        if obj.get("type") == "attack-pattern":
            techniques.append({
                "technique_id": obj.get("external_references", [{}])[0].get("external_id", ""),
                "technique_name": obj.get("name", ""),
                "description": obj.get("description", ""),
                "tactic_phase": ", ".join([p.get("phase_name", "") for p in obj.get("kill_chain_phases", [])])
            })

    df = pd.DataFrame(techniques)
    df.to_csv(attack_data_path, index=False)
    print(f"Saved {len(techniques)} techniques to {attack_data_path}")
    return df

In [13]:
def load_llm_model(model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
    """Load the specified LLM model"""
    print(f"Loading model: {model_name}")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            device_map="auto"
        )
        print("Model loaded successfully")
        return model, tokenizer
    except Exception as e:
        print(f"Error loading model: {e}")
        return None, None

In [14]:
def extract_attack_info(text, model, tokenizer):
    """Extract ATT&CK tactics and techniques from text using LLM"""
    prompt = f"""
    You are a cybersecurity expert specializing in MITRE ATT&CK framework. Analyze the following cyber threat intelligence text and extract all ATT&CK tactics and techniques mentioned explicitly or implicitly.

    For each identified technique, provide:
    1. The ATT&CK Technique ID (e.g., T1566)
    2. The technique name (e.g., Phishing)
    3. The tactic category (e.g., Initial Access)
    4. A brief explanation of why this technique was identified in the text
    5. The specific text excerpt that indicates this technique (if available)

    Format your response as JSON with the structure:
    {{
        "techniques": [
            {{
                "technique_id": "T1566",
                "technique_name": "Phishing",
                "tactic": "Initial Access",
                "explanation": "The report mentions phishing emails containing malicious attachments",
                "text_evidence": "Attackers sent targeted emails with PDF attachments containing exploits"
            }}
        ]
    }}

    Text to analyze:
    {text}
    """

    print("Analyzing text with the LLM model...")
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        outputs = model.generate(
            inputs["input_ids"],
            max_new_tokens=1024,
            temperature=0.1,
            do_sample=True,
            top_p=0.95,
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract the JSON part from the response
    try:
        json_match = re.search(r'(\{.*\})', response, re.DOTALL)
        if json_match:
            json_str = json_match.group(1)
            # Clean up any extra characters
            json_str = re.sub(r'```json|```', '', json_str).strip()
            return json_str
        else:
            # If no JSON found, return the full response
            return response
    except Exception as e:
        return f"Error parsing response: {e}\n\nFull response: {response}"

In [15]:
def process_large_document(document, model, tokenizer, chunk_size=2000, overlap=200):
    """Split large documents into manageable chunks and process each"""
    chunks = []
    start = 0
    while start < len(document):
        end = min(start + chunk_size, len(document))
        # Try to find a good breaking point (end of sentence)
        if end < len(document):
            # Look for period, question mark, or exclamation followed by space or newline
            match = re.search(r'[.!?][\s\n]', document[end-100:end])
            if match:
                end = end - 100 + match.end()

        chunks.append(document[start:end])
        start = end - overlap

    print(f"Processing document in {len(chunks)} chunks...")
    all_results = []

    for i, chunk in enumerate(chunks):
        print(f"Processing chunk {i+1}/{len(chunks)}...")
        result = extract_attack_info(chunk, model, tokenizer)
        all_results.append(result)

    # Combine results from different chunks
    combined_techniques = []
    seen_ids = set()

    for result in all_results:
        try:
            data = json.loads(result)
            if "techniques" in data:
                for technique in data["techniques"]:
                    tech_id = technique.get("technique_id")
                    if tech_id and tech_id not in seen_ids:
                        combined_techniques.append(technique)
                        seen_ids.add(tech_id)
        except json.JSONDecodeError:
            continue

    return json.dumps({"techniques": combined_techniques})

In [16]:
def display_results(json_result):
    """Format and display the extracted techniques in a readable way"""
    try:
        data = json.loads(json_result)

        if "techniques" not in data or not data["techniques"]:
            display(Markdown("### No ATT&CK techniques identified in the text."))
            return

        html = "<div style='background-color:#f8f9fa; padding:15px; border-radius:5px;'>"
        html += "<h2>Extracted ATT&CK Tactics and Techniques</h2>"

        for technique in data["techniques"]:
            html += f"<div style='margin-bottom:20px; padding:10px; border-left:4px solid #007bff;'>"
            html += f"<h3>{technique.get('technique_id', 'Unknown ID')} - {technique.get('technique_name', 'Unknown Technique')}</h3>"
            html += f"<p><strong>Tactic:</strong> {technique.get('tactic', 'Unknown')}</p>"
            html += f"<p><strong>Explanation:</strong> {technique.get('explanation', 'No explanation provided')}</p>"
            html += f"<p><strong>Evidence:</strong><br><em>{technique.get('text_evidence', 'No specific evidence cited')}</em></p>"
            html += "</div>"

        html += "</div>"
        display(HTML(html))

    except json.JSONDecodeError:
        print("Failed to parse results as JSON. Raw output:")
        print(json_result)

In [17]:
def analyze_threat_report(threat_report, model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
    """Main function to analyze a threat report and extract ATT&CK information"""
    # Load model
    model, tokenizer = load_llm_model(model_name)
    if model is None or tokenizer is None:
        print("Failed to load model. Please try a different model.")
        return

    # Process report
    if len(threat_report) > 2000:
        print("Large document detected. Splitting into chunks for processing...")
        result = process_large_document(threat_report, model, tokenizer)
    else:
        result = extract_attack_info(threat_report, model, tokenizer)

    # Display results
    display_results(result)

    # Return raw result
    return result


In [18]:
attack_data = load_attack_data()

# Sample threat report
sample_report = """
APT29, also known as Cozy Bear, has been observed using spear-phishing emails with malicious attachments
to gain initial access to victim networks. After gaining access, the threat actor deployed custom malware
that uses DNS tunneling for command and control communications, effectively bypassing traditional
network security controls. The attackers moved laterally through the network using stolen credentials
and exploited Windows Management Instrumentation (WMI) to execute code remotely. They also modified
registry keys to establish persistence and scheduled tasks to maintain access.
Data was collected and compressed before being exfiltrated through encrypted channels.
"""

# Analyze the report
result = analyze_threat_report(sample_report)

Downloading MITRE ATT&CK data...
Saved 799 techniques to attack_data.csv
Loading model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
Model loaded successfully
Analyzing text with the LLM model...
