In [16]:
!pip install transformers pinecone unidecode groq




[notice] A new release of pip is available: 23.2.1 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [51]:
# Import necessary libraries
from transformers import AutoTokenizer, AutoModelForMaskedLM
from sklearn.model_selection import train_test_split
from pinecone import ServerlessSpec
from dotenv import load_dotenv
from pinecone import Pinecone
from tqdm.auto import tqdm
from groq import Groq
import pandas as pd
import unidecode
import torch
import time
import os

### 1. Pinecone and Groq setup

In [18]:
load_dotenv() # Load environment variables
# Instantiate the Pinecone client
pc_api_key = os.getenv("PINECONE_API_KEY")
pc = Pinecone(api_key=pc_api_key)
# Instantiate the Groq client
groq_api_key = os.getenv("GROQ_API_KEY") 
groq_client = Groq(api_key=groq_api_key)

### 2. Loading Models 

In [19]:
# Set a version number for the model to manage multiple versions effectively
MODEL_VERSION = 1

# Load the fine-tuned model
model_path = f"./models/model_{MODEL_VERSION}/fine_tuned_lora_mlm"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForMaskedLM.from_pretrained(model_path)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### 3. Test Encoder Model 

In [22]:
cve_description = "Large language model (LLM) management tool does not validate the format of a digest value (CWE-1287) from a private, untrusted model registry, enabling relative path traversal (CWE-23), a.k.a. Probllama"

def encoder(inputs):
    embeddings = []
    for input in inputs:
        tokens =  tokenizer(input, return_tensors="pt", truncation=True, padding=True, max_length=128)
        with torch.no_grad():
            outputs = model.base_model(**tokens)  # Camada base
            embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy() 
            embeddings.append(embedding)
    return embeddings

embedding = encoder(cve_description) # Generate embeddings from the CVE description 

### 4. Show Pinecone Vectorial Database Index Stats

In [33]:
index_name = "mc959"

existing_indexes = [
    index_info["name"] for index_info in pc.list_indexes()
] 

spec = ServerlessSpec(
    cloud="aws", region="us-east-1"
) 

dimension = len(embedding) 

if index_name not in existing_indexes:
    pc.create_index(
        index_name,
        dimension=dimension,
        metric='cosine',
        spec=spec
    )
    while not pc.describe_index(index_name).status['ready']:
        time.sleep(1)

index = pc.Index(index_name)
time.sleep(1)

index.describe_index_stats()

{'dimension': 768,
 'index_fullness': 0.0,
 'namespaces': {'': {'vector_count': 965}},
 'total_vector_count': 965}

### 5. Prepare CWE dataset to create vector database

In [34]:
csv_file_path = './raw_data/cwe_dataset.csv'  
data = pd.read_csv(csv_file_path)

data['id'] = range(1, len(data) + 1)
data['id'] = data['id'].astype(str)
data['title'] = data['cwe_id'] + " - " + data['title']

data['title'] = data['title'].apply(unidecode.unidecode)
data['description'] = data['description'].apply(unidecode.unidecode)

data['metadata'] = data.apply(lambda x: {
    "id": x["id"],
    "title": x["title"],
    "description": x["description"]
}, axis=1)

print(data[0:5])

  cwe_id                                              title  \
0  CWE-5  CWE-5 - J2EE Misconfiguration: Data Transmissi...   
1  CWE-6  CWE-6 - J2EE Misconfiguration: Insufficient Se...   
2  CWE-7  CWE-7 - J2EE Misconfiguration: Missing Custom ...   
3  CWE-8  CWE-8 - J2EE Misconfiguration: Entity Bean Dec...   
4  CWE-9  CWE-9 - J2EE Misconfiguration: Weak Access Per...   

                                         description id  \
0  Information sent over a network can be comprom...  1   
1  The J2EE application is configured to use an i...  2   
2  The default error page of a web application sh...  3   
3  When an application exposes a remote interface...  4   
4  If elevated access rights are assigned to EJB ...  5   

                                            metadata  
