In [None]:
from sys import path
path.append("../..")

In [None]:
from argparse import ArgumentParser
from numpy import percentile

from typing import Optional
from torch import device, tensor, Tensor, where
from torch.cuda import is_available as cuda_is_available
from torch_geometric.utils import remove_isolated_nodes


### Helper functions for loading and filtering graphs

In [None]:
def edge_filter(pyg_dataset, edge_weight_percentile:float, threshold:Optional[float]=None, node_attributes:Optional[list]=[]) -> float:
    if not threshold:
        edge_weights = []
        for data in pyg_dataset:
            edge_weights += data.edge_attr
        threshold = percentile(edge_weights, edge_weight_percentile)
        
    for data in pyg_dataset:
        edge_attr = tensor(data.edge_attr)
        valid_edges = where(edge_attr > tensor(threshold))[0]
        data.edge_index = data.edge_index[:,valid_edges]
        data.edge_attr = edge_attr[valid_edges]
        
        # remove isolated nodes
        data.edge_index, data.edge_attr, mask = remove_isolated_nodes(data.edge_index, data.edge_attr, num_nodes=data.x.shape[0])
        data.x = data.x[mask,:]
        for attribute in node_attributes:
            if isinstance(data[attribute], list):
                data[attribute] = [value for node_id, value in enumerate(data[attribute]) if mask[node_id]]
            if isinstance(data[attribute], Tensor):
                data[attribute] = data[attribute][mask]
        
    return threshold


### Load Pretrained Models
- LLM (transformer decoder)
- Graph Neural Network (Graph Attention Network)

In [None]:
from torch import float32
import json
from pathlib import Path
import pickle as pkl
from torch import load as torch_load

from tooling.llm_engine.gorilla_v2 import GorillaFunctionCalling as llm
from torch_geometric.nn import global_max_pool, global_mean_pool
from models.gnn.graph_classifier import GraphClassifier

decoder_model = llm()

skill_selector_directory = "/home/azureuser/lbetthauser/models/orchestrator/gnn/model"
GRAPH_DEVICE = "cuda:0" if cuda_is_available else "cpu"

with open(Path(skill_selector_directory, "skill_map.json"), "r") as f:
    skill_mapping = json.load(f)
reverse_skill_mapping = {v:k for k,v in skill_mapping.items()}

with open(Path(skill_selector_directory, "model_metadata.json"), "rb") as f:
    model_metadata = json.load(f)
edge_filter_threshold = model_metadata.pop("edge_weight_threshold")
    
# load model
skill_selection_model = GraphClassifier(**model_metadata, pooling_layer=global_max_pool)
skill_selection_model.load_state_dict(torch_load(Path(skill_selector_directory, "model.pt")), strict=True)
skill_selection_model.eval()
skill_selection_model.to(device(GRAPH_DEVICE))
    
print(skill_selection_model)

### Define Function Templates and Provide User Questions

In [None]:
function_templates = [
    {
        "name": "investigation_report",
        "description": "Generates a structured incident report (Splunk Incident Report). The report includes a detailed timeline of actions taken and enriched context on what happened.",
        "parameters": {
            "type": "object",
            "properties": {
                "work_notes": {
                    "type": "string",
                    "description": "The work notes about the splunk investigation."
                },
                "output": {
                    "type": "string",
                    "description": "A splunk incident report."
                }
            },
            "required": ["work_notes"]
        }
    },
    {
        "name": "investigation_summarizer",
        "description": "Gets in progress work notes for an investigation contained in the chat history, fetches additional information based on the context, and generates a summary.",
        "parameters": {
            "type": "object",
            "properties": {
                "work_notes": {
                    "type": "string",
                    "description": "The work notes about the splunk investigation."
                },
                "output": {
                    "type": "string",
                    "description": "Natural language summary of the investigation."
                }
            },
            "required": ["work_notes"]
        }
    },
    {
        "name": "finding_summarizer",
        "description": "Gets in progress work notes for an investigation contained in the chat history, fetches additional information based on the context, and generates a summary.",
        "parameters": {
            "type": "object",
            "properties": {
                "work_notes": {
                    "type": "string",
                    "description": "The work notes about the splunk investigation."
                },
                "output": {
                    "type": "string",
                    "description": "Natural language summary of the finding."
                }
            },
            "required": ["finding_notes"]
        }
    },
    {
        "name": "spl_editor",
        "description": "Edits a user provided SPL query.",
        "examples": [
            "Edit this query to change the time frame, | rest /services/data/indexes | fields title | rename title AS index. Add a filter for title | rest /services/data/indexes.",
            "Update, | rest /services/data/indexes | fields title | rename title AS index, to only get top 10 results."
        ],
        "parameters": {
            "type": "object",
            "properties": {
                "spl_query": {
                    "type": "string",
                    "description": "The spl query the user wants to modify. SPL query from user question will start with 'search' or |."
                },
                "intent": {
                    "type": "string",
                    "description": "A description of what the user would like to edit/change/fix/use/add/remove about the spl_query."
                },
                "output": {
                    "type": "string",
                    "description": "The updated spl query."
                }
            },
            "required": ["spl_query", "intent"]
        }
    },
    {
        "name": "spl_executor",
        "description": "Executes a user provided splunk SPL query on the user's Splunk stack. Requires the spl_query to contain 'execute' spl, get results for spl, run, or run spl query.",
        "examples":  [
            "search src='10.9.165.*' OR dst='10.9.165.8",
            "| rest /services/data/indexes | fields title | rename title AS index"
        ],
        "parameters": {
            "type": "object",
            "properties": {
                "spl_query": {
                    "type": "string",
                    "description": "A Splunk SPL query explicitely contained in the chat history. The SPL query must start with 'search' UNLESS its first command is contained in [tstats, rest, makeresults, searchtxn, datamodel, dbinspect, eventcount, from, gentimes, inputcsv, Inputlookup, loadjob, metadata, metasearch, mstats, multisearch, pivot, set.]"
                },
                "output":{
                    "type": "array",
                    "items": {
                        "type": "object",
                        "additionalProperties": True
                    }
                }
            },
            "required": ["spl_query"]
        }
    },
    {
        "name": "spl_detection_recommender",
        "description": "Writes SPL to detect a user specified requested cyber security threat. If the user question does not define a specific threat, select another skill.",
        "examples": [
            "recommend a detection for 3CX Supply Chain Attack Network Indicators",
            "give me SPL to detect ASL AWS Defense Evasion Delete CloudWatch Log Group.",
            "Detections come from Splunk's Enterprise Security Content Updates (ESCU)."
        ],
        "parameters": {
            "type": "object",
            "properties": {
                "intent": {
                    "type": "string",
                    "description": "A plain English sentence describing the intent of the desired detection. It should never have specific values but should refer to values generically (e.g. IP)."
                },
                "output": {
                    "type": "string",
                    "description": "The spl query for the threat dectection."
                }
            },
            "required": ["intent"]
        }
    },
    {
        "name": "spl_writer",
        "description": "Generates an SPL query from a user's natural language request when no reference SPL query is present. Returns a splunk spl query.",
        "parameters": {
            "type": "object",
            "properties": {
                "intent": {
                    "type": "string",
                    "description": "A short English sentence (15 words or less) describing the intent of the desired query. It should always start with 'This SPL query'. NEVER include time window or search time (e.g. 'in the past X days') even if the user specified."
                }
            },
            "required": ["intent"]
        }
    },
    {
        "name": "conversation",
        "description": "This function generates a natural language response to user query. Returns a response or answer to the user question.",
        "examples": [
            "What does SPL mean?",
            "Can you summarize alerts?",
            "What kind of detections can you recommend?",
            "Can you execute spl?"
        ],
        "parameters": {
            "type": "object",
            "properties": {
                "user_message": {
                    "type": "string",
                    "description": "The user's message."
                }
            },
            "required": ["user_message"]
        }
    }
]

