# Description of the Notebook

This Jupyter Notebook is a comprehensive workflow designed for machine learning tasks, particularly focusing on language modeling and cybersecurity applications. It integrates various tools and frameworks to streamline processes such as data preparation, model fine-tuning, and deployment. Below is an outline of its key components:

1. **Data Preparation**:
    - Extracts and formats datasets from external sources like CAPEC and Mitre CTI.
    - Prepares data for fine-tuning and retrieval-augmented generation (RAG).

2. **MAL Compiler Integration**:
    - Automates the setup and configuration of the MAL compiler.
    - Validates and refines generated Meta Attack Language (MAL) code.

3. **Mitre Data Processing**:
    - Processes STIX objects from the Mitre CTI repository.
    - Converts data into structured formats for use in RAG workflows.

4. **MAL Agent**:
    - Combines RAG and fine-tuned language models to generate valid MAL code.
    - Iteratively refines code based on compiler feedback for accuracy.

5. **LLM Fine-Tuning**:
    - Fine-tunes models like Mistral using LoRA (Low-Rank Adaptation).
    - Includes steps for dataset preparation, training, and saving fine-tuned models.

6. **Model Merging and Deployment**:
    - Merges LoRA adapters into base models for optimized inference.
    - Pushes final models to Hugging Face for seamless deployment.

7. **Inference and Testing**:
    - Provides pipelines for text generation and MAL code generation.
    - Tests fine-tuned models for specific use cases.

This notebook is tailored for researchers and developers working on advanced machine learning projects, with a focus on integrating domain-specific knowledge into language models for cybersecurity and related fields.

In [3]:
# Install dependencies
!pip install gradio langchain_huggingface torch numpy faiss-cpu sentence-transformers transformers peft datasets bitsandbytes accelerate xformers openai huggingface_hub langchain-core langchain-community stix2 torchvision torchaudio google-generativeai




In [7]:
from google.colab import userdata

username = "TP15" # Replace with your GitHub username
repo_name = "MAThesis-MALLM" # Replace with your repository name

pat = userdata.get('Github_MALLM')


repo_url_authenticated = f"https://{pat}@github.com/{username}/{repo_name}.git"


!git clone {repo_url_authenticated}


!ls {repo_name}


MALC_TAR_URL = "https://github.com/mal-lang/malc/releases/download/release%2F0.2.0/malc-0.2.0.linux.amd64.tar.gz"
MALC_TAR_FILE = "malc-0.2.0.linux.amd64.tar.gz"
EXTRACT_DIR = "malc_extracted"

!wget {MALC_TAR_URL}

!mkdir -p {EXTRACT_DIR}

!tar -xvzf {MALC_TAR_FILE} -C {EXTRACT_DIR}

!echo "Content of the DIR '{EXTRACT_DIR}':"
!ls -l {EXTRACT_DIR}

MALC_EXECUTABLE_PATH = f"./{EXTRACT_DIR}/malc-0.2.0.linux.amd64/malc"

!chmod +x {MALC_EXECUTABLE_PATH}


MALC_EXECUTABLE_PATH = "./malc_extracted/malc-0.2.0.linux.amd64/malc"
MAL_SOURCE_FILE = "/content/emailphininglang.mal"

!{MALC_EXECUTABLE_PATH} {MAL_SOURCE_FILE}

