# Quickstart (Instruct Model)

This notebook demonstrates how to download Foundation AI's instruct model from Hugging Face and run an initial inference as a starting point. <br>
If you‚Äôre interested in more detailed cybersecurity [use cases](https://github.com/RobustIntelligence/foundation-ai-cookbook/tree/main/2_examples) or [adoptions](https://github.com/RobustIntelligence/foundation-ai-cookbook/tree/main/3_adoptions), please refer to the corresponding sections.

## Notes
This model is an instruction-following model fine-tuned for responding to prompted instructions. <br>
Unlike completion model (Foundation-Sec-8B), it is designed to engage in conversations.

## Setup
We recommend running the scripts with NVIDIA GPU(s) for optimal performance. <br>
While the code should work with both single and multiple GPUs, unexpected issues may arise with multiple GPUs. In such cases, minor code adjustments or limiting usage to one GPU (e.g., by setting CUDA_VISIBLE_DEVICES='0') might be necessary.
<br> Ensure a minimum of 20 GB of available storage and memory for the model.

In [1]:
import os

# export Huggfing Face token to HF_TOKEN
HF_TOKEN = os.getenv("HF_TOKEN")

In [2]:
import transformers
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", DEVICE)

device: cuda


In [3]:
MODEL_ID = "fdtn-ai/Foundation-Sec-8B-Instruct"

from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16, # this model's tensor_type is BF16
    token=HF_TOKEN,
)

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

### Configurations
You can adjust the model's text generation behavior by tuning its arguments. <br>
Below is an example configuration to ensure reproducible outputs. <br>
For a complete list of arguments and detailed explanations, refer to the [text generation document](https://huggingface.co/docs/transformers/en/main_classes/text_generation).

In [4]:
generation_args = {
    "max_new_tokens": 1024,
    "temperature": None,
    "repetition_penalty": 1.2,
    "do_sample": False,
    "use_cache": True,
    "eos_token_id": tokenizer.eos_token_id,
    "pad_token_id": tokenizer.pad_token_id,
}

In [None]:
import re

DEFAULT_SYSTEM_PROMPT = "You are a cybersecurity expert."
# The system prompt is for demo purpose.
# We have developed a detailed system prompt for general user interaction, which was tested
# in internal testing and found that it improved user satisfaction and safety.

# If you want to use the full system prompt, you can uncomment the line below.
# with open("recommended_system_prompt_for_instruct_model.txt", "r") as f:
#     DEFAULT_SYSTEM_PROMPT = f.read()

def inference(request, system_prompt = DEFAULT_SYSTEM_PROMPT):
    
    if isinstance(request, str):
        messages =  [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": request},
        ]
    elif isinstance(request, list):
        if request[0].get("role") != "system":
            messages = [{"role": "system", "content": system_prompt}] + request
        else:
            messages = request
    else:
        raise ValueError(
            "Request is not well formed. It must be a string or list of dict with correct format."
        )

    inputs = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    inputs = tokenizer(inputs, return_tensors="pt")
    input_ids = inputs["input_ids"].to(DEVICE)
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            **generation_args,
        )
    response = tokenizer.decode(
        outputs[0][input_ids.shape[1]:],  # Only get new tokens
        skip_special_tokens = False
    )
    
    if response.endswith(tokenizer.eos_token):
        response = response[:-len(tokenizer.eos_token)]
    
    return response

## Inference

If you want to know what MITRE ATT&CK means, you can structure the prompt as shown below. <br>
Unlike Foundation-Sec-8B model, the model will return natural responses when you ask a query.

In [6]:
print(inference("What is MITRE ATT&CK? Give a very brief answer"))

MITRE ATT&CK (Adversarial Tactics, Techniques, and Common Knowledge) is a globally accessible knowledge base of adversary tactics and techniques based on real-world observations. It's used to describe the actions that adversaries take during cyber attacks, helping organizations understand threats better for improved defense strategies and incident response planning in cybersecurity.


## Multi-turn
We can also convert the inference function into a multi-turn chat agent.<br>
Below is a sample chat demo where you can enter your prompt in the blank field. You can further respond to the model's output.

In [7]:
from IPython.display import display, Markdown