0  {'id': '1', 'title': 'CWE-5 - J2EE Misconfigur...  
1  {'id': '2', 'title': 'CWE-6 - J2EE Misconfigur...  
2  {'id': '3', 'title': 'CWE-7 - J2EE Misconfigur...  
3  {'id': '4', 'title': 'CWE-8 - J2EE Misconfi

### 6. Upsert CWE vectors to Pinecone database

In [None]:
batch_size = 1  
data_size = len(data) 

if index_name not in existing_indexes:
    for i in tqdm(range(0, len(data[:data_size + 1]), batch_size)):
        i_end = min(len(data), i + batch_size)
        batch = data[i:i_end]
        chunks = [f'{x["title"]}: {x["description"]}' for x in batch["metadata"]]
        embeds = encoder(chunks)
        assert len(embeds) == (i_end-i)
        to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
        index.upsert(vectors=to_upsert)

### 7. Function to make a query to the database

In [39]:
import numpy as np

def get_docs(query: str, top_k: int) -> list[str]:
    xq = encoder([query])
    xq = xq[0].tolist()
    res = index.query(vector=xq, top_k=top_k, include_metadata=True)
    docs = [x["metadata"] for x in res["matches"]]
    return docs

### 8. Example of a retrival from the database

In [93]:
cve = "Chain: library for generating SQL via LLMs using RAG uses a prompt function to present the user with visualized results, allowing altering of the prompt using prompt injection (CWE-1427) to run arbitrary Python code (CWE-94) instead of the intended visualization code."
query = "Which CWE category best classifies the vulnerability described by :" + cve
docs = get_docs(query, top_k=10)
# Extrair os títulos e descrições de cada item em docs
docs_text = [f"Title: {doc['title']}\nDescription: {doc['description']}" for doc in docs]
# Imprimir as informações com "---" entre cada documento
top_k_responses = "\n---\n".join(docs_text)
print(top_k_responses)

Title: CWE-98 - Improper Control of Filename for Include/Require Statement in PHP Program ('PHP Remote File Inclusion')
Description: The PHP application receives input from an upstream component, but it does not restrict or incorrectly restricts the input before its usage in "require," "include," or similar functions.
---
Title: CWE-433 - Unparsed Raw Web Content Delivery
Description: The product stores raw content or supporting code under the web document root with an extension that is not specifically handled by the server.
---
Title: CWE-98 - Improper Control of Filename for Include/Require Statement in PHP Program ('PHP Remote File Inclusion')
Description: The PHP application receives input from an upstream component, but it does not restrict or incorrectly restricts the input before its usage in "require," "include," or similar functions.
---
Title: CWE-566 - Authorization Bypass Through User-Controlled SQL Primary Key
Description: The product uses a database table that includes r

### 9. Generate a pair question/response for the incontext learning

In [107]:
csv_file_path = './raw_data/cwe2cve_dataset.csv'  
cve_cwe_pairs = pd.read_csv(csv_file_path)

if os.path.exists('./datasets/supervised'):
    if not any(os.scandir('./datasets/supervised/')) :
        train_data, test_data = train_test_split(
            cve_cwe_pairs, 
            test_size=0.2, 
            random_state=42
        )

        train_data = train_data.reset_index(drop=True)
        test_data = test_data.reset_index(drop=True)

        train_data.to_csv('./datasets/supervised/train.csv', index=False)
        test_data.to_csv('./datasets/supervised/test.csv', index=False)
    else:
        train_data = pd.read_csv('./datasets/supervised/train.csv')
        test_data = pd.read_csv('./datasets/supervised/test.csv')

train_examples = []
for _, row in train_data[0:70].iterrows():
    question = (f"Question:\nWhich CWE category best classifies the vulnerability "
                f"described by '{row['cve_id']}': {row['cve_description']}?")
    answer = f"Answer:\n{row['cwe_id']}: {row['cwe_description']}"
    train_examples.append(f"{question}\n{answer}")

print(len(train_examples))

in_context_examples = "\n---\n".join(train_examples)

print(in_context_examples)
    

70
Question:
Which CWE category best classifies the vulnerability described by 'CVE-2005-1645': database file under web root.?
Answer:
CWE-219: Storage of File with Sensitive Data Under Web Root
---
Question:
Which CWE category best classifies the vulnerability described by 'CVE-2004-1513': Spoofed entries in web server log file via carriage returns?
Answer:
CWE-93: Improper Neutralization of CRLF Sequences ('CRLF Injection')
---
Question:
Answer:
CWE-356: Product UI does not Warn User of Unsafe Actions
---
Question:
Which CWE category best classifies the vulnerability described by 'CVE-2003-1016': MIE. MFV too? bypass AV/security with fields that should not be quoted, duplicate quotes, missing leading/trailing quotes.?
Answer:
CWE-149: Improper Neutralization of Quoting Syntax
---
Question:
Which CWE category best classifies the vulnerability described by 'CVE-2002-0971': Bypass GUI and access restricted dialog box.?
Answer:
CWE-422: Unprotected Windows Messaging Channel ('Shatter')
-

### 10. Instantiates a Llama agent to classify CVEs into CWE classes

In [108]:
def generate(query: str, docs: list[str]):
    docs_text = [f"Title: {doc['title']}\nDescription: {doc['description']}" for doc in docs]
    print(docs_text)
    rag = "\n---\n".join(docs_text)
    system_message = (
        "You are a cybersecurity engineer specializing in vulnerability classification. Your work involves mapping and categorizing CVE (Common Vulnerabilities and Exposures) records into their corresponding CWE (Common Weakness Enumeration) categories."+
        "Here are some examples of how to answer questions:\n"+
        in_context_examples+'\n\n'+
        "To answer this question use the context provided below.\n\n"+
        "CONTEXT:\n"+
        rag
    )
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": query}
    ]
    # generate response
    chat_response = groq_client.chat.completions.create(
        model="llama3-70b-8192",
        messages=messages
    )
    return chat_response.choices[0].message.content

### 11. Evalueta model in test database

In [None]:
results = test_data.copy()

def evaluate(row):
    question = (f"Question:\nWhich CWE category best classifies the vulnerability "
                f"described by '{row['cve_id']}': {row['cve_description']}?")
    docs = get_docs(query, top_k=10)
    out = generate(query=question, docs=docs)
    
    # Separar os valores da saída
    cwe_id_pred, cwe_description_pred = out.split(':', 1)
    
    # Criar uma Series com os resultados
    return pd.Series({'cwe_id_pred': cwe_id_pred.strip(), 'cwe_description_pred': cwe_description_pred.strip()})


results[['cwe_id_pred', 'cwe_description_pred']] = results.apply(lambda row: evaluate(row), axis=1)

print(results)

['Title: CWE-98 - Improper Control of Filename for Include/Require Statement in PHP Program (\'PHP Remote File Inclusion\')\nDescription: The PHP application receives input from an upstream component, but it does not restrict or incorrectly restricts the input before its usage in "require," "include," or similar functions.', 'Title: CWE-566 - Authorization Bypass Through User-Controlled SQL Primary Key\nDescription: The product uses a database table that includes records that should not be accessible to an actor, but it executes a SQL statement with a primary key that can be controlled by that actor.', 'Title: CWE-98 - Improper Control of Filename for Include/Require Statement in PHP Program (\'PHP Remote File Inclusion\')\nDescription: The PHP application receives input from an upstream component, but it does not restrict or incorrectly restricts the input before its usage in "require," "include," or similar functions.', "Title: CWE-652 - Improper Neutralization of Data within XQuery 