Cloning into 'MAThesis-MALLM'...
remote: Enumerating objects: 2108, done.[K
remote: Total 2108 (delta 0), reused 0 (delta 0), pack-reused 2108 (from 1)[K
Receiving objects: 100% (2108/2108), 57.42 MiB | 21.40 MiB/s, done.
Resolving deltas: 100% (621/621), done.
Evaluation   InterFace-Code    MALTool.py	      requirements.txt
finished_FT  LLM-Code	       output_files	      temp_input
HelperData   MAL_Thesis.ipynb  output_Toolevaluation  torch-env
--2025-05-30 19:55:23--  https://github.com/mal-lang/malc/releases/download/release%2F0.2.0/malc-0.2.0.linux.amd64.tar.gz
Resolving github.com (github.com)... 20.27.177.113
Connecting to github.com (github.com)|20.27.177.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/385922173/34fa9e88-3a07-45c6-b3cf-1471ad8c9a9e?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250530%2Fus-east-1%2Fs3%2Faws4_request&X-Am

# RAG Building

## MitreDataprep

In [11]:

import os
import json
from stix2 import FileSystemSource, Filter
from typing import Dict, List, Set, Any, Optional

# --- Configuration ---

LOCAL_CTI_REPO_PATH = '/content/MAThesis-MALLM/HelperData/cti'
OUTPUT_RAG_DIR = "capec_rag_input_data"


DESIRED_CAPEC_TYPES: Set[str] = {
    "attack-pattern",
    "course-of-action",
}


def get_capec_id(stix_object: Dict[str, Any]) -> Optional[str]:
    """
    Extracts the CAPEC ID from a STIX object's external_references.

    Args:
        stix_object: A STIX object (as a dictionary or stix2 object).

    Returns:
        The CAPEC ID (e.g., "CAPEC-66") or None if not found.
    """
    if not hasattr(stix_object, 'external_references'):
        return None

    for ref in stix_object.external_references:
        if ref.get('source_name') == 'capec' and ref.get('external_id'):
            ext_id = ref['external_id']
            if isinstance(ext_id, int):
                 return f"CAPEC-{ext_id}"
            elif isinstance(ext_id, str):
                 return ext_id if ext_id.startswith("CAPEC-") else f"CAPEC-{ext_id}"
    return None

def save_for_rag(data_dict: Dict[str, List[Any]], output_dir: str):
    """
    Saves the extracted STIX objects into JSON files suitable for RAG input.

    Each object type gets its own JSON file containing a list of objects,
    where each object is converted to a standard Python dictionary.

    Args:
        data_dict: The dictionary containing lists of stix2 objects per type
                   (e.g., {'attack-pattern': [obj1, obj2], ...}).
        output_dir: The path to the directory where JSON files will be saved.
    """
    print(f"\nSaving data for RAG input into directory: {output_dir}")
    try:
        # Create the output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        print(f"Ensured output directory exists or created it.")
    except OSError as e:
        print(f"  Error creating directory {output_dir}: {e}")
        return

    for obj_type, object_list in data_dict.items():
        if not object_list:
            print(f"  Skipping type '{obj_type}': No objects found.")
            continue

        data_to_save = []
        print(f"  Processing {len(object_list)} objects of type '{obj_type}' for saving...")
        for stix_obj in object_list:
             try:
                 # serialize() gives a string, json.loads() makes it a dict
                 obj_dict = json.loads(stix_obj.serialize())
                 data_to_save.append(obj_dict)
             except Exception as e:
                 print(f"    Warning: Could not serialize object {getattr(stix_obj, 'id', 'N/A')}: {e}")


        if not data_to_save:
             print(f"  Skipping file for '{obj_type}': No objects could be serialized.")
             continue

        # Define the output filename
        file_name = f"{obj_type}_rag_data.json"
        file_path = os.path.join(output_dir, file_name)

        print(f"  Saving {len(data_to_save)} '{obj_type}' objects to {file_path}...")
        try:
            with open(file_path, 'w', encoding='utf-8') as f:
                json.dump(data_to_save, f, indent=4, ensure_ascii=False)
            print(f"    Successfully saved.")
        except IOError as e:
            print(f"    Error saving file {file_path}: {e}")
        except TypeError as e:
             print(f"    Error during JSON serialization for {file_path}: {e}")


# --- Main Execution ---
if __name__ == "__main__":
    capec_data_path = os.path.join(LOCAL_CTI_REPO_PATH, 'capec', '2.1')

    print(f"Attempting to access CAPEC data in: {capec_data_path}")

    # --- Pre-check: Verify the path exists ---
    if not os.path.isdir(capec_data_path):
         print("\n--- ERROR ---")
         print(f"The specific STIX version directory was not found: {capec_data_path}")
         print(f"Please ensure the repository at '{LOCAL_CTI_REPO_PATH}' is complete and contains the 'capec/2.1/' structure.")
         print("-------------\n")
         exit(1)
    # --- End Pre-check ---

    fs = None
    capec_data: Dict[str, List[Any]] = {obj_type: [] for obj_type in DESIRED_CAPEC_TYPES}

    try:
        print(f"\nInitializing STIX FileSystemSource for directory: {capec_data_path}")
        fs = FileSystemSource(capec_data_path, allow_custom=True)
        print("FileSystemSource initialized successfully.")

        print("\nQuerying for desired object types...")
        for obj_type in DESIRED_CAPEC_TYPES:
            try:
                filt = Filter('type', '=', obj_type)
                objects = fs.query([filt])
                capec_data[obj_type] = objects
                print(f"  Found {len(objects)} objects of type '{obj_type}'")
            except Exception as e:
                 print(f"  Error querying for type '{obj_type}': {e}")

        if capec_data.get("attack-pattern"):
            print("\n--- Example: First CAPEC Attack Pattern ---")
            first_ap = capec_data["attack-pattern"][0]
            capec_id = get_capec_id(first_ap)
            print(f"  STIX ID: {first_ap.id}")
            print(f"  CAPEC ID: {capec_id or 'Not Found'}")
            print(f"  Name: {getattr(first_ap, 'name', 'N/A')}")
            print(f"  Description: {getattr(first_ap, 'description', 'N/A')[:150]}...")
            print(f"  Custom Abstraction: {getattr(first_ap, 'x_capec_abstraction', 'N/A')}")
            prereqs = getattr(first_ap, 'x_capec_prerequisites', [])
            print(f"  Custom Prerequisites count: {len(prereqs)}")
            if prereqs:
                print(f"    - Prerequisite 1: {prereqs[0][:100]}...")

        else:
            print("\nNo CAPEC attack-patterns found or extracted.")

        # --- 4. SAVE THE EXTRACTED DATA ---
        if any(capec_data.values()):
             save_for_rag(capec_data, OUTPUT_RAG_DIR)
        else:
             print("\nNo data loaded, skipping save step.")
        # --- End Save Step ---

    except Exception as e:
        print(f"\nAn error occurred during STIX processing: {e}")

Attempting to access CAPEC data in: /content/MAThesis-MALLM/HelperData/cti/capec/2.1

--- ERROR ---
The specific STIX version directory was not found: /content/MAThesis-MALLM/HelperData/cti/capec/2.1
Please ensure the repository at '/content/MAThesis-MALLM/HelperData/cti' is complete and contains the 'capec/2.1/' structure.
-------------


Initializing STIX FileSystemSource for directory: /content/MAThesis-MALLM/HelperData/cti/capec/2.1

An error occurred during STIX processing: directory path for STIX data does not exist: /content/MAThesis-MALLM/HelperData/cti/capec/2.1


CAPEC Dataprep

In [10]:
import json
import os

def extract_capec_id(external_references):
    if not external_references:
        return None
    for ref in external_references:
        if ref.get("source_name") == "capec" and "external_id" in ref:
            return ref["external_id"]
    return None

def transform_attack_pattern(ap_data):
    if not isinstance(ap_data, dict):
        print(f"Skipping invalid attack pattern data: {ap_data}")
        return None

    embedding_input = f"Name: {ap_data.get('name', 'N/A')}\nDescription: {ap_data.get('description', 'N/A')}"

    metadata = {
        "id": ap_data.get("id"),
        "type": ap_data.get("type"),
        "name": ap_data.get("name"),
        "capec_id": extract_capec_id(ap_data.get("external_references")),
        "abstraction": ap_data.get("x_capec_abstraction"),
        "domains": ap_data.get("x_capec_domains"),
        "status": ap_data.get("x_capec_status"),
        "version": ap_data.get("x_capec_version")
    }
    metadata = {k: v for k, v in metadata.items() if v is not None}

    return {
        "embedding_input": embedding_input,
        "source_type": "CAPEC",
        "metadata": metadata,
        "raw": ap_data
    }

def transform_course_of_action(coa_data):
    if not isinstance(coa_data, dict):
        print(f"Skipping invalid course of action data: {coa_data}")
        return None

    embedding_input = f"Name: {coa_data.get('name', 'N/A')}\nDescription: {coa_data.get('description', 'N/A')}"

    metadata = {
        "id": coa_data.get("id"),
        "type": coa_data.get("type"),
        "name": coa_data.get("name"),
        "version": coa_data.get("x_capec_version")
    }
    metadata = {k: v for k, v in metadata.items() if v is not None}

    return {
        "embedding_input": embedding_input,
        "source_type": "CAPEC "+ coa_data.get("type", ""),
        "metadata": metadata,
        "raw": coa_data
    }

def transform_and_combine_jsonl(attack_pattern_file, course_of_action_file, output_file):

    processed_count = 0
    error_count = 0

    output_dir = os.path.dirname(output_file)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created output directory: {output_dir}")

    try:
        with open(output_file, 'w', encoding='utf-8') as outfile:

            # --- Process Attack Patterns ---
            print(f"Processing Attack Patterns from: {attack_pattern_file}")
            try:
                with open(attack_pattern_file, 'r', encoding='utf-8') as infile:
                    content = infile.read() # Read the whole file
                    try:

                        data_list = json.loads(content)
                        if not isinstance(data_list, list):
                            print(f"Error: Expected a JSON list in {attack_pattern_file}, but got {type(data_list)}")
                            error_count += 1
                        else:
                             # Iterate through items in the list
                            for original_data in data_list:
                                try:
                                    transformed_data = transform_attack_pattern(original_data)
                                    if transformed_data:
                                        json.dump(transformed_data, outfile, ensure_ascii=False)
                                        outfile.write('\n')
                                        processed_count += 1
                                    else:

                                        error_count += 1
                                except Exception as e:
                                    print(f"Error processing attack pattern item: {original_data.get('id', 'N/A')}. Error: {e}")
                                    error_count += 1

                    except json.JSONDecodeError as e:
                        print(f"Skipping invalid JSON file: {attack_pattern_file}. Error: {e}")
                        error_count += 1
            except FileNotFoundError:
                print(f"Error: Attack Pattern file not found at {attack_pattern_file}")
                error_count += 1
            except Exception as e:
                 print(f"An unexpected error occurred while processing {attack_pattern_file}: {e}")
                 error_count += 1


            # --- Process Courses of Action ---
            print(f"\nProcessing Courses of Action from: {course_of_action_file}")
            try:
                with open(course_of_action_file, 'r', encoding='utf-8') as infile:
                    content = infile.read()
                    try:
                        data_list = json.loads(content)
                        if not isinstance(data_list, list):
                           print(f"Error: Expected a JSON list in {course_of_action_file}, but got {type(data_list)}")
                           error_count += 1
                        else:
                            for original_data in data_list:
                                try:
                                    transformed_data = transform_course_of_action(original_data)
                                    if transformed_data:
                                        json.dump(transformed_data, outfile, ensure_ascii=False)
                                        outfile.write('\n')
                                        processed_count += 1
                                    else:
                                        error_count += 1
                                except Exception as e:
                                    print(f"Error processing course of action item: {original_data.get('id', 'N/A')}. Error: {e}")
                                    error_count += 1

                    except json.JSONDecodeError as e:
                        print(f"Skipping invalid JSON file: {course_of_action_file}. Error: {e}")
                        error_count += 1
            except FileNotFoundError:
                print(f"Error: Course of Action file not found at {course_of_action_file}")
                error_count += 1
            except Exception as e:
                 print(f"An unexpected error occurred while processing {course_of_action_file}: {e}")
                 error_count += 1

    except IOError as e:
        print(f"Error opening or writing to output file {output_file}: {e}")
        return

    print(f"\nTransformation complete.")
    print(f"Successfully processed and wrote {processed_count} entries to {output_file}")
    if error_count > 0:
        print(f"Encountered {error_count} errors or skipped entries/files.")


# --- MODIFIED Example Usage ---
if __name__ == "__main__":
    print("--- Starting Transformation ---")
    transform_and_combine_jsonl(
        attack_pattern_file="MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_rag_input_data/attack-pattern_rag_data.json",
        course_of_action_file="MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_rag_input_data/course-of-action_rag_data.json",
        output_file="MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_rag_input_data/capec_combined_rag_data.jsonl"
    )
    print("--- Transformation Finished ---")

--- Starting Transformation ---
Processing Attack Patterns from: MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_rag_input_data/attack-pattern_rag_data.json

Processing Courses of Action from: MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_rag_input_data/course-of-action_rag_data.json

Transformation complete.
Successfully processed and wrote 1499 entries to MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_rag_input_data/capec_combined_rag_data.jsonl
--- Transformation Finished ---


Final RAG Creation with CAPEC as input

In [12]:
import os
import json
import sys
import traceback
from langchain_core.documents import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from typing import Dict, List, Any, Optional

# --- Configuration ---

INPUT_JSONL_FILE = "MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_rag_input_data/capec_combined_rag_data.jsonl"

FAISS_INDEX_PATH = "MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_faiss_index"

EMBEDDING_MODEL_NAME = "all-mpnet-base-v2"

# --- End Configuration ---

def load_docs_from_jsonl(jsonl_file_path: str) -> List[Document]:


    all_docs: List[Document] = []
    print(f"Starting data loading from JSONL file: {jsonl_file_path}")

    if not os.path.exists(jsonl_file_path):
        print(f"Error: Input file not found: {jsonl_file_path}")
        return []

    try:
        with open(jsonl_file_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if not line:
                    continue

                try:
                    entry = json.loads(line)

                    page_content = entry.get("embedding_input")
                    metadata = entry.get("metadata", {})
                    source_type = entry.get("source_type")

                    if not page_content:
                        print(f"Warning: Skipping line {line_num} due to missing 'embedding_input'.")
                        continue
                    if not metadata:
                         print(f"Warning: Line {line_num} has missing 'metadata'. Using empty metadata.")
                    if source_type:
                        metadata['source_type'] = source_type
                    else:
                        print(f"Warning: Line {line_num} has missing 'source_type'. It won't be added to metadata.")


                    # --- Create LangChain Document ---
                    doc = Document(page_content=page_content, metadata=metadata)
                    all_docs.append(doc)

                except json.JSONDecodeError:
                    print(f"Warning: Skipping invalid JSON on line {line_num}: {line[:100]}...")
                except Exception as e:
                    print(f"Warning: Error processing line {line_num}: {e}. Data: {line[:100]}...")

    except IOError as e:
        print(f"Error reading file {jsonl_file_path}: {e}")
        return []
    except Exception as e:
         print(f"An unexpected error occurred during file processing: {e}")
         return []


    print(f"Finished loading. Total documents prepared: {len(all_docs)}")
    return all_docs


# --- Main Execution ---
if __name__ == "__main__":
    print("--- Starting RAG Phase 1: Indexing Combined MAL/CAPEC Data ---")
    print(f"Input JSONL file: {INPUT_JSONL_FILE}")
    print(f"Vector store persistence directory: {FAISS_INDEX_PATH}")
    print(f"Using embedding model: {EMBEDDING_MODEL_NAME}")

    # 1. Load documents from the single JSONL file
    documents = load_docs_from_jsonl(INPUT_JSONL_FILE)

    if not documents:
        print("\nNo documents were loaded. Please check the input JSONL file exists and contains valid data. Exiting.")
        sys.exit(1) # Use sys.exit for clearer exit status

    # 2. Initialize embedding model
    print(f"\nInitializing embedding model '{EMBEDDING_MODEL_NAME}'...")
    # model_kwargs = {'device': 'cpu'} # Uncomment to force CPU if needed
    encode_kwargs = {'normalize_embeddings': False}
    try:
        embeddings = HuggingFaceEmbeddings(
            model_name=EMBEDDING_MODEL_NAME,
            # model_kwargs=model_kwargs,
            encode_kwargs=encode_kwargs
        )
        print("Embedding model initialized successfully.")
    except Exception as e:
        print(f"Error initializing embedding model: {e}")
        print("Make sure 'sentence-transformers' and potentially 'torch' are installed correctly.")
        sys.exit(1)

    # 3. Create and persist the FAISS vector store
    print(f"\nCreating FAISS index and saving to: {FAISS_INDEX_PATH}")
    try:
        # Calculate embeddings and create FAISS index
        vectorstore = FAISS.from_documents(
            documents=documents,
            embedding=embeddings
        )

        # Save the index and document store locally
        vectorstore.save_local(folder_path=FAISS_INDEX_PATH)

        print(f"\n--- Success! ---")
        print(f"FAISS index created and saved successfully in '{FAISS_INDEX_PATH}'.")
        print(f"Total documents indexed: {len(documents)}")

        # Optional: Simple test query requires loading the index first
        print("\nPerforming a quick test query (loading from disk)...")
        if not os.path.exists(FAISS_INDEX_PATH):
             print(f"  Error: Saved index path '{FAISS_INDEX_PATH}' not found for testing.")
        else:
            try:
                # Load the persisted index for testing
                loaded_vectorstore = FAISS.load_local(
                    FAISS_INDEX_PATH,
                    embeddings,
                    allow_dangerous_deserialization=True # Required by recent LangChain versions
                )
                # Example query - adjust based on your data (MAL or CAPEC)
                test_query = "Hardware supply chain attack description"
                results = loaded_vectorstore.similarity_search(test_query, k=1)

                if results:
                    print(f"  Test query '{test_query}' found result:")
                    # Access metadata directly from the loaded document's metadata dict
                    meta = results[0].metadata
                    source_type = meta.get('source_type', 'N/A')
                    doc_name = meta.get('name', meta.get('mal_type', 'N/A'))
                    capec_id = meta.get('capec_id', 'N/A')
                    print(f"    Source Type: {source_type}")
                    print(f"    Name/Type: {doc_name}")
                    if capec_id != 'N/A':
                         print(f"    CAPEC ID: {capec_id}")
                    print(f"    Content Snippet: {results[0].page_content[:150]}...")
                else:
                    print(f"  Test query '{test_query}' returned no results.")
            except Exception as e:
                print(f"  Error during test query loading/execution: {e}")
                traceback.print_exc()


    except Exception as e:
        print(f"\n--- Error ---")
        print(f"An error occurred during FAISS index creation/saving: {e}")
        traceback.print_exc()

    print("\n--- RAG Phase 1 Finished (Using FAISS) ---")

--- Starting RAG Phase 1: Indexing Combined MAL/CAPEC Data ---
Input JSONL file: MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_rag_input_data/capec_combined_rag_data.jsonl
Vector store persistence directory: MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_faiss_index
Using embedding model: all-mpnet-base-v2
Starting data loading from JSONL file: MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_rag_input_data/capec_combined_rag_data.jsonl
Finished loading. Total documents prepared: 1499

Initializing embedding model 'all-mpnet-base-v2'...


  embeddings = HuggingFaceEmbeddings(


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.4k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Embedding model initialized successfully.

Creating FAISS index and saving to: MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_faiss_index

--- Success! ---
FAISS index created and saved successfully in 'MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_faiss_index'.
Total documents indexed: 1499

Performing a quick test query (loading from disk)...
  Test query 'Hardware supply chain attack description' found result:
    Source Type: CAPEC
    Name/Type: Malicious Hardware Component Replacement
    CAPEC ID: CAPEC-522
    Content Snippet: Name: Malicious Hardware Component Replacement
Description: An adversary replaces legitimate hardware in the system with faulty counterfeit or tampere...

--- RAG Phase 1 Finished (Using FAISS) ---


# MAL Agent
This section introduces the MAL Agent, a pipeline designed to generate valid Meta Attack Language (MAL) code using a combination of Retrieval-Augmented Generation (RAG) and a fine-tuned language model. The agent integrates context retrieval from a FAISS-based vector store, LLM inference for code generation, and a MAL compiler for validation. The workflow includes iterative refinement of generated code based on compiler feedback to ensure correctness.

In [None]:
import google.generativeai as genai
from google.colab import userdata
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
import os
import json

# --- Schema Helper Functions ---

def _inline_refs_recursive(current_obj: Any, root_definitions: Dict[str, Any]) -> Any:
    if isinstance(current_obj, dict):
        if '$ref' in current_obj:
            ref_key = current_obj['$ref']
            if ref_key.startswith('#/$defs/'):
                def_name = ref_key.split('/')[-1]
                if def_name in root_definitions:
                    definition_copy = root_definitions[def_name].copy()
                    return _inline_refs_recursive(definition_copy, root_definitions)
                else:
                    raise ValueError(f"Unresolved reference: {ref_key}. Definition '{def_name}' not found in $defs.")
            else:
                raise ValueError(f"Unsupported $ref format: {ref_key}. Expected format like '#/$defs/ModelName'.")
        new_dict = {}
        for key, value in current_obj.items():
            new_dict[key] = _inline_refs_recursive(value, root_definitions)
        return new_dict
    elif isinstance(current_obj, list):
        return [_inline_refs_recursive(item, root_definitions) for item in current_obj]
    else:
        return current_obj

def get_inlined_schema(pydantic_model_class: type[BaseModel]) -> Dict[str, Any]:
    schema = pydantic_model_class.model_json_schema()
    if '$defs' not in schema or not schema['$defs']:
        if '$defs' in schema: del schema['$defs']
        return schema
    root_definitions = schema['$defs']
    schema_main_part = {k: v for k, v in schema.items() if k != '$defs'}
    return _inline_refs_recursive(schema_main_part, root_definitions)

def remove_key_recursive(obj: Any, key_to_remove: str) -> Any:
    if isinstance(obj, dict):
        new_dict = {}
        for k, v in obj.items():
            if k == key_to_remove: continue
            new_dict[k] = remove_key_recursive(v, key_to_remove)
        return new_dict
    elif isinstance(obj, list):
        return [remove_key_recursive(item, key_to_remove) for item in obj]
    else:
        return obj

def simplify_optional_anyof_recursive(schema_node: Any) -> Any:

    if isinstance(schema_node, dict):
        if "anyOf" in schema_node and isinstance(schema_node["anyOf"], list):
            parts = schema_node["anyOf"]

            if len(parts) == 2:
                null_part_found = False
                non_null_part_schema = None

                for part in parts:
                    if isinstance(part, dict) and part.get("type") == "null":
                        null_part_found = True
                    elif isinstance(part, dict):
                        if non_null_part_schema is None:
                           non_null_part_schema = part
                        else:
                           non_null_part_schema = None
                           break

                if null_part_found and non_null_part_schema is not None:
                    new_node = non_null_part_schema.copy()
                    new_node["nullable"] = True # Add nullable property

                    for key, value in schema_node.items():
                        if key != "anyOf" and key not in new_node:
                            new_node[key] = value

                    return simplify_optional_anyof_recursive(new_node)

        return {key: simplify_optional_anyof_recursive(value) for key, value in schema_node.items()}

    elif isinstance(schema_node, list):
        return [simplify_optional_anyof_recursive(item) for item in schema_node]

    return schema_node

# --- Pydantic Models ---
class Meta(BaseModel):
    language_version: str = Field(description="Version for.")
    language_name: str = Field(description="Ein Name für das spezifische Modell, das erstellt wird.")
    description: Optional[str] = Field(description="Eine kurze Beschreibung des Modells.")

class TTCDistribution(BaseModel):
    type: str = Field(description="Typ der Verteilung (z.B. 'Deterministic', 'Gamma', 'Bernoulli').")
    value: Optional[float] = Field(description="Wert für deterministische Verteilung.")
    shape: Optional[float] = Field(description="Shape-Parameter für Gamma-Verteilung.")
    scale: Optional[float] = Field(description="Scale-Parameter für Gamma-Verteilung.")
    probability: Optional[float] = Field(description="Wahrscheinlichkeit für Bernoulli-Verteilung.")

class ReachesTarget(BaseModel):
    target_attack_step_name: str = Field(description="Name des Ziel-Angriffs-Schritts auf dem Ziel-Asset.")
    target_asset_role: Optional[str] = Field(description="Optionale Rolle des Ziel-Assets, über die navigiert wird (z.B. 'executees').")

class AttackStep(BaseModel):
    attackstep_name: str = Field(description="Name des Angriffs-Schritts (z.B. 'compromise', 'dictionaryCrack').")
    attackstep_type: str = Field(description="Typ des Angriffs-Schritts: 'AND' (alle Voraussetzungen müssen erfüllt sein) oder 'OR' (eine Voraussetzung genügt).")
    ttc_distribution: Optional[TTCDistribution] = Field(description="Optionale TTC Verteilung für diesen Schritt.")
    reaches: Optional[List[ReachesTarget]] = Field(description="Liste der Ziele, die durch diesen Angriffs-Schritt erreichbar werden.")

class Defense(BaseModel):
    defense_name: str = Field(description="Name des Verteidigungsmechanismus (z.B. 'encrypted').")
    reaches_if_false: Optional[List[ReachesTarget]] = Field(description="Ziele, die erreichbar werden, wenn diese Verteidigung in einer Instanz 'false' ist.")

class AssetDefinition(BaseModel):
    asset_name: str = Field(description="Name des Asset-Typs (z.B. 'Machine', 'Software').")
    is_abstract: bool = Field(description="Gibt an, ob dieser Asset-Typ abstrakt ist.")
    extends: Optional[str] = Field(description="Name des Asset-Typs, von dem dieser Typ erbt (falls vorhanden, sonst null).")
    attack_steps: Optional[List[AttackStep]] = Field(description="Liste der Angriffs-Schritte für diesen Asset-Typ.")
    defenses: Optional[List[Defense]] = Field(description="Liste der Verteidigungsmechanismen für diesen Asset-Typ.")

class AssociationDefinition(BaseModel):
    association_name: str = Field(description="Name zur Identifizierung dieser Assoziationsart (z.B. 'Execution', 'Communication').")
    end1_asset_type: str = Field(description="Name des Asset-Typs am einen Ende der Assoziation.")
    end1_role_name: str = Field(description="Rollenname für das erste Ende (z.B. 'executor').")
    end1_multiplicity: str = Field(description="Multiplizität am ersten Ende (z.B. '1', '*', '0..1', '2-*').")
    end2_asset_type: str = Field(description="Name des Asset-Typs am anderen Ende der Assoziation.")
    end2_role_name: str = Field(description="Rollenname für das zweite Ende (z.B. 'executees').")
    end2_multiplicity: str = Field(description="Multiplizität am zweiten Ende.")

class MALModel(BaseModel):
    meta: Meta
    asset_definitions: List[AssetDefinition]
    association_definitions: List[AssociationDefinition]
# --- End of Pydantic Models ---

def generate_mal_model_from_text(prompt_text: str, api_key: Optional[str] = None):
    if api_key:
        genai.configure(api_key=api_key)
    elif not os.getenv("GOOGLE_API_KEY"):
        print("API key not configured. Please set GOOGLE_API_KEY or pass api_key argument.")
        return None

    model = genai.GenerativeModel(model_name='gemini-1.5-flash-latest')

    print(f"Sending prompt to Gemini: \"{prompt_text}\"")

    print("Generating, inlining, and cleaning schema for MALModel...")
    try:
        # Schema processing pipeline:
        # 1. Pydantic generates schema (get_inlined_schema calls model_json_schema)
        # 2. Inline $refs
        processed_schema = get_inlined_schema(MALModel)
        # 3. Remove "title" fields
        processed_schema = remove_key_recursive(processed_schema, "title")
        # 4. Convert "anyOf" for Optionals to "nullable: true"
        final_schema_for_api = simplify_optional_anyof_recursive(processed_schema)

        # For debugging the schema being sent:
        # print("\nFinal schema to be sent to Gemini (after all processing):")
        # print(json.dumps(final_schema_for_api, indent=2))

    except Exception as e:
        print(f"Error preparing schema: {e}")
        return None

    print("Expecting a response structured according to MALModel...")

    try:
        response = model.generate_content(
            contents=prompt_text,
            generation_config=genai.types.GenerationConfig(
                response_mime_type="application/json",
                response_schema=final_schema_for_api, # Use the fully processed schema
            )
        )

        if response.text:
            print("\nSuccessfully received JSON string from Gemini.")
            try:
                parsed_model = MALModel.model_validate_json(response.text)
                print("\nParsed MALModel object (from response.text):")
                print(parsed_model.model_dump_json(indent=2))
                return parsed_model
            except Exception as e:
                print(f"\nError parsing the response from Gemini with Pydantic: {e}")
                print("Raw response text was:")
                print(response.text)
                return None
        else:
            print("\nGemini returned an empty response text.")
            if hasattr(response, 'prompt_feedback') and response.prompt_feedback:
                print(f"Prompt Feedback: {response.prompt_feedback}")
            if hasattr(response, 'candidates') and response.candidates and len(response.candidates) > 0:
                 candidate = response.candidates[0]
                 if hasattr(candidate, 'finish_reason') and candidate.finish_reason:
                     print(f"Finish Reason: {candidate.finish_reason}")
                 if hasattr(candidate, 'safety_ratings') and candidate.safety_ratings:
                     print(f"Safety Ratings: {candidate.safety_ratings}")
            else:
                print("No candidates in response.")
            return None

    except Exception as e:
        print(f"\nAn error occurred during the Gemini API call: {e}")
        if hasattr(e, 'response') and e.response:
            try:
                print(f"Error details from API (if available): {e.response.text}")
            except:
                print(f"Error details from API (if available, non-text): {e.response}")
        return None

if __name__ == "__main__":
    my_api_key = userdata.get('GEM_TOKEN2')

    if not my_api_key:
        print("Error: GOOGLE_API_KEY not found in environment variables.")
        print("Please set your API key to run the Gemini API call.")
    else:
        example_prompt = (
            "Generate a MAL (Meta Attack Language) model definition for a very simple IT system. "
            "The system should include: "
            "1. A 'User' asset that can 'login' (an attack step). "
            "2. A 'Workstation' asset that the User uses, which can be 'exploited' (an attack step). "
            "3. A 'Server' asset that the Workstation connects to, which has a 'dataBreach' attack step. "
            "Include basic metadata: language_version 'SimpleMAL_v0.1', language_name 'MySimpleSystem', description 'A basic example MAL model.'. "
            "Define an association where 'User' (role 'userRef') 'uses' 'Workstation' (role 'usedMachines') with multiplicity '1' to '0..*'. "
            "Define another association where 'Workstation' (role 'clientWorkstations') 'connectsTo' 'Server' (role 'connectedServer') with multiplicity '*' to '1'. "
            "The 'exploited' attack step on 'Workstation' should have a 'reaches' property targeting the 'dataBreach' attack step on the 'Server' asset, using the association role 'connectedServer'. "
            "For the 'login' attack step on 'User', set attackstep_type to 'OR'. For 'exploited' on 'Workstation', attackstep_type 'OR'. For 'dataBreach' on 'Server', attackstep_type 'AND'. "
            "Do not include TTC distributions or defenses."
        )

        print("Attempting to generate MAL model with Gemini...")
        generated_model = generate_mal_model_from_text(example_prompt, api_key=my_api_key)

        if generated_model:
            print("\n--- Successfully generated and parsed MALModel from Gemini ---")
        else:
            print("\n--- Failed to generate or parse MALModel from Gemini ---")

In [16]:
#MAL + RAG
import google.generativeai as genai
from google.colab import userdata
from enum import Enum
from pydantic import BaseModel, Field, model_validator
from typing import Optional, List, Dict, Any
import os
import json
import sys
import traceback
from collections import defaultdict
from langchain_core.documents import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS


INITIAL_PROMPT_TEMPLATE = """
You are a structured language model expert tasked with generating a valid and complete MALModel JSON object. This JSON structure defines a domain-specific threat modeling language built with the Meta Attack Language (MAL). Your job is to take a scenario (e.g., a system, threat, or attack path) and output a JSON object adhering to the following schema.

Each field must be filled with accurate, semantically meaningful data. Below is your generation guide, field by field:

meta
Purpose: High-level metadata for the MAL DSL.

language_version:
A string representing the version of the DSL, e.g., "1.0.0".
Best practice: Use semantic versioning (MAJOR.MINOR.PATCH).

language_name:
A short, unique identifier for this MAL-based DSL.
Example: "industrial.control.language" or "com.example.webapp".

description:
A human-readable summary of the language's domain and purpose.
Example: "A DSL to model and simulate attacks on industrial control systems".

asset_definitions
Purpose: Defines all asset types used in the domain model. Each asset may contain attack steps and defenses.

Each AssetDefinition object includes:
asset_name:
The asset's type (e.g., "Machine", "UserAccount", "Database"). Must be unique.

is_abstract:
true if this is an abstract base type not directly instantiable. Otherwise false.

extends:
Optional name of a parent asset this one inherits from (must match another asset_name). Inherits attack steps and associations.

attack_steps:
List of AttackStep objects (optional, but encouraged). Define what an attacker can do on this asset.

defenses:
List of Defense objects (optional). Represent security controls that can block propagation.

attack_steps (inside each AssetDefinition)
Purpose: Describe what attack actions can be taken on an asset and how they propagate.

Each AttackStep includes:
attackstep_name:
Unique name of the step (e.g., "connect", "compromise", "crack").

attackstep_type:
One of:

"|": OR step (only one parent needed)

"&": AND step (all parents needed)

ttc_distribution:
Optional probabilistic model of time/effort. Use when meaningful.

Example: {"type": "Exponential", "value": 0.01}

Can also use Gamma, Deterministic, Bernoulli, etc.

reaches:
Optional list of ReachesTarget objects. Each defines a propagation to a step on the same or another asset.

defenses (inside each AssetDefinition)
Purpose: Define security mechanisms that, if active, block or alter attack propagation.

Each Defense includes:
defense_name:
Name of the defense (e.g., "patched", "encrypted", "MFA").

reaches_if_false:
Optional list of ReachesTarget objects.
Meaning: if this defense is not present, then these steps become reachable.

ReachesTarget
Purpose: Links an attack step or defense to another step (possibly on another asset).

target_attack_step_name:
The name of the attack step that becomes enabled if this one is successful. Check that you are using the correct attack step name!

target_asset_role:
Optional. Required if the target attack step is on another asset, accessed via a role from an association.

association_definitions
Purpose: Define navigable, bidirectional relationships between assets.

Each AssociationDefinition includes:
association_name:
Unique name for the association (e.g., "Execution", "Storage", "Communication").

end1_asset_type:
Name of the asset type on one side of the association.

end1_role_name:
Role name used from end1_asset_type to access end2_asset_type.
(e.g., "executor" → "executees")

end1_multiplicity:
E.g., "1", "*" (zero or more), "0..1" — matches UML multiplicities.

end2_asset_type:
Asset type on the other end.

end2_role_name:
Role used from end2_asset_type to access end1_asset_type.

end2_multiplicity:
Multiplicity for end2_asset_type.

Structure Dependency Notes
All assets referenced in associations must be defined in asset_definitions.

Do not create Empty Assets!

Do not create an Attacker Asset! All attacks should be in the asset on that the attack are being performed.

Role names in target_asset_role (within reaches) must match the roles defined in associations. Make sure you choose the correct role name out of the Association (not the Role name of the origin Asset) Double Check that!

Inheritance via extends must only refer to existing asset names.

Attack steps referred to in reaches must exist and be uniquely named in their asset context.

Your Task
Given a scenario, generate a valid, complete MALModel JSON object structured exactly as described above.

Your output must contain:

Properly nested JSON structure.

Semantically accurate content based on the scenario.

No missing dependencies or undefined references.

Realistic TTCs and common-sense associations.

You will not include comments or explanations in the output.
You only output the final JSON object.
"""

# --- Schema Helper Functions ---
def _inline_refs_recursive(current_obj: Any, root_definitions: Dict[str, Any]) -> Any:
    if isinstance(current_obj, dict):
        if '$ref' in current_obj:
            ref_key = current_obj['$ref']
            if ref_key.startswith('#/$defs/'):
                def_name = ref_key.split('/')[-1]
                if def_name in root_definitions:
                    definition_copy = root_definitions[def_name].copy()
                    return _inline_refs_recursive(definition_copy, root_definitions)
                else:
                    raise ValueError(f"Unresolved reference: {ref_key}. Definition '{def_name}' not found in $defs.")
            else:
                raise ValueError(f"Unsupported $ref format: {ref_key}. Expected format like '#/$defs/ModelName'.")
        new_dict = {}
        for key, value in current_obj.items():
            new_dict[key] = _inline_refs_recursive(value, root_definitions)
        return new_dict
    elif isinstance(current_obj, list):
        return [_inline_refs_recursive(item, root_definitions) for item in current_obj]
    else:
        return current_obj

def get_inlined_schema(pydantic_model_class: type[BaseModel]) -> Dict[str, Any]:
    schema = pydantic_model_class.model_json_schema()
    if '$defs' not in schema or not schema['$defs']:
        if '$defs' in schema: del schema['$defs']
        return schema
    root_definitions = schema['$defs']
    schema_main_part = {k: v for k, v in schema.items() if k != '$defs'}
    return _inline_refs_recursive(schema_main_part, root_definitions)

def remove_key_recursive(obj: Any, key_to_remove: str) -> Any:
    if isinstance(obj, dict):
        new_dict = {}
        for k, v in obj.items():
            if k == key_to_remove: continue
            new_dict[k] = remove_key_recursive(v, key_to_remove)
        return new_dict
    elif isinstance(obj, list):
        return [remove_key_recursive(item, key_to_remove) for item in obj]
    else:
        return obj

def simplify_optional_anyof_recursive(schema_node: Any) -> Any:
    if isinstance(schema_node, dict):
        if "anyOf" in schema_node and isinstance(schema_node["anyOf"], list):
            parts = schema_node["anyOf"]
            if len(parts) == 2:
                null_part_found = False
                non_null_part_schema = None
                for part in parts:
                    if isinstance(part, dict) and part.get("type") == "null":
                        null_part_found = True
                    elif isinstance(part, dict):
                        if non_null_part_schema is None:
                           non_null_part_schema = part
                        else:
                           non_null_part_schema = None
                           break
                if null_part_found and non_null_part_schema is not None:
                    new_node = non_null_part_schema.copy()
                    new_node["nullable"] = True
                    for key, value in schema_node.items():
                        if key != "anyOf" and key not in new_node:
                            new_node[key] = value
                    return simplify_optional_anyof_recursive(new_node)
        return {key: simplify_optional_anyof_recursive(value) for key, value in schema_node.items()}
    elif isinstance(schema_node, list):
        return [simplify_optional_anyof_recursive(item) for item in schema_node]
    return schema_node

# --- Pydantic Models ---
class Meta(BaseModel):
    language_version: str = Field(
        description="The version identifier of the MAL DSL, useful for validating compatibility of instance models and tools."
    )
    language_name: str = Field(
        description="The unique name assigned to the specific MAL-based DSL being defined."
    )
    description: Optional[str] = Field(
        description="A brief human-readable summary describing the purpose and domain of the language."
    )

class TTCType(str, Enum):
    BERNOULLI = "Bernoulli"
    BINOMIAL = "Binomial"
    EXPONENTIAL = "Exponential"
    GAMMA = "Gamma"
    LOGNORMAL = "LogNormal"
    PARETO = "Pareto"
    TRUNCATED_NORMAL = "TruncatedNormal"
    UNIFORM = "Uniform"


class TTCDistribution(BaseModel):
    type: TTCType = Field(
        description="The type of Time-to-Compromise (TTC) distribution used. Allowed: Bernoulli(p), Binomial(n, p), Exponential(λ), Gamma(k, θ), LogNormal(μ, σ), Pareto(x, α), TruncatedNormal(μ, σ²), Uniform(a, b)."
    )

    # Generic parameters
    p: Optional[float] = Field(description="Probability (p) for Bernoulli and Binomial.")
    n: Optional[int] = Field(description="Number of trials (n) for Binomial.")
    lambda_: Optional[float] = Field(alias="λ", description="Rate (λ) for Exponential.")
    k: Optional[float] = Field(description="Shape parameter (k) for Gamma.")
    theta: Optional[float] = Field(description="Scale parameter (θ) for Gamma.")
    mu: Optional[float] = Field(alias="μ", description="Mean (μ) for LogNormal and TruncatedNormal.")
    sigma: Optional[float] = Field(alias="σ", description="Standard deviation (σ) for LogNormal and TruncatedNormal.")
    x: Optional[float] = Field(description="Minimum value (x) for Pareto.")
    alpha: Optional[float] = Field(description="Shape parameter (α) for Pareto.")
    variance: Optional[float] = Field(alias="σ²", description="Variance (σ²) for TruncatedNormal.")
    a: Optional[float] = Field(description="Lower bound (a) for Uniform.")
    b: Optional[float] = Field(description="Upper bound (b) for Uniform.")

    @model_validator(mode="after")
    def check_required_fields(self) -> "TTCDistribution":
        dist_type = self.type

        required_fields = {
            TTCType.BERNOULLI: ["p"],
            TTCType.BINOMIAL: ["n", "p"],
            TTCType.EXPONENTIAL: ["lambda_"],
            TTCType.GAMMA: ["k", "theta"],
            TTCType.LOGNORMAL: ["mu", "sigma"],
            TTCType.PARETO: ["x", "alpha"],
            TTCType.TRUNCATED_NORMAL: ["mu", "variance"],
            TTCType.UNIFORM: ["a", "b"],
        }

        missing = []
        for field in required_fields.get(dist_type, []):
            if getattr(self, field) is None:
                missing.append(field)

        if missing:
            raise ValueError(f"Missing required fields for {dist_type.value} distribution: {', '.join(missing)}")

        return self

class ReachesTarget(BaseModel):
    target_attack_step_name: str = Field(
        description="The name of the attack step on the target asset that becomes reachable from this step."
    )
    target_asset_role: Optional[str] = Field(
        description="The role name of the target asset through which the target asset is connected (the rolename is defined in the association, Check that you are using the correct role name!); required if the target asset is not local to the same asset. If within the same asset, leave this empty."
    )

class AttackStep(BaseModel):
    attackstep_name: str = Field(
        description="The unique identifier of the attack step (e.g., 'compromise', 'connect')."
    )
    attackstep_type: str = Field(
        description="Logical type of the step: 'AND' if all parent steps must be fulfilled, 'OR' if any one suffices."
    )
    ttc_distribution: Optional[TTCDistribution] = Field(
        description="Optional Time-to-Compromise distribution indicating how long it takes (probabilistically) for the attacker to perform this step."
    )
    reaches: Optional[List[ReachesTarget]] = Field(
        description="A list of steps that this attack step can enable or lead to, either on the same or a related asset."
    )

class Defense(BaseModel):
    defense_name: str = Field(
        description="Name of the defense mechanism (e.g., 'patched', 'encrypted'). It represents a conditional guard for an attack step."
    )
    reaches_if_false: Optional[List[ReachesTarget]] = Field(
        description="Attack steps that become reachable if this defense is *not* implemented (i.e., has value 'false' in the model instance)."
    )

class CategoryDefinition(BaseModel):
    category_name: str = Field(description="The unique name of the category grouping related assets (e.g., 'Hardware', 'Software').")
    description: Optional[str] = Field(description="Optional human-readable description of what kinds of assets belong to this category.")

class AssetDefinition(BaseModel):
    asset_name: str = Field(description="The name of the asset type being defined (e.g., 'Machine', 'Software'). Assets correspond to classifiers in MAL.")
    is_abstract: bool = Field(description="Boolean indicating whether the asset is abstract (cannot be instantiated directly).")
    extends: Optional[str] = Field(description="The name of a parent asset this asset inherits from, if any. Inherits all attack steps and associations of the parent.")
    category: str = Field(description="The name of the category this asset belongs to (must match a CategoryDefinition).")
    attack_steps: Optional[List[AttackStep]] = Field(description="List of attack steps that can be initiated from this asset, forming the nodes of the attack graph.")
    defenses: Optional[List[Defense]] = Field(description="List of defense mechanisms defined on this asset that can mitigate attack steps.")

class AssociationDefinition(BaseModel):
    association_name: str = Field(
        description="The name identifying the type of relationship (e.g., 'Execution', 'Communication'). It can not be the same as an already used Name."
    )
    end1_asset_type: str = Field(
        description="The name of the asset type at the first end of the association."
    )
    end1_role_name: str = Field(
        description="The role name used to navigate from the first asset to the second (e.g., 'executor'). It can not be the same as an already used Name."
    )
    end1_multiplicity: str = Field(
        description="The multiplicity constraint (e.g., '1', '*', '0..1') for the first asset's side of the association."
    )
    end2_asset_type: str = Field(
        description="The name of the asset type at the second end of the association."
    )
    end2_role_name: str = Field(
        description="The role name used to navigate from the second asset to the first (e.g., 'executees'). It can not be the same as an already used Name."
    )
    end2_multiplicity: str = Field(
        description="The multiplicity constraint for the second asset's side of the association."
    )

class MALModel(BaseModel):
    meta: Meta = Field(description="Metadata about the language, including version, name, and description.")
    category_definitions: List[CategoryDefinition] = Field(description="List of all defined asset categories that group asset types.")
    asset_definitions: List[AssetDefinition] = Field(description="The list of all asset types defined in the DSL, including their attack steps, defenses, and inheritance.")
    association_definitions: List[AssociationDefinition] = Field(description="The set of associations that define navigable relationships between assets, using roles and multiplicities.")

# --- RAG Configuration ---
INPUT_JSONL_FILE = "MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_rag_input_data/capec_combined_rag_data.jsonl"
FAISS_INDEX_PATH = "MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_faiss_index"
EMBEDDING_MODEL_NAME = "all-mpnet-base-v2"
_rag_retriever = None
_rag_embeddings = None

# --- RAG Helper Functions ---

def load_docs_from_jsonl(jsonl_file_path: str) -> List[Document]:
    """
    Loads data from a JSONL file, where each line follows the predefined RAG structure,
    and creates LangChain Documents.
    """
    all_docs: List[Document] = []
    print(f"Starting data load from JSONL file: {jsonl_file_path}")
    if not os.path.exists(jsonl_file_path):
        print(f"Error: Input file not found: {jsonl_file_path}")
        return []
    try:
        with open(jsonl_file_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if not line:
                    continue
                try:
                    entry = json.loads(line)
                    page_content = entry.get("embedding_input")
                    metadata = entry.get("metadata", {})
                    source_type = entry.get("source_type")
                    if not page_content:
                        print(f"Warning: Skipping line {line_num} due to missing 'embedding_input'.")
                        continue
                    if source_type:
                        metadata['source_type'] = source_type
                    else:
                        print(f"Warning: Line {line_num} has missing 'source_type'. Not added to metadata.")
                    doc = Document(page_content=page_content, metadata=metadata)
                    all_docs.append(doc)
                except json.JSONDecodeError:
                    print(f"Warning: Skipping invalid JSON in line {line_num}: {line[:100]}...")
                except Exception as e:
                    print(f"Warning: Error processing line {line_num}: {e}. Data: {line[:100]}...")
    except IOError as e:
        print(f"Error reading file {jsonl_file_path}: {e}")
        return []
    except Exception as e:
        print(f"An unexpected error occurred during file processing: {e}")
        return []
    print(f"Loading completed. Total number of prepared documents: {len(all_docs)}")
    return all_docs


def initialize_embeddings(model_name: str = EMBEDDING_MODEL_NAME):
    """Initializes the embedding model."""
    global _rag_embeddings
    if _rag_embeddings is None:
        print(f"\nInitializing embedding model '{model_name}'...")
        encode_kwargs = {'normalize_embeddings': False}
        try:
            _rag_embeddings = HuggingFaceEmbeddings(
                model_name=model_name,
                encode_kwargs=encode_kwargs
            )
            print("Embedding model initialized successfully.")
        except Exception as e:
            print(f"Error during embedding model initialization: {e}")
            _rag_embeddings = None  # Ensure it remains None on failure
            raise  # Re-raise the error to halt the main program if necessary
    return _rag_embeddings


def get_rag_retriever(faiss_index_path: str = FAISS_INDEX_PATH, force_reload: bool = False) -> Optional[FAISS]:
    """
    Loads the FAISS vector index and returns a retriever.
    Initializes the embedding model if it hasn't been initialized yet.
    """
    global _rag_retriever
    global _rag_embeddings

    if _rag_retriever is not None and not force_reload:
        return _rag_retriever

    try:
        embeddings = initialize_embeddings()
        if embeddings is None:
            print("Could not initialize embedding model. RAG retriever will not be created.")
            return None
    except Exception as e:
        print(f"Error initializing embedding model for RAG: {e}")
        return None

    if not os.path.exists(faiss_index_path):
        print(f"FAISS index not found at: {faiss_index_path}")
        print("Please create the index first using `create_rag_index()`.")
        return None

    print(f"Loading FAISS index from: {faiss_index_path}")
    try:
        _rag_retriever = FAISS.load_local(
            faiss_index_path,
            embeddings,
            allow_dangerous_deserialization=True
        )
        print("FAISS index loaded successfully.")
        return _rag_retriever
    except Exception as e:
        print(f"Error loading FAISS index: {e}")
        traceback.print_exc()
        _rag_retriever = None
        return None


def retrieve_relevant_documents(query: str, k: int = 3) -> List[Document]:
    """
    Retrieves relevant documents from the FAISS index.
    Initializes the retriever if it hasn't been initialized yet.
    """
    retriever = get_rag_retriever()
    if not retriever:
        print("RAG retriever not available. Skipping document retrieval.")
        return []
    try:
        print(f"Searching for relevant documents for query: \"{query[:100]}...\"")
        results = retriever.similarity_search(query, k=k)
        print(f"{len(results)} documents found.")
        return results
    except Exception as e:
        print(f"Error retrieving documents: {e}")
        return []

def format_retrieved_documents_for_prompt(documents: List[Document]) -> str:
    """
    Formats the retrieved documents as context for the LLM prompt.
    """
    if not documents:
        return ""

    context_str = "\n\n--- Relevant context from knowledge base ---\n"
    for i, doc in enumerate(documents):
        context_str += f"\nDocument {i+1}:\n"
        context_str += f"Source: {doc.metadata.get('source_type', 'Unknown')}\n"
        if 'name' in doc.metadata:  # For CAPEC
            context_str += f"Name: {doc.metadata.get('name')}\n"
        if 'capec_id' in doc.metadata:  # For CAPEC
            context_str += f"CAPEC ID: {doc.metadata.get('capec_id')}\n"
        if 'mal_type' in doc.metadata:  # For MAL
            context_str += f"MAL Type: {doc.metadata.get('mal_type')}\n"
        if 'attack_step_name' in doc.metadata:  # For MAL Attack Steps
            context_str += f"Attack Step: {doc.metadata.get('attack_step_name')}\n"

        context_str += f"Content: {doc.page_content}\n"
    context_str += "\n--- End of relevant context ---\n"
    return context_str



def generate_mal_model_from_text(prompt_text: str,
                                 system_prompt: Optional[str] = None,  # NEW PARAMETER
                                 api_key: Optional[str] = None,
                                 use_rag: bool = False,
                                 rag_k: int = 3
                                 ):
    if api_key:
        genai.configure(api_key=api_key)
    elif not os.getenv("GOOGLE_API_KEY"):
        print("API key not configured. Please set GOOGLE_API_KEY or pass api_key argument.")
        return None

    # Model instantiation with optional system_instruction
    model_kwargs = {'model_name': 'gemini-2.5-pro-preview-05-06'}
    if system_prompt:
        model_kwargs['system_instruction'] = system_prompt
        print(f"Using system prompt: \"{system_prompt[:100]}...\"")

    model = genai.GenerativeModel(**model_kwargs)

    final_prompt = prompt_text
    if use_rag:
        print("\nRAG is enabled. Attempting to retrieve relevant documents...")
        relevant_docs = retrieve_relevant_documents(prompt_text, k=rag_k)
        if relevant_docs:
            rag_context = format_retrieved_documents_for_prompt(relevant_docs)
            final_prompt = f"{rag_context}\n\nBased on the above context and your general knowledge base, process the following request:\n\n{prompt_text}"
            print("Prompt has been extended with RAG context.")
        else:
            print("No relevant documents found for RAG or retriever not available. Using original prompt.")

    print(f"Sending prompt to Gemini: \"{final_prompt[:300]}...\"")  # Truncated for readability

    print("Generating, inlining, and cleaning schema for MALModel...")
    try:
        processed_schema = get_inlined_schema(MALModel)

        processed_schema = remove_key_recursive(processed_schema, "title")
        # Simplification should occur after removing 'title',
        # or remove_key_recursive must handle 'title' in anyOf structures intelligently.
        final_schema_for_api = simplify_optional_anyof_recursive(processed_schema)

    except Exception as e:
        print(f"Error while preparing schema: {e}")
        return None

    print("Expecting a response structured according to MALModel...")

    try:
        contents_for_api = [final_prompt]

        response = model.generate_content(
            contents=contents_for_api,  # Uses the possibly extended prompt
            generation_config=genai.types.GenerationConfig(
                response_mime_type="application/json",
                response_schema=final_schema_for_api,
            )
        )

        if response.text:
            print("\nSuccessfully received JSON string from Gemini.")
            try:
                parsed_model = MALModel.model_validate_json(response.text)
                print("\nParsed MALModel object (from response.text):")
                print(parsed_model.model_dump_json(indent=2))
                return parsed_model
            except Exception as e:
                print(f"\nError parsing Gemini response with Pydantic: {e}")
                print("Raw response text was:")
                print(response.text)
                return None
        else:
            print("\nGemini returned an empty response text.")
            if hasattr(response, 'prompt_feedback') and response.prompt_feedback:
                print(f"Prompt Feedback: {response.prompt_feedback}")
            if hasattr(response, 'candidates') and response.candidates and len(response.candidates) > 0:
                candidate = response.candidates[0]
                if hasattr(candidate, 'finish_reason') and candidate.finish_reason:
                    print(f"Finish Reason: {candidate.finish_reason}")
                if hasattr(candidate, 'safety_ratings') and candidate.safety_ratings:
                    print(f"Safety Ratings: {candidate.safety_ratings}")
            else:
                print("No candidates in response.")
            return None

    except Exception as e:
        print(f"\nAn error occurred during the Gemini API call: {e}")
        if hasattr(e, 'response') and e.response:
            try:
                print(f"API error details (if available): {e.response.text}")
            except:
                print(f"API error details (if available, non-text): {e.response}")
        return None


def json_to_mal(json_data):
    lines = []

    meta = json_data.get("meta", {})
    lines.append(f'#id: "{meta.get("language_name", "example.lang")}"')
    lines.append(f'#version: "{meta.get("language_version", "1.0.0")}"\n')

    # Group assets by category
    category_map = defaultdict(list)
    for asset in json_data.get("asset_definitions", []):
        category = asset.get("category", "Uncategorized")
        category_map[category].append(asset)

    # Write each category block
    for category_name, assets in category_map.items():
        lines.append(f"category {category_name} {{")
        for asset in assets:
            abstract_kw = "abstract " if asset.get("is_abstract") else ""
            extends_kw  = f" extends {asset['extends']}" if asset.get("extends") else ""

            lines.append(f"  {abstract_kw}asset {asset['asset_name']}{extends_kw} {{")

            # Attack steps
            for step in asset.get("attack_steps", []):
                prefix = {"OR": "|", "|": "|", "AND": "&", "&": "&"}.get(step["attackstep_type"], "|")
                ttc = ""
                dist = step.get("ttc_distribution")
                if dist:
                    t_type = dist["type"]
                    if t_type == "Gamma":
                        ttc = f" [Gamma({dist.get('k')},{dist.get('theta')})]"
                    elif t_type == "Exponential":
                        ttc = f" [Exponential({dist.get('lambda_')})]"
                    elif t_type == "Bernoulli":
                        ttc = f" [Bernoulli({dist.get('p')})]"
                    elif t_type == "Binomial":
                        ttc = f" [Binomial({dist.get('n')},{dist.get('p')})]"
                    elif t_type == "LogNormal":
                        ttc = f" [LogNormal({dist.get('mu')},{dist.get('sigma')})]"
                    elif t_type == "Pareto":
                        ttc = f" [Pareto({dist.get('x')},{dist.get('alpha')})]"
                    elif t_type == "TruncatedNormal":
                        ttc = f" [TruncatedNormal({dist.get('mu')},{dist.get('variance')})]"
                    elif t_type == "Uniform":
                        ttc = f" [Uniform({dist.get('a')},{dist.get('b')})]"
                    else:
                        ttc = f" [{t_type}]"

                target_line = ""
                if step.get("reaches"):
                    targets = []
                    for target in step["reaches"]:
                        if target.get("target_asset_role"):
                            targets.append(f"{target['target_asset_role']}.{target['target_attack_step_name']}")
                        else:
                            targets.append(f"{target['target_attack_step_name']}")
                    target_line = f" -> {', '.join(targets)}"

                lines.append(f"    {prefix} {step['attackstep_name']}{ttc}{target_line}")

            # Defense steps
            for defense in asset.get("defenses") or []:
                line = f"    # {defense['defense_name']}"
                if defense.get("reaches_if_false"):
                    targets = []
                    for target in defense["reaches_if_false"]:
                        if target.get("target_asset_role"):
                            targets.append(f"{target['target_asset_role']}.{target['target_attack_step_name']}")
                        else:
                            targets.append(target["target_attack_step_name"])
                    line += f" -> {', '.join(targets)}"
                lines.append(line)

            lines.append("  }")
        lines.append("}\n")

    # Associations block
    lines.append("associations {")
    for assoc in json_data.get("association_definitions", []):
        lines.append(
            f"  {assoc['end1_asset_type']} [{assoc['end2_role_name']}] {assoc['end1_multiplicity']} <-- {assoc['association_name']} --> {assoc['end2_multiplicity']} [{assoc['end1_role_name']}] {assoc['end2_asset_type']}"
        )
    lines.append("}")

    return "\n".join(lines)



if __name__ == "__main__":
    my_api_key = userdata.get('GEM_TOKEN2')

    if not my_api_key:
        print("Error: GOOGLE_API_KEY not found in Colab secrets or environment variables.")
    else:
        genai.configure(api_key=my_api_key)

        example_prompt = ( r'''Generate a MAL (Meta Attack Language) model definition for the following Website.


# Cybersecurity Incident Report

**Incident ID:** IR-2025-0427
**Generated on:** 2025-05-15
**Reported by:** Security Operations Center (SOC)
**Category:** Infrastructure Threat Modeling

## System Overview

The affected infrastructure consists of interconnected components modeled in the Infrastructure category:

* **Gateways:** External-facing components that regulate entry to internal machines.
* **Machines:** Compute units capable of authentication and connection activities.
* **Operators:** Human users capable of launching social engineering attacks.
* **Credentials:** Authentication secrets shared between operators and machines.

## Threat Scenario Summary

An external Operator executed a successful phishing campaign against internal personnel, resulting in the extraction of Credentials. The stolen credentials were then used to gain unauthorized login access to several Machines. From there, lateral movement was observed via internal linking mechanisms.

## Attack Path Breakdown

### Phishing Launch

The attacker (Operator) initiated `launchPhishing`, resulting in `phishingHit`.

Based on simulation parameters:

* `phishingHit` has a 10% exponential probability per attempt.

### Credential Compromise

The `phishingHit` led to successful extract of Credentials.

### Machine Access

The attacker used the compromised Credentials to trigger `login` on a Machine.

* Each Machine uses `login` to allow entry.

### Lateral Movement

With entry established, the attacker used `link` capability to move between Machines through the internal network (Gateway → Machine).

## Associations Observed

* Gateways connected to multiple Machines via `LinkAccess`.
* Machines and Operators shared Credentials via `AuthData` relationships.
* No anomalies were detected in the Gateway itself, suggesting direct compromise was avoided.

## Risk Assessment

| Asset      | Risk Level | Exploited Technique            | Observed Impact                |
| :--------- | :--------- | :----------------------------- | :----------------------------- |
| Operator   | Medium     | Phishing (Social Engineering)  | Credential theft               |
| Credentials| High       | Unauthorized extraction        | Privilege escalation           |
| Machine    | High       | Login with stolen credentials  | Data access & persistence      |
| Gateway    | Low        | No direct access observed      | Used for lateral movement      |

## Mitigation Recommendations

* Enforce MFA for all Machine login attempts.
* Conduct phishing awareness training for all Operators.
* Rotate credentials and monitor for abnormal `link` behavior.
* Segment internal networks to restrict Gateway-based lateral movement.
''')

        if os.path.exists(FAISS_INDEX_PATH) and os.path.isdir(FAISS_INDEX_PATH):
            print("\nAttempting to generate MAL model with Gemini (WITH RAG)...")
            generated_model_with_rag = generate_mal_model_from_text(
                example_prompt,
                INITIAL_PROMPT_TEMPLATE,
                api_key=my_api_key,
                use_rag=True,
                rag_k=3
            )

            if generated_model_with_rag:
                print("\n--- Successfully generated and parsed MALModel (with RAG) ---")
                #json.loads(generated_model_with_rag)
                mal_output = json_to_mal(generated_model_with_rag.model_dump())
                print("\n--- Converted MAL syntax ---\n")
                print(mal_output)
                output_file_path = "generated_model.mal"
                with open(output_file_path, "w", encoding="utf-8") as f:
                    f.write(mal_output)

                print(f"\n MAL code saved to: {output_file_path}")
                MALC_EXECUTABLE_PATH = "./malc_extracted/malc-0.2.0.linux.amd64/malc"
                MAL_SOURCE_FILE = "/content/generated_model.mal"

                !{MALC_EXECUTABLE_PATH} {MAL_SOURCE_FILE}
            else:
                print("\n--- Error during generation or parsing of MALModel (with RAG) ---")
        else:
            print("\nSkipping RAG-based generation since the index is not available.")



Attempting to generate MAL model with Gemini (WITH RAG)...
Using system prompt: "
You are a structured language model expert tasked with generating a valid and complete MALModel JSO..."

RAG is enabled. Attempting to retrieve relevant documents...

Initializing embedding model 'all-mpnet-base-v2'...
Embedding model initialized successfully.
Loading FAISS index from: MAThesis-MALLM/LLM-Code/RAG/RAG-DataPrep/capec_faiss_index
FAISS index loaded successfully.
Searching for relevant documents for query: "Generate a MAL (Meta Attack Language) model definition for the following Website.


# Cybersecurity ..."
3 documents found.
Prompt has been extended with RAG context.
Sending prompt to Gemini: "

--- Relevant context from knowledge base ---

Document 1:
Source: CAPEC course-of-action
Name: coa-212-0
Content: Name: coa-212-0
Description: Perform comprehensive threat modeling, a process of identifying, evaluating, and mitigating potential threats to the application. This effort can help rev

Transform to MAL

In [None]:
import json

from collections import defaultdict

def json_to_mal(json_data):
    lines = []

    meta = json_data.get("meta", {})
    lines.append(f'#id: "{meta.get("language_name", "example.lang")}"')
    lines.append(f'#version: "{meta.get("language_version", "1.0.0")}"\n')

    # Group assets by category
    category_map = defaultdict(list)
    for asset in json_data.get("asset_definitions", []):
        category = asset.get("category", "Uncategorized")
        category_map[category].append(asset)

    # Write each category block
    for category_name, assets in category_map.items():
        lines.append(f"category {category_name} {{")
        for asset in assets:
            abstract_kw = "abstract " if asset.get("is_abstract") else ""
            extends_kw  = f" extends {asset['extends']}" if asset.get("extends") else ""

            lines.append(f"  {abstract_kw}asset {asset['asset_name']}{extends_kw} {{")

            # Attack steps
            for step in asset.get("attack_steps", []):
                prefix = {"OR": "|", "|": "|", "AND": "&", "&": "&"}.get(step["attackstep_type"], "|")
                ttc = ""
                dist = step.get("ttc_distribution")
                if dist:
                    t_type = dist["type"]
                    if t_type == "Gamma":
                        ttc = f" [Gamma({dist.get('k')},{dist.get('theta')})]"
                    elif t_type == "Exponential":
                        ttc = f" [Exponential({dist.get('lambda_')})]"
                    elif t_type == "Bernoulli":
                        ttc = f" [Bernoulli({dist.get('p')})]"
                    elif t_type == "Binomial":
                        ttc = f" [Binomial({dist.get('n')},{dist.get('p')})]"
                    elif t_type == "LogNormal":
                        ttc = f" [LogNormal({dist.get('mu')},{dist.get('sigma')})]"
                    elif t_type == "Pareto":
                        ttc = f" [Pareto({dist.get('x')},{dist.get('alpha')})]"
                    elif t_type == "TruncatedNormal":
                        ttc = f" [TruncatedNormal({dist.get('mu')},{dist.get('variance')})]"
                    elif t_type == "Uniform":
                        ttc = f" [Uniform({dist.get('a')},{dist.get('b')})]"
                    else:
                        ttc = f" [{t_type}]"

                target_line = ""
                if step.get("reaches"):
                    targets = []
                    for target in step["reaches"]:
                        if target.get("target_asset_role"):
                            targets.append(f"{target['target_asset_role']}.{target['target_attack_step_name']}")
                        else:
                            targets.append(f"{target['target_attack_step_name']}")
                    target_line = f" -> {', '.join(targets)}"

                lines.append(f"    {prefix} {step['attackstep_name']}{ttc}{target_line}")

            # Defense steps
            for defense in asset.get("defenses") or []:
                line = f"    # {defense['defense_name']}"
                if defense.get("reaches_if_false"):
                    targets = []
                    for target in defense["reaches_if_false"]:
                        if target.get("target_asset_role"):
                            targets.append(f"{target['target_asset_role']}.{target['target_attack_step_name']}")
                        else:
                            targets.append(target["target_attack_step_name"])
                    line += f" -> {', '.join(targets)}"
                lines.append(line)

            lines.append("  }")
        lines.append("}\n")

    # Associations block
    lines.append("associations {")
    for assoc in json_data.get("association_definitions", []):
        lines.append(
            f"  {assoc['end1_asset_type']} [{assoc['end1_role_name']}] {assoc['end1_multiplicity']} <-- {assoc['association_name']} --> {assoc['end2_multiplicity']} [{assoc['end2_role_name']}] {assoc['end2_asset_type']}"
        )
    lines.append("}")

    return "\n".join(lines)


# Load your JSON here (example using a string)
json_string = '''{
  "meta": {
    "language_version": "1.0.0",
    "language_name": "org.mitre.attack.web_exploitation",
    "description": "A MAL DSL for modeling attacks against public-facing web applications and related infrastructure, based on ATT&CK T1190 and relevant CAPECs."
  },
  "category_definitions": [
    {
      "category_name": "ApplicationLayer",
      "description": "Assets primarily operating at the application layer, such as servers and client software."
    },
    {
      "category_name": "NetworkLayer",
      "description": "Assets that form the network infrastructure, such as routers, switches, and firewalls."
    }
  ],
  "asset_definitions": [
    {
      "asset_name": "WebServer",
      "is_abstract": false,
      "extends": null,
      "category": "ApplicationLayer",
      "attack_steps": [
        {
          "attackstep_name": "discoverVulnerabilities",
          "attackstep_type": "&",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 0.1,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": [
            {
              "target_attack_step_name": "exploitKnownVulnerability",
              "target_asset_role": null
            }
          ]
        },
        {
          "attackstep_name": "exploitKnownVulnerability",
          "attackstep_type": "|",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 0.5,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": [
            {
              "target_attack_step_name": "gainUnauthorizedAccess",
              "target_asset_role": null
            },
            {
              "target_attack_step_name": "connect",
              "target_asset_role": "dataStore"
            },
            {
              "target_attack_step_name": "gainNetworkAccess",
              "target_asset_role": "internalNetworkDevice"
            },
            {
              "target_attack_step_name": "exploitBrowserVulnerability",
              "target_asset_role": "webClient"
            }
          ]
        },
        {
          "attackstep_name": "gainUnauthorizedAccess",
          "attackstep_type": "&",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 1.0,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": [
            {
              "target_attack_step_name": "uploadMaliciousFile",
              "target_asset_role": null
            },
            {
              "target_attack_step_name": "escalatePrivileges",
              "target_asset_role": null
            }
          ]
        },
        {
          "attackstep_name": "uploadMaliciousFile",
          "attackstep_type": "&",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 0.8,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": [
            {
              "target_attack_step_name": "executeUploadedFile",
              "target_asset_role": null
            }
          ]
        },
        {
          "attackstep_name": "executeUploadedFile",
          "attackstep_type": "&",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 2.0,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": []
        },
        {
          "attackstep_name": "escalatePrivileges",
          "attackstep_type": "|",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 0.3,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": []
        }
      ],
      "defenses": [
        {
          "defense_name": "softwarePatched",
          "reaches_if_false": [
            {
              "target_attack_step_name": "exploitKnownVulnerability",
              "target_asset_role": null
            }
          ]
        },
        {
          "defense_name": "webApplicationFirewall",
          "reaches_if_false": [
            {
              "target_attack_step_name": "exploitKnownVulnerability",
              "target_asset_role": null
            }
          ]
        },
        {
          "defense_name": "networkSegmentationDMZ",
          "reaches_if_false": [
            {
              "target_attack_step_name": "gainNetworkAccess",
              "target_asset_role": "internalNetworkDevice"
            }
          ]
        },
        {
          "defense_name": "leastPrivilegeServiceAccount",
          "reaches_if_false": [
            {
              "target_attack_step_name": "escalatePrivileges",
              "target_asset_role": null
            }
          ]
        }
      ]
    },
    {
      "asset_name": "Database",
      "is_abstract": false,
      "extends": null,
      "category": "ApplicationLayer",
      "attack_steps": [
        {
          "attackstep_name": "connect",
          "attackstep_type": "&",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 1.5,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": [
            {
              "target_attack_step_name": "attemptSqlInjection",
              "target_asset_role": null
            }
          ]
        },
        {
          "attackstep_name": "attemptSqlInjection",
          "attackstep_type": "|",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 0.7,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": [
            {
              "target_attack_step_name": "exfiltrateSensitiveData",
              "target_asset_role": null
            }
          ]
        },
        {
          "attackstep_name": "exfiltrateSensitiveData",
          "attackstep_type": "&",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 0.2,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": []
        }
      ],
      "defenses": [
        {
          "defense_name": "inputValidationOnQueries",
          "reaches_if_false": [
            {
              "target_attack_step_name": "attemptSqlInjection",
              "target_asset_role": null
            }
          ]
        },
        {
          "defense_name": "accessControlsEnforced",
          "reaches_if_false": [
            {
              "target_attack_step_name": "connect",
              "target_asset_role": null
            }
          ]
        }
      ]
    },
    {
      "asset_name": "NetworkInfrastructure",
      "is_abstract": false,
      "extends": null,
      "category": "NetworkLayer",
      "attack_steps": [
        {
          "attackstep_name": "scanForOpenPortsServices",
          "attackstep_type": "&",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 0.1,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": [
            {
              "target_attack_step_name": "exploitFirmwareVulnerability",
              "target_asset_role": null
            }
          ]
        },
        {
          "attackstep_name": "exploitFirmwareVulnerability",
          "attackstep_type": "|",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 0.4,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": [
            {
              "target_attack_step_name": "gainNetworkAccess",
              "target_asset_role": null
            }
          ]
        },
        {
          "attackstep_name": "gainNetworkAccess",
          "attackstep_type": "|",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 1.2,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": []
        }
      ],
      "defenses": [
        {
          "defense_name": "firmwareUpdated",
          "reaches_if_false": [
            {
              "target_attack_step_name": "exploitFirmwareVulnerability",
              "target_asset_role": null
            }
          ]
        },
        {
          "defense_name": "strongAdminCredentials",
          "reaches_if_false": [
            {
              "target_attack_step_name": "gainNetworkAccess",
              "target_asset_role": null
            }
          ]
        },
        {
          "defense_name": "disabledUnnecessaryServices",
          "reaches_if_false": [
            {
              "target_attack_step_name": "exploitFirmwareVulnerability",
              "target_asset_role": null
            }
          ]
        }
      ]
    },
    {
      "asset_name": "ClientBrowser",
      "is_abstract": false,
      "extends": null,
      "category": "ApplicationLayer",
      "attack_steps": [
        {
          "attackstep_name": "visitCompromisedWebsite",
          "attackstep_type": "&",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 0.05,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": [
            {
              "target_attack_step_name": "exploitBrowserVulnerability",
              "target_asset_role": null
            }
          ]
        },
        {
          "attackstep_name": "exploitBrowserVulnerability",
          "attackstep_type": "|",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 0.6,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": [
            {
              "target_attack_step_name": "manipulateTrafficOrDOM",
              "target_asset_role": null
            }
          ]
        },
        {
          "attackstep_name": "manipulateTrafficOrDOM",
          "attackstep_type": "&",
          "ttc_distribution": {
            "type": "Exponential",
            "p": null,
            "n": null,
            "lambda_": 0.9,
            "k": null,
            "theta": null,
            "mu": null,
            "sigma": null,
            "x": null,
            "alpha": null,
            "variance": null,
            "a": null,
            "b": null
          },
          "reaches": []
        }
      ],
      "defenses": [
        {
          "defense_name": "browserSoftwarePatched",
          "reaches_if_false": [
            {
              "target_attack_step_name": "exploitBrowserVulnerability",
              "target_asset_role": null
            }
          ]
        },
        {
          "defense_name": "securityExtensionsEnabled",
          "reaches_if_false": [
            {
              "target_attack_step_name": "exploitBrowserVulnerability",
              "target_asset_role": null
            }
          ]
        }
      ]
    }
  ],
  "association_definitions": [
    {
      "association_name": "WebServerHostsOnNetwork",
      "end1_asset_type": "WebServer",
      "end1_role_name": "hostedApplication",
      "end1_multiplicity": "*",
      "end2_asset_type": "NetworkInfrastructure",
      "end2_role_name": "hostingPlatform",
      "end2_multiplicity": "1..*"
    },
    {
      "association_name": "WebServerQueriesDatabase",
      "end1_asset_type": "WebServer",
      "end1_role_name": "databaseClient",
      "end1_multiplicity": "*",
      "end2_asset_type": "Database",
      "end2_role_name": "dataStore",
      "end2_multiplicity": "0..*"
    },
    {
      "association_name": "ClientAccessesWebServer",
      "end1_asset_type": "ClientBrowser",
      "end1_role_name": "webClient",
      "end1_multiplicity": "*",
      "end2_asset_type": "WebServer",
      "end2_role_name": "contentProvider",
      "end2_multiplicity": "*"
    },
    {
      "association_name": "WebServerInternalAccess",
      "end1_asset_type": "WebServer",
      "end1_role_name": "internalNetworkPivotSource",
      "end1_multiplicity": "*",
      "end2_asset_type": "NetworkInfrastructure",
      "end2_role_name": "internalNetworkDevice",
      "end2_multiplicity": "0..*"
    }
  ]
}'''

mal_json = json.loads(json_string)

mal_output = json_to_mal(mal_json)
print(mal_output)


#id: "org.mitre.attack.web_exploitation"
#version: "1.0.0"

category ApplicationLayer {
  asset WebServer {
    & discoverVulnerabilities [Exponential(0.1)] -> exploitKnownVulnerability
    | exploitKnownVulnerability [Exponential(0.5)] -> gainUnauthorizedAccess, dataStore.connect, internalNetworkDevice.gainNetworkAccess, webClient.exploitBrowserVulnerability
    & gainUnauthorizedAccess [Exponential(1.0)] -> uploadMaliciousFile, escalatePrivileges
    & uploadMaliciousFile [Exponential(0.8)] -> executeUploadedFile
    & executeUploadedFile [Exponential(2.0)]
    | escalatePrivileges [Exponential(0.3)]
    # softwarePatched -> exploitKnownVulnerability
    # webApplicationFirewall -> exploitKnownVulnerability
    # networkSegmentationDMZ -> internalNetworkDevice.gainNetworkAccess
    # leastPrivilegeServiceAccount -> escalatePrivileges
  }
  asset Database {
    & connect [Exponential(1.5)] -> attemptSqlInjection
    | attemptSqlInjection [Exponential(0.7)] -> exfiltrateSensitiveData
 

MAL Compiler install

In [None]:


# URL zur .tar.gz-Datei für Version 0.2.0
MALC_TAR_URL = "https://github.com/mal-lang/malc/releases/download/release%2F0.2.0/malc-0.2.0.linux.amd64.tar.gz"
MALC_TAR_FILE = "malc-0.2.0.linux.amd64.tar.gz"
EXTRACT_DIR = "malc_extracted" # Name für das Verzeichnis nach dem Entpacken

# Herunterladen
!wget {MALC_TAR_URL}

# Verzeichnis zum Entpacken erstellen (falls es bereits existiert, wird es übersprungen)
!mkdir -p {EXTRACT_DIR}

# Archiv in das Verzeichnis entpacken
# 'tar -xvzf' : eXtract, Verbose, Zipped (gzip), File
!tar -xvzf {MALC_TAR_FILE} -C {EXTRACT_DIR}

# Inhalt des entpackten Verzeichnisses anzeigen, um den Pfad zur 'malc'-Datei zu finden
!echo "Inhalt des Verzeichnisses '{EXTRACT_DIR}':"
!ls -l {EXTRACT_DIR}

# Prüfen, ob sich die 'malc'-Datei direkt im Verzeichnis oder in einem Unterverzeichnis wie 'bin' befindet.
# Wir gehen davon aus, dass sie direkt im Verzeichnis liegt (anpassen, falls nötig).
MALC_EXECUTABLE_PATH = f"./{EXTRACT_DIR}/malc-0.2.0.linux.amd64/malc"

# Sicherstellen, dass die Datei ausführbar ist
!chmod +x {MALC_EXECUTABLE_PATH}

--2025-05-15 11:32:01--  https://github.com/mal-lang/malc/releases/download/release%2F0.2.0/malc-0.2.0.linux.amd64.tar.gz
Resolving github.com (github.com)... 20.27.177.113
Connecting to github.com (github.com)|20.27.177.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/385922173/34fa9e88-3a07-45c6-b3cf-1471ad8c9a9e?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250515%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250515T113024Z&X-Amz-Expires=300&X-Amz-Signature=b914144ff885cdf6f2040b34b832d4bdf7bf1b0ef1b56c4c825669ee6dc077cd&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Dmalc-0.2.0.linux.amd64.tar.gz&response-content-type=application%2Foctet-stream [following]
--2025-05-15 11:32:01--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/385922173/34fa9e88-3a07-45c6-b3cf-1471ad8c9a9e?X-Amz-Al

In [None]:
MALC_EXECUTABLE_PATH = "./malc_extracted/malc-0.2.0.linux.amd64/malc" # Pfad anpassen, falls 'malc' in einem Unterverzeichnis (z.B. bin) liegt
MAL_SOURCE_FILE = "/content/generated_model.mal"

!{MALC_EXECUTABLE_PATH} {MAL_SOURCE_FILE}

[[1;31mANALYZER ERROR[m] <generated_model.mal:7:79> Field 'dataStore' not defined for asset 'WebServer'
[[1;31mANALYZER ERROR[m] <generated_model.mal:7:98> Field 'internalNetworkDevice' not defined for asset 'WebServer'
[[1;31mANALYZER ERROR[m] <generated_model.mal:7:139> Field 'webClient' not defined for asset 'WebServer'
[[1;31mANALYZER ERROR[m] <generated_model.mal:14:33> Field 'internalNetworkDevice' not defined for asset 'WebServer'
[0m[0m