In [None]:
!pip install pandas scikit-learn imbalanced-learn nltk joblib spacy
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.calibration import CalibratedClassifierCV
from imblearn.over_sampling import SMOTE
from collections import Counter
import nltk
import re
import joblib
from xgboost import XGBClassifier
from nltk.tokenize import word_tokenize
from sklearn.preprocessing import LabelEncoder

import nltk
nltk.download('punkt_tab')

nltk.download('punkt')

# Load dataset
df = pd.read_csv("vulnerability_fix_dataset.csv")

# Preprocess Code: Tokenization + Cleaning
def preprocess_code(code):
    code = re.sub(r'".*?"', 'STR_LITERAL', code)  # Replace strings
    code = re.sub(r'\d+', 'NUM_LITERAL', code)    # Replace numbers
    code = re.sub(r'(?<!^)(?=[A-Z])', ' ', code)  # Split camelCase
    code = code.replace('_', ' ')
    code = re.sub(r'[^a-zA-Z0-9\s]', ' ', code)   # Remove special characters
    tokens = word_tokenize(code.lower())
    stopwords = {'import', 'public', 'class', 'static', 'void', 'main'}
    return ' '.join([t for t in tokens if t not in stopwords])

df['processed'] = df['vulnerable_code'].apply(preprocess_code)

# TF-IDF Vectorization: Character-level n-grams
tfidf = TfidfVectorizer(ngram_range=(2, 5), analyzer='char', max_features=5000)
X = tfidf.fit_transform(df['processed'])
y = df['vulnerability_type']

# Handle Imbalance with SMOTE
smote = SMOTE(random_state=42, k_neighbors=3)
X_res, y_res = smote.fit_resample(X, y)

# Train-Test Split
X_train, X_test, y_train, y_test = train_test_split(X_res, y_res, test_size=0.2, random_state=42)

# Encode class labels into numeric values
label_encoder = LabelEncoder()
y_res_encoded = label_encoder.fit_transform(y_res)  # Convert to numbers
y_train_encoded = label_encoder.transform(y_train)
y_test_encoded = label_encoder.transform(y_test)

# Train XGBoost Model
model = XGBClassifier(n_estimators=100, max_depth=10, learning_rate=0.1, scale_pos_weight=1, random_state=42)

calibrated = CalibratedClassifierCV(model, cv=3)
calibrated.fit(X_train, y_train_encoded)  # Use encoded labels

# Predict using encoded labels
y_pred_encoded = calibrated.predict(X_test)

# Convert predictions back to original labels
y_pred = label_encoder.inverse_transform(y_pred_encoded)

# Print classification report
print(classification_report(y_test, y_pred))

# Save Model
#joblib.dump((calibrated, tfidf, df), 'vuln_model.pkl')


# Save the model, TF-IDF, dataset, and LabelEncoder
joblib.dump((calibrated, tfidf, df, label_encoder), 'vuln_model.pkl')

# Extract User Input Variable Name for Fix Suggestion
def extract_variable_name(code_snippet):
    match = re.search(r'\b(\w+)\s*=\s*.*?;', code_snippet)
    return match.group(1) if match else "userInput"

# Generate Context-Aware Fix
def generate_fix(vulnerable_code, predicted_vuln):
    user_input_var = extract_variable_name(vulnerable_code)

    fixes = {
        "SQL Injection": f"""
import java.sql.*;

public class SecureSQL {{
    public static void main(String[] args) {{
        String {user_input_var} = args[0]; // User input
        try {{
            Connection conn = DriverManager.getConnection("jdbc:mysql://localhost/db");
            String query = "SELECT * FROM users WHERE id=?";
            PreparedStatement stmt = conn.prepareStatement(query);
            stmt.setString(1, {user_input_var});
            ResultSet rs = stmt.executeQuery();
            while (rs.next()) {{
                System.out.println(rs.getString("password"));
            }}
        }} catch (SQLException e) {{
            e.printStackTrace();
        }}
    }}
}}
""",
        "Command Injection": f"""
import java.io.*;

public class SecureCommand {{
    public static void main(String[] args) {{
        String {user_input_var} = args[0]; // User input
        if (!{user_input_var}.matches("[a-zA-Z0-9]+")) {{
            throw new SecurityException("Invalid input detected");
        }}
        ProcessBuilder pb = new ProcessBuilder("sh", "-c", {user_input_var});
        pb.start();
    }}
}}
"""
    }

    return fixes.get(predicted_vuln, "No specific fix available. Use best security practices.")

