In [1]:
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import pandas as pd


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class BertClassifier(nn.Module):
    def __init__(self, num_vul_classes, num_danger_classes):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.vul_classifier = nn.Linear(self.bert.config.hidden_size, num_vul_classes)  # Multi-label output
        self.danger_classifier = nn.Linear(self.bert.config.hidden_size, num_danger_classes)  # Single-label output

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]  # [CLS] token output
        vul_logits = self.vul_classifier(cls_output)
        danger_logits = self.danger_classifier(cls_output)
        return vul_logits, danger_logits


In [3]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model with the correct number of classes
num_vul_classes = 8 # Replace with your actual number of vulnerability classes
num_danger_classes = 4  # Replace with your actual number of danger level classes
model = BertClassifier(num_vul_classes, num_danger_classes)
model.to(device)

# Load the saved model state
checkpoint = torch.load("final_bert_classifier.pkl", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


In [5]:
import joblib

# Load the encoders
mlb = joblib.load("mlb.pkl")  # Path to saved MultiLabelBinarizer
le_danger = joblib.load("le_danger.pkl")  # Path to saved LabelEncoder


In [6]:
def predict_with_details(text, model, tokenizer, dataset, threshold=0.5):
    """
    Predict the vulnerability type and danger level for the given text and retrieve additional details.
    """
    model.eval()

    # Tokenize the input text
    tokens = tokenizer(
        text,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        # Get model predictions
        vul_logits, danger_logits = model(tokens['input_ids'], tokens['attention_mask'])

        # Convert logits to probabilities
        vul_probs = torch.sigmoid(vul_logits).cpu().numpy()[0]
        danger_probs = torch.softmax(danger_logits, dim=1).cpu().numpy()[0]

        # Debug: Log probabilities
        #print(f"Vulnerability Probabilities: {vul_probs}")
        #print(f"Danger Level Probabilities: {danger_probs}")

        # Multi-label vulnerabilities above the threshold
        vul_labels = [mlb.classes_[i] for i, prob in enumerate(vul_probs) if prob > threshold]
        print(f"Detected Vulnerabilities: {vul_labels}")

        # Single-label danger level (highest probability)
        danger_label = le_danger.classes_[np.argmax(danger_probs)]

    # Match predicted vulnerabilities and danger levels to the dataset
    predictions = {
        "vulnerability_type": vul_labels,
        "danger_level": danger_label,
        "description": [],
        "fix_suggestions": []
    }

    for vul in vul_labels:
        # Find matching rows in the dataset
        match = dataset[dataset['vulnerability_type'].str.contains(vul, regex=False)]
        if not match.empty:
            # Append descriptions and fixes
            predictions["description"].append(match.iloc[0]['description'])
            predictions["fix_suggestions"].extend(match.iloc[0]['fix_suggestions'])
        else:
            print(f"No match found in dataset for vulnerability: {vul}")

    # Remove duplicates in suggestions
    predictions["fix_suggestions"] = list(set(predictions["fix_suggestions"]))

    return predictions


In [7]:
df = pd.read_json("data_big.json")  # Replace with your dataset path


In [8]:
def get_user_input():
    """
    Prompt the user to input a code snippet, line by line, and return the complete input as a string.
    """
    print("Enter the code snippet to analyze (press Enter twice to finish):")
    user_input_lines = []
    while True:
        line = input()
        if line == "":
            break
        user_input_lines.append(line)
    return "\n".join(user_input_lines)

# Main logic for prediction
if __name__ == "__main__":
    # Prompt the user for input
    input_text = get_user_input()

    # Ensure the user provided some input
    if not input_text.strip():
        print("No input provided. Exiting.")
    else:
        # Get predictions with details
        result = predict_with_details(input_text, model, tokenizer, df, threshold=0.05)

        # Print the results
        print("\n=== Predicted Results ===")
        print(f"Vulnerability Type: {result['vulnerability_type']}")
        print(f"Danger Level: {result['danger_level']}")
        print("\nDescription of Vulnerabilities:")
        for description in result["description"]:
            print(f"- {description}")
        print("\nFix Suggestions:")
        for suggestion in result["fix_suggestions"]:
            print(f"- {suggestion}")


Enter the code snippet to analyze (press Enter twice to finish):
Detected Vulnerabilities: ['Cross-Site Request Forgery (CSRF)', 'DOM-Based XSS', 'Improper Error Handling', 'Improper Input Validation', 'Insecure Deserialization', 'Reflected XSS in URL Parameters', 'SQL Injection', 'Stored XSS']


NameError: name 'np' is not defined