user_questions = [
    "How can I search for all incidents on my Splunk instance?",
]

In [None]:
# display functions
print("(skill_name: description)")
for function_template in function_templates:
    print(f"({function_template['name']}: {function_template['description']})")

In [None]:
prompt = decoder_model.format_conversations(
    messages=[{"role": "user", "content": user_questions[0]}],
    functions = [str(template) for template in function_templates]
)
print(prompt)

In [None]:
responses = decoder_model.generate_text(
    requests=[[{"role": "user", "content": user_questions[0]}]],
    functions = [str(template) for template in function_templates]
)
print("\n",responses[0])

### Inference with LLM

In [None]:
from tooling.huggingface_latent_representations.transformers.attention_featurization import extract_latent_feature_graph

user_question = user_questions[0]

# compute decoder attention
input_ids, outputs  = decoder_model._extract_latent_features(prompt=prompt, max_tokens=1)

# specify token ids for relevant subportion of the prompt
template_start_ids = decoder_model.search_for_target_sequence("### Instruction: <<function>>", prompt, input_ids)
template_end_ids = decoder_model.search_for_target_sequence("<<question>>", prompt, input_ids)
input_end_ids = decoder_model.search_for_target_sequence("### Response: ", prompt, input_ids)
# end_of_user_response = (outputs[0].view(-1)==100015).nonzero()[0].item()
pairs_of_positions = [[template_start_ids+1, template_end_ids-2], [template_end_ids+2, input_end_ids-4]]

# construct attention graph
attention_graph = extract_latent_feature_graph(
    input_ids=input_ids,
    outputs=outputs,
    position_index_a_endpoints=(template_start_ids+1, template_end_ids-2),
    position_index_b_endpoints=(template_end_ids+2, input_end_ids-4)
)

# select skill with GNN
print(attention_graph)
_ = edge_filter( # filter edges below attention threshold
    [attention_graph], # pytorch geometric graph
    edge_weight_percentile=0, # ignored if threshold is specified
    threshold=edge_filter_threshold, # use edge filtering value from pretraining
    node_attributes=["ids", "node_types"]
)
print(attention_graph)

attention_graph.x = attention_graph.x.to(float32)
attention_graph.batch = tensor([0]*attention_graph.x.shape[0])
# attention_graph.ptr = tensor([0,attention_graph.x.shape[0]])
attention_graph.to(GRAPH_DEVICE)
out = skill_selection_model(node_features=attention_graph.x, edge_index=attention_graph.edge_index, batch=attention_graph.batch) # Perform a single forward pass.
out.detach()
pred = out.argmax(dim=1)  # Use the class with highest probability.
# output_skill_idx = skill_selection_model(attention_graph)
skill_prediction = reverse_skill_mapping[pred.item()]
del out

print("\n", skill_prediction)

### Use LLM with constrained generation

In [None]:
skill_str = "<<function>>{output_skill_name}(".format(output_skill_name=skill_prediction)
function_selection = decoder_model._generate_response(
    prompt+skill_str # continuation where the skill name is included and expecting parameters
)
print("\n",skill_str+function_selection)