# [Preview] Quickstart (Instruct Model)

This notebook demonstrates how to download Foundation AI's instruct model from Hugging Face and run an initial inference as a starting point. <br>
If you’re interested in more detailed cybersecurity [use cases](https://github.com/RobustIntelligence/foundation-ai-cookbook/tree/main/2_examples) or [adoptions](https://github.com/RobustIntelligence/foundation-ai-cookbook/tree/main/3_adoptions), please refer to the corresponding sections.

## Notes
This model is an instruction-following model fine-tuned for responding to prompted instructions. <br>
Unlike completion model (Foundation-Sec-8B), it is designed to engage in conversations.

**This model is currently in preview mode and may receive updates. As a result, outputs can vary even when parameters are configured to ensure reproducibility.**

## Setup
We recommend running the scripts with NVIDIA GPU(s) for optimal performance. <br>
While the code should work with both single and multiple GPUs, unexpected issues may arise with multiple GPUs. In such cases, minor code adjustments or limiting usage to one GPU (e.g., by setting CUDA_VISIBLE_DEVICES='0') might be necessary.
<br> Ensure a minimum of 20 GB of available storage and memory for the model.

## If you have an access via Hugging Face
Since the model is in preview mode, you'll need to log in to your authorized Hugging Face account and use the correct token.

In [1]:
import os

# export Huggfing Face token to HF_TOKEN
HF_TOKEN = os.getenv("HF_TOKEN")

In [2]:
import transformers
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", DEVICE)

device: cuda


In [3]:
MODEL_ID = "" # To be relaced with the final model name

from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16, # this model's tensor_type is BF16
    token=HF_TOKEN,
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

### Configurations
You can adjust the model's text generation behavior by tuning its arguments. <br>
Below is an example configuration to ensure reproducible outputs. <br>
For a complete list of arguments and detailed explanations, refer to the [text generation document](https://huggingface.co/docs/transformers/en/main_classes/text_generation).

In [4]:
generation_args = {
    "max_new_tokens": 1024,
    "temperature": None,
    "repetition_penalty": 1.2,
    "do_sample": False,
    "use_cache": True,
    "eos_token_id": tokenizer.eos_token_id,
    "pad_token_id": tokenizer.pad_token_id,
}

In [5]:
import re

SYSTEM_PROMPT = "You are a cybersecurity expert."

def inference(prompt):
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": prompt},
    ]
    inputs = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    inputs = tokenizer(inputs, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            **generation_args,
        )
    response = tokenizer.decode(outputs[0], skip_special_tokens = False)
    
    # extract the assistant response part only
    match = re.search(r"<\|assistant\|>(.*?)<\|end_of_text\|>", response, re.DOTALL)
    
    return match.group(1).strip()

## If you have an access via remote API
Please prepare API_KEY as well as endpoint urls provided by Foundation AI team

In [None]:
API_KEY = ""
ENDPOINT_URL = ""

In [None]:
import requests
import re

def inference(prompt):
    data = {'prompt': prompt}
    # If you want to add your own generation_args, you can do so like this:
    # data['generate_args'] = YOUR_GENERATION_ARGS
    response = requests.post(
        ENDPOINT_URL,
        headers={"Authorization": f"Api-Key {API_KEY}"},
        json=data,
    )

    match = re.search(r"<\|assistant\|>(.*?)<\|end_of_text\|>", response.text, re.DOTALL)
    return match.group(1).strip()


## Inference

If you want to know what MITRE ATT&CK means, you can structure the prompt as shown below. <br>
Unlike Foundation-Sec-8B model, the model will return natural responses when you ask a query.

In [6]:
print(inference("What is MITRE ATT&CK? Give a very brief answer"))

MITRE ATT&CK is a widely used framework that provides a comprehensive understanding of the tactics, techniques, and procedures (TTPs) used by cyber adversaries. It's designed to help organizations better understand and defend against various types of attacks across different levels of an enterprise network.
