## Package Installation

In [None]:
!pip install pinecone pandas numpy datasets transformers torch tqdm fsspec==2023.9.2



## Import Packages

In [None]:
import pinecone
import pandas as pd
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
import torch
from google.colab import userdata
from tqdm import tqdm

## Load Pinecone Index

In [None]:
PINECONE_API_KEY = userdata.get('PINECONE_API_KEY')
pc = pinecone.Pinecone(api_key=PINECONE_API_KEY, environment='us-east-1')
index_name = 'code-vulnerability-index'
index = pc.Index(index_name)

## Load Embedding Model

In [None]:
tokenizer = AutoTokenizer.from_pretrained('microsoft/unixcoder-base-nine')
model = AutoModel.from_pretrained('microsoft/unixcoder-base-nine')

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/691 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/504M [00:00<?, ?B/s]

## Loading Dataset from Hugging Face

In [None]:
try:
    dataset = load_dataset('mahdin70/cwe_enriched_balanced_bigvul_primevul')
    test_df = pd.DataFrame(dataset['validation'])
except Exception as e:
    print(f"Error loading dataset: {e}")
    raise

README.md: 0.00B [00:00, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/17.6M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/2.56M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/4.93M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15770 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2253 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/4506 [00:00<?, ? examples/s]

## Embedding Function

In [None]:
def get_embedding(code: str):
    inputs = tokenizer(code, return_tensors="pt", truncation=True, padding=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
        hidden_states = outputs.last_hidden_state  # [1, seq_len, hidden_dim]
        attention_mask = inputs.attention_mask.unsqueeze(-1)  # [1, seq_len, 1]
        masked = hidden_states * attention_mask
        summed = masked.sum(dim=1)
        count = attention_mask.sum(dim=1).clamp(min=1e-9)
        mean_pooled = summed / count
    return mean_pooled.squeeze(0).numpy()

## Retrieval Function

In [None]:
def retrieve(func_code, top_k=3):
    try:
        query_embedding = get_embedding(func_code).tolist()
        search_results = index.query(
            vector=query_embedding,
            top_k=top_k,
            include_metadata=True
        )
        return search_results['matches']
    except Exception as e:
        print(f"Error in retrieve: {e}")
        return []

## Prompt Generation Function

In [None]:
def generate_vulnerability_prompt(input_function: str, retrieved_examples: list) -> str:
    header = """You are a professional cybersecurity analyst with expertise in static code analysis and Common Weakness Enumeration (CWE) classification.

You will be provided with:
- One **input function**, whose vulnerability status you must assess.
- Several **reference examples**, each containing:
  - A code function
  - A known vulnerability label (1 for vulnerable, 0 for not)
  - CWE ID (if vulnerable)
  - CWE Name (if vulnerable)

---

### TASK

Analyze the **structure, logic, and behavior** of the **input function**. Use deep comparison and reasoning based on the structure, intent, and usage patterns in the reference examples. Focus on the logical flow, operations performed, and overall behavior of the code to assess similarity and risk.

You must:
- Output whether the input function is vulnerable (1) or not (0)
- If vulnerable, identify the **most appropriate CWE ID and CWE Name**
- If not vulnerable, set CWE ID and CWE Name to "NOT_APPLICABLE"

---

### INSTRUCTIONS

- Carefully compare the input function with the reference examples
- Only use the structure and patterns from the examples for decision-making
- Strictly output your result in the exact 3-line format shown below
- No explanation or additional commentary
- Your answer **must** be wrapped in triple backticks for structured parsing

---

Each reference is separated by:
`### Reference Example`

Respond in this **exact format**:
Vulnerability: <1 or 0>
CWE ID: <CWE-ID or NOT_APPLICABLE>
CWE Name: <CWE Name or NOT_APPLICABLE>

---"""
    examples_text = ""
    for example in retrieved_examples:
        examples_text += f"""
### Reference Example
Function:
{example['metadata']['func'].strip()}
Vulnerability: {example['metadata']['vul']}
CWE ID: {example['metadata']['CWE ID']}
CWE Name: {example['metadata']['CWE Name']}

---"""
    input_text = f"""
### Input Function
{input_function.strip()}

---

### Your Response
Vulnerability:
CWE ID:
CWE Name:
"""
    full_prompt = header + examples_text + input_text
    return full_prompt

In [None]:
subset_df = test_df.copy()
subset_df['prompt'] = None
subset_df['prompt_char_count'] = 0
subset_df['prompt_token_count'] = 0

# Generate prompts with progress bar
for idx, row in tqdm(subset_df.iterrows(), total=len(test_df), desc="Generating prompts"):
    func_code = row['func']
    retrieved_examples = retrieve(func_code, top_k=3)
    prompt = generate_vulnerability_prompt(func_code, retrieved_examples)

    char_count = len(prompt)
    estimated_token_count = char_count // 4

    subset_df.at[idx, 'prompt'] = prompt
    subset_df.at[idx, 'prompt_char_count'] = char_count
    subset_df.at[idx, 'prompt_token_count'] = estimated_token_count

subset_df.to_parquet('prompts_full_dataset.parquet', index=True)

print(f"Generated prompts for {len(test_df)} entries and saved to 'prompts_full_dataset.parquet'")

  return forward_call(*args, **kwargs)
Generating prompts: 100%|██████████| 5/5 [00:14<00:00,  2.80s/it]

Generated prompts for 10 entries and saved to 'prompts_10_entries.parquet'