class ChatApp():
    
    def __init__(self, system_message = DEFAULT_SYSTEM_PROMPT):
        self.system_message = system_message
        self.messages = [{"role": "system", "content": self.system_message}]

    def chat(self):
        print("-" * 40)
        print("Type 'quit', 'exit', or 'q' to end the conversation")
        print("Type 'clear' to clear conversation history")
        print("Type 'history' to see conversation history")
        print("-" * 40)
        print("ü§ñ Chat begins")

        while True:
            try:
                user_input = input("\nüí¨ You: ").strip()
    
                if user_input.lower() in ['quit', 'exit', 'q']:
                    print("\nüëã Goodbye!")
                    break
                elif user_input.lower() == 'clear':
                    self.messages = [{"role": "system", "content": self.system_message}]
                    print("üßπ Conversation history cleared!")
                    continue
                elif user_input.lower() == 'history':
                    if self.messages:
                        print("\n==========üìú Conversation History üìú==========")
                        for message in self.messages:
                            print(message.get("role", "Unknown").title(),":", message.get("content", "N/A"), "\n")
                        print("========== End of Conversation History ==========")
                    else:
                        print("üìú No conversation history yet.")
                    continue                    
                elif not user_input:
                    print("‚ùå Please enter a message.")
                    continue
    
                print("\nü§î Thinking...")
                self.messages.append({"role": "user", "content": user_input})
                response = inference(self.messages)
    
                print("\nü§ñ Assistant: ")
                display(Markdown(response))
                self.messages.append({"role": "assistant", "content": response})
            
            except KeyboardInterrupt:
                print("\n\nüëã Chat interrupted. Goodbye!")
                break
            except Exception as e:
                print(f"\n‚ùå An error occurred: {str(e)}")

In [8]:
chatapp = ChatApp()
chatapp.chat()

----------------------------------------
Type 'quit', 'exit', or 'q' to end the conversation
Type 'clear' to clear conversation history
Type 'history' to see conversation history
----------------------------------------
ü§ñ Chat begins



üí¨ You:  What is MITRE ATT&CK? Give a very brief answer



ü§î Thinking...

ü§ñ Assistant: 


MITRE ATT&CK (Adversarial Tactics, Techniques, and Common Knowledge) is a globally accessible knowledge base of adversary tactics and techniques based on real-world observations. It's used to describe the actions that adversaries take during cyber attacks, helping organizations understand threats better for improved defense strategies and incident response planning in cybersecurity.


üí¨ You:  Thank you



ü§î Thinking...

ü§ñ Assistant: 


You're welcome! If you have any more questions about cybersecurity or anything else, feel free to ask. Happy learning!



üí¨ You:  clear


üßπ Conversation history cleared!



üí¨ You:  Tell me briefly what OSINT means



ü§î Thinking...

ü§ñ Assistant: 


OSINT stands for Open Source Intelligence. It refers to the process of collecting and analyzing data from publicly available sources, such as social media platforms, websites, forums, news articles, public records, and other open-access information repositories, in order to gather intelligence or uncover specific pieces of information relevant to an investigation, security analysis, competitive research, or any situation where knowledge is power. The goal of OSINT is to extract valuable insights without breaching privacy laws or ethical boundaries by relying solely on openly accessible resources.


üí¨ You:  Summarize in 1 sentence



ü§î Thinking...

ü§ñ Assistant: 


Open Source Intelligence (OSINT) involves gathering and analyzing publicly available information from various online and offline sources to obtain actionable intelligence legally and ethically.


üí¨ You:  history



System : You are a cybersecurity expert. 

User : Tell me briefly what OSINT means 

Assistant : OSINT stands for Open Source Intelligence. It refers to the process of collecting and analyzing data from publicly available sources, such as social media platforms, websites, forums, news articles, public records, and other open-access information repositories, in order to gather intelligence or uncover specific pieces of information relevant to an investigation, security analysis, competitive research, or any situation where knowledge is power. The goal of OSINT is to extract valuable insights without breaching privacy laws or ethical boundaries by relying solely on openly accessible resources. 

User : Summarize in 1 sentence 

Assistant : Open Source Intelligence (OSINT) involves gathering and analyzing publicly available information from various online and offline sources to obtain actionable intelligence legally and ethically. 




üí¨ You:  exit



üëã Goodbye!