# Predict Vulnerability and Suggest Fix
def predict_vulnerability(code):
    model, tfidf, df, label_encoder = joblib.load('vuln_model.pkl')
    processed = preprocess_code(code)
    features = tfidf.transform([processed])
    proba = model.predict_proba(features)[0]
    confidence = max(proba)
    predicted_encoded = proba.argmax()

    predicted = label_encoder.inverse_transform([predicted_encoded])[0]  # Convert number → string

    suggested_fix = generate_fix(code, predicted)

    return predicted, suggested_fix, confidence

# Example Usage
test_code = '''import java.sql.*;
public class InsecureExample {
    public static void main(String[] args) {
        String userId = args[0];
        try {
            Connection conn = DriverManager.getConnection("jdbc:mysql://localhost/db");
            Statement stmt = conn.createStatement();
            ResultSet rs = stmt.executeQuery("SELECT * FROM users WHERE id = " + userId);
            while (rs.next()) {
                System.out.println(rs.getString("password"));
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }
}'''

prediction, suggestion, confidence = predict_vulnerability(test_code)
print(f"\nPrediction: {prediction}")
print(f"Confidence: {confidence:.2f}")
print(f"Suggested Fix:\n{suggestion}")



[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Parameters: { "scale_pos_weight" } are not used.

Parameters: { "scale_pos_weight" } are not used.



In [None]:
# Example Usage
test_code = '''import java.sql.*;
public class InsecureExample {
    public static void main(String[] args) {
        String userId = args[0];
        try {
            Connection conn = DriverManager.getConnection("jdbc:mysql://localhost/db");
            Statement stmt = conn.createStatement();
            ResultSet rs = stmt.executeQuery("SELECT * FROM users WHERE id = " + userId);
            while (rs.next()) {
                System.out.println(rs.getString("password"));
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }
}'''

prediction, suggestion, confidence = predict_vulnerability(test_code)
print(f"\nPrediction: {prediction}")
print(f"Confidence: {confidence:.2f}")
print(f"Suggested Fix:\n{suggestion}")


Prediction: SQL Injection
Confidence: 1.00
Suggested Fix:

import java.sql.*;

public class SecureSQL {
    public static void main(String[] args) {
        String userId = args[0]; // User input
        try {
            Connection conn = DriverManager.getConnection("jdbc:mysql://localhost/db");
            String query = "SELECT * FROM users WHERE id=?";
            PreparedStatement stmt = conn.prepareStatement(query);
            stmt.setString(1, userId);
            ResultSet rs = stmt.executeQuery();
            while (rs.next()) {
                System.out.println(rs.getString("password"));
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }
}



In [None]:
test_code = '''using System;
using System.Net.Http;
using System.Threading.Tasks;

class Program {{
    static async Task Main() {{
        string url = "file:///etc/passwd";
        HttpClient client = new HttpClient();
        string response = await client.GetStringAsync(url);
        Console.WriteLine(response);
    }}
}}'''
prediction, suggestion, confidence = predict_vulnerability(test_code)
print(f"\nPrediction: {prediction}")
print(f"Confidence: {confidence:.2f}")
print(f"Suggested Fix:\n{suggestion}")


Prediction: Server-Side Request Forgery (SSRF)
Confidence: 1.00
Suggested Fix:
No specific fix available. Use best security practices.


In [None]:
test_code = '''import javax.xml.parsers.*;
import org.w3c.dom.*;
import java.io.*;

public class VulnerableTest {
    public static void main(String[] args) throws Exception {
        DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
        DocumentBuilder builder = factory.newDocumentBuilder();
        Document doc = builder.parse(new File("data.xml"));
        System.out.println(doc.getDocumentElement().getTextContent());
    }
}'''
prediction, suggestion, confidence = predict_vulnerability(test_code)
print(f"\nPrediction: {prediction}")
print(f"Confidence: {confidence:.2f}")
print(f"Suggested Fix:\n{suggestion}")


Prediction: Cross-Site Scripting (XSS)
Confidence: 0.56
Suggested Fix:
No specific fix available. Use best security practices.
